Interactive tiled matrix multiplication
Matrix multiplication is the first GPU kernel most people write, and also the first one that teaches you the difference between “technically correct” and “worth running.” The naive version is trivially parallel but heavily memory-bound: every thread reads its row of A and column of B straight from global DRAM, and adjacent threads re-fetch almost all the same values independently. The GPU spends most of its time waiting on loads rather than doing arithmetic.
Shared memory is a small on-chip scratchpad (tens of KB per SM) that all threads in a block can read and write at close-to-register speed. It doesn’t make DRAM faster, but it lets a block cooperatively stage a slice of the inputs into the scratchpad once and then reuse it many times, turning what were N independent slow loads into one cooperative slow load plus many fast ones. Shared-memory tiling is the natural way to apply that idea to matrix multiplication, and it offers good insight into one of the central aspects of GPU programming: managing the memory hierarchy.
This post walks from the naive kernel to a tiled one, deriving why tiling helps and what the indexing looks like at every step. All diagrams are designed to be interactive.
Throughout, we assume square matrices of size N, with a tile size TILE_DIM that divides N cleanly. We also assume the block size equals the tile size: a TILE_DIM × TILE_DIM block of threads, one thread per output element of a tile. Those assumptions make the indexing pleasant to read; dropping them means adding boundary checks and per-thread loops, and obscures the core idea.
This post was created with the aim of visualing the memory accesses that occur in tiling and is by no means meant as a complete introductory guide on GPU programming. As such it assumes some level of familiarity with GPU programming, specifically in CUDA. I refer anyone looking for a comprehensive guide to this amazing post by Aleksa Gordic.
The naive kernel
For C = AB with N × N matrices, launch N × N threads, one per output element:
__global__ void matmul_naive(const float* A, const float* B, float* C, int N) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= N || col >= N) return;
float acc = 0.0f;
for (int k = 0; k < N; ++k) {
acc += A[row * N + k] * B[k * N + col];
}
C[row * N + col] = acc;
}
Each thread independently streams one row of A and one column of B out of global memory to compute with. Pick an output cell below and watch which cells in memory it touches:
Why this is bad: arithmetic intensity
Let’s analyse what is happening in terms of memory operation versus computation:
- Loads:
Nfloats from A +Nfloats from B =2Nfloats =8Nbytes - Compute:
Nfused multiply-adds =2NFLOPs
So the kernel does 2N / 8N = 0.25 FLOPs per byte loaded. That’s the arithmetic intensity. An A100 is rated at ~19.5 TFLOPs/s of fp32 against ~1.5 TB/s of HBM bandwidth, so its roofline breakeven is ~13 FLOPs/byte. At 0.25, we’re running the kernel at roughly 0.25 / 13 ≈ 2% of peak compute. The GPU is doing almost nothing but waiting for loads.
So what problem are we solving exactly? Well, every thread loads its row of A and its column of B from scratch. Thread (row, col) and thread (row, col+1) both read the entire row of A, independently, out of slow DRAM.
Grouping threads: output tiling
The constraint we’re working with is that shared memory is small and only visible within a single thread block. So for sharing to buy us anything, we have to partition the work so that threads which need overlapping inputs end up in the same block.
Matrix multiplication makes this easy: threads computing adjacent outputs of C use overlapping inputs, so grouping them by output region is the natural partition. We pick a size TILE_DIM for how coarse that grouping is, meaning each block is responsible for a TILE_DIM × TILE_DIM region of C. That region depends on TILE_DIM rows of A and TILE_DIM columns of B in total, which is still too much to hold in shared memory all at once, but we’ll see shortly that we can walk it as a sequence of TILE_DIM × TILE_DIM chunks, which do fit, and which every thread in the block can reuse.
Step one is to rearrange the launch so that threads computing a common output tile live in the same block:
No loads have been saved yet, we’ve only relabelled which thread belongs to which block. But now threads in a block share an L1-speed scratchpad (shared memory), and we can use it.
Walking the K axis in tiles
A block’s output tile of C needs a full strip of A (all TILE_DIM rows, N columns wide) and a full strip of B (N rows, TILE_DIM columns). Instead of loading those strips whole, chop them into TILE_DIM × TILE_DIM sub-tiles along the K axis and process one pair at a time.
On each iteration the block:
- Cooperatively loads one
TILE_DIM × TILE_DIMtile of A into shared memory. - Cooperatively loads one
TILE_DIM × TILE_DIMtile of B into shared memory. - Synchronises so everyone sees the loaded values.
- Each thread adds
TILE_DIMmultiply-adds to its own running accumulator, reading out of shared memory. - Synchronises again before overwriting the tiles on the next iteration.
Cooperative loading means each of the TILE_DIM × TILE_DIM threads in the block copies exactly one element of A and one element of B from global memory into shared memory. The thread at threadIdx=(y, x) is responsible for the slot (y, x) of both shared tiles. On tile iteration t, that translates to:
As[threadIdx.y][threadIdx.x] = A[row * N + (t * TILE_DIM + threadIdx.x)];
Bs[threadIdx.y][threadIdx.x] = B[(t * TILE_DIM + threadIdx.y) * N + col];
For A, the row is row = blockIdx.y * TILE_DIM + threadIdx.y, since this thread is always working on the same row of C. The column t * TILE_DIM + threadIdx.x slides right as t increases: tile 0 covers columns 0..TILE_DIM-1, tile 1 covers TILE_DIM..2*TILE_DIM-1, and so on. B is the dual: column col is fixed, and the row t * TILE_DIM + threadIdx.y slides down.
A quick note on the
[threadIdx.y][threadIdx.x]convention:threadIdx.xis the fastest-varying index across a warp (32 consecutive threads differ only in their.x), so putting it in the innermost slot means a warp reads 32 contiguous addresses in a single coalesced memory transaction rather than 32 scattered ones.
Click any element of C below to pick a thread, and step through the tile iterations to see exactly which indices get read and written:
Notice the reuse: a single TILE_DIM × TILE_DIM tile of A loaded into shared memory is read by all TILE_DIM × TILE_DIM threads in the block, i.e. TILE_DIM times more often than in the naive kernel. Same for B. That’s where the speedup comes from.
The new arithmetic intensity
Per block, per k-tile iteration:
- Loads:
2 · TILE_DIM · TILE_DIMfloats =8 · TILE_DIM²bytes. - Compute:
TILE_DIM · TILE_DIM · TILE_DIM · 2FLOPs (each of theTILE_DIM²threads doesTILE_DIMFMAs).
Intensity: 2 · TILE_DIM³ / (8 · TILE_DIM²) = TILE_DIM / 4 FLOPs/byte.
For TILE_DIM = 16 that’s 4 FLOPs/byte: a 16× improvement over the naive kernel, enough to put the kernel in a much more compute-bound regime. For TILE_DIM = 32, 8 FLOPs/byte, approaching the A100 roofline.
Putting it together
Here is the full kernel:
__global__ void matmul_tiled(const float* A, const float* B, float* C, int N) {
__shared__ float As[TILE_DIM][TILE_DIM];
__shared__ float Bs[TILE_DIM][TILE_DIM];
int row = blockIdx.y * TILE_DIM + threadIdx.y;
int col = blockIdx.x * TILE_DIM + threadIdx.x;
float acc = 0.0f;
for (int t = 0; t < N / TILE_DIM; ++t) {
As[threadIdx.y][threadIdx.x] = A[row * N + (t * TILE_DIM + threadIdx.x)];
Bs[threadIdx.y][threadIdx.x] = B[(t * TILE_DIM + threadIdx.y) * N + col];
__syncthreads();
for (int k = 0; k < TILE_DIM; ++k) {
acc += As[threadIdx.y][k] * Bs[k][threadIdx.x];
}
__syncthreads();
}
C[row * N + col] = acc;
}
To make the accumulation concrete, here’s a fixed 6 × 6 scenario with actual integer values for A and B. Pick any output block in C and step through the three k-tile iterations: the highlighted 2 × 2 tiles of A and B are the ones loaded into shared memory on that iteration, and the partial-sum slots below fill in as the running total builds up to each cell’s final value.
Where to go from here
TILE_DIM / 4 FLOPs/byte is a huge win over naive, but real matmul kernels go further:
- Register tiling. Each thread computes not one but
TM × TNoutput elements, accumulating in registers. Reuse shifts from shared memory -> registers, the fastest storage on the chip. - Warp-level tiling. Groups of 32 threads cooperate on a larger sub-tile, using the warp shuffle network for exchange.
- Double buffering. Load the next k-tile while computing on the current one, overlapping DRAM latency with compute.
- Tensor cores. On Volta and later, dedicated matmul units operate on
16×16or16×8×16tiles of fp16/bf16 with dramatically higher throughput. cuBLAS and CUTLASS wrap this with much more subtlety than we’ve covered here.