Tutorial 3: Intrinsics

This tutorial demonstrates how to use AMDGPU hardware intrinsics to improve GEMM kernels.

Using Matrix Fused Multiply-Add

Modern GPUs include specialized matrix hardware, such as tensor cores on NVIDIA GPUs and matrix cores on AMD GPUs. On AMD GPUs, MFMA instructions let a wave compute a matrix multiply cooperatively. Each lane contributes operand fragments and receives accumulator fragments in the hardware layout described by the AMD matrix instruction calculator.

This example still computes a full 128 x 128 matrix multiplication, C = A @ B^T. Each wave computes one 32 x 32 output tile, and the grid covers the full matrix with 4 x 4 wave tiles. The kernel uses shared memory to move between logical matrix layouts and the lane layouts expected by mfma_32x32x8_bf16_f32.

The calculator reports that v_mfma_f32_32x32x8_bf16 consumes two registers for A and two for B: each lane contributes four BF16 values, lanes 0..31 provide one half of the K = 8 slice, and lanes 32..63 provide the other half. To keep the memory path contiguous, this kernel stages K = 16 at a time. Each lane loads and stores one 16-byte vector for A and one for B. The shared-memory read returns that same 16-byte vector, which is split into two 8-byte fragments for two MFMA instructions.

import avelang
import avelang.language as al

@avelang.jit
def gemm_mfma_128_bf16_mfma(
    A: al.Tensor((128, 128), al.bf16),
    B: al.Tensor((128, 128), al.bf16),
    C: al.Tensor((128, 128), al.f32),
):
    TILE_M = 32
    TILE_N = 32
    TILE_K = 16
    K_TILES = 8

    lane = al.thread_id(0)
    lane_col = lane & 31
    lane_group = lane >> 5

    block_m = al.block_id(1) * TILE_M
    block_n = al.block_id(0) * TILE_N

    A_vec = al.view(A, al.Tensor((128, 16, 4), al.i32))
    B_vec = al.view(B, al.Tensor((128, 16, 4), al.i32))

    a_smem = al.make_shared((TILE_M * (TILE_K >> 3), TILE_K >> 2), al.i32)
    b_smem = al.make_shared((TILE_N * (TILE_K >> 3), TILE_K >> 2), al.i32)
    C_vec = al.view(C, al.Tensor((128, 32, 4), al.i32))

    c_smem = al.make_shared((32, 32), al.f32)
    c_smem_vec = al.view(c_smem, al.Tensor((32, 8, 4), al.i32))
    acc = al.full((16,), 0.0, al.f32)

    for kt in al.range(K_TILES):
        k_vec = kt * 2 + lane_group

        a_smem[lane] = A_vec[block_m + lane_col, k_vec]
        b_smem[lane] = B_vec[block_n + lane_col, k_vec]

        al.syncthreads()

        a_words = a_smem[lane]
        b_words = b_smem[lane]
        a_frag = al.view(a_words, al.Tensor((2, 4, 1), al.bf16))
        b_frag = al.view(b_words, al.Tensor((2, 4, 1), al.bf16))

        acc = al.amdgpu.mfma_32x32x8_bf16_f32(b_frag[0], a_frag[0], acc)
        acc = al.amdgpu.mfma_32x32x8_bf16_f32(b_frag[1], a_frag[1], acc)

        al.syncthreads()

    for r in al.range(16):
        row_offset = ((r >> 2) << 3) + lane_group * 4 + (r & 3)
        c_smem[lane_col, row_offset] = acc[r]

    al.syncthreads()

    store_row = lane >> 1
    store_vec_base = (lane & 1) * 4

    for v in al.range(4):
        C_vec[block_m + store_row, (block_n >> 2) + store_vec_base + v] = (
            c_smem_vec[store_row, store_vec_base + v]
        )

Launch one wave per 32 x 32 tile:

import torch

M = 128
N = 128
K = 128

A = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
B = torch.randn((N, K), dtype=torch.bfloat16, device="cuda")
C = torch.empty((M, N), dtype=torch.float32, device="cuda")

