Skip to content

Commit 89c3655

Browse files
authored
Merge pull request #13 from BQSKit/residual-fix
Fixes, new gate, and 0.4.1
2 parents 3aaacf7 + 189416b commit 89c3655

File tree

6 files changed

+147
-26
lines changed

6 files changed

+147
-26
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ build-backend = "maturin"
88

99
[project]
1010
name = "bqskitrs"
11-
version = "0.4.0"
11+
version = "0.4.1"
1212
maintainers = [
1313
{name = "Ethan Smith", email = "[email protected]"},
1414
{name = "Ed Younis", email = "[email protected]"},

src/ir/gates/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub enum Gate {
3838
CRX(CRXGate),
3939
CRY(CRYGate),
4040
CRZ(CRZGate),
41+
RZSubGate(RZSubGate),
4142
VariableUnitary(VariableUnitaryGate),
4243
Dynamic(Arc<dyn DynGate + Send + Sync>),
4344
}
@@ -59,6 +60,7 @@ impl Unitary for Gate {
5960
Gate::CRX(_) => 1,
6061
Gate::CRY(_) => 1,
6162
Gate::CRZ(_) => 1,
63+
Gate::RZSubGate(_) => 1,
6264
Gate::VariableUnitary(v) => v.num_params(),
6365
Gate::Dynamic(d) => d.num_params(),
6466
}
@@ -80,6 +82,7 @@ impl Unitary for Gate {
8082
Gate::CRX(x) => x.get_utry(params, const_gates),
8183
Gate::CRY(y) => y.get_utry(params, const_gates),
8284
Gate::CRZ(z) => z.get_utry(params, const_gates),
85+
Gate::RZSubGate(z) => z.get_utry(params, const_gates),
8386
Gate::VariableUnitary(v) => v.get_utry(params, const_gates),
8487
Gate::Dynamic(d) => d.get_utry(params, const_gates),
8588
}
@@ -103,6 +106,7 @@ impl Gradient for Gate {
103106
Gate::CRX(x) => x.get_grad(params, const_gates),
104107
Gate::CRY(y) => y.get_grad(params, const_gates),
105108
Gate::CRZ(z) => z.get_grad(params, const_gates),
109+
Gate::RZSubGate(z) => z.get_grad(params, const_gates),
106110
Gate::VariableUnitary(v) => v.get_grad(params, const_gates),
107111
Gate::Dynamic(d) => d.get_grad(params, const_gates),
108112
}
@@ -128,6 +132,7 @@ impl Gradient for Gate {
128132
Gate::CRX(x) => x.get_utry_and_grad(params, const_gates),
129133
Gate::CRY(y) => y.get_utry_and_grad(params, const_gates),
130134
Gate::CRZ(z) => z.get_utry_and_grad(params, const_gates),
135+
Gate::RZSubGate(z) => z.get_utry_and_grad(params, const_gates),
131136
Gate::VariableUnitary(v) => v.get_utry_and_grad(params, const_gates),
132137
Gate::Dynamic(d) => d.get_utry_and_grad(params, const_gates),
133138
}
@@ -151,6 +156,7 @@ impl Size for Gate {
151156
Gate::CRX(_) => 2,
152157
Gate::CRY(_) => 2,
153158
Gate::CRZ(_) => 2,
159+
Gate::RZSubGate(_) => 1,
154160
Gate::VariableUnitary(v) => v.num_qudits(),
155161
Gate::Dynamic(d) => d.num_qudits(),
156162
}
@@ -174,6 +180,7 @@ impl Optimize for Gate {
174180
Gate::CRX(x) => x.optimize(env_matrix),
175181
Gate::CRY(y) => y.optimize(env_matrix),
176182
Gate::CRZ(z) => z.optimize(env_matrix),
183+
Gate::RZSubGate(z) => z.optimize(env_matrix),
177184
Gate::VariableUnitary(v) => v.optimize(env_matrix),
178185
Gate::Dynamic(d) => d.optimize(env_matrix),
179186
}

src/ir/gates/parameterized/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mod u2;
1212
mod u3;
1313
mod u8;
1414
mod variable;
15+
mod rzsub;
1516

1617
pub use self::u8::U8Gate;
1718
pub use crx::CRXGate;
@@ -26,4 +27,5 @@ pub use rzz::RZZGate;
2627
pub use u1::U1Gate;
2728
pub use u2::U2Gate;
2829
pub use u3::U3Gate;
30+
pub use rzsub::RZSubGate;
2931
pub use variable::VariableUnitaryGate;

src/ir/gates/parameterized/rzsub.rs

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
use std::f64::consts::PI;
2+
3+
use crate::i;
4+
use crate::ir::gates::utils::{rot_z, rot_z_jac};
5+
use crate::ir::gates::{Gradient, Size};
6+
use crate::ir::gates::{Optimize, Unitary};
7+
8+
use ndarray::{Array2, Array3, ArrayViewMut2};
9+
use ndarray_linalg::c64;
10+
11+
/// Arbitrary Y rotation single qubit gate
12+
#[derive(Copy, Clone, Debug, PartialEq, Default)]
13+
pub struct RZSubGate{
14+
radix: usize,
15+
level1: usize,
16+
level2: usize,
17+
}
18+
19+
impl RZSubGate {
20+
pub fn new(radix: usize, level1: usize, level2: usize) -> Self {
21+
RZSubGate {
22+
radix: radix,
23+
level1: level1,
24+
level2: level2,
25+
}
26+
}
27+
}
28+
29+
impl Unitary for RZSubGate {
30+
fn num_params(&self) -> usize {
31+
1
32+
}
33+
34+
fn get_utry(&self, params: &[f64], _constant_gates: &[Array2<c64>]) -> Array2<c64> {
35+
let pexp = i!(0.5 * params[0]).exp();
36+
let nexp = i!(-0.5 * params[0]).exp();
37+
38+
let mut unitary = Array2::eye(self.radix);
39+
unitary[[self.level1, self.level1]] = nexp;
40+
unitary[[self.level2, self.level2]] = pexp;
41+
unitary
42+
}
43+
}
44+
45+
impl Gradient for RZSubGate {
46+
fn get_grad(&self, params: &[f64], _const_gates: &[Array2<c64>]) -> Array3<c64> {
47+
let dpexp = i!(0.5) * i!(0.5 * params[0]).exp();
48+
let dnexp = i!(-0.5) * i!(-0.5 * params[0]).exp();
49+
50+
let mut grad = Array3::zeros((1, self.radix, self.radix));
51+
grad[[0, self.level1, self.level1]] = dnexp;
52+
grad[[0, self.level2, self.level2]] = dpexp;
53+
grad
54+
}
55+
56+
fn get_utry_and_grad(
57+
&self,
58+
params: &[f64],
59+
_const_gates: &[Array2<c64>],
60+
) -> (Array2<c64>, Array3<c64>) {
61+
let pexp = i!(0.5 * params[0]).exp();
62+
let nexp = i!(-0.5 * params[0]).exp();
63+
let dpexp = i!(0.5) * pexp;
64+
let dnexp = i!(-0.5) * nexp;
65+
66+
let mut unitary = Array2::eye(self.radix);
67+
unitary[[self.level1, self.level1]] = nexp;
68+
unitary[[self.level2, self.level2]] = pexp;
69+
70+
let mut grad = Array3::zeros((1, self.radix, self.radix));
71+
grad[[0, self.level1, self.level1]] = dnexp;
72+
grad[[0, self.level2, self.level2]] = dpexp;
73+
74+
(unitary, grad)
75+
}
76+
}
77+
78+
impl Size for RZSubGate {
79+
fn num_qudits(&self) -> usize {
80+
1
81+
}
82+
}
83+
84+
impl Optimize for RZSubGate {
85+
fn optimize(&self, env_matrix: ArrayViewMut2<c64>) -> Vec<f64> {
86+
unimplemented!()
87+
}
88+
}

src/python/circuit.rs

+43-22
Original file line numberDiff line numberDiff line change
@@ -36,37 +36,58 @@ fn pygate_to_native(pygate: &PyAny, constant_gates: &mut Vec<Array2<c64>>) -> Py
3636
"U2Gate" => Ok(U2Gate::new().into()),
3737
"U3Gate" => Ok(U3Gate::new().into()),
3838
"U8Gate" => Ok(U8Gate::new().into()),
39+
"EmbeddedGate" => {
40+
let egate = pygate.getattr("gate")?;
41+
let egate_cls = egate.getattr("__class__")?;
42+
let egate_dunder_name = egate_cls.getattr("__name__")?;
43+
let egate_name = egate_dunder_name.extract::<&str>()?;
44+
45+
if egate_name == "RZGate" {
46+
let level_maps = pygate.getattr("level_maps")?.extract::<Vec<Vec<usize>>>()?;
47+
let level1 = level_maps[0][0];
48+
let level2 = level_maps[0][1];
49+
let radix = pygate.getattr("dim")?.extract::<usize>()?;
50+
Ok(RZSubGate::new(radix, level1, level2).into())
51+
} else {
52+
extract_dynamic_gate(pygate, constant_gates, name)
53+
// TODO: Generalize
54+
}
55+
},
3956
"VariableUnitaryGate" => {
4057
let size = pygate.getattr("num_qudits")?.extract::<usize>()?;
4158
let radixes = pygate.getattr("radixes")?.extract::<Vec<usize>>()?;
4259
Ok(VariableUnitaryGate::new(size, radixes).into())
43-
}
60+
},
4461
_ => {
45-
if pygate.getattr("num_params")?.extract::<usize>()? == 0 {
46-
let args: Vec<f64> = vec![];
47-
let pyobj = pygate.call_method("get_unitary", (args,), None)?;
48-
let pymat = pyobj.getattr("numpy")?.extract::<&PyArray2<c64>>()?;
49-
let mat = pymat.to_owned_array();
50-
let gate_size = pygate.getattr("num_qudits")?.extract::<usize>()?;
51-
let index = constant_gates.len();
52-
constant_gates.push(mat);
53-
Ok(ConstantGate::new(index, gate_size).into())
54-
} else if pygate.hasattr("get_unitary")?
55-
&& ((pygate.hasattr("get_grad")? && pygate.hasattr("get_unitary_and_grad")?)
56-
|| pygate.hasattr("optimize")?)
57-
{
58-
let dynamic: Arc<dyn DynGate + Send + Sync> = Arc::new(PyGate::new(pygate.into()));
59-
Ok(Gate::Dynamic(dynamic))
60-
} else {
61-
Err(exceptions::PyValueError::new_err(format!(
62-
"Gate {} does not implement the necessary methods for optimization.",
63-
name
64-
)))
65-
}
62+
extract_dynamic_gate(pygate, constant_gates, name)
6663
}
6764
}
6865
}
6966

