Skip to content

Commit a561506

Browse files
vsemenov368MrSidims0x12CC
authored
[Backport to 11] Implement SPV_KHR_bfloat16 extension (#3252)
The extension add translation from LLVM's bfloat type to OpTypeFloat %width% 16 %fp encoding% BFloat16KHR Mangling follows LLVM's rules for the type. Spec PR: KhronosGroup/SPIRV-Registry#323 --------- Signed-off-by: Sidorov, Dmitry <[email protected]> Co-authored-by: Aziz, Michael <[email protected]> Signed-off-by: Sidorov, Dmitry <[email protected]> Co-authored-by: Dmitry Sidorov <[email protected]> Co-authored-by: Aziz, Michael <[email protected]> Signed-off-by: Sidorov, Dmitry <[email protected]> Co-authored-by: Dmitry Sidorov <[email protected]> Co-authored-by: Aziz, Michael <[email protected]>
1 parent 65d9ca7 commit a561506

File tree

17 files changed

+211
-35
lines changed

17 files changed

+211
-35
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ EXT(SPV_INTEL_hw_thread_queries)
4949
EXT(SPV_EXT_relaxed_printf_string_address_space)
5050
EXT(SPV_INTEL_global_variable_decorations)
5151
EXT(SPV_INTEL_maximum_registers)
52+
EXT(SPV_KHR_bfloat16)

lib/SPIRV/Mangler/ManglingUtils.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ static const char *PrimitiveNames[PRIMITIVE_NUM] = {
2828
"half",
2929
"float",
3030
"double",
31+
"__bf16",
3132
"void",
3233
"...",
3334
"image1d_ro_t",
@@ -89,8 +90,7 @@ static const char *PrimitiveNames[PRIMITIVE_NUM] = {
8990
"intel_sub_group_avc_ime_result_single_reference_streamout_t",
9091
"intel_sub_group_avc_ime_result_dual_reference_streamout_t",
9192
"intel_sub_group_avc_ime_result_single_reference_streamin_t",
92-
"intel_sub_group_avc_ime_result_dual_reference_streamin_t"
93-
};
93+
"intel_sub_group_avc_ime_result_dual_reference_streamin_t"};
9494

9595
const char *MangledTypes[PRIMITIVE_NUM] = {
9696
"b", // BOOL
@@ -105,6 +105,7 @@ const char *MangledTypes[PRIMITIVE_NUM] = {
105105
"Dh", // HALF
106106
"f", // FLOAT
107107
"d", // DOUBLE
108+
"u6__bf16", // __BF16
108109
"v", // VOID
109110
"z", // VarArg
110111
"14ocl_image1d_ro", // PRIMITIVE_IMAGE1D_RO_T
@@ -157,21 +158,21 @@ const char *MangledTypes[PRIMITIVE_NUM] = {
157158
"i", // PRIMITIVE_MEMORY_ORDER
158159
"i", // PRIMITIVE_MEMORY_SCOPE
159160
#else
160-
"12memory_order", // PRIMITIVE_MEMORY_ORDER
161-
"12memory_scope", // PRIMITIVE_MEMORY_SCOPE
161+
"12memory_order", // PRIMITIVE_MEMORY_ORDER
162+
"12memory_scope", // PRIMITIVE_MEMORY_SCOPE
162163
#endif
163164
"37ocl_intel_sub_group_avc_mce_payload_t", // PRIMITIVE_SUB_GROUP_AVC_MCE_PAYLOAD_T
164165
"37ocl_intel_sub_group_avc_ime_payload_t", // PRIMITIVE_SUB_GROUP_AVC_IME_PAYLOAD_T
165166
"37ocl_intel_sub_group_avc_ref_payload_t", // PRIMITIVE_SUB_GROUP_AVC_REF_PAYLOAD_T
166167
"37ocl_intel_sub_group_avc_sic_payload_t", // PRIMITIVE_SUB_GROUP_AVC_SIC_PAYLOAD_T
167-
"36ocl_intel_sub_group_avc_mce_result_t", // PRIMITIVE_SUB_GROUP_AVC_MCE_RESULT_T
168-
"36ocl_intel_sub_group_avc_ime_result_t", // PRIMITIVE_SUB_GROUP_AVC_IME_RESULT_T
169-
"36ocl_intel_sub_group_avc_ref_result_t", // PRIMITIVE_SUB_GROUP_AVC_REF_RESULT_T
170-
"36ocl_intel_sub_group_avc_sic_result_t", // PRIMITIVE_SUB_GROUP_AVC_REF_RESULT_T
171-
"63ocl_intel_sub_group_avc_ime_result_single_reference_streamout_t", // PRIMITIVE_SUB_GROUP_AVC_IME_SINGLE_REF_STREAMOUT_T
172-
"61ocl_intel_sub_group_avc_ime_result_dual_reference_streamout_t", // PRIMITIVE_SUB_GROUP_AVC_IME_DUAL_REF_STREAMOUT_T
173-
"55ocl_intel_sub_group_avc_ime_single_reference_streamin_t", // PRIMITIVE_SUB_GROUP_AVC_IME_SINGLE_REF_STREAMIN_T
174-
"53ocl_intel_sub_group_avc_ime_dual_reference_streamin_t" // PRIMITIVE_SUB_GROUP_AVC_IME_DUAL_REF_STREAMIN_T
168+
"36ocl_intel_sub_group_avc_mce_result_t", // PRIMITIVE_SUB_GROUP_AVC_MCE_RESULT_T
169+
"36ocl_intel_sub_group_avc_ime_result_t", // PRIMITIVE_SUB_GROUP_AVC_IME_RESULT_T
170+
"36ocl_intel_sub_group_avc_ref_result_t", // PRIMITIVE_SUB_GROUP_AVC_REF_RESULT_T
171+
"36ocl_intel_sub_group_avc_sic_result_t", // PRIMITIVE_SUB_GROUP_AVC_REF_RESULT_T
172+
"63ocl_intel_sub_group_avc_ime_result_single_reference_streamout_t", // PRIMITIVE_SUB_GROUP_AVC_IME_SINGLE_REF_STREAMOUT_T
173+
"61ocl_intel_sub_group_avc_ime_result_dual_reference_streamout_t", // PRIMITIVE_SUB_GROUP_AVC_IME_DUAL_REF_STREAMOUT_T
174+
"55ocl_intel_sub_group_avc_ime_single_reference_streamin_t", // PRIMITIVE_SUB_GROUP_AVC_IME_SINGLE_REF_STREAMIN_T
175+
"53ocl_intel_sub_group_avc_ime_dual_reference_streamin_t" // PRIMITIVE_SUB_GROUP_AVC_IME_DUAL_REF_STREAMIN_T
175176
};
176177

177178
const char *ReadableAttribute[ATTR_NUM] = {
@@ -197,6 +198,7 @@ static const SPIRversion PrimitiveSupportedVersions[PRIMITIVE_NUM] = {
197198
SPIR12, // HALF
198199
SPIR12, // FLOAT
199200
SPIR12, // DOUBLE
201+
SPIR12, // __BF16
200202
SPIR12, // VOID
201203
SPIR12, // VarArg
202204
SPIR12, // PRIMITIVE_IMAGE1D_RO_T

lib/SPIRV/Mangler/ParameterType.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ enum TypePrimitiveEnum {
4545
PRIMITIVE_HALF,
4646
PRIMITIVE_FLOAT,
4747
PRIMITIVE_DOUBLE,
48+
PRIMITIVE_BFLOAT,
4849
PRIMITIVE_VOID,
4950
PRIMITIVE_VAR_ARG,
5051
PRIMITIVE_STRUCT_FIRST,

lib/SPIRV/SPIRVReader.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ Value *SPIRVToLLVM::mapFunction(SPIRVFunction *BF, Function *F) {
265265
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
266266
switch (T->getFloatBitWidth()) {
267267
case 16:
268+
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
269+
return Type::getBFloatTy(*Context);
268270
return Type::getHalfTy(*Context);
269271
case 32:
270272
return Type::getFloatTy(*Context);
@@ -467,6 +469,9 @@ std::string SPIRVToLLVM::transTypeToOCLTypeName(SPIRVType *T, bool IsSigned) {
467469
case OpTypeFloat:
468470
switch (T->getFloatBitWidth()) {
469471
case 16:
472+
if (static_cast<SPIRVTypeFloat *>(T)->getFloatingPointEncoding() ==
473+
FPEncodingBFloat16KHR)
474+
return "bfloat16";
470475
return "half";
471476
case 32:
472477
return "float";
@@ -1299,7 +1304,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
12991304
const llvm::fltSemantics *FS = nullptr;
13001305
switch (BT->getFloatBitWidth()) {
13011306
case 16:
1302-
FS = &APFloat::IEEEhalf();
1307+
FS =
1308+
(BT->isTypeFloat(16, FPEncodingBFloat16KHR) ? &APFloat::BFloat()
1309+
: &APFloat::IEEEhalf());
13031310
break;
13041311
case 32:
13051312
FS = &APFloat::IEEEsingle();

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,8 @@ static SPIR::RefParamType transTypeDesc(Type *Ty,
10661066
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_FLOAT));
10671067
if (Ty->isDoubleTy())
10681068
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_DOUBLE));
1069+
if (Ty->isBFloatTy())
1070+
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_BFLOAT));
10691071
if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
10701072
return SPIR::RefParamType(new SPIR::VectorType(
10711073
transTypeDesc(VecTy->getElementType(), Info), VecTy->getNumElements()));

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,16 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
326326
}
327327
}
328328

329+
if (T->isBFloatTy()) {
330+
BM->getErrorLog().checkError(
331+
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_bfloat16),
332+
SPIRVEC_RequiresExtension,
333+
"SPV_KHR_bfloat16\n"
334+
"NOTE: LLVM module contains bfloat type, translation of which "
335+
"requires this extension");
336+
return mapType(T, BM->addFloatType(16, FPEncodingBFloat16KHR));
337+
}
338+
329339
if (T->isFloatingPointTy())
330340
return mapType(T, BM->addFloatType(T->getPrimitiveSizeInBits()));
331341

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
197197
{CapabilitySubgroupAvcMotionEstimationINTEL});
198198
ADD_VEC_INIT(CapabilitySubgroupAvcMotionEstimationChromaINTEL,
199199
{CapabilitySubgroupAvcMotionEstimationIntraINTEL});
200+
ADD_VEC_INIT(CapabilityBFloat16DotProductKHR, {CapabilityBFloat16TypeKHR});
200201
}
201202

202203
template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {
@@ -432,7 +433,6 @@ template <> inline void SPIRVMap<Decoration, SPIRVCapVec>::init() {
432433
{internal::CapabilityMemoryAccessAliasingINTEL});
433434
ADD_VEC_INIT(internal::DecorationNoAliasINTEL,
434435
{internal::CapabilityMemoryAccessAliasingINTEL});
435-
436436
ADD_VEC_INIT(internal::DecorationHostAccessINTEL,
437437
{internal::CapabilityGlobalVariableDecorationsINTEL});
438438
ADD_VEC_INIT(internal::DecorationInitModeINTEL,

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,23 @@ class SPIRVBinary : public SPIRVInstTemplateBase {
652652
assert(0 && "Invalid op code!");
653653
}
654654
}
655+
SPIRVWord getRequiredSPIRVVersion() const override {
656+
if (isBinaryOpCode(OpCode))
657+
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_4);
658+
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_0);
659+
}
660+
SPIRVCapVec getRequiredCapability() const override {
661+
if (OpCode == OpDot) {
662+
const SPIRVType *OpTy = getValueType(Ops[0]);
663+
if (OpTy && OpTy->isTypeVector()) {
664+
OpTy = OpTy->getVectorComponentType();
665+
if (OpTy && OpTy->isTypeFloat(16, FPEncodingBFloat16KHR)) {
666+
return getVec(CapabilityBFloat16DotProductKHR);
667+
}
668+
}
669+
}
670+
return SPIRVInstruction::getRequiredCapability();
671+
}
655672
};
656673

