Skip to content

Commit

Permalink
Custom extensions to enable lowering for non-resource based functions (
Browse files Browse the repository at this point in the history
…#5579)

A downstream consumer requires the use of custom function lowering using
a json string as well as methods. For instance, a direct memory load not
based on a resource.
  • Loading branch information
bfavela authored Aug 23, 2023
1 parent f49bb3c commit ce8268a
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 65 deletions.
2 changes: 2 additions & 0 deletions include/dxc/HLSL/HLOperationLowerExtension.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace hlsl {
Pack, // Convert the vector arguments into structs.
Resource, // Convert return value to resource return and explode vectors.
Dxil, // Convert call to a dxil intrinsic.
Custom, // Custom lowering based on flexible json string.
};

// Create the lowering using the given strategy and custom codegen helper.
Expand Down Expand Up @@ -86,5 +87,6 @@ namespace hlsl {
llvm::Value *Resource(llvm::CallInst *CI);
llvm::Value *Dxil(llvm::CallInst *CI);
llvm::Value *CustomResource(llvm::CallInst *CI);
llvm::Value *Custom(llvm::CallInst *CI);
};
}
178 changes: 118 additions & 60 deletions lib/HLSL/HLOperationLowerExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
case 'p': return Strategy::Pack;
case 'm': return Strategy::Resource;
case 'd': return Strategy::Dxil;
case 'c': return Strategy::Custom;
default: break;
}
return Strategy::Unknown;
Expand All @@ -63,6 +64,7 @@ llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
case Strategy::Pack: return "p";
case Strategy::Resource: return "m"; // m for resource method
case Strategy::Dxil: return "d";
case Strategy::Custom: return "c";
default: break;
}
return "?";
Expand Down Expand Up @@ -91,6 +93,7 @@ llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
case Strategy::Pack: return Pack(CI);
case Strategy::Resource: return Resource(CI);
case Strategy::Dxil: return Dxil(CI);
case Strategy::Custom: return Custom(CI);
default: break;
}
return Unknown(CI);
Expand Down Expand Up @@ -373,6 +376,51 @@ Value *ExtensionLowering::Replicate(CallInst *CI) {
return replicate.Generate();
}

///////////////////////////////////////////////////////////////////////////////
// Helper functions
static VectorType* ConvertStructTypeToVectorType(Type* structTy) {
assert(structTy->isStructTy());
return VectorType::get(structTy->getStructElementType(0), structTy->getStructNumElements());
}

static Value* PackStructIntoVector(IRBuilder<>& builder, Value* strukt) {
Type* vecTy = ConvertStructTypeToVectorType(strukt->getType());
Value* packed = UndefValue::get(vecTy);

unsigned numElements = vecTy->getVectorNumElements();
for (unsigned i = 0; i < numElements; ++i) {
Value* element = builder.CreateExtractValue(strukt, i);
packed = builder.CreateInsertElement(packed, element, i);
}

return packed;
}

static StructType* ConvertVectorTypeToStructType(Type* vecTy) {
assert(vecTy->isVectorTy());
Type* elementTy = vecTy->getVectorElementType();
unsigned numElements = vecTy->getVectorNumElements();
SmallVector<Type*, 4> elements;
for (unsigned i = 0; i < numElements; ++i)
elements.push_back(elementTy);

return StructType::get(vecTy->getContext(), elements);
}


static Value* PackVectorIntoStruct(IRBuilder<>& builder, Value* vec) {
StructType* structTy = ConvertVectorTypeToStructType(vec->getType());
Value* packed = UndefValue::get(structTy);

unsigned numElements = structTy->getStructNumElements();
for (unsigned i = 0; i < numElements; ++i) {
Value* element = builder.CreateExtractElement(vec, i);
packed = builder.CreateInsertValue(packed, element, { i });
}

return packed;
}

///////////////////////////////////////////////////////////////////////////////
// Packed Lowering.
class PackCall {
Expand All @@ -389,17 +437,6 @@ class PackCall {
Value *result = CreateCall(args);
return UnpackResult(result);
}

static StructType *ConvertVectorTypeToStructType(Type *vecTy) {
assert(vecTy->isVectorTy());
Type *elementTy = vecTy->getVectorElementType();
unsigned numElements = vecTy->getVectorNumElements();
SmallVector<Type *, 4> elements;
for (unsigned i = 0; i < numElements; ++i)
elements.push_back(elementTy);

return StructType::get(vecTy->getContext(), elements);
}

private:
CallInst *m_CI;
Expand All @@ -425,37 +462,6 @@ class PackCall {
}
return result;
}

