Skip to content

[Backport 3.10]Update parameter/constant op handling for performance (#335) #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conan/qasm/conandata.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
sources:
hash: "f6d695fd9f18462e65f6290d05ccb4ccb371b288"
hash: "ec7731bf645240a597cd9ebb2c395b114f155ed2"
requirements:
- "gmp/6.3.0"
- "mpfr/4.1.0"
Expand Down
2 changes: 1 addition & 1 deletion conan/qasm/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class QasmConan(ConanFile):
name = "qasm"
version = "0.3.2"
version = "0.3.3"
url = "https://github.com/openqasm/qe-qasm.git"
settings = "os", "compiler", "build_type", "arch"
options = {"shared": [True, False], "examples": [True, False]}
Expand Down
2 changes: 1 addition & 1 deletion conandata.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ requirements:
- pybind11/2.11.1
- clang-tools-extra/17.0.5-0@
- llvm/17.0.5-0@
- qasm/0.3.2@qss/stable
- qasm/0.3.3@qss/stable
2 changes: 1 addition & 1 deletion include/Conversion/QUIRToPulse/QUIRToPulse.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ struct QUIRToPulsePass
mlir::func::FuncOp &mainFunc);
// map of the hashed location of quir angle/duration ops to their converted
// pulse ops
std::unordered_map<std::string, mlir::Value>
std::unordered_map<Operation *, mlir::Value>
classicalQUIROpLocToConvertedPulseOpMap;

// port name to Port_CreateOp map
Expand Down
14 changes: 8 additions & 6 deletions include/Dialect/QUIR/Transforms/ExtractCircuits.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
#include "Utils/SymbolCacheAnalysis.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

#include "llvm/ADT/SmallVector.h"

#include <set>
#include <unordered_map>

namespace mlir::quir {
Expand All @@ -49,14 +49,14 @@ struct ExtractCircuitsPass
OpBuilder circuitBuilder);
OpBuilder startCircuit(mlir::Location location, OpBuilder topLevelBuilder);
void endCircuit(mlir::Operation *firstOp, mlir::Operation *lastOp,
OpBuilder topLevelBuilder, OpBuilder circuitBuilder,
llvm::SmallVector<Operation *> &eraseList);
void addToCircuit(mlir::Operation *currentOp, OpBuilder circuitBuilder,
llvm::SmallVector<Operation *> &eraseList);
OpBuilder topLevelBuilder, OpBuilder circuitBuilder);
void addToCircuit(mlir::Operation *currentOp, OpBuilder circuitBuilder);

uint64_t circuitCount = 0;
qssc::utils::SymbolCacheAnalysis *symbolCache{nullptr};

mlir::quir::CircuitOp currentCircuitOp = nullptr;
mlir::IRMapping currentCircuitMapper;
mlir::quir::CallCircuitOp newCallCircuitOp;

llvm::SmallVector<Type> inputTypes;
Expand All @@ -68,6 +68,8 @@ struct ExtractCircuitsPass

std::unordered_map<Operation *, uint32_t> circuitOperands;
llvm::SmallVector<OpResult> originalResults;
std::set<Operation *> eraseConstSet;
std::set<Operation *> eraseOpSet;

}; // struct ExtractCircuitsPass
} // namespace mlir::quir
Expand Down
3 changes: 3 additions & 0 deletions include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ class QUIRGenQASM3Visitor : public BaseQASM3Visitor {
mlir::Type getQUIRTypeFromDeclaration(const QASM::ASTDeclarationNode *);

bool enableParametersWarningEmitted = false;

/// Cached dummy value for error handling
mlir::Value voidValue;
};

} // namespace qssc::frontend::openqasm3
Expand Down
2 changes: 1 addition & 1 deletion include/Frontend/OpenQASM3/QUIRVariableBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class QUIRVariableBuilder {

mlir::Value generateParameterLoad(mlir::Location location,
llvm::StringRef variableName,
mlir::Value assignedValue);
double initialValue);

