Just-In-Time Compiled CUDA Kernel

Overview

Just-in-Time (JIT) compilation is a powerful tool for developers aiming to optimize applications dynamically. With this technique, compilation of the code can be deferred to runtime rather than ahead-of-time (AoT). Here I explore the use of NVIDIA’s NVRTC runtime compilation library to JIT-compile a simple CUDA kernel for vector addition. With this technique, we can alter CUDA code dynamically based on run-time conditions.

Why to Use JIT in CUDA?

  • On-the-fly generation of CUDA kernels based on configuration at run-time
  • Optimization of kernel for specific GPU architecture
  • Testing without the overhead of recompiling the whole program after every change

Example: Vector Addition

Code is available at vector_add_JIT.cu. Major components of the code are described below.

Kernel code as string

NVRTC requires CUDA kernel written as string. Here is the simple kernel to add two vectors and store it in a third vector.

const char* cudaKernelCode = R"(
extern "C"
__global__ void vectorAdd(const float* A, const float* B, float* C, int N) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < N) {
        C[idx] = A[idx] + B[idx];
    }
}
)";

Compilation

Compile the kernel using NVRTC. options is a vector that reads in the compilation flags passed as arguments at run-time.

    // Use NVRTC to compile the kernel
    nvrtcProgram prog;
    NVRTC_SAFE_CALL(nvrtcCreateProgram(&prog, cudaKernelCode, "vectorAdd.cu", 0, NULL, NULL));

    // Compile the kernel
    nvrtcResult compileResult = nvrtcCompileProgram(prog, options.size(), options.data());

Generation of PTX

Generate PTX code. Optionally, write it to a file for analysis.

    // Obtain the PTX from the program
    size_t ptxSize;
    NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize));
    std::vector<char> ptx(ptxSize);
    NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data()));

    //...

    ptxFile.write(ptx.data(), ptx.size());

Execution

Load the compiled kernel and execute it on GPU.

    // Load the PTX and get the kernel handle
    CUmodule module;
    CUfunction kernel;
    CU_SAFE_CALL(cuModuleLoadData(&module, ptx.data()));
    CU_SAFE_CALL(cuModuleGetFunction(&kernel, module, "vectorAdd"));

    // Launch the kernel
    void* args[] = { &d_A, &d_B, &d_C, &N };
    int threadsPerBlock = 256;
    int blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock;
    CU_SAFE_CALL(cuLaunchKernel(kernel, blocksPerGrid, 1, 1, threadsPerBlock, 1, 1, 0, 0, args, 0));

Compile and Run

nvcc vectorAdd_JIT.cu -o solver.x -lnvrtc -lcuda -std=c++14

Analysis

Let’s run the code with and without passing the argument --use_fast_math and compare whether the JIT-compiled code generates a different assembly code at run time. This flag replaces certain floating point operations with faster versions that may be less accurate for efficiency.

Output of vimdiff command on ptx generated without (left) and with --use_fast_math flag (right).

From the figure above, we see the difference in assembly at an instruction where two floating points in %f2 and %f1 registers are added and stored in %f3 register.

The ftz qualifier stands for Flust To Zero, which affects how the instruction handles very small numbers close to zero, known as subnormal numbers. (When the result of a computation is too small to be represented as a normal floating-point number (given the constraints of the exponent and mantissa), it can be represented as a subnormal number. In subnormal numbers, the leading bit of the mantissa is zero instead of one, allowing for the representation of numbers closer to zero than would otherwise be possible. This is done at the cost of precision).

The ftz mode ensures subnormal numbers are treated as zero, which may prevent performance penalties arising from handling denormalized numbers.

References




    Enjoy Reading This Article?

    Here are some more articles you might like to read next:

  • Exploring Policy-Based Design: A Customizable Message Logger in C++
  • Exploring Type Erasure as a Design Pattern: A Generic Materials Solver
  • GoF Design Patterns: A Brief Overview
  • GPU Practice Codes
  • C++ Template Basics