@@ -286,6 +286,8 @@ std::string Mapper::mapTypeToC(mlir::Type t) const {
286286 else if (elementTypeStr == " !u32i" ) return " unsigned int*" ;
287287 else if (elementTypeStr == " !s64i" ) return " long long*" ;
288288 else if (elementTypeStr == " !u64i" ) return " unsigned long long*" ;
289+ else if (elementTypeStr.find (" !cir.float" ) != std::string::npos) return " float*" ;
290+ else if (elementTypeStr.find (" !cir.double" ) != std::string::npos) return " double*" ;
289291 else if (elementTypeStr.find (" !cir.int<s, 8>" ) != std::string::npos) return " char*" ;
290292 else if (elementTypeStr.find (" !cir.int<u, 8>" ) != std::string::npos) return " unsigned char*" ;
291293 else if (elementTypeStr.find (" !cir.int<s, 16>" ) != std::string::npos) return " short*" ;
@@ -635,7 +637,8 @@ bool Mapper::mapModule(ModuleOp module, std::ostream &out) {
635637 bool isArray = false ; // whether field is an array
636638 std::vector<std::string> dims; // array dimensions outer->inner
637639 };
638- std::map<std::string, std::vector<FieldInfo>> structFields; // structName -> fields
640+ std::map<std::string, std::vector<FieldInfo>> structFields; // recordName -> fields
641+ std::map<std::string, bool > isUnionContainer; // recordName -> isUnion
639642
640643 // Use static flag to only emit struct definitions once (for top-level module)
641644 static bool structsEmitted = false ;
@@ -650,7 +653,7 @@ bool Mapper::mapModule(ModuleOp module, std::ostream &out) {
650653 }
651654 }
652655
653- // Collect struct/field info from cir.get_member
656+ // Collect struct/union field info from cir.get_member
654657 if (auto gm = llvm::dyn_cast<cir::GetMemberOp>(genericOp)) {
655658 // Base is pointer to struct; extract struct name from base type string
656659 mlir::Type baseType = gm.getOperation ()->getOperand (0 ).getType ();
@@ -665,6 +668,11 @@ bool Mapper::mapModule(ModuleOp module, std::ostream &out) {
665668 size_t start = baseTypeStr.find (" !rec_" ) + 5 ;
666669 size_t end = baseTypeStr.find (" >" , start);
667670 if (end != std::string::npos) structName = baseTypeStr.substr (start, end - start);
671+ } else if (baseTypeStr.find (" !cir.ptr<!cir.record<union " ) != std::string::npos) {
672+ size_t start = baseTypeStr.find (" union \" " ) + 7 ;
673+ size_t end = baseTypeStr.find (" \" " , start);
674+ if (end != std::string::npos) structName = baseTypeStr.substr (start, end - start);
675+ if (!structName.empty ()) isUnionContainer[structName] = true ;
668676 }
669677 if (structName.empty ()) return ; // Not a struct base
670678
@@ -793,6 +801,11 @@ bool Mapper::mapModule(ModuleOp module, std::ostream &out) {
793801 }
794802 info.name = fname;
795803
804+ // Heuristic: if we see a field named 'value' on a record, it's likely a union
805+ if (fname == " value" ) {
806+ isUnionContainer[structName] = true ;
807+ }
808+
796809 // Avoid duplicate entries for the same field name
797810 auto &vec = structFields[structName];
798811 bool exists = false ;
@@ -846,7 +859,10 @@ bool Mapper::mapModule(ModuleOp module, std::ostream &out) {
846859 if (!structsEmitted && !order.empty ()) {
847860 out << " // Struct definitions (auto-parsed)\n " ;
848861 for (auto &sname : order) {
849- out << " struct " << sname << " { " ;
862+ bool isU = false ;
863+ auto itk = isUnionContainer.find (sname);
864+ if (itk != isUnionContainer.end ()) isU = itk->second ;
865+ out << (isU ? " union " : " struct " ) << sname << " { " ;
850866 auto &vec = structFields[sname];
851867 for (size_t i = 0 ; i < vec.size (); ++i) {
852868 const FieldInfo &fi = vec[i];
0 commit comments