/// Generate code for declaring an array (at the builder's current insertion
/// point).
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/QUIRToPulse/LoadPulseCals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void LoadPulseCalsPass::loadPulseCals(CallCircuitOp callCircuitOp,
LLVM_DEBUG(llvm::dbgs() << "no pulse cal loading needed for " << op);
assert((!op->hasTrait<mlir::quir::UnitaryOp>() and
!op->hasTrait<mlir::quir::CPTPOp>()) &&
"unkown operation");
"unknown operation");
}
});
}
Expand Down
99 changes: 62 additions & 37 deletions lib/Conversion/QUIRToPulse/QUIRToPulse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ void QUIRToPulsePass::runOnOperation() {
moduleOp->walk([&](CallCircuitOp callCircOp) {
if (isa<CircuitOp>(callCircOp->getParentOp()))
return;

auto convertedPulseCallSequenceOp =
convertCircuitToSequence(callCircOp, mainFunc, moduleOp);

if (!callCircOp->use_empty())
callCircOp->replaceAllUsesWith(convertedPulseCallSequenceOp);
callCircOp->erase();
Expand Down Expand Up @@ -229,8 +231,9 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
auto *newDelayCyclesOp = builder.clone(*quirOp);
newDelayCyclesOp->moveAfter(callCircuitOp);
} else
assert(((isa<quir::ConstantOp>(quirOp) or isa<quir::ReturnOp>(quirOp) or
isa<quir::CircuitOp>(quirOp))) &&
assert(((isa<quir::ConstantOp>(quirOp) ||
isa<qcs::ParameterLoadOp>(quirOp) ||
isa<quir::ReturnOp>(quirOp) || isa<quir::CircuitOp>(quirOp))) &&
"quir op is not allowed in this pass.");
});

Expand All @@ -251,6 +254,7 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
convertedPulseSequenceOp,
convertedPulseSequenceOpArgs);
convertedPulseCallSequenceOp->moveAfter(callCircuitOp);

return convertedPulseCallSequenceOp;
}

Expand Down Expand Up @@ -286,7 +290,7 @@ void QUIRToPulsePass::processCircuitArgs(
} else if (argumentType.isa<mlir::quir::QubitType>()) {
auto *qubitOp = callCircuitOp.getOperand(cnt).getDefiningOp();
} else
llvm_unreachable("unkown circuit argument.");
llvm_unreachable("unknown circuit argument.");
}
}

Expand Down Expand Up @@ -339,7 +343,7 @@ void QUIRToPulsePass::processPulseCalArgs(
} else if (argumentType.isa<FloatType>()) {
assert(argAttr[index].dyn_cast<StringAttr>().getValue().str() ==
"angle" &&
"unkown argument.");
"unknown argument.");
assert(angleOperands.size() && "no angle operand found.");
auto nextAngle = angleOperands.front();
LLVM_DEBUG(llvm::dbgs() << "angle argument ");
Expand All @@ -350,7 +354,7 @@ void QUIRToPulsePass::processPulseCalArgs(
} else if (argumentType.isa<IntegerType>()) {
assert(argAttr[index].dyn_cast<StringAttr>().getValue().str() ==
"duration" &&
"unkown argument.");
"unknown argument.");
assert(durationOperands.size() && "no duration operand found.");
auto nextDuration = durationOperands.front();
LLVM_DEBUG(llvm::dbgs() << "duration argument ");
Expand All @@ -359,7 +363,7 @@ void QUIRToPulsePass::processPulseCalArgs(
pulseCalSequenceArgs, builder);
durationOperands.pop();
} else
llvm_unreachable("unkown argument type.");
llvm_unreachable("unknown argument type.");
}
}

Expand All @@ -379,12 +383,13 @@ void QUIRToPulsePass::getQUIROpClassicalOperands(
}

for (auto operand : classicalOperands)
if (operand.getType().isa<mlir::quir::AngleType>())
if (operand.getType().isa<mlir::quir::AngleType>() ||
operand.getType().isa<FloatType>())
angleOperands.push(operand);
else if (operand.getType().isa<mlir::quir::DurationType>())
durationOperands.push(operand);
else
llvm_unreachable("unkown operand.");
llvm_unreachable("unknown operand.");
}

