Skip to content

Commit 53bdcfd

Browse files
[Substrait] Add support for SI1 and nested tuple types.
1 parent 9035088 commit 53bdcfd

File tree

5 files changed

+157
-16
lines changed

5 files changed

+157
-16
lines changed

include/structured/Dialect/Substrait/IR/SubstraitTypes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Substrait_Type<string name, string typeMnemonic, list<Trait> traits = []>
2424
// TODO(ingomueller): Add the other low-hanging fruits here.
2525
def Substrait_AtomicTypes {
2626
list<Type> types = [
27+
SI1, // Boolean
2728
SI32 // I32
2829
];
2930
}

lib/Target/SubstraitPB/Export.cpp

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,51 @@ FailureOr<std::unique_ptr<Rel>> exportOperation(RelOpInterface op);
4949

5050
FailureOr<std::unique_ptr<proto::Type>> exportType(Location loc,
5151
mlir::Type mlirType) {
52-
// TODO(ingomueller): Support other types.
53-
auto si32 = IntegerType::get(mlirType.getContext(), 32, IntegerType::Signed);
54-
if (mlirType != si32)
55-
return emitError(loc) << "could not export unsupported type " << mlirType;
52+
MLIRContext *context = mlirType.getContext();
53+
54+
// Handle SI1.
55+
auto si1 = IntegerType::get(context, 1, IntegerType::Signed);
56+
if (mlirType == si1) {
57+
// TODO(ingomueller): support other nullability modes.
58+
auto i1Type = std::make_unique<proto::Type::Boolean>();
59+
i1Type->set_nullability(
60+
Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED);
61+
62+
auto type = std::make_unique<proto::Type>();
63+
type->set_allocated_bool_(i1Type.release());
64+
return std::move(type);
65+
}
5666

57-
// TODO(ingomueller): support other nullability modes.
58-
auto i32Type = std::make_unique<proto::Type::I32>();
59-
i32Type->set_nullability(
60-
Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED);
67+
// Handle SI32.
68+
auto si32 = IntegerType::get(context, 32, IntegerType::Signed);
69+
if (mlirType == si32) {
70+
// TODO(ingomueller): support other nullability modes.
71+
auto i32Type = std::make_unique<proto::Type::I32>();
72+
i32Type->set_nullability(
73+
Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED);
74+
75+
auto type = std::make_unique<proto::Type>();
76+
type->set_allocated_i32(i32Type.release());
77+
return std::move(type);
78+
}
6179

62-
auto type = std::make_unique<proto::Type>();
63-
type->set_allocated_i32(i32Type.release());
80+
if (auto tupleType = llvm::dyn_cast<TupleType>(mlirType)) {
81+
auto structType = std::make_unique<proto::Type::Struct>();
82+
for (mlir::Type fieldType : tupleType.getTypes()) {
83+
// Convert field type recursively.
84+
FailureOr<std::unique_ptr<proto::Type>> type = exportType(loc, fieldType);
85+
if (failed(type))
86+
return failure();
87+
*structType->add_types() = *type.value();
88+
}
89+
90+
auto type = std::make_unique<proto::Type>();
91+
type->set_allocated_struct_(structType.release());
92+
return std::move(type);
93+
}
6494

65-
return std::move(type);
95+
// TODO(ingomueller): Support other types.
96+
return emitError(loc) << "could not export unsupported type " << mlirType;
6697
}
6798

6899
FailureOr<std::unique_ptr<Rel>> exportOperation(CrossOp op) {

lib/Target/SubstraitPB/Import.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,36 @@ DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface)
5353

