diff --git a/src/Simulation/Simulators/ResourcesEstimator/ResourcesEstimatorWithAdditionalPrimitiveOperations.cs b/src/Simulation/Simulators/ResourcesEstimator/ResourcesEstimatorWithAdditionalPrimitiveOperations.cs new file mode 100644 index 00000000000..8b36cf4b21b --- /dev/null +++ b/src/Simulation/Simulators/ResourcesEstimator/ResourcesEstimatorWithAdditionalPrimitiveOperations.cs @@ -0,0 +1,92 @@ +#nullable enable + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.Quantum.Simulation.Core; +using Microsoft.Quantum.Simulation.QCTraceSimulatorRuntime; +using Microsoft.Quantum.Simulation.Simulators; +using Microsoft.Quantum.Simulation.Simulators.QCTraceSimulators; + +namespace Simulator +{ + public class ResourcesEstimatorWithAdditionalPrimitiveOperations : ResourcesEstimator + { + public ResourcesEstimatorWithAdditionalPrimitiveOperations() : this(ResourcesEstimator.RecommendedConfig()) + { + } + + public ResourcesEstimatorWithAdditionalPrimitiveOperations(QCTraceSimulatorConfiguration config) : base(WithoutPrimitiveOperationsCounter(config)) + { + } + + private static QCTraceSimulatorConfiguration WithoutPrimitiveOperationsCounter(QCTraceSimulatorConfiguration config) + { + config.UsePrimitiveOperationsCounter = false; + return config; + } + + protected virtual IDictionary>? AdditionalOperations { get; } + + protected override void InitializeQCTracerCoreListeners(IList listeners) + { + base.InitializeQCTracerCoreListeners(listeners); + + // add custom primitive operations listener + var primitiveOperationsIdToNames = new Dictionary(); + Utils.FillDictionaryForEnumNames(primitiveOperationsIdToNames); + + var operationNameToId = new Dictionary(); + + if (AdditionalOperations != null) + { + foreach (var name in AdditionalOperations.Keys) + { + var id = primitiveOperationsIdToNames.Count; + operationNameToId[name] = id; + primitiveOperationsIdToNames.Add(id, name); + } + } + + var cfg = new PrimitiveOperationsCounterConfiguration { primitiveOperationsNames = primitiveOperationsIdToNames.Values.ToArray() }; + var operationsCounter = new PrimitiveOperationsCounter(cfg); + tCoreConfig.Listeners.Add(operationsCounter); + + if (AdditionalOperations != null) + { + var compare = new AssignableTypeComparer(); + this.OnOperationStart += (callable, data) => { + var unwrapped = callable.UnwrapCallable(); + foreach (var (name, types) in AdditionalOperations) + { + if (types.Contains(unwrapped.GetType(), compare)) + { + var adjName = $"Adjoint{name}"; + + var key = (callable.Variant == OperationFunctor.Adjoint || callable.Variant == OperationFunctor.ControlledAdjoint) && AdditionalOperations.ContainsKey(adjName) + ? adjName + : name; + + operationsCounter.OnPrimitiveOperation(operationNameToId[key], new object[] { }, 0.0); + break; + } + } + }; + } + } + + private class AssignableTypeComparer : IEqualityComparer + { + public bool Equals([AllowNull] Type x, [AllowNull] Type y) + { + return x != null && x.IsAssignableFrom(y); + } + + public int GetHashCode([DisallowNull] Type obj) + { + return obj.GetHashCode(); + } + } + } +} diff --git a/src/Simulation/Simulators/ResourcesEstimator/RuntimeCounter.cs b/src/Simulation/Simulators/ResourcesEstimator/RuntimeCounter.cs new file mode 100644 index 00000000000..060dc9967f8 --- /dev/null +++ b/src/Simulation/Simulators/ResourcesEstimator/RuntimeCounter.cs @@ -0,0 +1,62 @@ +#nullable enable + +using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.Quantum.Simulation.Core; +using Microsoft.Quantum.Simulation.QCTraceSimulatorRuntime; + +namespace Simulator +{ + public class RuntimeCounter : IQCTraceSimulatorListener, ICallGraphStatistics + { + public RuntimeCounter() + { + AddToCallStack(CallGraphEdge.CallGraphRootHashed, OperationFunctor.Body); + stats = new StatisticsCollector( + new [] { "Runtime" }, + StatisticsCollector.DefaultStatistics() + ); + } + + public bool NeedsTracingDataInQubits => false; + + public object? NewTracingData(long qubitId) => null; + + public void OnAllocate(object[] qubitsTraceData) {} + + public void OnRelease(object[] qubitsTraceData) {} + + public void OnBorrow(object[] qubitsTraceData, long newQubitsAllocated) {} + + public void OnReturn(object[] qubitsTraceData, long qubitReleased) {} + + public void OnOperationStart(HashedString name, OperationFunctor variant, object[] qubitsTraceData) + { + AddToCallStack(name, variant); + operationCallStack.Peek().Watch.Start(); + } + + public void OnOperationEnd(object[] returnedQubitsTraceData) + { + var record = operationCallStack.Pop(); + record.Watch.Stop(); + Debug.Assert(operationCallStack.Count != 0, "Operation call stack must never get empty"); + stats.AddSample(new CallGraphEdge(record.OperationName, operationCallStack.Peek().OperationName, record.FunctorSpecialization, operationCallStack.Peek().FunctorSpecialization), new [] { (double)record.Watch.ElapsedMilliseconds }); + } + + public void OnPrimitiveOperation(int id, object[] qubitsTraceData, double primitiveOperationDuration) {} + + public IStatisticCollectorResults Results { get => stats as IStatisticCollectorResults; } + + private record OperationCallRecord(HashedString OperationName, OperationFunctor FunctorSpecialization) + { + public Stopwatch Watch { get; } = new(); + } + + private readonly Stack operationCallStack = new Stack(); + private readonly StatisticsCollector stats; + + private void AddToCallStack(HashedString operationName, OperationFunctor functorSpecialization) => + operationCallStack.Push(new OperationCallRecord(operationName, functorSpecialization)); + } +} diff --git a/src/Simulation/Simulators/ResourcesEstimator/Simulator.cs b/src/Simulation/Simulators/ResourcesEstimator/Simulator.cs new file mode 100755 index 00000000000..19afe36413d --- /dev/null +++ b/src/Simulation/Simulators/ResourcesEstimator/Simulator.cs @@ -0,0 +1,87 @@ +using System; +using System.Collections.Generic; +using System.Data; +using Microsoft.Quantum.Simulation.Simulators; +using Microsoft.Quantum.Simulation.Core; +using Microsoft.Quantum.Simulation.QCTraceSimulatorRuntime; +using Microsoft.Quantum.Simulation.Simulators.QCTraceSimulators; + +// using System.IO; +// using System.Threading.Tasks; + +namespace Simulator +{ + public class AdvancedSimulator : ResourcesEstimatorWithAdditionalPrimitiveOperations + { + // public override Task Run(I args) + // { + // var result = base.Run(args).Result; + // var name = typeof(T).Name; + // File.WriteAllText($"{name}.txt", ToTSV()); + // return Task.Run(() => result); + // } + + protected override IDictionary> AdditionalOperations { get; } = + new Dictionary> { + ["CCZ"] = new [] { typeof(Microsoft.Quantum.Simulation.Simulators.QCTraceSimulators.Circuits.CCZ) }, + ["And"] = new [] { typeof(Microsoft.Quantum.Canon.ApplyAnd), typeof(Microsoft.Quantum.Canon.ApplyLowDepthAnd) }, + ["AdjointAnd"] = Array.Empty() + }; + + protected override void InitializeQCTracerCoreListeners(IList listeners) + { + base.InitializeQCTracerCoreListeners(listeners); + tCoreConfig.Listeners.Add(new RuntimeCounter()); + } + + // CCNOT(a, b, c); + // T(a); + // T(b); + + // Original QDK ResEst. -> 9 Ts + // New QDK ResEst. -> 1 CCZ, 2 Ts + + public override DataTable Data + { + get + { + var data = base.Data; + + var androw = data.Rows.Find("And"); + var adjandrow = data.Rows.Find("AdjointAnd"); + var cczrow = data.Rows.Find("CCZ"); + var trow = data.Rows.Find("T"); + + // Update T count + trow["Sum"] = (double)trow["Sum"] - 4 * (double)androw["Sum"] - 7 * (double)cczrow["Sum"]; + trow["Max"] = (double)trow["Max"] - 4 * (double)androw["Max"] - 7 * (double)cczrow["Max"]; + + // TODO: update CNOT, QubitClifford, and Measure as well + + return data; + } + } + + #region Direct access to counts + public long CNOT => (long)(double)Data!.Rows!.Find("CNOT")![1]; + public long QubitClifford => (long)(double)Data!.Rows!.Find("QubitClifford")![1]; + public long T => (long)(double)Data!.Rows!.Find("T")![1]; + public long Measure => (long)(double)Data!.Rows!.Find("Measure")![1]; + public long QubitCount => (long)(double)Data!.Rows!.Find("QubitCount")![1]; + public long Depth => (long)(double)Data!.Rows!.Find("Depth")![1]; + public long CCZ => (long)(double)Data!.Rows!.Find("CCZ")![1]; + public long And => (long)(double)Data!.Rows!.Find("And")![1]; + public long AdjointAnd => (long)(double)Data!.Rows!.Find("AdjointAnd")![1]; + #endregion + + public override O Execute(I args) + { + var result = base.Execute(args); + Console.WriteLine(""); + Console.WriteLine("---BEGIN TABLE---"); + Console.WriteLine(ToTSV()); + Console.WriteLine("---END TABLE---"); + return result; + } + } +}