Skip to content

Commit ff5d386

Browse files
committed
Fix WASM C ABI for raw unions
1 parent 8e782d9 commit ff5d386

File tree

3 files changed

+134
-15
lines changed

3 files changed

+134
-15
lines changed

src/llvm_abi.cpp

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,7 +1313,7 @@ namespace lbAbiWasm {
13131313
registers/arguments if possible rather than by pointer.
13141314
*/
13151315
gb_internal Array<lbArgType> compute_arg_types(LLVMContextRef c, LLVMTypeRef *arg_types, unsigned arg_count, ProcCallingConvention calling_convention, Type *original_type);
1316-
gb_internal LB_ABI_COMPUTE_RETURN_TYPE(compute_return_type);
1316+
gb_internal lbArgType compute_return_type(lbFunctionType *ft, LLVMContextRef c, LLVMTypeRef return_type, bool return_is_defined, bool return_is_tuple, Type* original_type);
13171317

13181318
enum {MAX_DIRECT_STRUCT_SIZE = 32};
13191319

@@ -1323,7 +1323,9 @@ namespace lbAbiWasm {
13231323
ft->ctx = c;
13241324
ft->calling_convention = calling_convention;
13251325
ft->args = compute_arg_types(c, arg_types, arg_count, calling_convention, original_type);
1326-
ft->ret = compute_return_type(ft, c, return_type, return_is_defined, return_is_tuple);
1326+
1327+
GB_ASSERT(original_type->kind == Type_Proc);
1328+
ft->ret = compute_return_type(ft, c, return_type, return_is_defined, return_is_tuple, original_type->Proc.results);
13271329
return ft;
13281330
}
13291331

@@ -1359,7 +1361,7 @@ namespace lbAbiWasm {
13591361
return false;
13601362
}
13611363

1362-
gb_internal bool type_can_be_direct(LLVMTypeRef type, ProcCallingConvention calling_convention) {
1364+
gb_internal bool type_can_be_direct(LLVMTypeRef type, Type *original_type, ProcCallingConvention calling_convention) {
13631365
LLVMTypeKind kind = LLVMGetTypeKind(type);
13641366
i64 sz = lb_sizeof(type);
13651367
if (sz == 0) {
@@ -1372,9 +1374,21 @@ namespace lbAbiWasm {
13721374
return false;
13731375
} else if (kind == LLVMStructTypeKind) {
13741376
unsigned count = LLVMCountStructElementTypes(type);
1377+
1378+
// NOTE(laytan): raw unions are always structs with 1 field in LLVM, need to check our own def.
1379+
Type *bt = base_type(original_type);
1380+
if (bt->kind == Type_Struct && bt->Struct.is_raw_union) {
1381+
count = bt->Struct.fields.count;
1382+
}
1383+
13751384
if (count == 1) {
1376-
return type_can_be_direct(LLVMStructGetTypeAtIndex(type, 0), calling_convention);
1385+
return type_can_be_direct(
1386+
LLVMStructGetTypeAtIndex(type, 0),
1387+
type_internal_index(original_type, 0),
1388+
calling_convention
1389+
);
13771390
}
1391+
13781392
} else if (is_basic_register_type(type)) {
13791393
return true;
13801394
}
@@ -1398,23 +1412,23 @@ namespace lbAbiWasm {
13981412
return false;
13991413
}
14001414

1401-
gb_internal lbArgType is_struct(LLVMContextRef c, LLVMTypeRef type, ProcCallingConvention calling_convention) {
1415+
gb_internal lbArgType is_struct(LLVMContextRef c, LLVMTypeRef type, Type *original_type, ProcCallingConvention calling_convention) {
14021416
LLVMTypeKind kind = LLVMGetTypeKind(type);
14031417
GB_ASSERT(kind == LLVMArrayTypeKind || kind == LLVMStructTypeKind);
14041418

14051419
i64 sz = lb_sizeof(type);
14061420
if (sz == 0) {
14071421
return lb_arg_type_ignore(type);
14081422
}
1409-
if (type_can_be_direct(type, calling_convention)) {
1423+
if (type_can_be_direct(type, original_type, calling_convention)) {
14101424
return lb_arg_type_direct(type);
14111425
}
14121426
return lb_arg_type_indirect(type, nullptr);
14131427
}
14141428

1415-
gb_internal lbArgType pseudo_slice(LLVMContextRef c, LLVMTypeRef type, ProcCallingConvention calling_convention) {
1429+
gb_internal lbArgType pseudo_slice(LLVMContextRef c, LLVMTypeRef type, Type *original_type, ProcCallingConvention calling_convention) {
14161430
if (build_context.metrics.ptr_size < build_context.metrics.int_size &&
1417-
type_can_be_direct(type, calling_convention)) {
1431+
type_can_be_direct(type, original_type, calling_convention)) {
14181432
LLVMTypeRef types[2] = {
14191433
LLVMStructGetTypeAtIndex(type, 0),
14201434
// ignore padding
@@ -1423,7 +1437,7 @@ namespace lbAbiWasm {
14231437
LLVMTypeRef new_type = LLVMStructTypeInContext(c, types, gb_count_of(types), false);
14241438
return lb_arg_type_direct(type, new_type, nullptr, nullptr);
14251439
} else {
1426-
return is_struct(c, type, calling_convention);
1440+
return is_struct(c, type, original_type, calling_convention);
14271441
}
14281442
}
14291443

@@ -1444,9 +1458,9 @@ namespace lbAbiWasm {
14441458
LLVMTypeKind kind = LLVMGetTypeKind(t);
14451459
if (kind == LLVMStructTypeKind || kind == LLVMArrayTypeKind) {
14461460
if (is_type_slice(ptype) || is_type_string(ptype)) {
1447-
args[i] = pseudo_slice(c, t, calling_convention);
1461+
args[i] = pseudo_slice(c, t, ptype, calling_convention);
14481462
} else {
1449-
args[i] = is_struct(c, t, calling_convention);
1463+
args[i] = is_struct(c, t, ptype, calling_convention);
14501464
}
14511465
} else {
14521466
args[i] = non_struct(c, t, false);
@@ -1455,11 +1469,11 @@ namespace lbAbiWasm {
14551469
return args;
14561470
}
14571471

1458-
gb_internal LB_ABI_COMPUTE_RETURN_TYPE(compute_return_type) {
1472+
gb_internal lbArgType compute_return_type(lbFunctionType *ft, LLVMContextRef c, LLVMTypeRef return_type, bool return_is_defined, bool return_is_tuple, Type* original_type) {
14591473
if (!return_is_defined) {
14601474
return lb_arg_type_direct(LLVMVoidTypeInContext(c));
14611475
} else if (lb_is_type_kind(return_type, LLVMStructTypeKind) || lb_is_type_kind(return_type, LLVMArrayTypeKind)) {
1462-
if (type_can_be_direct(return_type, ft->calling_convention)) {
1476+
if (type_can_be_direct(return_type, original_type, ft->calling_convention)) {
14631477
return lb_arg_type_direct(return_type);
14641478
} else if (ft->calling_convention != ProcCC_CDecl) {
14651479
i64 sz = lb_sizeof(return_type);
@@ -1471,7 +1485,36 @@ namespace lbAbiWasm {
14711485
}
14721486
}
14731487

1474-
LB_ABI_MODIFY_RETURN_IF_TUPLE_MACRO();
1488+
// Multiple returns.
1489+
if (return_is_tuple) { \
1490+
lbArgType return_arg = {};
1491+
if (lb_is_type_kind(return_type, LLVMStructTypeKind)) {
1492+
unsigned field_count = LLVMCountStructElementTypes(return_type);
1493+
if (field_count > 1) {
1494+
ft->original_arg_count = ft->args.count;
1495+
ft->multiple_return_original_type = return_type;
1496+
1497+
for (unsigned i = 0; i < field_count-1; i++) {
1498+
LLVMTypeRef field_type = LLVMStructGetTypeAtIndex(return_type, i);
1499+
LLVMTypeRef field_pointer_type = LLVMPointerType(field_type, 0);
1500+
lbArgType ret_partial = lb_arg_type_direct(field_pointer_type);
1501+
array_add(&ft->args, ret_partial);
1502+
}
1503+
1504+
LLVMTypeRef new_return_type = LLVMStructGetTypeAtIndex(return_type, field_count-1);
1505+
return_arg = compute_return_type(
1506+
ft,
1507+
c,
1508+
LLVMStructGetTypeAtIndex(return_type, field_count-1),
1509+
true, false,
1510+
type_internal_index(original_type, field_count-1)
1511+
);
1512+
}
1513+
}
1514+
if (return_arg.type != nullptr) {
1515+
return return_arg;
1516+
}
1517+
}
14751518

14761519
LLVMAttributeRef attr = lb_create_enum_attribute_with_type(c, "sret", return_type);
14771520
return lb_arg_type_indirect(return_type, attr);

src/llvm_backend_general.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2206,7 +2206,7 @@ gb_internal LLVMTypeRef lb_type_internal(lbModule *m, Type *type) {
22062206
field_count = 3;
22072207
}
22082208
LLVMTypeRef *fields = gb_alloc_array(permanent_allocator(), LLVMTypeRef, field_count);
2209-
fields[0] = LLVMPointerType(lb_type(m, type->Pointer.elem), 0);
2209+
fields[0] = LLVMPointerType(lb_type(m, type->SoaPointer.elem), 0);
22102210
if (bigger_int) {
22112211
fields[1] = lb_type_padding_filler(m, build_context.ptr_size, build_context.ptr_size);
22122212
fields[2] = LLVMIntTypeInContext(ctx, 8*cast(unsigned)build_context.int_size);

src/types.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4614,6 +4614,82 @@ gb_internal Type *alloc_type_proc_from_types(Type **param_types, unsigned param_
46144614
// return type;
46154615
// }
46164616

4617+
// Index a type that is internally a struct or array.
4618+
gb_internal Type *type_internal_index(Type *t, isize index) {
4619+
Type *bt = base_type(t);
4620+
if (bt == nullptr) {
4621+
return nullptr;
4622+
}
4623+
4624+
switch (bt->kind) {
4625+
case Type_Basic:
4626+
{
4627+
switch (bt->Basic.kind) {
4628+
case Basic_complex32: return t_f16;
4629+
case Basic_complex64: return t_f32;
4630+
case Basic_complex128: return t_f64;
4631+
case Basic_quaternion64: return t_f16;
4632+
case Basic_quaternion128: return t_f32;
4633+
case Basic_quaternion256: return t_f64;
4634+
case Basic_string:
4635+
{
4636+
GB_ASSERT(index == 0 || index == 1);
4637+
return index == 0 ? t_u8_ptr : t_int;
4638+
}
4639+
case Basic_any:
4640+
{
4641+
GB_ASSERT(index == 0 || index == 1);
4642+
return index == 0 ? t_rawptr : t_typeid;
4643+
}
4644+
}
4645+
}
4646+
break;
4647+
4648+
case Type_Array: return bt->Array.elem;
4649+
case Type_EnumeratedArray: return bt->EnumeratedArray.elem;
4650+
case Type_SimdVector: return bt->SimdVector.elem;
4651+
case Type_Slice:
4652+
{
4653+
GB_ASSERT(index == 0 || index == 1);
4654+
return index == 0 ? t_rawptr : t_typeid;
4655+
}
4656+
case Type_DynamicArray:
4657+
{
4658+
switch (index) {
4659+
case 0: return t_rawptr;
4660+
case 1: return t_int;
4661+
case 2: return t_int;
4662+
case 3: return t_allocator;
4663+
default: GB_PANIC("invalid raw dynamic array index");
4664+
};
4665+
}
4666+
case Type_Struct:
4667+
return get_struct_field_type(bt, index);
4668+
case Type_Union:
4669+
if (index < bt->Union.variants.count) {
4670+
return bt->Union.variants[index];
4671+
}
4672+
return union_tag_type(bt);
4673+
case Type_Tuple:
4674+
return bt->Tuple.variables[index]->type;
4675+
case Type_Matrix:
4676+
return bt->Matrix.elem;
4677+
case Type_SoaPointer:
4678+
{
4679+
GB_ASSERT(index == 0 || index == 1);
4680+
return index == 0 ? t_rawptr : t_int;
4681+
}
4682+
case Type_Map:
4683+
return type_internal_index(bt->Map.debug_metadata_type, index);
4684+
case Type_BitField:
4685+
return type_internal_index(bt->BitField.backing_type, index);
4686+
case Type_Generic:
4687+
return type_internal_index(bt->Generic.specialized, index);
4688+
};
4689+
4690+
GB_PANIC("Unhandled type %s", type_to_string(bt));
4691+
};
4692+
46174693
gb_internal gbString write_type_to_string(gbString str, Type *type, bool shorthand=false, bool allow_polymorphic=false) {
46184694
if (type == nullptr) {
46194695
return gb_string_appendc(str, "<no type>");

0 commit comments

Comments
 (0)