67+
fn extract_dynamic_gate(pygate: &PyAny, constant_gates: &mut Vec<Array2<c64>>, name: &str) -> Result<Gate, PyErr> {
68+
if pygate.getattr("num_params")?.extract::<usize>()? == 0 {
69+
let args: Vec<f64> = vec![];
70+
let pyobj = pygate.call_method("get_unitary", (args,), None)?;
71+
let pymat = pyobj.getattr("numpy")?.extract::<&PyArray2<c64>>()?;
72+
let mat = pymat.to_owned_array();
73+
let gate_size = pygate.getattr("num_qudits")?.extract::<usize>()?;
74+
let index = constant_gates.len();
75+
constant_gates.push(mat);
76+
Ok(ConstantGate::new(index, gate_size).into())
77+
} else if pygate.hasattr("get_unitary")?
78+
&& ((pygate.hasattr("get_grad")? && pygate.hasattr("get_unitary_and_grad")?)
79+
|| pygate.hasattr("optimize")?)
80+
{
81+
let dynamic: Arc<dyn DynGate + Send + Sync> = Arc::new(PyGate::new(pygate.into()));
82+
Ok(Gate::Dynamic(dynamic))
83+
} else {
84+
Err(exceptions::PyValueError::new_err(format!(
85+
"Gate {} does not implement the necessary methods for optimization.",
86+
name
87+
)))
88+
}
89+
}
90+
7091
impl<'source> FromPyObject<'source> for Circuit {
7192
fn extract(ob: &'source PyAny) -> PyResult<Self> {
7293
let gil = Python::acquire_gil();

src/python/minimizers/residual_fn.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,14 @@ impl ResidualFn for PyResidualFn {
5252
let py = gil.python();
5353
let parameters = PyArray1::from_slice(py, params);
5454
let args = PyTuple::new(py, &[parameters]);
55-
match self.cost_fn.call_method1(py, "get_cost", args) {
55+
match self.cost_fn.call_method1(py, "get_residuals", args) {
5656
Ok(val) => val
5757
.extract::<Vec<f64>>(py)
58-
.expect("Return type of get_cost was not a sequence of floats."),
59-
Err(..) => panic!("Failed to call 'get_cost' on passed ResidualFunction."), // TODO: make a Python exception?
58+
.expect("Return type of get_residuals was not a sequence of floats."),
59+
Err(e) => {
60+
println!("{:?}, {:?}, {:?}", e.get_type(py), e.value(py), e.traceback(py));
61+
panic!("Failed to call 'get_residuals' on passed ResidualFunction."); // TODO: make a Python exception?
62+
},
6063
}
6164
}
6265
}

0 commit comments

Comments
 (0)