Introduction
This is the first in a series of posts about programming against the GPU. This is also, really, just an attempt to crystallise my understanding of the subject matter, but maybe some readers will find these notes helpful.
Note that (at least for the time being), the scope of these posts is limited to NVidia hardware and the CUDA runtime (and wider ecosystem).
One of my favourite textbooks is Computer Networking - A Top-Down Approach. It teaches the subject of networking by starting first with the application layer (HTTP, FTP, SMTP, etc.), then moves down the network stack. I found this à la carte style (ie. go only as deep as you feel like) particularly effective, and I’ll try my best to emulate it in this series. To that end, in this first post, we’ll write some simple programs that execute on a GPU, without worrying too much about how such programs are compiled, how the host and device interact, etc. There’ll be plenty of time to dig into that as we go.
It’s likely these posts will also limit themselves in scope to only consider NVidia hardware, together with the CUDA programming environment (and wider ecosystem). In this post, we’ll stick to CUDA C++, but later posts may explore other programming languages and environments.
With all that in mind, let’s write our first CUDA kernels.
CUDA concepts
Before we dive in to writing some code, we should cover a few key concepts.
A CUDA kernel is simply a function which is executed on the GPU. A kernel is written once, and executed multiple times on the device1. In CUDA C++, a kernel is written much like any other function, but they must return void and be defined with a __global__
declaration. For example, the function
__global__ void kernel_function(const int *input, int input_length, int *output) {
// kernel body
}
declares itself a kernel function, and takes in a pointer to some input integers, their length, and a pointer to some memory to write to.
The CUDA programming model also provides a hierarchy of threads, blocks and grids:
- Thread: the basic unit of execution. Each thread runs the same kernel function, and has a unique ID within its block.
- Block: a group of threads that can cooperate with each other through shared memory. All threads in a block are guaranteed to run on the same Streaming Multiprocessor2 (SM). Each block has its own unique ID within its grid.
- Grid: a group of blocks that execute the same kernel function. Blocks in a grid can be scheduled on any SM and run independently in parallel. Each grid runs on a single device.
The CUDA SDK provides makes available three built-in variables to each thread, which can be used to determine both their unique global index, and index within a given block. These are threadIdx
, blockIdx
, and blockDim
:
-
threadIdx
: this built-in variable is a three-component vector (threadIdx.x
,threadIdx.y
, andthreadIdx.z
) that provides the unique ID for each thread within a block. The thread IDs are zero-based, which means they start from 0. If a block of threads is one-dimensional, you only need to usethreadIdx.x
. If it’s two-dimensional, you can usethreadIdx.x
andthreadIdx.y
, and so on. -
blockIdx
: similar to threadIdx, blockIdx is also a three-component vector (blockIdx.x
,blockIdx.y
, andblockIdx.z
) providing the unique ID for each block within a grid. The block IDs are also zero-based. -
blockDim
: this is a three-component vector (blockDim.x
,blockDim.y
, andblockDim.z
) containing the number of threads in a block along each dimension. For example, if you have a block size of 128 threads, and you’ve organized it as a 4x32 two-dimensional block, blockDim.x would be 4 and blockDim.y would be 32.
When launching a kernel, at run time, you specify the number of blocks in the grid and the number of threads in each block. We go into more detail below, but for example, the kernel launch
kernel_function<<<16, 1024>>>( /* ... params ... */ )
results in kernel_function
executing across a grid of 16 blocks, each having 1024 threads. Note that we could also pass in dim3
types instead of integers; in this case they are implicitly converted to 1-dimensional representations.
Finally, we will also need to manage memory access and synchronisation between host and device. The CUDA SDK provides this functionality, as we’ll see below.
The scalar product
We’ll use the example of computing the scalar product of two floating point vectors. Recall that, for two vectors $x = (x_i)$ and $y = (y_i)$ of some $n$-dimensional vector space, their scalar (or dot) product $x \cdot y$ is the sum of the pairwise products of each vector’s components: $$ x \cdot y = \sum_{i=0}^{n - 1} x_i y_i $$
The scalar product is a worthwhile place to start for two reasons:
- Scalar products are everywhere. Matrix multiplications are just lots of scalar products, and machine learning is just lots of matrix multiplications.
- It is straightforward to implement, but not so trivial that we can’t learn anything from it.
We’ll actually write two kernels – one to compute the pairwise products, and a second to compute their sum.
A quick caveat: what follows is likely not the most efficient way to do this – if you’re serious about computing scalar products on a GPU, you should be using something like cuBLAS3!
The first kernel looks like this:
// CUDA kernel to compute the pairwise product of vector elements
__global__ void pairwiseProducts(const float* input_x, const float* input_y, float *output, int n) {
// Get the global thread ID
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
// Check if the thread ID is within the range of N
if (idx < n) {
// Compute the product of the corresponding elements from a and b
output[idx] = input_x[idx] * input_y[idx];
}
}
The first line computes the global thread ID from the built-in variables described above. In this case, our data is one-dimensional, and (as we will see later), we’ll launch the kernel with a one-dimensional configuration, so we only care about the .x
attributes of each. Let’s break it down a bit further:
threadIdx.x
: this is the index of the executing thread within its block. There may be another thread in another block with the same index, so this is not globally unique.blockIdx.x * blockDim.x
:blockIdx.x
is the index of the block in which the thread is executing, andblockDim.x
is the number of threads in each block. In particular,blockIdx.x
is always greater thanthreadIdx.x
, so adding a multiple ofblockDim.x
is what makes the sum result in the globally unique index.
We next check if the thread has any work to do. If so, we set the output index to be the product of the inputs at the same index. As this index is unique, there are no race conditions to worry about.
Our second kernel is more interesting. Once we have computed our pairwise products, we now need to sum them. Recall above that we said a block of threads have access to shared memory – we’ll make use of that feature here. We’ll proceed as follows:
- Within a given block, copy a chunk (addressed by the global thread index) to some shared memory.
- Wait for all threads to finish.
- Consider a window over all the shared memory. Repeatedly add the right half of the window to the left half, then halve the window size and repeat until done.
- The first element of the shared memory then contains the scalar sub-product of the thread block.
We may need to call this second kernel multiple times, to reduce the results of each block to a single number.
// CUDA kernel for parallel reduction
__global__ void parallelSum(const float* input, float* output, int n) {
// Define an external shared array accessible by all threads in a block
extern __shared__ float sdata[];
// Get the global and local thread ID
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
unsigned int tid = threadIdx.x;
// Load data from global to shared memory
sdata[tid] = (idx < n) ? input[idx] : 0;
// Sync all threads in the block
__syncthreads();
// Do reduction in shared memory
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
// Make sure all additions at the current stage are done
__syncthreads();
}
// Write result of this block to global memory
if (tid == 0) output[blockIdx.x] = sdata[0];
}
Similar to the first kernel, we use the expression threadIdx.x + blockIdx.x * blockDim.x
to compute a thread’s global index. There’s quite a bit more going on here, though. For starters, the kernel begins with the line
extern __shared__ float sdata[];
Here is where the block shared memory is declared. This needs a bit of unpacking – from right to left:
float sdata[]
: an array of floating point numbers.__shared__
: this keyword tells the CUDA compiler that the array should be placed in shared memory on the device.extern
: since we do not know the size of the shared memory array at compile time, we tell the compiler it will be defined elsewhere. Specifically, in a kernel launch configuration, we can optionally provide an optionalsize_t
number of bytes of shared memory to be allocated per block (see below).
After thread indices established, the shared memory receives a copy of the input data addressed by the block:
// Load data from global to shared memory
sdata[tid] = (idx < n) ? input[idx] : 0;
// Sync all threads in the block
__syncthreads();
We call __syncthreads()
here to ensure no thread proceeds past this point until the shared memory has been written.
Once the shared memory is populated, the sum occurs:
// Do reduction in shared memory
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
// Make sure all additions at the current stage are done
__syncthreads();
}
// Write result of this block to global memory
if (tid == 0) output[blockIdx.x] = sdata[0];
At the start of the loop we consider a window over the entire shared memory and mark its middle as s
. The threads whose index fall in the right hand side of this window have no more work to do at this point, while the threads in the left hand side perform the sum. At each iteration of the loop, the working thread takes its index within the block, and updates the shared memory at that index with the corresponding value in the right half of the window. We repeatedly halve the length of the window until we’re left with the sum of the block at index zero, at which point the thread with index zero copies it back out to global memory.
Putting it all together
At this point, we’ve written the kernels which will do our heavy lifting. We still need to call these from the host. Let’s put this all together in a complete program. It does the following:
- Declares a block size of 256.
- Creates two constant vectors of length 1048576 and copies them to the GPU.
- Allocates additional memory on the GPU to store the output of their pairwise component products.
- Launches the first kernel
pairwiseProducts
to compute and store these products. - Repeatedly calls the second kernel
parallelSum
, at each stage halving the number of blocks to consider. Each loop also allocates memory on the device to store intermediate results. - Copies the result from the device back to the host and prints it out.
- Cleans up any allocated memory on both host and device.
#include <iostream>
#include <cuda.h>
// ...
// CUDA kernels defined above
// ...
int main() {
const int N = 1<<20;
size_t size = N * sizeof(float);
const int BLOCK_SIZE = 256;
float *h_x, *h_y;
h_x = (float*)malloc(size);
h_y = (float*)malloc(size);
for (int i = 0; i < N; i++) {
h_x[i] = 1.0;
h_y[i] = 2.0;
}
float *d_x, *d_y, *d_z;
cudaMalloc(&d_x, size);
cudaMalloc(&d_y, size);
cudaMalloc(&d_z, size);
cudaMemcpy(d_x, h_x, size, cudaMemcpyHostToDevice);
cudaMemcpy(d_y, h_y, size, cudaMemcpyHostToDevice);
// Compute the product of the two vectors.
pairwiseProducts<<<N/BLOCK_SIZE, BLOCK_SIZE>>>(d_x, d_y, d_z, N);
cudaDeviceSynchronize();
// Compute the sum of the products with parallel reduction.
int numElements = N;
float* d_in = d_z;
float* d_out;
while(numElements > 1) {
int numBlocks = (numElements + BLOCK_SIZE - 1) / BLOCK_SIZE;
cudaMalloc(&d_out, numBlocks*sizeof(float));
parallelSum<<<numBlocks, BLOCK_SIZE, BLOCK_SIZE*sizeof(float)>>>(d_in, d_out, numElements);
cudaDeviceSynchronize();
if (d_in != d_z) { // Don't free the original input array.
cudaFree(d_in);
}
d_in = d_out;
numElements = numBlocks;
}
float h_out;
cudaMemcpy(&h_out, d_out, sizeof(float), cudaMemcpyDeviceToHost);
std::cout << "The dot product is: " << h_out << std::endl;
cudaFree(d_x);
cudaFree(d_y);
cudaFree(d_z);
cudaFree(d_out);
free(h_x);
free(h_y);
return 0;
}
I won’t cover everything line by line, but let’s look a bit at memory management and kernel execution.
Memory management
The CUDA SDK provides an API for memory management between host and device. A full survey is beyond the scope of this article, but we use some of its basic functionality here. For example, the following code allocates space for three floating point arrays on the device with cudaMalloc
, and copies the input from the host with cudaMemcpy
with kind cudaMemcpyHostToDevice
(copying from device to host is later accomplished with cudaMemcpyDeviceToHost
).
float *d_x, *d_y, *d_z;
cudaMalloc(&d_x, size);
cudaMalloc(&d_y, size);
cudaMalloc(&d_z, size);
cudaMemcpy(d_x, h_x, size, cudaMemcpyHostToDevice);
cudaMemcpy(d_y, h_y, size, cudaMemcpyHostToDevice);
Similarly, like free
, memory on device can be released once it’s no longer needed:
cudaFree(d_x);
cudaFree(d_y);
cudaFree(d_z);
cudaFree(d_out);
Kernel execution
We execute our kernels in two places in the above program. First, to compute the pairwise products, with the following invocation:
pairwiseProducts<<<N/BLOCK_SIZE, BLOCK_SIZE>>>(d_x, d_y, d_z, N);
cudaDeviceSynchronize();
The expression <<<N/BLOCK_SIZE, BLOCK_SIZE>>>
is the execution configuration for the kernel. Here, we are declaring a grid of 4096 blocks, each having 256 threads. The CUDA runtime itself is responsible for orchestrating how and when these execute. Launching a kernel is asynchronous by default and returns immediately, but we want to wait for this computation to finish, so we call cudaDeviceSynchronize()
to wait.
Next, we have the summation itself:
// Compute the sum of the products with parallel reduction.
int numElements = N;
float* d_in = d_z;
float* d_out;
while(numElements > 1) {
int numBlocks = (numElements + BLOCK_SIZE - 1) / BLOCK_SIZE;
cudaMalloc(&d_out, numBlocks*sizeof(float));
parallelSum<<<numBlocks, BLOCK_SIZE, BLOCK_SIZE*sizeof(float)>>>(d_in, d_out, numElements);
cudaDeviceSynchronize();
if (d_in != d_z) { // Don't free the original input array.
cudaFree(d_in);
}
d_in = d_out;
numElements = numBlocks;
}
Each iteration of this loop launches the parallelSum
kernel with the configuration <<<numBlocks, BLOCK_SIZE, BLOCK_SIZE*sizeof(float)>>>
. The first two parameters are what we saw above – the grid size and block size (respectively). The third parameter tells the CUDA runtime to allocate BLOCK_SIZE*sizeof(float)
bytes of shared memory in each block. After the first iteration, d_out
contains 4096 scalar sub-products. After the second iteration, it contains 16, and finally in the third iteration the final sum is computed, ready to be copied back to the host.
Wrapping up
We’ve now seen how to use some simple features of the CUDA SDK and runtime environment to implement a parallel scalar product on a GPU. In the next article, we’ll see how to benchmark this code against similar code executed entirely on-host, and (in fairness), see how it falls short against an optimised implementation such as cuBLAS.
-
In the context of these posts, and generally in CUDA parlance, “device” refers to the physical hardware that executes a kernel (eg. my 2080 Super), and “host” refers to the machine which calls it. ↩︎
-
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#hardware-implementation ↩︎