Skip to content

Commit

Permalink
[flang] Fix lowering of host associated cray pointee symbols (#86121)
Browse files Browse the repository at this point in the history
Cray pointee symbols can be host associated from a module or host
procedure while the related cray pointer is not explicitly associated.
This caused the "not yet implemented: lowering symbol to HLFIR" to fire
when lowering a reference to the cray pointee and fetching the cray
pointer.

This patch:
- Ensures cray pointers are always instantiated when instantiating a
cray pointee.
- Fix internal procedure lowering to deal with cray pointee host
association like it does for pointers (the lowering strategy for cray
pointee is to create a pointer that is updated with the cray pointer
value before being fetched).

This should fix the bug reported in
#85420.
  • Loading branch information
jeanPerier authored Mar 22, 2024
1 parent 465ea0b commit de7a50f
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 53 deletions.
6 changes: 3 additions & 3 deletions flang/include/flang/Lower/ConvertVariable.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
fir::FortranVariableFlagsEnum::None,
bool force = false);

/// For the given Cray pointee symbol return the corresponding
/// Cray pointer symbol. Assert if the pointer symbol cannot be found.
Fortran::semantics::SymbolRef getCrayPointer(Fortran::semantics::SymbolRef sym);
/// Given the Fortran type of a Cray pointee, return the fir.box type used to
/// track the cray pointee as Fortran pointer.
mlir::Type getCrayPointeeBoxType(mlir::Type);

} // namespace lower
} // namespace Fortran
Expand Down
3 changes: 3 additions & 0 deletions flang/include/flang/Semantics/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ const Symbol *FindExternallyVisibleObject(
// specific procedure of the same name, return it instead.
const Symbol &BypassGeneric(const Symbol &);

// Given a cray pointee symbol, returns the related cray pointer symbol.
const Symbol &GetCrayPointer(const Symbol &crayPointee);

using SomeExpr = evaluate::Expr<evaluate::SomeType>;

bool ExprHasTypeCategory(
Expand Down
5 changes: 3 additions & 2 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3995,11 +3995,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
sym->Rank() == 0) {
// get the corresponding Cray pointer

auto ptrSym = Fortran::lower::getCrayPointer(*sym);
const Fortran::semantics::Symbol &ptrSym =
Fortran::semantics::GetCrayPointer(*sym);
fir::ExtendedValue ptr =
getSymbolExtendedValue(ptrSym, nullptr);
mlir::Value ptrVal = fir::getBase(ptr);
mlir::Type ptrTy = genType(*ptrSym);
mlir::Type ptrTy = genType(ptrSym);

fir::ExtendedValue pte =
getSymbolExtendedValue(*sym, nullptr);
Expand Down
10 changes: 6 additions & 4 deletions flang/lib/Lower/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,8 @@ class ScalarExprLowering {
addr);
} else if (sym->test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
// get the corresponding Cray pointer
auto ptrSym = Fortran::lower::getCrayPointer(sym);
Fortran::semantics::SymbolRef ptrSym{
Fortran::semantics::GetCrayPointer(sym)};
ExtValue ptr = gen(ptrSym);
mlir::Value ptrVal = fir::getBase(ptr);
mlir::Type ptrTy = converter.genType(*ptrSym);
Expand Down Expand Up @@ -1537,8 +1538,8 @@ class ScalarExprLowering {
auto baseSym = getFirstSym(aref);
if (baseSym.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
// get the corresponding Cray pointer
auto ptrSym = Fortran::lower::getCrayPointer(baseSym);

Fortran::semantics::SymbolRef ptrSym{
Fortran::semantics::GetCrayPointer(baseSym)};
fir::ExtendedValue ptr = gen(ptrSym);
mlir::Value ptrVal = fir::getBase(ptr);
mlir::Type ptrTy = ptrVal.getType();
Expand Down Expand Up @@ -6946,7 +6947,8 @@ class ArrayExprLowering {
ComponentPath &components) {
mlir::Value ptrVal = nullptr;
if (x.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
auto ptrSym = Fortran::lower::getCrayPointer(x);
Fortran::semantics::SymbolRef ptrSym{
Fortran::semantics::GetCrayPointer(x)};
ExtValue ptr = converter.getSymbolExtendedValue(ptrSym);
ptrVal = fir::getBase(ptr);
}
Expand Down
9 changes: 8 additions & 1 deletion flang/lib/Lower/ConvertExprToHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ class HlfirDesignatorBuilder {
// value of the Cray pointer variable.
fir::FirOpBuilder &builder = getBuilder();
fir::FortranVariableOpInterface ptrVar =
gen(Fortran::lower::getCrayPointer(symbolRef));
gen(Fortran::semantics::GetCrayPointer(symbolRef));
mlir::Value ptrAddr = ptrVar.getBase();

// Reinterpret the reference to a Cray pointer so that
Expand All @@ -306,9 +306,16 @@ class HlfirDesignatorBuilder {
}
return *varDef;
}
llvm::errs() << *symbolRef << "\n";
TODO(getLoc(), "lowering symbol to HLFIR");
}

fir::FortranVariableOpInterface
gen(const Fortran::semantics::Symbol &symbol) {
Fortran::evaluate::SymbolRef symref{symbol};
return gen(symref);
}

fir::FortranVariableOpInterface
gen(const Fortran::evaluate::Component &component) {
if (Fortran::semantics::IsAllocatableOrPointer(component.GetLastSymbol()))
Expand Down
48 changes: 19 additions & 29 deletions flang/lib/Lower/ConvertVariable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,11 @@ fir::FortranVariableFlagsAttr Fortran::lower::translateSymbolAttributes(
mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym,
fir::FortranVariableFlagsEnum extraFlags) {
fir::FortranVariableFlagsEnum flags = extraFlags;
if (sym.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
// CrayPointee are represented as pointers.
flags = flags | fir::FortranVariableFlagsEnum::pointer;
return fir::FortranVariableFlagsAttr::get(mlirContext, flags);
}
const auto &attrs = sym.attrs();
if (attrs.test(Fortran::semantics::Attr::ALLOCATABLE))
flags = flags | fir::FortranVariableFlagsEnum::allocatable;
Expand Down Expand Up @@ -1615,8 +1620,6 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
(!Fortran::semantics::IsProcedure(sym) ||
Fortran::semantics::IsPointer(sym)) &&
!sym.detailsIf<Fortran::semantics::CommonBlockDetails>()) {
bool isCrayPointee =
sym.test(Fortran::semantics::Symbol::Flag::CrayPointee);
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
const mlir::Location loc = genLocation(converter, sym);
mlir::Value shapeOrShift;
Expand All @@ -1636,31 +1639,21 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
Fortran::lower::translateSymbolCUDADataAttribute(builder.getContext(),
sym);

if (isCrayPointee) {
mlir::Type baseType =
hlfir::getFortranElementOrSequenceType(base.getType());
if (auto seqType = mlir::dyn_cast<fir::SequenceType>(baseType)) {
// The pointer box's sequence type must be with unknown shape.
llvm::SmallVector<int64_t> shape(seqType.getDimension(),
fir::SequenceType::getUnknownExtent());
baseType = fir::SequenceType::get(shape, seqType.getEleTy());
}
fir::BoxType ptrBoxType =
fir::BoxType::get(fir::PointerType::get(baseType));
if (sym.test(Fortran::semantics::Symbol::Flag::CrayPointee)) {
mlir::Type ptrBoxType =
Fortran::lower::getCrayPointeeBoxType(base.getType());
mlir::Value boxAlloc = builder.createTemporary(loc, ptrBoxType);

// Declare a local pointer variable.
attributes = fir::FortranVariableFlagsAttr::get(
builder.getContext(), fir::FortranVariableFlagsEnum::pointer);
auto newBase = builder.create<hlfir::DeclareOp>(
loc, boxAlloc, name, /*shape=*/nullptr, lenParams, attributes);
mlir::Value nullAddr =
builder.createNullConstant(loc, ptrBoxType.getEleTy());
mlir::Value nullAddr = builder.createNullConstant(
loc, llvm::cast<fir::BaseBoxType>(ptrBoxType).getEleTy());

// If the element type is known-length character, then
// EmboxOp does not need the length parameters.
if (auto charType = mlir::dyn_cast<fir::CharacterType>(
fir::unwrapSequenceType(baseType)))
hlfir::getFortranElementType(base.getType())))
if (!charType.hasDynamicLen())
lenParams.clear();

Expand Down Expand Up @@ -2346,16 +2339,13 @@ void Fortran::lower::createRuntimeTypeInfoGlobal(
defineGlobal(converter, var, globalName, linkage);
}

Fortran::semantics::SymbolRef
Fortran::lower::getCrayPointer(Fortran::semantics::SymbolRef sym) {
assert(!sym->GetUltimate().owner().crayPointers().empty() &&
"empty Cray pointer/pointee map");
for (const auto &[pointee, pointer] :
sym->GetUltimate().owner().crayPointers()) {
if (pointee == sym->name()) {
Fortran::semantics::SymbolRef v{pointer.get()};
return v;
}
mlir::Type Fortran::lower::getCrayPointeeBoxType(mlir::Type fortranType) {
mlir::Type baseType = hlfir::getFortranElementOrSequenceType(fortranType);
if (auto seqType = mlir::dyn_cast<fir::SequenceType>(baseType)) {
// The pointer box's sequence type must be with unknown shape.
llvm::SmallVector<int64_t> shape(seqType.getDimension(),
fir::SequenceType::getUnknownExtent());
baseType = fir::SequenceType::get(shape, seqType.getEleTy());
}
llvm_unreachable("corresponding Cray pointer cannot be found");
return fir::BoxType::get(fir::PointerType::get(baseType));
}
9 changes: 7 additions & 2 deletions flang/lib/Lower/HostAssociations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,11 @@ class CapturedAllocatableAndPointer
public:
static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
return fir::ReferenceType::get(converter.genType(sym));
mlir::Type baseType = converter.genType(sym);
if (sym.GetUltimate().test(Fortran::semantics::Symbol::Flag::CrayPointee))
return fir::ReferenceType::get(
Fortran::lower::getCrayPointeeBoxType(baseType));
return fir::ReferenceType::get(baseType);
}
static void instantiateHostTuple(const InstantiateHostTuple &args,
Fortran::lower::AbstractConverter &converter,
Expand Down Expand Up @@ -507,7 +511,8 @@ walkCaptureCategories(T visitor, Fortran::lower::AbstractConverter &converter,
if (Fortran::semantics::IsProcedure(sym))
return CapturedProcedure::visit(visitor, converter, sym, ba);
ba.analyze(sym);
if (Fortran::semantics::IsAllocatableOrPointer(sym))
if (Fortran::semantics::IsAllocatableOrPointer(sym) ||
sym.GetUltimate().test(Fortran::semantics::Symbol::Flag::CrayPointee))
return CapturedAllocatableAndPointer::visit(visitor, converter, sym, ba);
if (ba.isArray())
return CapturedArrays::visit(visitor, converter, sym, ba);
Expand Down
9 changes: 9 additions & 0 deletions flang/lib/Lower/PFTBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,11 @@ struct SymbolDependenceAnalysis {
if (!s->has<semantics::DerivedTypeDetails>())
depth = std::max(analyze(s) + 1, depth);
}

// Make sure cray pointer is instantiated even if it is not visible.
if (ultimate.test(Fortran::semantics::Symbol::Flag::CrayPointee))
depth = std::max(
analyze(Fortran::semantics::GetCrayPointer(ultimate)) + 1, depth);
adjustSize(depth + 1);
bool global = lower::symbolIsGlobal(sym);
layeredVarList[depth].emplace_back(sym, global, depth);
Expand Down Expand Up @@ -2002,6 +2007,10 @@ struct SymbolVisitor {
}
}
}
// - CrayPointer needs to be available whenever a CrayPointee is used.
if (symbol.GetUltimate().test(
Fortran::semantics::Symbol::Flag::CrayPointee))
visitSymbol(Fortran::semantics::GetCrayPointer(symbol));
}

template <typename A>
Expand Down
12 changes: 12 additions & 0 deletions flang/lib/Semantics/tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,18 @@ const Symbol &BypassGeneric(const Symbol &symbol) {
return symbol;
}

const Symbol &GetCrayPointer(const Symbol &crayPointee) {
const Symbol *found{nullptr};
for (const auto &[pointee, pointer] :
crayPointee.GetUltimate().owner().crayPointers()) {
if (pointee == crayPointee.name()) {
found = &pointer.get();
break;
}
}
return DEREF(found);
}

bool ExprHasTypeCategory(
const SomeExpr &expr, const common::TypeCategory &type) {
auto dynamicType{expr.GetType()};
Expand Down
Loading

0 comments on commit de7a50f

Please sign in to comment.