Home » Blogs » AMD lab notes » AMD matrix cores

AMD matrix cores

Matrix multiplication is a fundamental aspect of Linear Algebra and it is an ubiquitous computation within High Performance Computing (HPC) Applications. Since the introduction of AMD’s CDNA Architecture, Generalized Matrix Multiplication (GEMM) computations are now hardware-accelerated through Matrix Core Processing Units. Matrix Core accelerated GEMM kernels lie at the heart of BLAS libraries like rocBLAS but they can also be programmed directly by developers. Applications that are throughput bound by GEMM computation can achieve additional speedups by utilizing Matrix Cores.

AMD’s Matrix Core technology supports a full range of mixed precision operations bringing us the ability to work with large models and enhance memory-bound operation performance for any combination of AI and machine learning workloads. The various numerical formats have uses in different applications. Examples include use of 8-bit integers (INT8) for ML inference, 32-bit floating point (FP32) data for ML Training and HPC applications, 16-bit floating point (FP16) data for graphics workloads and 16-bit brain float (BF16) data for ML training with fewer convergence issues.

To learn more about the theoretical speedups achievable by using matrix cores compared to SIMD Vector Units, please refer to the tables below. The tables list the performance of the Vector (i.e. Fused Multiply-Add or FMA) and Matrix core units of the previous generation (MI100) and current generation (MI250X) of CDNA Accelerators.

Matrix Core Performance for MI100 and MI250X:

Data format

MI100 Flops/Clock/CU

MI250X Flops/Clock/CU

FP64

N/A

256

FP32

256

256

FP16

1024

1024

BF16

512

1024

INT8

1024

1024

Vector (FMA) Unit Performance for MI100 and MI250X:

Data format

MI100 Flops/Clock/CU

MI250X Flops/Clock/CU

FP64

64

128

FP32

128

128

Matrix Core speedup compared to Vector Unit Performance for MI100 and MI250X. Note, MI250X also supports packed FP32 instructions which also double the FP32 throughput:

Data format

MI100 Matrix/Vector Speedup

MI250X Matrix/Vector Speedup

FP64

N/A

2x

FP32

2x

2x

Using AMD matrix cores

The Matrix Fused Multiply Add (MFMA) instructions in AMD CDNA GPUs operate on a per-wavefront basis, rather than on a per-lane (thread) basis: entries of the input and output matrices are distributed over the lanes of the wavefront’s vector registers.

AMD Matrix Cores can be leveraged in several ways. At a high level, one can use libraries such as rocBLAS or rocWMMA to do matrix operations on the GPU. For instance, rocBLAS may choose to use MFMA instructions if this is advantageous for the computations at hand. For a closer-to-the-metal approach, one can choose to

  • write GPU kernels entirely in assembly (which can be somewhat challenging and impractical)

  • sprinkle HIP kernels with inline assembly (not recommended, since the compiler does not look at the semantics of the inlined instructions, and may not take care of data hazards such as the mandatory number of cycles before using the results of the MFMA instructions)

  • use compiler intrinsics: these represent the assembly instructions in such a way that the compiler knows about the semantics and requirements.

The coding examples in this post use some of the available compiler intrinsics for the MFMA instructions and show how to map entries of the input and output matrices to lanes of vector registers of the wavefront. All the examples use a single wavefront to compute a small matrix multiplication. The examples are not intended to show how to achieve high performance out of the MFMA operations.

MFMA compiler intrinsic syntax

Consider the following multiplication MFMA operation where all the operands A, B, C, and D are matrices:

D = A B + C

To perform the MFMA operation on AMD GPUs, LLVM has builtin intrinsic functions. Recall that these intrinsic functions are executed on a wavefront-wide basis and pieces of the input and output matrices are loaded into registers of each lane in the wavefront. The syntax of a MFMA compiler intrinsic is shown below.

d = __builtin_amdgcn_mfma_CDFmt_MxNxKABFmt (a, b, c, cbsz, abid, blgp)

