1use super::{Channel, OutputBuffer};
4use crate::pauli_string::PauliString;
5use crate::phase::Phase;
6use num_complex::Complex64;
7
8pub struct PauliRotation<const W: usize> {
30 pub support: Vec<u32>,
32 pub gen_x: [u64; W],
34 pub gen_z: [u64; W],
36 pub theta: f64,
38}
39
40impl<const W: usize> PauliRotation<W> {
41 #[inline]
45 fn apply_with_theta(
46 &self,
47 theta: f64,
48 input_x: &[u64; W],
49 input_z: &[u64; W],
50 coeff: Complex64,
51 out: &mut OutputBuffer<'_, W>,
52 ) {
53 let input = PauliString::<W> {
54 x: *input_x,
55 z: *input_z,
56 };
57 let gen = PauliString::<W> {
58 x: self.gen_x,
59 z: self.gen_z,
60 };
61
62 if input.commutes_with(&gen) {
63 out.push(*input_x, *input_z, coeff);
64 return;
65 }
66
67 let cos_t = theta.cos();
68 let sin_t = theta.sin();
69
70 out.push(*input_x, *input_z, coeff * cos_t);
72
73 let mut prod = input;
77 let phase = prod.mul_assign(&gen);
78 let total_phase = Phase::I + phase;
79 out.push(prod.x, prod.z, total_phase.apply(coeff) * sin_t);
80 }
81}
82
83impl<const W: usize> Channel<W> for PauliRotation<W> {
84 #[inline]
85 fn max_fanout(&self) -> usize {
86 2
87 }
88
89 #[inline]
90 fn support(&self) -> &[u32] {
91 &self.support
92 }
93
94 #[inline]
95 fn apply(
96 &self,
97 input_x: &[u64; W],
98 input_z: &[u64; W],
99 coeff: Complex64,
100 out: &mut OutputBuffer<'_, W>,
101 ) {
102 self.apply_with_theta(self.theta, input_x, input_z, coeff, out);
103 }
104
105 #[inline]
106 fn apply_adjoint(
107 &self,
108 input_x: &[u64; W],
109 input_z: &[u64; W],
110 coeff: Complex64,
111 out: &mut OutputBuffer<'_, W>,
112 ) {
113 self.apply_with_theta(-self.theta, input_x, input_z, coeff, out);
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 const TOL: f64 = 1e-12;
122
123 #[allow(clippy::type_complexity)]
124 fn alloc_bufs<const W: usize>(
125 n: usize,
126 ) -> (Vec<[u64; W]>, Vec<[u64; W]>, Vec<Complex64>, usize) {
127 (
128 vec![[0u64; W]; n],
129 vec![[0u64; W]; n],
130 vec![Complex64::new(0.0, 0.0); n],
131 0usize,
132 )
133 }
134
135 fn approx_eq(a: Complex64, b: Complex64, tol: f64) -> bool {
136 (a - b).norm() <= tol
137 }
138
139 #[test]
142 fn theta_zero_anticommuting_w1() {
143 let q = PauliString::<1>::x(0);
144 let p = PauliString::<1>::z(0);
145 let rot = PauliRotation::<1> {
146 support: vec![0],
147 gen_x: p.x,
148 gen_z: p.z,
149 theta: 0.0,
150 };
151 let c = Complex64::new(2.0, 3.0);
152 let (mut bx, mut bz, mut bc, mut len) = alloc_bufs::<1>(2);
153 let mut buf = OutputBuffer::<1> {
154 x: &mut bx,
155 z: &mut bz,
156 coeff: &mut bc,
157 len: &mut len,
158 };
159 rot.apply(&q.x, &q.z, c, &mut buf);
160 assert_eq!(*buf.len, 2);
161 assert_eq!(bx[0], q.x);
162 assert_eq!(bz[0], q.z);
163 assert!(approx_eq(bc[0], c, TOL));
164 let y = PauliString::<1>::y(0);
166 assert_eq!(bx[1], y.x);
167 assert_eq!(bz[1], y.z);
168 assert!(approx_eq(bc[1], Complex64::new(0.0, 0.0), TOL));
169 }
170
171 #[test]
173 fn theta_zero_commuting_w1() {
174 let q = PauliString::<1>::z(0);
175 let p = PauliString::<1>::z(0);
176 let rot = PauliRotation::<1> {
177 support: vec![0],
178 gen_x: p.x,
179 gen_z: p.z,
180 theta: 0.0,
181 };
182 let c = Complex64::new(2.0, 3.0);
183 let (mut bx, mut bz, mut bc, mut len) = alloc_bufs::<1>(2);
184 let mut buf = OutputBuffer::<1> {
185 x: &mut bx,
186 z: &mut bz,
187 coeff: &mut bc,
188 len: &mut len,
189 };
190 rot.apply(&q.x, &q.z, c, &mut buf);
191 assert_eq!(*buf.len, 1);
192 assert_eq!(bx[0], q.x);
193 assert_eq!(bz[0], q.z);
194 assert_eq!(bc[0], c);
195 }
196
197 #[test]
199 fn pi_z_flips_x_to_minus_x_w1() {
200 let q = PauliString::<1>::x(0);
201 let p = PauliString::<1>::z(0);
202 let rot = PauliRotation::<1> {
203 support: vec![0],
204 gen_x: p.x,
205 gen_z: p.z,
206 theta: std::f64::consts::PI,
207 };
208 let c = Complex64::new(1.0, 0.0);
209 let (mut bx, mut bz, mut bc, mut len) = alloc_bufs::<1>(2);
210 let mut buf = OutputBuffer::<1> {
211 x: &mut bx,
212 z: &mut bz,
213 coeff: &mut bc,
214 len: &mut len,
215 };
216 rot.apply(&q.x, &q.z, c, &mut buf);
217 assert_eq!(*buf.len, 2);
218 assert_eq!(bx[0], q.x);
219 assert_eq!(bz[0], q.z);
220 assert!(approx_eq(bc[0], Complex64::new(-1.0, 0.0), TOL));
221 let y = PauliString::<1>::y(0);
222 assert_eq!(bx[1], y.x);
223 assert_eq!(bz[1], y.z);
224 assert!(approx_eq(bc[1], Complex64::new(0.0, 0.0), TOL));
225 }
226
227 #[test]
229 fn commuting_case_is_fanout_one_w1() {
230 let q = PauliString::<1>::identity();
231 let p = PauliString::<1>::z(0);
232 let rot = PauliRotation::<1> {
233 support: vec![0],
234 gen_x: p.x,
235 gen_z: p.z,
236 theta: std::f64::consts::FRAC_PI_4,
237 };
238 let c = Complex64::new(0.5, 0.25);
239 let (mut bx, mut bz, mut bc, mut len) = alloc_bufs::<1>(2);
240 let mut buf = OutputBuffer::<1> {
241 x: &mut bx,
242 z: &mut bz,
243 coeff: &mut bc,
244 len: &mut len,
245 };
246 rot.apply(&q.x, &q.z, c, &mut buf);
247 assert_eq!(*buf.len, 1);
248 assert_eq!(bx[0], q.x);
249 assert_eq!(bz[0], q.z);
250 assert_eq!(bc[0], c);
251 }
252
253 #[test]
256 fn anticommuting_case_is_fanout_two_w1() {
257 let q = PauliString::<1>::x(0);
258 let p = PauliString::<1>::z(0);
259 let theta = std::f64::consts::FRAC_PI_3;
260 let rot = PauliRotation::<1> {
261 support: vec![0],
262 gen_x: p.x,
263 gen_z: p.z,
264 theta,
265 };
266 let c = Complex64::new(1.0, 0.0);
267 let (mut bx, mut bz, mut bc, mut len) = alloc_bufs::<1>(2);
268 let mut buf = OutputBuffer::<1> {
269 x: &mut bx,
270 z: &mut bz,
271 coeff: &mut bc,
272 len: &mut len,
273 };
274 rot.apply(&q.x, &q.z, c, &mut buf);
275 assert_eq!(*buf.len, 2);
276 assert_eq!(bx[0], q.x);
277 assert_eq!(bz[0], q.z);
278 assert!(approx_eq(bc[0], Complex64::new(theta.cos(), 0.0), TOL));
279 let y = PauliString::<1>::y(0);
280 assert_eq!(bx[1], y.x);
281 assert_eq!(bz[1], y.z);
282 assert!(approx_eq(bc[1], Complex64::new(theta.sin(), 0.0), TOL));
283 }
284
285 #[test]
288 fn pi_over_two_x_rotates_z_to_minus_y_w1() {
289 let q = PauliString::<1>::z(0);
290 let p = PauliString::<1>::x(0);
291 let rot = PauliRotation::<1> {
292 support: vec![0],
293 gen_x: p.x,
294 gen_z: p.z,
295 theta: std::f64::consts::FRAC_PI_2,
296 };
297 let c = Complex64::new(1.0, 0.0);
298 let (mut bx, mut bz, mut bc, mut len) = alloc_bufs::<1>(2);
299 let mut buf = OutputBuffer::<1> {
300 x: &mut bx,
301 z: &mut bz,
302 coeff: &mut bc,
303 len: &mut len,
304 };
305 rot.apply(&q.x, &q.z, c, &mut buf);
306 assert_eq!(*buf.len, 2);
307 assert_eq!(bx[0], q.x);
308 assert_eq!(bz[0], q.z);
309 assert!(approx_eq(bc[0], Complex64::new(0.0, 0.0), TOL));
310 let y = PauliString::<1>::y(0);
311 assert_eq!(bx[1], y.x);
312 assert_eq!(bz[1], y.z);
313 assert!(approx_eq(bc[1], Complex64::new(-1.0, 0.0), TOL));
314 }
315
316 #[test]
320 fn phase_from_mul_assign_is_folded_w1() {
321 let q = PauliString::<1>::y(0);
322 let p = PauliString::<1>::z(0);
323 let rot = PauliRotation::<1> {
324 support: vec![0],
325 gen_x: p.x,
326 gen_z: p.z,
327 theta: std::f64::consts::FRAC_PI_2,
328 };
329 let c = Complex64::new(1.0, 0.0);
330 let (mut bx, mut bz, mut bc, mut len) = alloc_bufs::<1>(2);
331 let mut buf = OutputBuffer::<1> {
332 x: &mut bx,
333 z: &mut bz,
334 coeff: &mut bc,
335 len: &mut len,
336 };
337 rot.apply(&q.x, &q.z, c, &mut buf);
338 assert_eq!(*buf.len, 2);
339 assert_eq!(bx[0], q.x);
340 assert_eq!(bz[0], q.z);
341 assert!(approx_eq(bc[0], Complex64::new(0.0, 0.0), TOL));
342 let xp = PauliString::<1>::x(0);
343 assert_eq!(bx[1], xp.x);
344 assert_eq!(bz[1], xp.z);
345 assert!(approx_eq(bc[1], Complex64::new(-1.0, 0.0), TOL));
346 }
347
348 #[test]
351 fn multi_word_disjoint_support_commutes_w2() {
352 let q = PauliString::<2>::x(0);
353 let p = PauliString::<2>::z(64);
354 let rot = PauliRotation::<2> {
355 support: vec![64],
356 gen_x: p.x,
357 gen_z: p.z,
358 theta: std::f64::consts::FRAC_PI_4,
359 };
360 let c = Complex64::new(1.0, 0.0);
361 let (mut bx, mut bz, mut bc, mut len) = alloc_bufs::<2>(2);
362 let mut buf = OutputBuffer::<2> {
363 x: &mut bx,
364 z: &mut bz,
365 coeff: &mut bc,
366 len: &mut len,
367 };
368 rot.apply(&q.x, &q.z, c, &mut buf);
369 assert_eq!(*buf.len, 1);
370 assert_eq!(bx[0], q.x);
371 assert_eq!(bz[0], q.z);
372 assert_eq!(bc[0], c);
373 }
374
375 #[test]
377 fn multi_word_anticommute_in_word_1_w2() {
378 let q = PauliString::<2>::x(64);
379 let p = PauliString::<2>::z(64);
380 let theta = std::f64::consts::FRAC_PI_3;
381 let rot = PauliRotation::<2> {
382 support: vec![64],
383 gen_x: p.x,
384 gen_z: p.z,
385 theta,
386 };
387 let c = Complex64::new(1.0, 0.0);
388 let (mut bx, mut bz, mut bc, mut len) = alloc_bufs::<2>(2);
389 let mut buf = OutputBuffer::<2> {
390 x: &mut bx,
391 z: &mut bz,
392 coeff: &mut bc,
393 len: &mut len,
394 };
395 rot.apply(&q.x, &q.z, c, &mut buf);
396 assert_eq!(*buf.len, 2);
397 assert_eq!(bx[0], q.x);
398 assert_eq!(bz[0], q.z);
399 assert!(approx_eq(bc[0], Complex64::new(theta.cos(), 0.0), TOL));
400 let y = PauliString::<2>::y(64);
401 assert_eq!(bx[1], y.x);
402 assert_eq!(bz[1], y.z);
403 assert_eq!(bx[1][0], 0u64);
404 assert_eq!(bz[1][0], 0u64);
405 assert!(approx_eq(bc[1], Complex64::new(theta.sin(), 0.0), TOL));
406 }
407
408 #[test]
411 fn reuse_buffer_across_calls() {
412 let cap = 2;
413 let mut bx: Vec<[u64; 1]> = vec![[0u64; 1]; cap];
414 let mut bz: Vec<[u64; 1]> = vec![[0u64; 1]; cap];
415 let mut bc: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); cap];
416 let p = PauliString::<1>::z(0);
417 let rot = PauliRotation::<1> {
418 support: vec![0],
419 gen_x: p.x,
420 gen_z: p.z,
421 theta: std::f64::consts::FRAC_PI_3,
422 };
423 let q = PauliString::<1>::x(0);
424 for _ in 0..3 {
425 let mut len = 0usize;
426 let mut buf = OutputBuffer::<1> {
427 x: &mut bx,
428 z: &mut bz,
429 coeff: &mut bc,
430 len: &mut len,
431 };
432 rot.apply(&q.x, &q.z, Complex64::new(1.0, 0.0), &mut buf);
433 assert_eq!(*buf.len, 2);
434 }
435 assert_eq!(bx.capacity(), cap);
436 assert_eq!(bz.capacity(), cap);
437 assert_eq!(bc.capacity(), cap);
438 }
439}