657674
template <Op OC>

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ class SPIRVModuleImpl : public SPIRVModule {
228228
template <class T> T *addType(T *Ty);
229229
SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVConstant *) override;
230230
SPIRVTypeBool *addBoolType() override;
231-
SPIRVTypeFloat *addFloatType(unsigned BitWidth) override;
231+
SPIRVTypeFloat *addFloatType(unsigned BitWidth,
232+
unsigned FloatingPointEncoding) override;
232233
SPIRVTypeFunction *addFunctionType(SPIRVType *,
233234
const std::vector<SPIRVType *> &) override;
234235
SPIRVTypeInt *addIntegerType(unsigned BitWidth) override;
@@ -521,6 +522,8 @@ class SPIRVModuleImpl : public SPIRVModule {
521522
SPIRVCapMap CapMap;
522523
SPIRVUnknownStructFieldMap UnknownStructFieldMap;
523524
std::map<unsigned, SPIRVTypeInt *> IntTypeMap;
525+
SmallDenseMap<std::pair<unsigned, unsigned>, SPIRVTypeFloat *, 4>
526+
FloatTypeMap;
524527
std::map<unsigned, SPIRVConstant *> LiteralMap;
525528
std::vector<SPIRVExtInst *> DebugInstVec;
526529
std::vector<SPIRVModuleProcessed *> ModuleProcessedVec;
@@ -851,9 +854,15 @@ SPIRVTypeInt *SPIRVModuleImpl::addIntegerType(unsigned BitWidth) {
851854
return addType(Ty);
852855
}
853856

854-
SPIRVTypeFloat *SPIRVModuleImpl::addFloatType(unsigned BitWidth) {
855-
SPIRVTypeFloat *T = addType(new SPIRVTypeFloat(this, getId(), BitWidth));
856-
return T;
857+
SPIRVTypeFloat *SPIRVModuleImpl::addFloatType(unsigned BitWidth,
858+
unsigned FloatingPointEncoding) {
859+
auto Desc = std::make_pair(BitWidth, FloatingPointEncoding);
860+
auto Loc = FloatTypeMap.find(Desc);
861+
if (Loc != FloatTypeMap.end())
862+
return Loc->second;
863+
auto *Ty = new SPIRVTypeFloat(this, getId(), BitWidth, FloatingPointEncoding);
864+
FloatTypeMap[Desc] = Ty;
865+
return addType(Ty);
857866
}
858867

859868
SPIRVTypePointer *

lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class SPIRVModule {
223223
// Type creation functions
224224
virtual SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVConstant *) = 0;
225225
virtual SPIRVTypeBool *addBoolType() = 0;
226-
virtual SPIRVTypeFloat *addFloatType(unsigned) = 0;
226+
virtual SPIRVTypeFloat *addFloatType(unsigned, unsigned = FPEncodingMax) = 0;
227227
virtual SPIRVTypeFunction *
228228
addFunctionType(SPIRVType *, const std::vector<SPIRVType *> &) = 0;
229229
virtual SPIRVTypeImage *addImageType(SPIRVType *,

0 commit comments

Comments
 (0)