gemm_mfma_128_bf16_mfma[lambda: ((4, 4, 1), (64, 1, 1))](A, B, C)
expected = A.to(torch.float32) @ B.to(torch.float32).T
torch.testing.assert_close(C.cpu(), expected.cpu(), rtol=1e-2, atol=1e-2)

The input tiles are written contiguously as packed i32 vectors. Each lane reads back one 16-byte vector from shared memory, splits it into two 8-byte fragments, and issues two MFMA instructions. The MFMA operands are swapped so the accumulator fragment lines up with row-major C; the final shared-memory step writes four contiguous f32 values at a time.

Runtime Shapes

So far, the MFMA kernel used static 128 x 128 tensors. For a reusable GEMM entry point, accept raw pointers plus runtime m, n, and k, then build typed tensor views inside the kernel.

al.Pointer(dtype) marks a pointer argument. al.make_tensor(ptr, dtype, layout) gives that pointer a tensor view, and al.make_layout() can use runtime dimensions.

BLOCK_M, BLOCK_N, and BLOCK_K describe the logical tile computed by one program instance. In this first MFMA kernel they are 32, 32, and 16: one wave computes one 32 x 32 output tile and advances through K in 16-BF16 chunks. Keeping them as al.constexpr makes the tile shape explicit and lets larger kernels build bigger program tiles by composing multiple MFMA tiles while preserving static shared-memory shapes and loop bounds. The packed shared-memory buffers are sized from these values: BLOCK_M or BLOCK_N rows, BLOCK_K / 8 MFMA lane groups, and BLOCK_K / 4 packed i32 words per vector.

For now, assume the runtime shapes are aligned to the tile and vector shape: m is a multiple of BLOCK_M, n is a multiple of BLOCK_N, and k is a multiple of BLOCK_K. That keeps every MFMA tile and 16-byte vectorized load in bounds.

BF16_BYTES = 2
F32_BYTES = 4