5454
static mlir::FailureOr<mlir::Type> importType(MLIRContext *context,
5555
const proto::Type &type) {
56-
// TODO(ingomueller): Support more types.
57-
if (!type.has_i32()) {
56+
57+
proto::Type::KindCase kind_case = type.kind_case();
58+
switch (kind_case) {
59+
case proto::Type::kBool: {
60+
return IntegerType::get(context, 1, IntegerType::Signed);
61+
}
62+
case proto::Type::kI32: {
63+
return IntegerType::get(context, 32, IntegerType::Signed);
64+
}
65+
case proto::Type::kStruct: {
66+
const proto::Type::Struct &structType = type.struct_();
67+
llvm::SmallVector<mlir::Type> fieldTypes;
68+
fieldTypes.reserve(structType.types_size());
69+
for (const proto::Type &fieldType : structType.types()) {
70+
FailureOr<mlir::Type> mlirFieldType = importType(context, fieldType);
71+
if (failed(mlirFieldType))
72+
return failure();
73+
fieldTypes.push_back(mlirFieldType.value());
74+
}
75+
return TupleType::get(context, fieldTypes);
76+
}
77+
// TODO(ingomueller): Support more types.
78+
default: {
5879
auto loc = UnknownLoc::get(context);
5980
const pb::FieldDescriptor *desc =
60-
proto::Type::GetDescriptor()->FindFieldByNumber(type.kind_case());
81+
proto::Type::GetDescriptor()->FindFieldByNumber(kind_case);
6182
return emitError(loc) << "could not import unsupported type "
6283
<< desc->name();
6384
}
64-
return IntegerType::get(context, 32, IntegerType::Signed);
85+
}
6586
}
6687

6788
static mlir::FailureOr<CrossOp> importCross(ImplicitLocOpBuilder builder,

test/Target/SubstraitPB/Export/types.mlir

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
// RUN: structured-translate -substrait-to-protobuf %s \
55
// RUN: --split-input-file --output-split-marker="# -----" \
66
// RUN: | structured-translate -protobuf-to-substrait \
7-
// RUN: --split-input-file="# -----" \
7+
// RUN: --split-input-file="# -----" --output-split-marker="// ""-----" \
88
// RUN: | structured-translate -substrait-to-protobuf \
9+
// RUN: --split-input-file --output-split-marker="# -----" \
910
// RUN: | FileCheck %s
1011

1112
// CHECK-LABEL: relations {
@@ -30,3 +31,39 @@ substrait.plan version 0 : 42 : 1 {
3031
yield %0 : tuple<si32>
3132
}
3233
}
34+
35+
// -----
36+
37+
// CHECK-LABEL: relations {
38+
// CHECK-NEXT: rel {
39+
// CHECK-NEXT: read {
40+
// CHECK: base_schema {
41+
// CHECK-NEXT: names: "a"
42+
// CHECK-NEXT: names: "b"
43+
// CHECK-NEXT: names: "c"
44+
// CHECK-NEXT: struct {
45+
// CHECK-NEXT: types {
46+
// CHECK-NEXT: bool {
47+
// CHECK-NEXT: nullability: NULLABILITY_REQUIRED
48+
// CHECK-NEXT: }
49+
// CHECK-NEXT: }
50+
// CHECK-NEXT: types {
51+
// CHECK-NEXT: struct {
52+
// CHECK-NEXT: types {
53+
// CHECK-NEXT: bool {
54+
// CHECK-NEXT: nullability: NULLABILITY_REQUIRED
55+
// CHECK-NEXT: }
56+
// CHECK-NEXT: }
57+
// CHECK-NEXT: }
58+
// CHECK-NEXT: }
59+
// CHECK-NEXT: nullability: NULLABILITY_REQUIRED
60+
// CHECK-NEXT: }
61+
// CHECK-NEXT: }
62+
// CHECK-NEXT: named_table {
63+
64+
substrait.plan version 0 : 42 : 1 {
65+
relation {
66+
%0 = named_table @t1 as ["a", "b", "c"] : tuple<si1, tuple<si1>>
67+
yield %0 : tuple<si1, tuple<si1>>
68+
}
69+
}

test/Target/SubstraitPB/Import/types.textpb

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# RUN: structured-translate -protobuf-to-substrait %s \
2+
# RUN: --split-input-file="# ""-----" \
23
# RUN: | FileCheck %s
34

45
# RUN: structured-translate -protobuf-to-substrait %s \
6+
# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \
57
# RUN: | structured-translate -substrait-to-protobuf \
8+
# RUN: --split-input-file --output-split-marker="# ""-----" \
69
# RUN: | structured-translate -protobuf-to-substrait \
10+
# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \
711
# RUN: | FileCheck %s
812

913
# CHECK: substrait.plan
@@ -39,3 +43,50 @@ version {
3943
minor_number: 42
4044
patch_number: 1
4145
}
46+
47+
# -----
48+
49+
# CHECK: substrait.plan
50+
# CHECK-NEXT: relation
51+
# CHECK-NEXT: named_table
52+
# CHECK-SAME: : tuple<si1, tuple<si1>>
53+
54+
relations {
55+
rel {
56+
read {
57+
common {
58+
direct {
59+
}
60+
}
61+
base_schema {
62+
names: "a"
63+
names: "b"
64+
names: "c"
65+
struct {
66+
types {
67+
bool {
68+
nullability: NULLABILITY_REQUIRED
69+
}
70+
}
71+
types {
72+
struct {
73+
types {
74+
bool {
75+
nullability: NULLABILITY_REQUIRED
76+
}
77+
}
78+
}
79+
}
80+
nullability: NULLABILITY_REQUIRED
81+
}
82+
}
83+
named_table {
84+
names: "t1"
85+
}
86+
}
87+
}
88+
}
89+
version {
90+
minor_number: 42
91+
patch_number: 1
92+
}

0 commit comments

Comments
 (0)