Skip to content

Commit c299228

Browse files
authored
Add header file to define cooperative matrices in spir-v (#6720)
This pr will introduce the HLSL standard header files. This will become the a way to add new extensions to HLSL without having to make modifications to the compiler. This first example is a fairly complex example that demonstrates how to do a few different things: 1. Trying to create a simple class interface that allows the compiler to naturally enforce the validation rules. In this case, we might be more strict than the spir-v validation, but I believe this is still usable. 2. How to create a builtin that can be expanded by the spir-v backend. The OpCooperativeMatrixLengthKHR instruction does not have an interface that is natural in a language like HLSL. However, we can define a function that is, and have the backend, make small adjustments. These cases should be avoided as much as possible.
1 parent 88fcc1b commit c299228

19 files changed

+1427
-4
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
// Copyright (c) 2024 Google LLC
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_
8+
#define _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_
9+
10+
#if __SPIRV_MAJOR_VERSION__ == 1 && __SPIRV_MINOR_VERSION__ < 6
11+
#error "CooperativeMatrix requires a minimum of SPIR-V 1.6"
12+
#endif
13+
14+
#include "vk/spirv.h"
15+
16+
namespace vk {
17+
namespace khr {
18+
19+
// The base cooperative matrix class. The template arguments correspond to the
20+
// operands in the OpTypeCooperativeMatrixKHR instruction.
21+
template <typename ComponentType, Scope scope, uint rows, uint columns,
22+
CooperativeMatrixUse use>
23+
class CooperativeMatrix {
24+
template <class NewComponentType>
25+
CooperativeMatrix<NewComponentType, scope, rows, columns, use> cast();
26+
27+
// Apply OpSNegate or OFNegate, depending on ComponentType, in a element by
28+
// element manner.
29+
CooperativeMatrix negate();
30+
31+
// Apply OpIAdd or OFAdd, depending on ComponentType, in a element by element
32+
// manner.
33+
CooperativeMatrix operator+(CooperativeMatrix other);
34+
35+
// Apply OpISub or OFSub, depending on ComponentType, in a element by element
36+
// manner.
37+
CooperativeMatrix operator-(CooperativeMatrix other);
38+
39+
// Apply OpIMul or OFMul, depending on ComponentType, in a element by element
40+
// manner.
41+
CooperativeMatrix operator*(CooperativeMatrix other);
42+
43+
// Apply OpSDiv, OpUDiv or OFDiv, depending on ComponentType, in a element by
44+
// element manner.
45+
CooperativeMatrix operator/(CooperativeMatrix other);
46+
47+
// Apply OpMatrixTimesScalar in a element by element manner.
48+
CooperativeMatrix operator*(ComponentType scalar);
49+
50+
// Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
51+
// data using the given memory layout, stride, and memory access operands.
52+
// `NonPrivatePointer` and `MakePointerAvailable` with the workgroup scope
53+
// will be added to the memory access operands to make the memory coherent.
54+
//
55+
// This function uses a SPIR-V pointer because HLSL does not allow groupshared
56+
// memory object to be passed by reference. The pointer is a hack to get
57+
// around that.
58+
//
59+
// The layout and stride will be passed to the SPIR-V instruction as is. The
60+
// precise meaning can be found in the specification for
61+
// SPV_KHR_cooperative_matrix.
62+
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
63+
class Type>
64+
void Store(WorkgroupSpirvPointer<Type> data, uint32_t stride);
65+
66+
// Same as above, but uses MemoryAccessMaskNone for the memory access
67+
// operands.
68+
template <CooperativeMatrixLayout layout, class Type>
69+
void Store(WorkgroupSpirvPointer<Type> data, uint32_t stride) {
70+
Store<MemoryAccessMaskNone, layout>(data, stride);
71+
}
72+
73+
// Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
74+
// data[index] using the given memory layout, stride, and memory access
75+
// operands. The layout and stride will be passed to the SPIR-V instruction as
76+
// is. The precise meaning can be found in the specification for
77+
// SPV_KHR_cooperative_matrix.
78+
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
79+
class Type>
80+
void Store(RWStructuredBuffer<Type> data, uint32_t index, uint32_t stride);
81+
82+
// Same as above, but uses MemoryAccessMaskNone for the memory access
83+
// operands.
84+
template <CooperativeMatrixLayout layout, class Type>
85+
void Store(RWStructuredBuffer<Type> data, uint32_t index, uint32_t stride) {
86+
Store<MemoryAccessMaskNone, layout>(data, index, stride);
87+
}
88+
89+
// Store the cooperative matrix using OpCooperativeMatrixStoreKHR to
90+
// data[index] using the given memory layout, stride, and memory access
91+
// operands. `NonPrivatePointer` and `MakePointerAvailable` with the
92+
// QueueFamily scope will be added to the memory access operands to make the
93+
// memory coherent.
94+
//
95+
// The layout and stride will be passed to the SPIR-V instruction as is. The
96+
// precise meaning can be found in the specification for
97+
// SPV_KHR_cooperative_matrix.
98+
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
99+
class Type>
100+
void CoherentStore(globallycoherent RWStructuredBuffer<Type> data,
101+
uint32_t index, uint32_t stride);
102+
103+
// Same as above, but uses MemoryAccessMaskNone for the memory access operands
104+
// template argument.
105+
template <CooperativeMatrixLayout layout, class Type>
106+
void CoherentStore(globallycoherent RWStructuredBuffer<Type> data,
107+
uint32_t index, uint32_t stride) {
108+
CoherentStore<MemoryAccessMaskNone, layout>(data, index, stride);
109+
}
110+
111+
// Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
112+
// data using the given memory layout, stride, and memory access operands.
113+
// `NonPrivatePointer` and `MakePointerVisible` with the workgroup scope
114+
// will be added to the memory access operands to make the memory coherent.
115+
//
116+
// This function uses a SPIR-V pointer because HLSL does not allow groupshared
117+
// memory object to be passed by reference. The pointer is a hack to get
118+
// around that.
119+
//
120+
// The layout and stride will be passed to the SPIR-V instruction as is. The
121+
// precise meaning can be found in the specification for
122+
// SPV_KHR_cooperative_matrix.
123+
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
124+
class Type>
125+
static CooperativeMatrix Load(WorkgroupSpirvPointer<Type> data,
126+
uint32_t stride);
127+
128+
// Same as above, but uses MemoryAccessMaskNone for the memory access
129+
// operands.
130+
template <CooperativeMatrixLayout layout, class Type>
131+
static CooperativeMatrix Load(WorkgroupSpirvPointer<Type> data,
132+
uint32_t stride) {
133+
return Load<MemoryAccessMaskNone, layout>(data, stride);
134+
}
135+
136+
// Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
137+
// data[index] using the given memory layout, stride, and memory access
138+
// operands.
139+
//
140+
// The layout and stride will be passed to the SPIR-V instruction as is. The
141+
// precise meaning can be found in the specification for
142+
// SPV_KHR_cooperative_matrix.
143+
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
144+
class Type>
145+
static CooperativeMatrix Load(RWStructuredBuffer<Type> data, uint32_t index,
146+
uint32_t stride);
147+
148+
// Same as above, but uses MemoryAccessMaskNone for the memory access
149+
// operands.
150+
template <CooperativeMatrixLayout layout, class Type>
151+
static CooperativeMatrix Load(RWStructuredBuffer<Type> data, uint32_t index,
152+
uint32_t stride) {
153+
return Load<MemoryAccessMaskNone, layout>(data, index, stride);
154+
}
155+
156+
// Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
157+
// data[index] using the given memory layout, stride, and memory access
158+
// operands. `NonPrivatePointer` and `MakePointerVisible` with the QueueFamily
159+
// scope will be added to the memory access operands to make the memory
160+
// coherent.
161+
//
162+
//
163+
// The layout and stride will be passed to the SPIR-V instruction as is. The
164+
// precise meaning can be found in the specification for
165+
// SPV_KHR_cooperative_matrix.
166+
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
167+
class Type>
168+
static CooperativeMatrix
169+
CoherentLoad(globallycoherent RWStructuredBuffer<Type> data, uint32_t index,
170+
uint32_t stride);
171+
172+
// Same as above, but uses MemoryAccessMaskNone for the memory access operands
173+
// template argument.
174+
template <CooperativeMatrixLayout layout, class Type>
175+
static CooperativeMatrix
176+
CoherentLoad(globallycoherent RWStructuredBuffer<Type> data, uint32_t index,
177+
uint32_t stride) {
178+
return CoherentLoad<MemoryAccessMaskNone, layout>(data, index, stride);
179+
}
180+
181+
// Loads a cooperative matrix using OpCooperativeMatrixLoadKHR from
182+
// data[index] using the given memory layout, stride, and memory access
183+
// operands. No memory access bits are added to the operands. Since the memory
184+
// is readonly, there should be no need.
185+
//
186+
// The layout and stride will be passed to the SPIR-V instruction as is. The
187+
// precise meaning can be found in the specification for
188+
// SPV_KHR_cooperative_matrix.
189+
template <uint32_t memoryAccessOperands, CooperativeMatrixLayout layout,
190+
class Type>
191+
static CooperativeMatrix Load(StructuredBuffer<Type> data, uint32_t index,
192+
uint32_t stride);
193+
194+
// Same as above, but uses MemoryAccessMaskNone for the memory access
195+
// operands.
196+
template <CooperativeMatrixLayout layout, class Type>
197+
static CooperativeMatrix Load(StructuredBuffer<Type> data, uint32_t index,
198+
uint32_t stride) {
199+
return Load<MemoryAccessMaskNone, layout>(data, index, stride);
200+
}
201+
202+
// Constructs a cooperative matrix with all values initialized to v. Note that
203+
// all threads in scope must have the same value for v.
204+
static CooperativeMatrix Splat(ComponentType v);
205+
206+
// Returns the result of OpCooperativeMatrixLengthKHR on the current type.
207+
static uint32_t GetLength();
208+
209+
// Functions to access the elements of the cooperative matrix. The index must
210+
// be less than GetLength().
211+
void Set(ComponentType value, uint32_t index);
212+
ComponentType Get(uint32_t index);
213+
214+
static const bool hasSignedIntegerComponentType =
215+
(ComponentType(0) - ComponentType(1) < ComponentType(0));
216+
217+
// clang-format off
218+
using SpirvMatrixType = vk::SpirvOpaqueType<
219+
/* OpTypeCooperativeMatrixKHR */ 4456, ComponentType,
220+
vk::integral_constant<uint, scope>, vk::integral_constant<uint, rows>,
221+
vk::integral_constant<uint, columns>, vk::integral_constant<uint, use> >;
222+
223+
[[vk::ext_extension("SPV_KHR_cooperative_matrix")]]
224+
[[vk::ext_capability(/* CooperativeMatrixKHRCapability */ 6022)]]
225+
[[vk::ext_capability(/* VulkanMemoryModel */ 5345)]]
226+
SpirvMatrixType _matrix;
227+
// clang-format on
228+
};
229+
230+
// Cooperative matrix that can be used in the "a" position of a multiply add
231+
// instruction (r = (a * b) + c).
232+
template <typename ComponentType, Scope scope, uint rows, uint columns>
233+
using CooperativeMatrixA =
234+
CooperativeMatrix<ComponentType, scope, rows, columns,
235+
CooperativeMatrixUseMatrixAKHR>;
236+
237+
// Cooperative matrix that can be used in the "b" position of a multiply add
238+
// instruction (r = (a * b) + c).
239+
template <typename ComponentType, Scope scope, uint rows, uint columns>
240+
using CooperativeMatrixB =
241+
CooperativeMatrix<ComponentType, scope, rows, columns,
242+
CooperativeMatrixUseMatrixBKHR>;
243+
244+
// Cooperative matrix that can be used in the "r" and "c" position of a multiply
245+
// add instruction (r = (a * b) + c).
246+
template <typename ComponentType, Scope scope, uint rows, uint columns>
247+
using CooperativeMatrixAccumulator =
248+
CooperativeMatrix<ComponentType, scope, rows, columns,
249+
CooperativeMatrixUseMatrixAccumulatorKHR>;
250+
251+
// Returns the result of OpCooperativeMatrixMulAddKHR when applied to a, b, and
252+
// c. The cooperative matrix operands are inferred, with the
253+
// SaturatingAccumulationKHR bit not set.
254+
template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
255+
CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
256+
cooperativeMatrixMultiplyAdd(
257+
CooperativeMatrixA<ComponentType, scope, rows, K> a,
258+
CooperativeMatrixB<ComponentType, scope, K, columns> b,
259+
CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c);
260+
261+
// Returns the result of OpCooperativeMatrixMulAddKHR when applied to a, b, and
262+
// c. The cooperative matrix operands are inferred, with the
263+
// SaturatingAccumulationKHR bit set.
264+
template <typename ComponentType, Scope scope, uint rows, uint columns, uint K>
265+
CooperativeMatrixAccumulator<ComponentType, scope, rows, columns>
266+
cooperativeMatrixSaturatingMultiplyAdd(
267+
CooperativeMatrixA<ComponentType, scope, rows, K> a,
268+
CooperativeMatrixB<ComponentType, scope, K, columns> b,
269+
CooperativeMatrixAccumulator<ComponentType, scope, rows, columns> c);
270+
271+
} // namespace khr
272+
} // namespace vk
273+
274+
#include "cooperative_matrix.impl"
275+
#endif // _HLSL_VK_KHR_COOPERATIVE_MATRIX_H_

0 commit comments

Comments
 (0)