void QUIRToPulsePass::processMixFrameOpArg(
Expand Down Expand Up @@ -463,21 +468,38 @@ void QUIRToPulsePass::processAngleArg(Value nextAngleOperand,
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp
.getArguments()[circuitArgToConvertedSequenceArgMap[circNum]]);
} else {
auto angleOp = nextAngleOperand.getDefiningOp<mlir::quir::ConstantOp>();
std::string const angleLocHash =
std::to_string(mlir::hash_value(angleOp->getLoc()));
if (classicalQUIROpLocToConvertedPulseOpMap.find(angleLocHash) ==
} else if (auto angleOp =
nextAngleOperand.getDefiningOp<mlir::quir::ConstantOp>()) {
auto *op = angleOp.getOperation();
if (classicalQUIROpLocToConvertedPulseOpMap.find(op) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {
double const angleVal =
angleOp.getAngleValueFromConstant().convertToDouble();
auto f64Angle = entryBuilder.create<mlir::arith::ConstantOp>(
angleOp.getLoc(), entryBuilder.getFloatAttr(entryBuilder.getF64Type(),
llvm::APFloat(angleVal)));
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = f64Angle;
classicalQUIROpLocToConvertedPulseOpMap[op] = f64Angle;
}
pulseCalSequenceArgs.push_back(
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash]);
pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]);
} else if (auto paramOp =
nextAngleOperand.getDefiningOp<mlir::qcs::ParameterLoadOp>()) {
auto *op = paramOp.getOperation();
if (classicalQUIROpLocToConvertedPulseOpMap.find(op) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {

auto newParam = entryBuilder.create<qcs::ParameterLoadOp>(
paramOp->getLoc(), entryBuilder.getF64Type(),
paramOp.getParameterName());
if (paramOp->hasAttr("initialValue")) {
auto initAttr = paramOp->getAttr("initialValue").dyn_cast<FloatAttr>();
if (initAttr)
newParam->setAttr("initialValue", initAttr);
}

classicalQUIROpLocToConvertedPulseOpMap[op] = newParam;
}

pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]);
}
}

Expand All @@ -501,25 +523,23 @@ void QUIRToPulsePass::processDurationArg(
TimeUnits::dt &&
"this pass only accepts durations with dt unit");

if (classicalQUIROpLocToConvertedPulseOpMap.find(durLocHash) ==
auto *op = durationOp.getOperation();
if (classicalQUIROpLocToConvertedPulseOpMap.find(op) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {
auto dur64 = entryBuilder.create<mlir::arith::ConstantOp>(
durationOp.getLoc(),
entryBuilder.getIntegerAttr(entryBuilder.getI64Type(),
uint64_t(durVal)));
classicalQUIROpLocToConvertedPulseOpMap[durLocHash] = dur64;
classicalQUIROpLocToConvertedPulseOpMap[op] = dur64;
}
pulseCalSequenceArgs.push_back(
classicalQUIROpLocToConvertedPulseOpMap[durLocHash]);
pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]);
}
}

mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp,
mlir::OpBuilder &builder) {
assert(angleOp && "angle op is null");
std::string const angleLocHash =
std::to_string(mlir::hash_value(angleOp->getLoc()));
if (classicalQUIROpLocToConvertedPulseOpMap.find(angleLocHash) ==
if (classicalQUIROpLocToConvertedPulseOpMap.find(angleOp) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {
if (auto castOp = dyn_cast<quir::ConstantOp>(angleOp)) {
double const angleVal =
Expand All @@ -528,41 +548,46 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp,
castOp->getLoc(),
builder.getFloatAttr(builder.getF64Type(), llvm::APFloat(angleVal)));
f64Angle->moveAfter(castOp);
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = f64Angle;
classicalQUIROpLocToConvertedPulseOpMap[angleOp] = f64Angle;
} else if (auto castOp = dyn_cast<qcs::ParameterLoadOp>(angleOp)) {
auto angleCastedOp = builder.create<oq3::CastOp>(
castOp->getLoc(), builder.getF64Type(), castOp.getRes());
angleCastedOp->moveAfter(castOp);
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = angleCastedOp;
// Just convert to an f64 directly
auto newParam = builder.create<qcs::ParameterLoadOp>(
angleOp->getLoc(), builder.getF64Type(), castOp.getParameterName());
if (castOp->hasAttr("initialValue")) {
auto initAttr = castOp->getAttr("initialValue").dyn_cast<FloatAttr>();
if (initAttr)
newParam->setAttr("initialValue", initAttr);
}
newParam->moveAfter(castOp);

classicalQUIROpLocToConvertedPulseOpMap[angleOp] = newParam;
} else if (auto castOp = dyn_cast<oq3::CastOp>(angleOp)) {
auto castOpArg = castOp.getArg();
if (auto paramCastOp =
dyn_cast<qcs::ParameterLoadOp>(castOpArg.getDefiningOp())) {
auto angleCastedOp = builder.create<oq3::CastOp>(
paramCastOp->getLoc(), builder.getF64Type(), paramCastOp.getRes());
angleCastedOp->moveAfter(paramCastOp);
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = angleCastedOp;
classicalQUIROpLocToConvertedPulseOpMap[angleOp] = angleCastedOp;
} else if (auto constOp =
dyn_cast<arith::ConstantOp>(castOpArg.getDefiningOp())) {
// if cast from float64 then use directly
assert(constOp.getType() == builder.getF64Type() &&
"expected angle type to be float 64");
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = constOp;
classicalQUIROpLocToConvertedPulseOpMap[angleOp] = constOp;
} else
llvm_unreachable("castOp arg unknown");
} else
llvm_unreachable("angleOp unknown");
}
return classicalQUIROpLocToConvertedPulseOpMap[angleLocHash];
return classicalQUIROpLocToConvertedPulseOpMap[angleOp];
}

mlir::Value QUIRToPulsePass::convertDurationToI64(
mlir::quir::CallCircuitOp &callCircuitOp, Operation *durationOp, uint &cnt,
mlir::OpBuilder &builder, mlir::func::FuncOp &mainFunc) {
assert(durationOp && "duration op is null");
std::string const durLocHash =
std::to_string(mlir::hash_value(durationOp->getLoc()));
if (classicalQUIROpLocToConvertedPulseOpMap.find(durLocHash) ==
if (classicalQUIROpLocToConvertedPulseOpMap.find(durationOp) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {
if (auto castOp = dyn_cast<quir::ConstantOp>(durationOp)) {
auto durVal =
Expand All @@ -575,11 +600,11 @@ mlir::Value QUIRToPulsePass::convertDurationToI64(
castOp->getLoc(),
builder.getIntegerAttr(builder.getI64Type(), uint64_t(durVal)));
I64Dur->moveAfter(castOp);
classicalQUIROpLocToConvertedPulseOpMap[durLocHash] = I64Dur;
classicalQUIROpLocToConvertedPulseOpMap[durationOp] = I64Dur;
} else
llvm_unreachable("unkown duration op");
llvm_unreachable("unknown duration op");
}
return classicalQUIROpLocToConvertedPulseOpMap[durLocHash];
return classicalQUIROpLocToConvertedPulseOpMap[durationOp];
}

mlir::pulse::Port_CreateOp
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/Pulse/IR/PulseOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "Dialect/Pulse/IR/PulseOps.h"

#include "Dialect/Pulse/IR/PulseTraits.h"
#include "Dialect/QCS/IR/QCSOps.h"
#include "Dialect/QUIR/IR/QUIROps.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -356,8 +357,9 @@ LogicalResult verifyClassical_(SequenceOp op) {
mlir::Operation *classicalOp = nullptr;
WalkResult const result = op->walk([&](Operation *subOp) {
if (isa<mlir::arith::ConstantOp>(subOp) || isa<quir::ConstantOp>(subOp) ||
isa<CallSequenceOp>(subOp) || isa<pulse::ReturnOp>(subOp) ||
isa<SequenceOp>(subOp) || isa<mlir::complex::CreateOp>(subOp) ||
isa<qcs::ParameterLoadOp>(subOp) || isa<CallSequenceOp>(subOp) ||
isa<pulse::ReturnOp>(subOp) || isa<SequenceOp>(subOp) ||
isa<mlir::complex::CreateOp>(subOp) ||
subOp->hasTrait<mlir::pulse::SequenceAllowed>() ||
subOp->hasTrait<mlir::pulse::SequenceRequired>())
return WalkResult::advance();
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Pulse/Transforms/Scheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ void QuantumCircuitPulseSchedulingPass::scheduleAlap(
opEnd = quantumCircuitSequenceOpBlock->rend();
opIt != opEnd; ++opIt) {
auto &op = *opIt;

if (auto quantumGateCallSequenceOp =
dyn_cast<mlir::pulse::CallSequenceOp>(op)) {
// find quantum gate SequenceOp
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/QUIR/IR/QUIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,9 @@ LogicalResult verifyClassical_(CircuitOp op) {
mlir::Operation *classicalOp = nullptr;
WalkResult const result = op->walk([&](Operation *subOp) {
if (isa<mlir::arith::ConstantOp>(subOp) || isa<quir::ConstantOp>(subOp) ||
isa<CallCircuitOp>(subOp) || isa<quir::ReturnOp>(subOp) ||
isa<CircuitOp>(subOp) || subOp->hasTrait<mlir::quir::UnitaryOp>() ||
isa<qcs::ParameterLoadOp>(subOp) || isa<CallCircuitOp>(subOp) ||
isa<quir::ReturnOp>(subOp) || isa<CircuitOp>(subOp) ||
subOp->hasTrait<mlir::quir::UnitaryOp>() ||
subOp->hasTrait<mlir::quir::CPTPOp>())
return WalkResult::advance();
classicalOp = subOp;
Expand Down
Loading
Loading