where,

  • CDfmt is the data format of the C & D matrices

  • ABfmt is the data format of the A & B matrices

  • M, N and K are matrix dimensions:

    • mA[M][K] Source A matrix

    • mB[K][N] Source B matrix

    • mC[M][N] Accumulation input matrix C

    • mD[M][N] Accumulation result matrix D

  • a is the set of vector registers storing values from source matrix A

  • b is the set of vector registers storing values from source matrix B

  • c is the set of vector registers storing values from accumulation input matrix C

  • d is the set of vector registers storing values of accumulation result matrix D

  • cbsz, the Control Broadcast Size modifier, is used to change which input values are fed to the matrix cores and is supported only by instructions that have multiple input blocks for the A matrix. Setting cbsz informs the instruction to broadcast values of one chosen input block to 2^{cbsz} other neighboring blocks in A. The input block picked for broadcasting is determined by the abid parameter. The default value of 0 results in no broadcast of values. As an example, for a 16-block A matrix, setting cbsz=1 would result in blocks 0 and 1 receiving the same input values, blocks 2 and 3 receiving the same input values, blocks 4 and 5 receiving the same input values, etc.

  • abid, the A-matrix Broadcast Identifier, is supported by instructions that have multiple input blocks for the A matrix. It is used with cbsz and indicates which input block is selected for broadcast to other neighboring blocks in the A matrix. As an example, for a 16-block A matrix, setting cbsz=2 and abid=1 will result in the values from block 1 to be broadcast to blocks 0-3, values from block 5 to be broadcast to blocks 4-7, values from block 9 to be broadcast to blocks 8-11, etc.

  • blgp, the B-matrix Lane Group Pattern modifier, allows a constrained set of swizzling operations on B matrix data between lanes. For the instructions that support this modifier, the following values are supported:

    • blgp=0 normal matrix layout for B

    • blgp=1 the B matrix data from lanes 0-31 is also broadcast into lanes 32-63

    • blgp=2 the B matrix data from lanes 32-63 is broadcast into lanes 0-31

    • blgp=3 the B matrix data from all lanes is rotated down by 16 (e.g., lane 0’s data is put into lane 48, lane 16’s data is put into lane 0)

    • blgp=4 the B matrix data from lanes 0-15 is broadcast into lanes 16-31, 32-47, and 48-63

    • blgp=5 the B matrix data from lanes 16-31 is broadcast into lanes 0-15, 32-47, and 48-63

    • blgp=6 the B matrix data from lanes 32-47 is broadcast into lanes 0-15, 16-31, and 48-63

    • blgp=7 the B matrix data from lanes 48-63 is broadcast into lanes 0-15, 16-31, and 32-47

The matrix dimensions and number of blocks supported on CDNA2 GPUs are listed in the table below.

A/B Data Format

C/D Data Format

M

N

K

Blocks

Cycles

Flops/cycle/CU

FP32

FP32

32

32

2

1

64

256

32

32

1

2

64

256

16

16

4

1

32

256

16

16

1

4

32

256

4

4

1

16

8

256

FP16

FP32

32

32

8

1

64

1024

32

32

4

2

64

1024

16

16

16

1

32

1024

16

16

4

4

32

1024

4

4

4

16

8

1024

INT8

INT32

32

32

8

1

64

1024

32

32

4

2

64

1024

16

16

16

1

32

1024

16

16

4

4

32

1024

4

4

4

16

8

1024

BF16

FP32

32

32

8

1

64

1024

32

32

4

2

64

1024

16

16

16

1

32

1024

16

16

4

4

32

1024

4

4

4

16

8

1024

32

32

4

1

64

512

32

32

2

2

64

512

16

16

8

1

32

512

16

16

2

4

32

512

4

4

2

16

8

512

FP64

FP64

16

16

4

1

32

256

4

4

4

4

16

128

A complete list of all instructions supported by the CDNA2 Architecture can be found in the AMD Instinct MI200 Instruction Set Architecture Reference Guide. AMD’s Matrix Instruction Calculator tool allows generating more information such as computational throughput and register usage of MFMA instructions on AMD Radeon™ and AMD Instinct™ accelerators.

