mxnet-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From Xi Wang <>
Subject [apache/incubator-mxnet] [RFC] A faster version of Gamma sampling on GPU. (#15928)
Date Fri, 16 Aug 2019 07:51:56 GMT
### Description

Sampling from Gamma distribution requires rejection sampling, which is applied in the implementation
of `mxnet.ndarray.random.gamma()`.  However, two main drawbacks exist in the current implementation
( []()

1. Random numbers used in the rejection sampling ( N(0,1) and U(0,1) ) are generated inside
the kernel using CUDA device api. Also, although every batch of threads has its own RNG, samples
are actually generated in serial inside each batch of threads.

2. Rejection sampling is achieved by using an infinite while loop inside the kernel, which
may potentially affect the performance on GPU.

To solve the problems above, I write a new version of  Gamma sampling on GPU innovated by
this blog post: [](url)

### Implementation details

My implementation differs from the current version in the following aspects: 

1. Instead of generating samples in the kernel, we generate them in advance using host api,
which allows us to  fill a buffer with random samples directly.

2. Redundant samples are generated to replace the while loop. Suppose we are going to generate
a Gamma tensor of size **(N,)**,  N x (M + 1)  zero-one gaussian samples and N x (M + 1) zero-one
uniform samples will be generated before entering the kernel, where M is a predefined const.
For each entity, we generate M proposed Gamma r.v. and then select the first accepted one
as the output.  The one extra sample is required when \alpha is less than one.

3. In case all M proposed samples get rejected in some entities (which would be marked as
-1), we simply resample the random buffer again and perform another round of rejection sampling,
**but only at the entities that fail the last round**.

Here's part of the implementation :

In my experiment, I set M to be 1 ( i.e. no redundant samples are generated.) as the adopted
policy(Marsaglia and Tsang's method) has a rather high acceptance rate of around 98%. 

The profiling result is listed below:

| Size | native numpy | ndarray on GPU | my implementation |
| 10e2 | <0.1ms       | 3~5ms          | 0.5~0.7ms         |
| 10e4 | 0.76ms       | 7.6~7.8ms      | 0.72~0.76ms       |
| 10e6 | 70ms         | 12~13ms        | 3.1ms             |
| 10e8 | 7200ms       | 1600~1700ms    | 150~160ms         |


The new version is currently under development on  `numpy` branch. It also designed to support
broadcastable parameters.

You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
  • Unnamed multipart/alternative (inline, 7-Bit, 0 bytes)
View raw message