static VectorType *ConvertStructTypeToVectorType(Type *structTy) {
assert(structTy->isStructTy());
return VectorType::get(structTy->getStructElementType(0), structTy->getStructNumElements());
}

static Value *PackVectorIntoStruct(IRBuilder<> &builder, Value *vec) {
StructType *structTy = ConvertVectorTypeToStructType(vec->getType());
Value *packed = UndefValue::get(structTy);

unsigned numElements = structTy->getStructNumElements();
for (unsigned i = 0; i < numElements; ++i) {
Value *element = builder.CreateExtractElement(vec, i);
packed = builder.CreateInsertValue(packed, element, { i });
}

return packed;
}

static Value *PackStructIntoVector(IRBuilder<> &builder, Value *strukt) {
Type *vecTy = ConvertStructTypeToVectorType(strukt->getType());
Value *packed = UndefValue::get(vecTy);

unsigned numElements = vecTy->getVectorNumElements();
for (unsigned i = 0; i < numElements; ++i) {
Value *element = builder.CreateExtractValue(strukt, i);
packed = builder.CreateInsertElement(packed, element, i);
}

return packed;
}
};

class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
Expand All @@ -468,7 +474,7 @@ class PackedFunctionTypeTranslator : public FunctionTypeTranslator {

Type *TranslateIfVector(Type *ty) {
if (ty->isVectorTy())
ty = PackCall::ConvertVectorTypeToStructType(ty);
ty = ConvertVectorTypeToStructType(ty);
return ty;
}
};
Expand Down Expand Up @@ -713,10 +719,30 @@ Value *ExtensionLowering::Resource(CallInst *CI) {
// dxil: @MyTextureOp(17, handle, a.x, a.y, undef, c.x, c.y)
//
//
class CustomResourceLowering
class CustomLowering
{
public:
CustomResourceLowering(StringRef LoweringInfo, CallInst *CI, HLResourceLookup &ResourceLookup)
CustomLowering(StringRef LoweringInfo, CallInst* CI)
{
// Parse lowering info json format.
std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap =
ParseLoweringInfo(LoweringInfo, CI->getContext());

// Find the default lowering kind
std::vector<DxilArgInfo> *pArgInfo = nullptr;
if (LoweringInfoMap.count(m_DefaultInfoName))
{
pArgInfo = &LoweringInfoMap.at(m_DefaultInfoName);
}
else
{
ThrowExtensionError("Unable to find lowering info for custom function");
}
// Don't explode vectors for custom functions
GenerateLoweredArgs(CI, *pArgInfo);
}

CustomLowering(StringRef LoweringInfo, CallInst *CI, HLResourceLookup &ResourceLookup)
{
// Parse lowering info json format.
std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap =
Expand All @@ -732,15 +758,14 @@ class CustomResourceLowering
std::string Name(pName);

// Select lowering info to use based on resource kind.
const char *DefaultInfoName = "default";
std::vector<DxilArgInfo> *pArgInfo = nullptr;
if (LoweringInfoMap.count(Name))
{
pArgInfo = &LoweringInfoMap.at(Name);
}
else if (LoweringInfoMap.count(DefaultInfoName))
else if (LoweringInfoMap.count(m_DefaultInfoName))
{
pArgInfo = &LoweringInfoMap.at(DefaultInfoName);
pArgInfo = &LoweringInfoMap.at(m_DefaultInfoName);
}
else
{
Expand Down Expand Up @@ -775,6 +800,7 @@ class CustomResourceLowering
{"?half", Type::getHalfTy(Ctx)},
{"?i8", Type::getInt8Ty(Ctx)},
{"?i16", Type::getInt16Ty(Ctx)},
{"?i1", Type::getInt1Ty(Ctx)},
};
DXASSERT(m_OptionalTypes.empty(), "Init should only be called once");
m_OptionalTypes.clear();
Expand Down Expand Up @@ -965,6 +991,13 @@ class CustomResourceLowering
}
}
}
else
{
// If the vector isn't exploded, use structs for DXIL Intrinsics
if (Arg->getType()->isVectorTy()) {
Arg = PackVectorIntoStruct(builder, Arg);
}
}