Example 1 – V_MFMA_F32_16x16x4F32

Consider the matrix multiplication operation D = A B where M=N=16 and K=4 and the elements are of type FP32. Assume that the input C matrix contains zeroes for simplicity sake. We will demonstrate the use of the intrinsic function __builtin_amdgcn_mfma_f32_16x16x4f32 that computes the sum of four outer products in one invocation. This function operates on a single block of matrices.

The input matrices, A and B, have dimensions 16\times4 and 4\times16 respectively and the matrices C and D have 16\times16 elements. It is convenient to map a 16\times4 thread block to the elements of both input matrices. Here, the thread block has one wavefront with 16 threads in the x dimension and 4 threads in the y dimension. We represent the matrices in row major format: A[i][j] = j + i*N where i is the row and j is the column. Using this representation, a thread at position x, y in the block would load entry A[x][y] and B[y][x]. The output matrix has 16\times16 elements, so each thread would have 4 elements to store as illustrated in the figure and code snippet below.

The following two figures show 1) the shape and size of the A and B inputs; and 2) how the elements of A and B map to lanes in the register owned by the wavefront.

!

!

The following two figures show 1) the shape and size of the output matrix D; and 2) how the elements of D map to lanes in the registers owned by the wavefront:

!

!

An example kernel performing this MFMA operation is given below.

Copied!

#define M 16
#define N 16
#define K 4
 
 
__global__ void sgemm_16x16x4(const float *A, const float *B, float *D)
{
  using float4 = __attribute__( (__vector_size__(K * sizeof(float)) )) float;
  float4 dmn = {0};
 
  int mk = threadIdx.y + K * threadIdx.x;
  int kn = threadIdx.x + N * threadIdx.y;
 
  float amk = A[mk];
  float bkn = B[kn];
  dmn = __builtin_amdgcn_mfma_f32_16x16x4f32(amk, bkn, dmn, 0, 0, 0);
 
  for (int i = 0; i < 4; ++i) {
    const int idx = threadIdx.x + i * N + threadIdx.y * 4 * N;
    D[idx] = dmn[i];
  }
}

This kernel is launched as follows.

Copied!

dim3 grid (1, 1, 1);
dim3 block(16, 4, 1);
 
sgemm_16x16x4 <<< grid, block >>> (d_A, d_B, d_D);

As previously noted, the input C matrix is assumed to contain zeroes.

Example 2 – V_MFMA_F32_16x16x1F32

Consider the case of multiplying matrices of dimensions M=N=16 and K=1 using the compiler intrinsic __builtin_amdgcn_mfma_f32_16x16x1f32. In this case, the input values could be held just by 16 lanes of the wavefront. In fact, this instruction could simultaneously multiply 4 such matrices thereby having each lane hold values from one of those 4 matrices.

We can re-use the figure from the previous example to illustrate the data layout for this operation too. The input A is not a 16 \times 4 matrix in this case but four 16 \times 1 matrices. But the way they are laid out, and the elements that are owned by each lane in the wavefront is the same. The “columns” of A are distinct 16 \times 1 matrices. The input B is similar.

!

!

The output of a given matrix multiplication has exactly the same data layout as in the previous example. The difference is that now there are four separate outputs, one for each multiplication.

The kernel below shows an example of this multiplication for a packed batch of 4 matrices of size M=N=16 and K=1.

Copied!

#define M 16
#define N 16
#define K 1
 
 
__global__ void sgemm_16x16x1(const float *A, const float *B, float *D)
{
  using float16 = __attribute__( (__vector_size__(16 * sizeof(float)) )) float;
  float16 dmnl = {0};
 
  int mkl = K * threadIdx.x + M * K * threadIdx.y;
  int knl = threadIdx.x + N * K * threadIdx.y;
 
  float amkl = A[mkl];
  float bknl = B[knl];
 
  dmnl = __builtin_amdgcn_mfma_f32_16x16x1f32(amkl, bknl, dnml, 0, 0, 0);
 
  for (int l = 0; l < 4; ++l) {
    for (int i = 0; i < 4; ++i) {
      const int idx = threadIdx.x + i * N  + threadIdx.y * 4 * N + l * M * N;
      D[idx] = dmnl[i];
    }
  }
}

