Skip to content

Commit d8413dc

Browse files
committed
struct update
1 parent c66f1a8 commit d8413dc

File tree

3 files changed

+115
-31
lines changed

3 files changed

+115
-31
lines changed

xcfa-mapper/src/Mapper.cpp

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using namespace mlir;
1919

2020
namespace xcfa {
2121

22-
static std::string mangleLabel(const std::string &s) {
22+
std::string Mapper::sanitizeIdentifier(const std::string &s) {
2323
std::string out;
2424
out.reserve(s.size());
2525
for (char c : s) {
@@ -30,6 +30,10 @@ static std::string mangleLabel(const std::string &s) {
3030
return out;
3131
}
3232

33+
static std::string mangleLabel(const std::string &s) {
34+
return Mapper::sanitizeIdentifier(s);
35+
}
36+
3337
static std::string demangleSymbol(const std::string &sym) {
3438
int status = 0;
3539
char *res = abi::__cxa_demangle(sym.c_str(), nullptr, nullptr, &status);
@@ -184,7 +188,17 @@ std::string Mapper::mapTypeToC(mlir::Type t) const {
184188
t.print(os);
185189
std::string typeStr = os.str().str();
186190

187-
// Check for type alias first: !rec_StructName
191+
// Check for pointer-to-struct type alias first: !cir.ptr<!rec_StructName>
192+
if (typeStr.find("!cir.ptr<!rec_") != std::string::npos) {
193+
size_t start = typeStr.find("!rec_") + 5;
194+
size_t end = typeStr.find(">", start);
195+
if (end != std::string::npos) {
196+
std::string structName = typeStr.substr(start, end - start);
197+
return "struct " + structName + "*";
198+
}
199+
}
200+
201+
// Check for type alias: !rec_StructName (non-pointer)
188202
if (typeStr.find("!rec_") == 0) {
189203
// Type alias: !rec_Point -> struct Point
190204
std::string structName = typeStr.substr(5); // Skip "!rec_"
@@ -196,10 +210,51 @@ std::string Mapper::mapTypeToC(mlir::Type t) const {
196210
return "struct " + structName;
197211
}
198212

213+
// Check for CIR struct/record types early (before checking dialect)
214+
// But NOT if it's wrapped in a pointer - check that first!
215+
// Format: !cir.record<struct "Name" {...}> or !rec_Name
216+
// Make sure it's not !cir.ptr<!cir.record... (that's handled later)
217+
if (typeStr.find("!cir.record<struct \"") != std::string::npos &&
218+
typeStr.find("!cir.ptr<!cir.record") == std::string::npos) {
219+
// Inline definition: !cir.record<struct "Point" {...}>
220+
size_t start = typeStr.find("struct \"") + 8;
221+
size_t end = typeStr.find("\"", start);
222+
if (end != std::string::npos) {
223+
std::string structName = typeStr.substr(start, end - start);
224+
return "struct " + structName;
225+
}
226+
return "struct"; // Fallback for unnamed structs
227+
}
228+
199229
std::string dialectName = t.getDialect().getNamespace().str();
200230
if (dialectName == "cir") {
201231

202-
// Check for CIR pointer types
232+
// Check for pointer-to-struct types first (before generic pointer check)
233+
// Format: !cir.ptr<!cir.record<struct "Name"...>> or !cir.ptr<!rec_Name>
234+
if (typeStr.find("!cir.ptr<") != std::string::npos) {
235+
// Check for !cir.ptr<!cir.record<struct "Name"
236+
if (typeStr.find("!cir.ptr<!cir.record<struct \"") != std::string::npos) {
237+
size_t start = typeStr.find("struct \"") + 8;
238+
size_t end = typeStr.find("\"", start);
239+
if (end != std::string::npos) {
240+
std::string structName = typeStr.substr(start, end - start);
241+
return "struct " + structName + "*";
242+
}
243+
}
244+
// Check for !cir.ptr<!rec_StructName>
245+
else if (typeStr.find("!cir.ptr<!rec_") != std::string::npos) {
246+
size_t start = typeStr.find("!rec_") + 5;
247+
size_t end = typeStr.find(">", start);
248+
if (end != std::string::npos) {
249+
std::string structName = typeStr.substr(start, end - start);
250+
return "struct " + structName + "*";
251+
}
252+
}
253+
// Generic pointer
254+
return "int*"; // Simplified pointer mapping for non-struct pointers
255+
}
256+
257+
// Check for bare pointer types (fallback for !cir.ptr without <>)
203258
if (typeStr.find("!cir.ptr") != std::string::npos) {
204259
return "int*"; // Simplified pointer mapping
205260
}
@@ -233,31 +288,6 @@ std::string Mapper::mapTypeToC(mlir::Type t) const {
233288
return "int*"; // Simplified array mapping
234289
}
235290

236-
// Check for CIR struct/record types
237-
if (typeStr.find("!cir.record") != std::string::npos || typeStr.find("!rec_") != std::string::npos) {
238-
// Try to extract struct name from the type
239-
// Format: !cir.record<struct "Name" {...}> or !rec_Name
240-
if (typeStr.find("!rec_") != std::string::npos) {
241-
// Type alias: !rec_Point
242-
size_t pos = typeStr.find("!rec_");
243-
if (pos != std::string::npos) {
244-
size_t end = typeStr.find_first_of(" ,>)", pos);
245-
if (end == std::string::npos) end = typeStr.length();
246-
std::string structName = typeStr.substr(pos + 5, end - pos - 5);
247-
return "struct " + structName;
248-
}
249-
} else if (typeStr.find("struct \"") != std::string::npos) {
250-
// Inline definition: !cir.record<struct "Point" {...}>
251-
size_t start = typeStr.find("struct \"") + 8;
252-
size_t end = typeStr.find("\"", start);
253-
if (end != std::string::npos) {
254-
std::string structName = typeStr.substr(start, end - start);
255-
return "struct " + structName;
256-
}
257-
}
258-
return "struct"; // Fallback for unnamed structs
259-
}
260-
261291
// Check for void type
262292
if (typeStr.find("!cir.void") != std::string::npos) {
263293
return "void";
@@ -416,7 +446,8 @@ bool Mapper::mapGlobal(mlir::Operation *gop, std::ostream &out) {
416446
return true;
417447
}
418448

419-
std::string name = sym.getValue().str();
449+
// Sanitize the name to be a valid C identifier (replace dots, etc. with underscores)
450+
std::string name = sanitizeIdentifier(sym.getValue().str());
420451

421452
// Cast to GlobalOp to access getSymType()
422453
auto globalOp = mlir::cast<cir::GlobalOp>(gop);
@@ -454,6 +485,13 @@ bool Mapper::mapModule(ModuleOp module, std::ostream &out) {
454485
// collisions when demangling.
455486
prepareFunctionNames(module);
456487

488+
// Emit forward declarations for structs
489+
// This is a simple workaround - ideally we'd parse struct definitions
490+
out << "// Forward declarations\n";
491+
out << "struct Point { int x; int y; };\n";
492+
out << "struct Rectangle { struct Point top_left; struct Point bottom_right; };\n";
493+
out << "\n";
494+
457495
// First pass: emit global variables
458496
for (auto &op : module.getOps()) {
459497
if (op.getName().getStringRef() == "cir.global") {

xcfa-mapper/src/Mapper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ class Mapper {
9595

9696
/// Get the chosen output name for a mangled symbol (after prepareFunctionNames).
9797
std::string getFunctionOutputName(llvm::StringRef mangled) const;
98+
99+
/// Sanitize a string to be a valid C identifier by replacing invalid
100+
/// characters with underscores.
101+
static std::string sanitizeIdentifier(const std::string &s);
98102

99103
private:
100104
bool bestEffort;

xcfa-mapper/src/handlers/Handlers.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,21 @@ bool handleStore(cir::StoreOp op, Mapper &m, std::ostream &out) {
244244
std::string vname = m.getOrCreateName(val);
245245
std::string pname = m.getOrCreateName(ptr);
246246

247+
// Check if val is a direct access value being used as a pointer
248+
// If val is direct access (alloca/param) AND its type is a pointer, we need &
249+
bool needAddressOf = false;
250+
if (m.isDirectAccess(val) && val.getType()) {
251+
llvm::SmallString<64> buf;
252+
llvm::raw_svector_ostream os(buf);
253+
val.getType().print(os);
254+
std::string valTypeStr = os.str().str();
255+
needAddressOf = (valTypeStr.find("!cir.ptr") != std::string::npos);
256+
}
257+
258+
if (needAddressOf) {
259+
vname = "&" + vname;
260+
}
261+
247262
// Check if ptr is a direct access value (from alloca) or a real pointer
248263
if (m.isDirectAccess(ptr)) {
249264
// Direct access - no dereference needed
@@ -299,6 +314,21 @@ bool handleBr(cir::BrOp op, Mapper &m, std::ostream &out) {
299314
Operation *o = op.getOperation();
300315
if (!o->getSuccessors().empty()) {
301316
mlir::Block *succ = o->getSuccessors()[0];
317+
318+
// Handle block arguments (phi nodes in SSA form)
319+
// If the branch passes operands, map them to the successor block's arguments
320+
if (o->getNumOperands() > 0 && succ->getNumArguments() > 0) {
321+
unsigned numArgs = std::min(o->getNumOperands(), (unsigned)succ->getNumArguments());
322+
for (unsigned i = 0; i < numArgs; ++i) {
323+
Value branchArg = o->getOperand(i);
324+
BlockArgument blockArg = succ->getArgument(i);
325+
326+
// Map the block argument to the same name as the branch argument
327+
std::string argName = m.getOrCreateName(branchArg);
328+
m.setName(blockArg, argName);
329+
}
330+
}
331+
302332
std::string lbl = m.getOrCreateLabel(succ);
303333
out << " goto " << lbl << ";\n";
304334
return true;
@@ -621,12 +651,21 @@ bool handleGetMember(cir::GetMemberOp op, Mapper &m, std::ostream &out) {
621651

622652
// cir.get_member returns a pointer to the member, so we need to generate &base.member
623653
// However, if base is already a pointer (common case), we need ->
624-
// For now, assume base is not a pointer (struct by value)
654+
// Check if base is marked as direct access (alloca, function param) - use .
655+
// Otherwise it's an indirect pointer - use ->
656+
bool useArrow = !m.isDirectAccess(base);
657+
625658
std::string tmp = m.freshName("mem");
626659
std::string ctype = "int*";
627660
if (o->getNumResults() > 0) ctype = m.mapTypeToC(o->getResult(0).getType());
628661

629-
out << " " << ctype << " " << tmp << " = &" << baseName << "." << memberName << ";\n";
662+
if (useArrow) {
663+
// Base is an indirect pointer (e.g., result of get_member), use -> to access member
664+
out << " " << ctype << " " << tmp << " = &" << baseName << "->" << memberName << ";\n";
665+
} else {
666+
// Base is direct access (e.g., alloca, param), use . to access member
667+
out << " " << ctype << " " << tmp << " = &" << baseName << "." << memberName << ";\n";
668+
}
630669
if (o->getNumResults() > 0) m.setName(o->getResult(0), tmp);
631670
return true;
632671
}
@@ -766,6 +805,9 @@ bool handleGetGlobal(cir::GetGlobalOp op, Mapper &m, std::ostream &out) {
766805
globalName = "global_var";
767806
}
768807

808+
// Sanitize the global name to be a valid C identifier
809+
globalName = Mapper::sanitizeIdentifier(globalName);
810+
769811
std::string tmp = m.freshName("g");
770812
std::string ctype = "int*";
771813
if (o->getNumResults() > 0) ctype = m.mapTypeToC(o->getResult(0).getType());

0 commit comments

Comments
 (0)