@@ -100,8 +100,10 @@ void QUIRToPulsePass::runOnOperation() {
100
100
moduleOp->walk ([&](CallCircuitOp callCircOp) {
101
101
if (isa<CircuitOp>(callCircOp->getParentOp ()))
102
102
return ;
103
+
103
104
auto convertedPulseCallSequenceOp =
104
105
convertCircuitToSequence (callCircOp, mainFunc, moduleOp);
106
+
105
107
if (!callCircOp->use_empty ())
106
108
callCircOp->replaceAllUsesWith (convertedPulseCallSequenceOp);
107
109
callCircOp->erase ();
@@ -229,8 +231,9 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
229
231
auto *newDelayCyclesOp = builder.clone (*quirOp);
230
232
newDelayCyclesOp->moveAfter (callCircuitOp);
231
233
} else
232
- assert (((isa<quir::ConstantOp>(quirOp) or isa<quir::ReturnOp>(quirOp) or
233
- isa<quir::CircuitOp>(quirOp))) &&
234
+ assert (((isa<quir::ConstantOp>(quirOp) ||
235
+ isa<qcs::ParameterLoadOp>(quirOp) ||
236
+ isa<quir::ReturnOp>(quirOp) || isa<quir::CircuitOp>(quirOp))) &&
234
237
" quir op is not allowed in this pass." );
235
238
});
236
239
@@ -251,6 +254,7 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
251
254
convertedPulseSequenceOp,
252
255
convertedPulseSequenceOpArgs);
253
256
convertedPulseCallSequenceOp->moveAfter (callCircuitOp);
257
+
254
258
return convertedPulseCallSequenceOp;
255
259
}
256
260
@@ -286,7 +290,7 @@ void QUIRToPulsePass::processCircuitArgs(
286
290
} else if (argumentType.isa <mlir::quir::QubitType>()) {
287
291
auto *qubitOp = callCircuitOp.getOperand (cnt).getDefiningOp ();
288
292
} else
289
- llvm_unreachable (" unkown circuit argument." );
293
+ llvm_unreachable (" unknown circuit argument." );
290
294
}
291
295
}
292
296
@@ -339,7 +343,7 @@ void QUIRToPulsePass::processPulseCalArgs(
339
343
} else if (argumentType.isa <FloatType>()) {
340
344
assert (argAttr[index].dyn_cast <StringAttr>().getValue ().str () ==
341
345
" angle" &&
342
- " unkown argument." );
346
+ " unknown argument." );
343
347
assert (angleOperands.size () && " no angle operand found." );
344
348
auto nextAngle = angleOperands.front ();
345
349
LLVM_DEBUG (llvm::dbgs () << " angle argument " );
@@ -350,7 +354,7 @@ void QUIRToPulsePass::processPulseCalArgs(
350
354
} else if (argumentType.isa <IntegerType>()) {
351
355
assert (argAttr[index].dyn_cast <StringAttr>().getValue ().str () ==
352
356
" duration" &&
353
- " unkown argument." );
357
+ " unknown argument." );
354
358
assert (durationOperands.size () && " no duration operand found." );
355
359
auto nextDuration = durationOperands.front ();
356
360
LLVM_DEBUG (llvm::dbgs () << " duration argument " );
@@ -359,7 +363,7 @@ void QUIRToPulsePass::processPulseCalArgs(
359
363
pulseCalSequenceArgs, builder);
360
364
durationOperands.pop ();
361
365
} else
362
- llvm_unreachable (" unkown argument type." );
366
+ llvm_unreachable (" unknown argument type." );
363
367
}
364
368
}
365
369
@@ -379,12 +383,13 @@ void QUIRToPulsePass::getQUIROpClassicalOperands(
379
383
}
380
384
381
385
for (auto operand : classicalOperands)
382
- if (operand.getType ().isa <mlir::quir::AngleType>())
386
+ if (operand.getType ().isa <mlir::quir::AngleType>() ||
387
+ operand.getType ().isa <FloatType>())
383
388
angleOperands.push (operand);
384
389
else if (operand.getType ().isa <mlir::quir::DurationType>())
385
390
durationOperands.push (operand);
386
391
else
387
- llvm_unreachable (" unkown operand." );
392
+ llvm_unreachable (" unknown operand." );
388
393
}
389
394
390
395
void QUIRToPulsePass::processMixFrameOpArg (
@@ -463,21 +468,38 @@ void QUIRToPulsePass::processAngleArg(Value nextAngleOperand,
463
468
pulseCalSequenceArgs.push_back (
464
469
convertedPulseSequenceOp
465
470
.getArguments ()[circuitArgToConvertedSequenceArgMap[circNum]]);
466
- } else {
467
- auto angleOp = nextAngleOperand.getDefiningOp <mlir::quir::ConstantOp>();
468
- std::string const angleLocHash =
469
- std::to_string (mlir::hash_value (angleOp->getLoc ()));
470
- if (classicalQUIROpLocToConvertedPulseOpMap.find (angleLocHash) ==
471
+ } else if (auto angleOp =
472
+ nextAngleOperand.getDefiningOp <mlir::quir::ConstantOp>()) {
473
+ auto *op = angleOp.getOperation ();
474
+ if (classicalQUIROpLocToConvertedPulseOpMap.find (op) ==
471
475
classicalQUIROpLocToConvertedPulseOpMap.end ()) {
472
476
double const angleVal =
473
477
angleOp.getAngleValueFromConstant ().convertToDouble ();
474
478
auto f64Angle = entryBuilder.create <mlir::arith::ConstantOp>(
475
479
angleOp.getLoc (), entryBuilder.getFloatAttr (entryBuilder.getF64Type (),
476
480
llvm::APFloat (angleVal)));
477
- classicalQUIROpLocToConvertedPulseOpMap[angleLocHash ] = f64Angle;
481
+ classicalQUIROpLocToConvertedPulseOpMap[op ] = f64Angle;
478
482
}
479
- pulseCalSequenceArgs.push_back (
480
- classicalQUIROpLocToConvertedPulseOpMap[angleLocHash]);
483
+ pulseCalSequenceArgs.push_back (classicalQUIROpLocToConvertedPulseOpMap[op]);
484
+ } else if (auto paramOp =
485
+ nextAngleOperand.getDefiningOp <mlir::qcs::ParameterLoadOp>()) {
486
+ auto *op = paramOp.getOperation ();
487
+ if (classicalQUIROpLocToConvertedPulseOpMap.find (op) ==
488
+ classicalQUIROpLocToConvertedPulseOpMap.end ()) {
489
+
490
+ auto newParam = entryBuilder.create <qcs::ParameterLoadOp>(
491
+ paramOp->getLoc (), entryBuilder.getF64Type (),
492
+ paramOp.getParameterName ());
493
+ if (paramOp->hasAttr (" initialValue" )) {
494
+ auto initAttr = paramOp->getAttr (" initialValue" ).dyn_cast <FloatAttr>();
495
+ if (initAttr)
496
+ newParam->setAttr (" initialValue" , initAttr);
497
+ }
498
+
499
+ classicalQUIROpLocToConvertedPulseOpMap[op] = newParam;
500
+ }
501
+
502
+ pulseCalSequenceArgs.push_back (classicalQUIROpLocToConvertedPulseOpMap[op]);
481
503
}
482
504
}
483
505
@@ -501,25 +523,23 @@ void QUIRToPulsePass::processDurationArg(
501
523
TimeUnits::dt &&
502
524
" this pass only accepts durations with dt unit" );
503
525
504
- if (classicalQUIROpLocToConvertedPulseOpMap.find (durLocHash) ==
526
+ auto *op = durationOp.getOperation ();
527
+ if (classicalQUIROpLocToConvertedPulseOpMap.find (op) ==
505
528
classicalQUIROpLocToConvertedPulseOpMap.end ()) {
506
529
auto dur64 = entryBuilder.create <mlir::arith::ConstantOp>(
507
530
durationOp.getLoc (),
508
531
entryBuilder.getIntegerAttr (entryBuilder.getI64Type (),
509
532
uint64_t (durVal)));
510
- classicalQUIROpLocToConvertedPulseOpMap[durLocHash ] = dur64;
533
+ classicalQUIROpLocToConvertedPulseOpMap[op ] = dur64;
511
534
}
512
- pulseCalSequenceArgs.push_back (
513
- classicalQUIROpLocToConvertedPulseOpMap[durLocHash]);
535
+ pulseCalSequenceArgs.push_back (classicalQUIROpLocToConvertedPulseOpMap[op]);
514
536
}
515
537
}
516
538
517
539
mlir::Value QUIRToPulsePass::convertAngleToF64 (Operation *angleOp,
518
540
mlir::OpBuilder &builder) {
519
541
assert (angleOp && " angle op is null" );
520
- std::string const angleLocHash =
521
- std::to_string (mlir::hash_value (angleOp->getLoc ()));
522
- if (classicalQUIROpLocToConvertedPulseOpMap.find (angleLocHash) ==
542
+ if (classicalQUIROpLocToConvertedPulseOpMap.find (angleOp) ==
523
543
classicalQUIROpLocToConvertedPulseOpMap.end ()) {
524
544
if (auto castOp = dyn_cast<quir::ConstantOp>(angleOp)) {
525
545
double const angleVal =
@@ -528,41 +548,46 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp,
528
548
castOp->getLoc (),
529
549
builder.getFloatAttr (builder.getF64Type (), llvm::APFloat (angleVal)));
530
550
f64Angle->moveAfter (castOp);
531
- classicalQUIROpLocToConvertedPulseOpMap[angleLocHash ] = f64Angle;
551
+ classicalQUIROpLocToConvertedPulseOpMap[angleOp ] = f64Angle;
532
552
} else if (auto castOp = dyn_cast<qcs::ParameterLoadOp>(angleOp)) {
533
- auto angleCastedOp = builder.create <oq3::CastOp>(
534
- castOp->getLoc (), builder.getF64Type (), castOp.getRes ());
535
- angleCastedOp->moveAfter (castOp);
536
- classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = angleCastedOp;
553
+ // Just convert to an f64 directly
554
+ auto newParam = builder.create <qcs::ParameterLoadOp>(
555
+ angleOp->getLoc (), builder.getF64Type (), castOp.getParameterName ());
556
+ if (castOp->hasAttr (" initialValue" )) {
557
+ auto initAttr = castOp->getAttr (" initialValue" ).dyn_cast <FloatAttr>();
558
+ if (initAttr)
559
+ newParam->setAttr (" initialValue" , initAttr);
560
+ }
561
+ newParam->moveAfter (castOp);
562
+
563
+ classicalQUIROpLocToConvertedPulseOpMap[angleOp] = newParam;
537
564
} else if (auto castOp = dyn_cast<oq3::CastOp>(angleOp)) {
538
565
auto castOpArg = castOp.getArg ();
539
566
if (auto paramCastOp =
540
567
dyn_cast<qcs::ParameterLoadOp>(castOpArg.getDefiningOp ())) {
541
568
auto angleCastedOp = builder.create <oq3::CastOp>(
542
569
paramCastOp->getLoc (), builder.getF64Type (), paramCastOp.getRes ());
543
570
angleCastedOp->moveAfter (paramCastOp);
544
- classicalQUIROpLocToConvertedPulseOpMap[angleLocHash ] = angleCastedOp;
571
+ classicalQUIROpLocToConvertedPulseOpMap[angleOp ] = angleCastedOp;
545
572
} else if (auto constOp =
546
573
dyn_cast<arith::ConstantOp>(castOpArg.getDefiningOp ())) {
547
574
// if cast from float64 then use directly
548
575
assert (constOp.getType () == builder.getF64Type () &&
549
576
" expected angle type to be float 64" );
550
- classicalQUIROpLocToConvertedPulseOpMap[angleLocHash ] = constOp;
577
+ classicalQUIROpLocToConvertedPulseOpMap[angleOp ] = constOp;
551
578
} else
552
579
llvm_unreachable (" castOp arg unknown" );
553
580
} else
554
581
llvm_unreachable (" angleOp unknown" );
555
582
}
556
- return classicalQUIROpLocToConvertedPulseOpMap[angleLocHash ];
583
+ return classicalQUIROpLocToConvertedPulseOpMap[angleOp ];
557
584
}
558
585
559
586
mlir::Value QUIRToPulsePass::convertDurationToI64 (
560
587
mlir::quir::CallCircuitOp &callCircuitOp, Operation *durationOp, uint &cnt,
561
588
mlir::OpBuilder &builder, mlir::func::FuncOp &mainFunc) {
562
589
assert (durationOp && " duration op is null" );
563
- std::string const durLocHash =
564
- std::to_string (mlir::hash_value (durationOp->getLoc ()));
565
- if (classicalQUIROpLocToConvertedPulseOpMap.find (durLocHash) ==
590
+ if (classicalQUIROpLocToConvertedPulseOpMap.find (durationOp) ==
566
591
classicalQUIROpLocToConvertedPulseOpMap.end ()) {
567
592
if (auto castOp = dyn_cast<quir::ConstantOp>(durationOp)) {
568
593
auto durVal =
@@ -575,11 +600,11 @@ mlir::Value QUIRToPulsePass::convertDurationToI64(
575
600
castOp->getLoc (),
576
601
builder.getIntegerAttr (builder.getI64Type (), uint64_t (durVal)));
577
602
I64Dur->moveAfter (castOp);
578
- classicalQUIROpLocToConvertedPulseOpMap[durLocHash ] = I64Dur;
603
+ classicalQUIROpLocToConvertedPulseOpMap[durationOp ] = I64Dur;
579
604
} else
580
- llvm_unreachable (" unkown duration op" );
605
+ llvm_unreachable (" unknown duration op" );
581
606
}
582
- return classicalQUIROpLocToConvertedPulseOpMap[durLocHash ];
607
+ return classicalQUIROpLocToConvertedPulseOpMap[durationOp ];
583
608
}
584
609
585
610
mlir::pulse::Port_CreateOp
0 commit comments