This kernel is launched using:

Copied!

dim3 grid (1, 1, 1);
dim3 block(16, 4, 1);
 
sgemm_16x16x1 <<< grid, block >>> (d_A, d_B, d_D);

Example 3 – V_MFMA_F64_4x4x4F64

Consider the V_MFMA_F64_4x4x4F64 instruction, which computes the MFMA of four independent blocks of matrices of size 4\times4. The operation performed is Z_N = W_N X_N + Y_N where, W_N, X_N, Y_N and Z_N are all matrices of size 4\times4 elements and N = 0,1,2,3.

The two figures below show 1) the size and shapes of the four components of the input arguments A and B and 2) how those components map to lanes in the register owned by the wavefront. The arguments of this instruction include A, B, C and returns D, so we understand that each argument and the output holds 4 matrices.

!

!

The layout for output D and input C is the same as the layout for input B.

A note on rocWMMA

We have presented only three examples of leveraging AMD Matrix cores using compiler intrinsics. More examples can be found here. Note that the builtin functions may change in the future, so it may be better to use AMD’s rocWMMA C++ library instead for accelerating mixed precision MFMA operations. The rocWMMA API facilitates breaking down matrix multiply-accumulate problems into fragments and using them in block-wise operations that are distributed in parallel across wavefronts. The API is a header library of GPU device code allowing matrix core acceleration may be compiled directly into your kernel device code. This can benefit from compiler optimization in the generation of kernel assembly. More details are available in the rocWMMA repo.

A note on the AMD Matrix Instruction Calculator tool

For those curious about how various MFMA instructions perform on AMD Radeon™ and AMD Instinct™ accelerators and would like to understand the mapping between matrix elements and hardware registers, we direct you to the AMD Matrix Instruction Calculator tool. This powerful tool can be used to describe WMMA instructions as well as MFMA ISA-level instructions for a given architecture. We welcome issues and feedback from the community.

References

If you have any questions or comments, please reach out to us on GitHub Discussions


Gina Sitaraman
Gina Sitaraman

Gina Sitaraman is a Senior Member of Technical Staff (SMTS) Software System Design Engineer in the Data Center GPU Software Solutions group. She obtained her PhD in Computer Science from the University of Texas at Dallas. She has over a decade of experience in the seismic data processing field developing and optimizing pre-processing, migration and post processing applications using hybrid MPI + OpenMP on CPU clusters and using CUDA or OpenCL on GPUs. She spends her time at AMD solving optimization challenges in scientific applications running on large-scale HPC clusters.

Noel Chalmers
Noel Chalmers

Noel Chalmers is a Senior Member of Technical Staff (SMTS) in the Data Center GPU Software Solutions Group at AMD. Noel is the lead developer of the rocHPL benchmark, AMD's optimized implementation of the well-known LINPACK benchmark which is responsible for achieving over 1 Exaflop of performance on the Frontier supercomputer at ORNL. Noel obtained a PhD in Applied Mathematics from the University of Waterloo, where he studied convergence and stability properties of the discontinuous Galerkin finite element method on hyperbolic systems. Noel's research interests include GPU acceleration of high-order continuous and discontinuous finite element methods, and large-scale geometric and algebraic multigrid methods.

Nicholas Malaya
Nicholas Malaya

Nicholas Malaya is an AMD Fellow with an emphasis in software development, algorithms, and high performance computing. He is AMD's technical lead for exascale application performance, and is focused on ensuring workloads run efficiently on the world's largest supercomputers. Nick's research interests include HPC, computational fluid dynamics, Bayesian inference, and ML/AI. He received his PhD from the University of Texas. Before that, he double majored in Physics and Mathematics at Georgetown University where he received the Treado medal. In his copious spare time he enjoys motorcycles, long distance running, wine, and spending time with his wife and children.

