Skip to content

Commit c66f1a8

Browse files
committed
Bugfixes
1 parent c5fd4ac commit c66f1a8

File tree

3 files changed

+273
-26
lines changed

3 files changed

+273
-26
lines changed

xcfa-mapper/src/Mapper.cpp

Lines changed: 129 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <mlir/IR/Attributes.h>
66
#include <mlir/IR/Types.h>
77

8+
#include <clang/CIR/Dialect/IR/CIRDialect.h>
9+
810
#include <cxxabi.h>
911
#include <cstdlib>
1012

@@ -126,6 +128,14 @@ void Mapper::setName(mlir::Value v, const std::string &name) {
126128
valueNames[v] = name;
127129
}
128130

131+
void Mapper::markAsDirectAccess(mlir::Value v) {
132+
directAccessValues.insert(v);
133+
}
134+
135+
bool Mapper::isDirectAccess(mlir::Value v) const {
136+
return directAccessValues.count(v) > 0;
137+
}
138+
129139
std::string Mapper::mapTypeToC(mlir::Type t) const {
130140
// Handle MLIR built-in integer types
131141
if (auto it = mlir::dyn_cast<mlir::IntegerType>(t)) {
@@ -166,15 +176,28 @@ std::string Mapper::mapTypeToC(mlir::Type t) const {
166176
}
167177

168178
// Handle CIR-specific types (need to check dialect)
179+
// Note: Type aliases like !rec_Point may not have a dialect, so check string first
180+
181+
// Try to extract type name from the type string representation
182+
llvm::SmallString<64> buf;
183+
llvm::raw_svector_ostream os(buf);
184+
t.print(os);
185+
std::string typeStr = os.str().str();
186+
187+
// Check for type alias first: !rec_StructName
188+
if (typeStr.find("!rec_") == 0) {
189+
// Type alias: !rec_Point -> struct Point
190+
std::string structName = typeStr.substr(5); // Skip "!rec_"
191+
// Remove any trailing characters that aren't part of the name
192+
size_t end = structName.find_first_of(" ,>)");
193+
if (end != std::string::npos) {
194+
structName = structName.substr(0, end);
195+
}
196+
return "struct " + structName;
197+
}
198+
169199
std::string dialectName = t.getDialect().getNamespace().str();
170200
if (dialectName == "cir") {
171-
std::string typeName = t.getAsOpaquePointer() ? "" : "";
172-
173-
// Try to extract type name from the type string representation
174-
llvm::SmallString<64> buf;
175-
llvm::raw_svector_ostream os(buf);
176-
t.print(os);
177-
std::string typeStr = os.str().str();
178201

179202
// Check for CIR pointer types
180203
if (typeStr.find("!cir.ptr") != std::string::npos) {
@@ -211,8 +234,28 @@ std::string Mapper::mapTypeToC(mlir::Type t) const {
211234
}
212235

213236
// Check for CIR struct/record types
214-
if (typeStr.find("!cir.record") != std::string::npos) {
215-
return "struct"; // Incomplete, but shows intent
237+
if (typeStr.find("!cir.record") != std::string::npos || typeStr.find("!rec_") != std::string::npos) {
238+
// Try to extract struct name from the type
239+
// Format: !cir.record<struct "Name" {...}> or !rec_Name
240+
if (typeStr.find("!rec_") != std::string::npos) {
241+
// Type alias: !rec_Point
242+
size_t pos = typeStr.find("!rec_");
243+
if (pos != std::string::npos) {
244+
size_t end = typeStr.find_first_of(" ,>)", pos);
245+
if (end == std::string::npos) end = typeStr.length();
246+
std::string structName = typeStr.substr(pos + 5, end - pos - 5);
247+
return "struct " + structName;
248+
}
249+
} else if (typeStr.find("struct \"") != std::string::npos) {
250+
// Inline definition: !cir.record<struct "Point" {...}>
251+
size_t start = typeStr.find("struct \"") + 8;
252+
size_t end = typeStr.find("\"", start);
253+
if (end != std::string::npos) {
254+
std::string structName = typeStr.substr(start, end - start);
255+
return "struct " + structName;
256+
}
257+
}
258+
return "struct"; // Fallback for unnamed structs
216259
}
217260

218261
// Check for void type
@@ -286,7 +329,29 @@ bool Mapper::mapFunc(mlir::Operation *fop, std::ostream &out) {
286329
out << "// function: " << sym.getValue().str() << "\n";
287330
// Use the chosen output name (demangled when unique, otherwise mangled).
288331
std::string outName = getFunctionOutputName(sym.getValue().str());
289-
out << retType << " " << outName << "()";
332+
333+
// Extract function parameters from the entry block arguments
334+
std::string params = "";
335+
if (fop->getNumRegions() > 0 && !fop->getRegion(0).empty()) {
336+
Block &entryBlock = fop->getRegion(0).front();
337+
bool first = true;
338+
for (BlockArgument arg : entryBlock.getArguments()) {
339+
if (!first) params += ", ";
340+
first = false;
341+
342+
// Map parameter type to C type
343+
std::string paramType = mapTypeToC(arg.getType());
344+
345+
// Generate parameter name
346+
std::string paramName = freshName("v");
347+
setName(arg, paramName);
348+
markAsDirectAccess(arg); // Function parameters are direct access like alloca
349+
350+
params += paramType + " " + paramName;
351+
}
352+
}
353+
354+
out << retType << " " << outName << "(" << params << ")";
290355

291356
// If there is no region/body then emit a declaration (prototype).
292357
if (fop->getNumRegions() == 0 || fop->getRegion(0).empty()) {
@@ -339,11 +404,64 @@ bool Mapper::mapFunc(mlir::Operation *fop, std::ostream &out) {
339404
return true;
340405
}
341406

407+
bool Mapper::mapGlobal(mlir::Operation *gop, std::ostream &out) {
408+
// Extract global variable name
409+
auto sym = gop->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
410+
if (!sym) {
411+
if (!isBestEffort()) {
412+
llvm::errs() << "Global op missing symbol name\n";
413+
return false;
414+
}
415+
out << "// Global variable with missing name\n";
416+
return true;
417+
}
418+
419+
std::string name = sym.getValue().str();
420+
421+
// Cast to GlobalOp to access getSymType()
422+
auto globalOp = mlir::cast<cir::GlobalOp>(gop);
423+
mlir::Type symType = globalOp.getSymType();
424+
425+
// Get type string
426+
std::string typeStr;
427+
llvm::raw_string_ostream rso(typeStr);
428+
symType.print(rso);
429+
rso.flush();
430+
431+
std::string ctype = mapTypeToC(symType);
432+
433+
// Check if it's an array type
434+
if (typeStr.find("!cir.array<") != std::string::npos) {
435+
// Extract array size from type string like "!cir.array<!s32i x 5>"
436+
size_t xPos = typeStr.find(" x ");
437+
size_t endPos = typeStr.find(">", xPos);
438+
439+
if (xPos != std::string::npos && endPos != std::string::npos) {
440+
std::string sizeStr = typeStr.substr(xPos + 3, endPos - xPos - 3);
441+
out << ctype << " " << name << "[" << sizeStr << "];\n";
442+
return true;
443+
}
444+
}
445+
446+
// For non-array types
447+
out << ctype << " " << name << ";\n";
448+
return true;
449+
}
450+
342451
bool Mapper::mapModule(ModuleOp module, std::ostream &out) {
343452
// Prepare function names (demangle where possible and unique) before
344453
// emitting any function declarations/definitions so we can avoid name
345454
// collisions when demangling.
346455
prepareFunctionNames(module);
456+
457+
// First pass: emit global variables
458+
for (auto &op : module.getOps()) {
459+
if (op.getName().getStringRef() == "cir.global") {
460+
if (!mapGlobal(&op, out) && !isBestEffort()) return false;
461+
}
462+
}
463+
464+
// Second pass: emit functions
347465
for (auto &op : module.getOps()) {
348466
if (llvm::isa<mlir::ModuleOp>(op)) {
349467
mlir::ModuleOp inner = mlir::cast<mlir::ModuleOp>(op);
@@ -353,9 +471,8 @@ bool Mapper::mapModule(ModuleOp module, std::ostream &out) {
353471

354472
if (op.getName().getStringRef() == "cir.func") {
355473
if (!mapFunc(&op, out) && !isBestEffort()) return false;
356-
} else {
357-
out << "// top-level op: " << op.getName().getStringRef().str() << " -- not mapped yet\n";
358474
}
475+
// Skip cir.global (already processed)
359476
}
360477

361478
return true;

xcfa-mapper/src/Mapper.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <functional>
1111

1212
#include <llvm/ADT/DenseMap.h>
13+
#include <llvm/ADT/DenseSet.h>
1314
#include <llvm/ADT/StringRef.h>
1415

1516
namespace mlir {
@@ -61,6 +62,9 @@ class Mapper {
6162
/// Map a single function. Returns true on success, false on unrecoverable error.
6263
bool mapFunc(mlir::Operation *fop, std::ostream &out);
6364

65+
/// Map a global variable. Returns true on success, false on unrecoverable error.
66+
bool mapGlobal(mlir::Operation *gop, std::ostream &out);
67+
6468
/// Map an MLIR module to a C program written to `out`.
6569
bool mapModule(mlir::ModuleOp module, std::ostream &out);
6670

@@ -71,6 +75,12 @@ class Mapper {
7175
std::string getOrCreateName(mlir::Value v);
7276
/// Force-set the C identifier for a Value.
7377
void setName(mlir::Value v, const std::string &name);
78+
79+
/// Mark a value as being a "direct access" pointer (from alloca).
80+
/// These don't need dereferencing in load/store.
81+
void markAsDirectAccess(mlir::Value v);
82+
/// Check if a value is marked as direct access.
83+
bool isDirectAccess(mlir::Value v) const;
7484

7585
/// Map an MLIR type to a C type string. This provides a central place for
7686
/// converting common MLIR types (e.g., integer widths and floats) into
@@ -91,6 +101,7 @@ class Mapper {
91101
std::unordered_map<std::string, std::unique_ptr<OpHandler>> handlers;
92102
llvm::DenseMap<mlir::Value, std::string> valueNames;
93103
llvm::DenseMap<mlir::Block *, std::string> blockLabels;
104+
llvm::DenseSet<mlir::Value> directAccessValues;
94105
unsigned counter;
95106
// Mapping from original (mangled) symbol -> chosen output name
96107
std::unordered_map<std::string, std::string> functionOutputNames;

0 commit comments

Comments
 (0)