m_LoweredArgs.push_back(Arg);
}
Expand All @@ -984,27 +1017,28 @@ class CustomResourceLowering

std::vector<Value *> m_LoweredArgs;
SmallVector<OptionalTypeSpec, 5> m_OptionalTypes;
const char* m_DefaultInfoName = "default";
};

// Boilerplate to reuse exising logic as much as possible.
// We just want to overload GetFunctionType here.
class CustomResourceFunctionTranslator : public FunctionTranslator {
class CustomFunctionTranslator : public FunctionTranslator {
public:
static Function *GetLoweredFunction(
const CustomResourceLowering &CustomLowering,
ResourceFunctionTypeTranslator &typeTranslator,
const CustomLowering &CustomLowering,
FunctionTypeTranslator &typeTranslator,
CallInst *CI,
ExtensionLowering &lower
)
{
CustomResourceFunctionTranslator T(CustomLowering, typeTranslator, lower);
CustomFunctionTranslator T(CustomLowering, typeTranslator, lower);
return T.FunctionTranslator::GetLoweredFunction(CI);
}

private:
CustomResourceFunctionTranslator(
const CustomResourceLowering &CustomLowering,
ResourceFunctionTypeTranslator &typeTranslator,
CustomFunctionTranslator(
const CustomLowering &CustomLowering,
FunctionTypeTranslator &typeTranslator,
ExtensionLowering &lower
)
: FunctionTranslator(typeTranslator, lower)
Expand All @@ -1023,15 +1057,15 @@ class CustomResourceFunctionTranslator : public FunctionTranslator {
}

private:
const CustomResourceLowering &m_CustomLowering;
const CustomLowering &m_CustomLowering;
};

// Boilerplate to reuse exising logic as much as possible.
// We just want to overload Generate here.
class CustomResourceMethodCall : public ResourceMethodCall
{
public:
CustomResourceMethodCall(CallInst *CI, const CustomResourceLowering &CustomLowering)
CustomResourceMethodCall(CallInst *CI, const CustomLowering &CustomLowering)
: ResourceMethodCall(CI)
, m_CustomLowering(CustomLowering)
{}
Expand All @@ -1043,14 +1077,14 @@ class CustomResourceMethodCall : public ResourceMethodCall
}

private:
const CustomResourceLowering &m_CustomLowering;
const CustomLowering &m_CustomLowering;
};

// Support custom lowering logic for resource functions.
Value *ExtensionLowering::CustomResource(CallInst *CI) {
CustomResourceLowering CustomLowering(m_extraStrategyInfo, CI, m_hlResourceLookup);
CustomLowering CustomLowering(m_extraStrategyInfo, CI, m_hlResourceLookup);
ResourceFunctionTypeTranslator ResourceTypeTranslator(m_hlslOp);
Function *ResourceFunction = CustomResourceFunctionTranslator::GetLoweredFunction(
Function *ResourceFunction = CustomFunctionTranslator::GetLoweredFunction(
CustomLowering,
ResourceTypeTranslator,
CI,
Expand All @@ -1064,6 +1098,30 @@ Value *ExtensionLowering::CustomResource(CallInst *CI) {
return Result;
}

// Support custom lowering logic for arbitrary functions.
Value *ExtensionLowering::Custom(CallInst *CI) {
CustomLowering CustomLowering(m_extraStrategyInfo, CI);
PackedFunctionTypeTranslator TypeTranslator;
Function *CustomFunction = CustomFunctionTranslator::GetLoweredFunction(
CustomLowering,
TypeTranslator,
CI,
*this
);
if (!CustomFunction)
return NoTranslation(CI);

IRBuilder<> builder(CI);
Value* result = builder.CreateCall(CustomFunction, CustomLowering.GetLoweredArgs());

// Arbitrary functions will expect vectors, not structs
if (CustomFunction->getReturnType()->isStructTy()) {
return PackStructIntoVector(builder, result);
}

return result;
}

///////////////////////////////////////////////////////////////////////////////
// Dxil Lowering.

Expand Down
Loading

0 comments on commit ce8268a

Please sign in to comment.