Damon McDougall
Damon McDougall

Damon McDougall is a Principal Engineer in the Data Center GPU Software Solutions Group at AMD. He obtained a PhD in Mathematics from the University of Warwick and spent six years as a researcher in the Oden Insitute for Computational Engineering and Sciences and at the Texas Advanced Computing Center. At AMD, Damon spends his time as part of the Frontier Center of Excellence team optimizing scientific codes for AMD's CPUs and GPUs at exascale. His professional interests include high-performance computing, large-scale systems, uncertainty quantification, statistical computing, and scientific software development.

Ossian O'Reilly
Ossian O'Reilly

Ossian O'Reilly is a Member of Technical Staff (MTS) Software System Design Engineer in the Data Center GPU Software Solutions group at AMD. He works on porting and optimizing scientific computing and engineering applications for AMD GPUs. He holds a PhD in Geophysics from Stanford University and in Computational Mathematics from Linköping University, Sweden. His PhD research focused on high order numerical methods for seismic wave propagation containing frictional interfaces and fluid-filled cracks with applications to earthquake and volcano seismology, and the oil and gas industry. As a postdoc, he worked on numerical method development and analysis for seismic wave propagation with topography and implementing GPU stencil kernels targeting the OLCF Summit supercomputer. Some of Ossian's technical interests include high order numerical methods for partial differential equations, stencil-based and matrix-free methods, as well as GPU kernel development and optimization.

Rene Van Oostrum
Rene Van Oostrum

René van Oostrum is a Principal Member of Technical Staff (PMTS) Software Development Engineer in AMD Research. He obtained a PhD in Computer Science from Utrecht University and has a background in design and analysis of algorithms. In the past decade, he focused on implementation and performance tuning of GPU codes. He currently works on performance analysis tools for HPC workloads on AMD GPUs.

Joseph Greathouse
Joseph Greathouse

Joseph Greathouse is a Fellow in AMD's AI GPU Software group who focuses on the performance and architecture of AMD's Instinct accelerators and ROCm software stack. After earning his Ph.D. in computer science and engineering from the University of Michigan, Ann Arbor, he joined AMD Research and has been at the company ever since. Throughout all of these positions, his work has covered various aspects of hardware & software interactions. This includes architecting hardware features to accelerate software development, creating hardware-optimized library optimizations, and codesigning hardware and software to optimize for power.

Introduction – Matrix Compendium

The GPUOpen Matrix Compendium covers how matrices are used in 3D graphics and implementations in host code and shading languages. It’s a growing guide, so keep checking back!

Looking for a good place to get started with exploring GPUOpen?

AMD GPUOpen documentation

Explore our huge collection of detailed tutorials, sample code, presentations, and documentation to find answers to your graphics development questions.

AMD GPUOpen Effects - AMD FidelityFX technologies

Create wonder. No black boxes. Meet the AMD FidelityFX SDK!

AMD GPUOpen Performance Guides

The home of great performance and optimization advice for AMD RDNA™ 2 GPUs, AMD Ryzen™ CPUs, and so much more.

AMD GPUOpen Samples

Browse all our useful samples. Perfect for when you’re needing to get started, want to integrate one of our libraries, and much more.

AMD GPUOpen developer SDKs

Discover what our SDK technologies can offer you. Query hardware or software, manage memory, create rendering applications or machine learning, and much more!

AMD GPUOpen Developer Tools

Analyze, Optimize, Profile, Benchmark. We provide you with the developer tools you need to make sure your game is the best it can be!

Getting started: AMD GPUOpen software

New or fairly new to AMD’s tools, libraries, and effects? This is the best place to get started on GPUOpen!

AMD GPUOpen Getting Started Development and Performance

Looking for tips on getting started with developing and/or optimizing your game, whether on AMD hardware or generally? We’ve got you covered!