@avelang.jit
def gemm_runtime_shape(
    A_ptr: al.Pointer(al.bf16),
    B_ptr: al.Pointer(al.bf16),
    C_ptr: al.Pointer(al.f32),
    m: al.u32,
    n: al.u32,
    k: al.u32,
    BLOCK_M: al.constexpr,
    BLOCK_N: al.constexpr,
    BLOCK_K: al.constexpr,
):
    A_bf16 = al.make_tensor(A_ptr, al.bf16, al.make_layout((m, k), (k, 1)))
    B_bf16 = al.make_tensor(B_ptr, al.bf16, al.make_layout((n, k), (k, 1)))
    C = al.make_tensor(C_ptr, al.f32, al.make_layout((m, n), (n, 1)))

    k_vecs = k >> 3
    packed_row_stride = k >> 1

    A_vec = al.view(
        A_bf16,
        al.i32,
        al.make_layout((m, k_vecs, 4), (packed_row_stride, 4, 1)),
    )
    B_vec = al.view(
        B_bf16,
        al.i32,
        al.make_layout((n, k_vecs, 4), (packed_row_stride, 4, 1)),
    )
    C_vec = al.view(
        C,
        al.i32,
        al.make_layout((m, n >> 2, 4), (n, 4, 1)),
    )

    lane = al.thread_id(0)
    lane_col = lane & 31
    lane_group = lane >> 5

    block_m = al.block_id(1) * BLOCK_M
    block_n = al.block_id(0) * BLOCK_N

    a_smem = al.make_shared((BLOCK_M * (BLOCK_K >> 3), BLOCK_K >> 2), al.i32)
    b_smem = al.make_shared((BLOCK_N * (BLOCK_K >> 3), BLOCK_K >> 2), al.i32)
    c_smem = al.make_shared((BLOCK_M, BLOCK_N), al.f32)
    c_smem_vec = al.view(
        c_smem,
        al.i32,
        al.make_layout((BLOCK_M, BLOCK_N >> 2, 4), (BLOCK_N, 4, 1)),
    )
    acc = al.full((16,), 0.0, al.f32)

    for kt in al.range(k // BLOCK_K):
        k_vec = kt * 2 + lane_group

        a_smem[lane] = A_vec[block_m + lane_col, k_vec]
        b_smem[lane] = B_vec[block_n + lane_col, k_vec]

        al.syncthreads()

        a_words = a_smem[lane]
        b_words = b_smem[lane]
        a_frag = al.view(a_words, al.Tensor((2, 4, 1), al.bf16))
        b_frag = al.view(b_words, al.Tensor((2, 4, 1), al.bf16))

        acc = al.amdgpu.mfma_32x32x8_bf16_f32(b_frag[0], a_frag[0], acc)
        acc = al.amdgpu.mfma_32x32x8_bf16_f32(b_frag[1], a_frag[1], acc)

        al.syncthreads()

    for r in al.range(16):
        row_offset = ((r >> 2) << 3) + lane_group * 4 + (r & 3)
        c_smem[lane_col, row_offset] = acc[r]

    al.syncthreads()

    store_row = lane >> 1
    store_vec_base = (lane & 1) * (BLOCK_N >> 3)

    for v in al.range(BLOCK_N >> 3):
        C_vec[block_m + store_row, (block_n >> 2) + store_vec_base + v] = (
            c_smem_vec[store_row, store_vec_base + v]
        )

The launch grid now comes from the runtime shape:

m = 128
n = 128
k = 128
C = torch.empty((m, n), dtype=torch.float32, device="cuda")

grid_x = n // 32
grid_y = m // 32

gemm_runtime_shape[lambda: ((grid_x, grid_y, 1), (64, 1, 1))](
    A, B, C, m, n, k, 32, 32, 16
)

expected = A.to(torch.float32) @ B.to(torch.float32).T
torch.testing.assert_close(C.cpu(), expected.cpu(), rtol=1e-2, atol=1e-2)

This is the same MFMA kernel as before, but the row-major tensor layouts come from runtime dimensions. The aligned-shape assumption removes edge guards; the next section shows how raw-buffer loads handle those edge cases while keeping vectorized memory movement.

Hardware-Guarded Buffer Loads

Dynamic shapes introduce edge tiles. With a ceil-divided grid, the final tile in M or N may ask some lanes to load rows outside A or B. AMDGPU raw-buffer operations attach a byte range to a resource descriptor; out-of-range loads return zero, and stores can use the same descriptor form.

The descriptor base must be uniform across the wave. The kernel below creates block-uniform descriptors for A and B, then creates a uniform row descriptor for each C row during writeback. Per-lane row, column, and K positions become byte offsets into those descriptors. It still assumes K and the row length of C are 16-byte aligned, so each vectorized access belongs to one row.

@avelang.jit
def gemm_runtime_shape_guarded_loads(
    A_ptr: al.Pointer(al.bf16),
    B_ptr: al.Pointer(al.bf16),
    C_ptr: al.Pointer(al.f32),
    m: al.u32,
    n: al.u32,
    k: al.u32,
    BLOCK_M: al.constexpr,
    BLOCK_N: al.constexpr,
    BLOCK_K: al.constexpr,
):
    lane = al.thread_id(0)
    lane_col = lane & 31
    lane_group = lane >> 5

    block_m = al.block_id(1) * BLOCK_M
    block_n = al.block_id(0) * BLOCK_N

    A_flat = al.make_tensor(A_ptr, al.bf16, al.make_layout((m * k,), (1,)))
    B_flat = al.make_tensor(B_ptr, al.bf16, al.make_layout((n * k,), (1,)))
    C_flat = al.make_tensor(C_ptr, al.f32, al.make_layout((m * n,), (1,)))

    a_rows = m - block_m
    b_rows = n - block_n
    c_rows = m - block_m
    c_cols = n - block_n

    if a_rows > BLOCK_M:
        a_rows = BLOCK_M

    if b_rows > BLOCK_N:
        b_rows = BLOCK_N

    if c_rows > BLOCK_M:
        c_rows = BLOCK_M

    if c_cols > BLOCK_N:
        c_cols = BLOCK_N

    A_block = al.subview(A_flat, (block_m * k,), (a_rows * k,), (1,))
    B_block = al.subview(B_flat, (block_n * k,), (b_rows * k,), (1,))

    A_rsrc = al.amdgpu.make_rsrc(A_block, a_rows * k * BF16_BYTES)
    B_rsrc = al.amdgpu.make_rsrc(B_block, b_rows * k * BF16_BYTES)

    a_smem = al.make_shared((BLOCK_M * (BLOCK_K >> 3), BLOCK_K >> 2), al.i32)
    b_smem = al.make_shared((BLOCK_N * (BLOCK_K >> 3), BLOCK_K >> 2), al.i32)
    c_smem = al.make_shared((BLOCK_M, BLOCK_N), al.f32)
    c_smem_vec = al.view(
        c_smem,
        al.i32,
        al.make_layout((BLOCK_M, BLOCK_N >> 2, 4), (BLOCK_N, 4, 1)),
    )
    acc = al.full((16,), 0.0, al.f32)

    zero = al.convert(0, al.i32)

    for kt in al.range(k // BLOCK_K):
        k_base = kt * BLOCK_K + lane_group * (BLOCK_K >> 1)
        load_offset = al.convert((lane_col * k + k_base) * BF16_BYTES, al.i32)

        a_smem[lane] = al.amdgpu.raw_buffer_load_x4(A_rsrc, zero, load_offset, 0)
        b_smem[lane] = al.amdgpu.raw_buffer_load_x4(B_rsrc, zero, load_offset, 0)

        al.syncthreads()

        a_words = a_smem[lane]
        b_words = b_smem[lane]
        a_frag = al.view(a_words, al.Tensor((2, 4, 1), al.bf16))
        b_frag = al.view(b_words, al.Tensor((2, 4, 1), al.bf16))

        acc = al.amdgpu.mfma_32x32x8_bf16_f32(b_frag[0], a_frag[0], acc)
        acc = al.amdgpu.mfma_32x32x8_bf16_f32(b_frag[1], a_frag[1], acc)

        al.syncthreads()

    for r in al.range(16):
        row_offset = ((r >> 2) << 3) + lane_group * 4 + (r & 3)
        c_smem[lane_col, row_offset] = acc[r]

    al.syncthreads()

    store_vec = lane
    store_offset = al.convert(store_vec * 4 * F32_BYTES, al.i32)

    for row in al.range(BLOCK_M):
        c_row_range = al.convert(0, al.u32)

        if row < c_rows:
            c_row_range = c_cols * F32_BYTES

        C_row = al.subview(C_flat, ((block_m + row) * n + block_n,), (c_cols,), (1,))
        C_rsrc = al.amdgpu.make_rsrc(C_row, c_row_range)

        if store_vec < (BLOCK_N >> 2):
            al.amdgpu.raw_buffer_store_x4(
                c_smem_vec[row, store_vec],
                C_rsrc,
                zero,
                store_offset,
                0,
            )

The input descriptors start at the first A and B rows owned by the block. Their ranges cover only the valid rows in the edge tile, so lanes assigned past m or n read zeros. For C, the writeback loop creates a uniform row descriptor; its range removes the edge-row and edge-column guards for raw_buffer_store_x4. The remaining lane predicate only selects the eight lanes that own vector stores for a row.

Launch with ceil-divided output tiles:

m = 117
n = 121
k = 128

A = torch.randn((m, k), dtype=torch.bfloat16, device="cuda")
B = torch.randn((n, k), dtype=torch.bfloat16, device="cuda")
C = torch.empty((m, n), dtype=torch.float32, device="cuda")

grid_x = (n + 31) // 32
grid_y = (m + 31) // 32

gemm_runtime_shape_guarded_loads[lambda: ((grid_x, grid_y, 1), (64, 1, 1))](
    A, B, C, m, n, k, 32, 32, 16
)

expected = A.to(torch.float32) @ B.to(torch.float32).T
torch.testing.assert_close(C.cpu(), expected.cpu(), rtol=1e-2, atol=1e-2)