1#![allow(unused)]
4
5use num_complex::Complex64;
6use rayon::prelude::*;
7
8use crate::channel::{Channel, OutputBuffer};
9use crate::pauli_sum::PauliSum;
10use crate::truncation::TruncationPolicy;
11
12pub const SMALL_SUM_THRESHOLD: usize = 4096;
15
16pub fn apply_layer<const W: usize, C, T>(
26 input: &PauliSum<W>,
27 channel: &C,
28 policy: &T,
29) -> PauliSum<W>
30where
31 C: Channel<W> + ?Sized,
32 T: TruncationPolicy<W> + ?Sized,
33{
34 apply_layer_inner(input, channel, policy, false)
35}
36
37pub fn apply_layer_adjoint<const W: usize, C, T>(
41 input: &PauliSum<W>,
42 channel: &C,
43 policy: &T,
44) -> PauliSum<W>
45where
46 C: Channel<W> + ?Sized,
47 T: TruncationPolicy<W> + ?Sized,
48{
49 apply_layer_inner(input, channel, policy, true)
50}
51
52fn apply_layer_inner<const W: usize, C, T>(
53 input: &PauliSum<W>,
54 channel: &C,
55 policy: &T,
56 adjoint: bool,
57) -> PauliSum<W>
58where
59 C: Channel<W> + ?Sized,
60 T: TruncationPolicy<W> + ?Sized,
61{
62 let n_in = input.len();
63 let mf = channel.max_fanout();
64 let cap = n_in * mf;
65 let mut out_x: Vec<[u64; W]> = vec![[0u64; W]; cap];
66 let mut out_z: Vec<[u64; W]> = vec![[0u64; W]; cap];
67 let mut out_coeff: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); cap];
68 let len = if adjoint {
69 scan_phase_adjoint(input, channel, &mut out_x, &mut out_z, &mut out_coeff)
70 } else {
71 scan_phase(input, channel, &mut out_x, &mut out_z, &mut out_coeff)
72 };
73 sort_phase(&mut out_x, &mut out_z, &mut out_coeff, len);
74 let result = merge_phase::<W, T>(&out_x, &out_z, &out_coeff, len, input.num_qubits(), policy);
75 #[cfg(debug_assertions)]
76 result.assert_invariants();
77 result
78}
79
80fn scan_phase_with<const W: usize, C, F>(
102 input: &PauliSum<W>,
103 channel: &C,
104 out_x: &mut [[u64; W]],
105 out_z: &mut [[u64; W]],
106 out_coeff: &mut [Complex64],
107 apply_fn: F,
108) -> usize
109where
110 C: Channel<W> + ?Sized,
111 F: Fn(&C, &[u64; W], &[u64; W], Complex64, &mut OutputBuffer<'_, W>) + Sync,
112{
113 let mf = channel.max_fanout();
114 let n_in = input.len();
115 debug_assert!(out_x.len() >= n_in * mf);
116 debug_assert_eq!(out_x.len(), out_z.len());
117 debug_assert_eq!(out_x.len(), out_coeff.len());
118 if n_in == 0 || mf == 0 {
119 return 0;
120 }
121 let cap = n_in * mf;
122 let lens: Vec<usize> = out_x[..cap]
123 .par_chunks_mut(mf)
124 .zip(out_z[..cap].par_chunks_mut(mf))
125 .zip(out_coeff[..cap].par_chunks_mut(mf))
126 .enumerate()
127 .map(|(i, ((sx, sz), sc))| {
128 let mut local_len = 0usize;
129 let mut buf = OutputBuffer::<W> {
130 x: sx,
131 z: sz,
132 coeff: sc,
133 len: &mut local_len,
134 };
135 apply_fn(channel, &input.x[i], &input.z[i], input.coeff[i], &mut buf);
136 local_len
137 })
138 .collect();
139 compact_in_place(out_x, out_z, out_coeff, &lens, mf)
140}
141
142fn compact_in_place<const W: usize>(
147 out_x: &mut [[u64; W]],
148 out_z: &mut [[u64; W]],
149 out_coeff: &mut [Complex64],
150 lens: &[usize],
151 mf: usize,
152) -> usize {
153 let mut write = 0usize;
154 for (i, &len_i) in lens.iter().enumerate() {
155 if len_i == 0 {
156 continue;
157 }
158 let src = i * mf;
159 debug_assert!(write <= src);
160 if write != src {
161 out_x.copy_within(src..src + len_i, write);
162 out_z.copy_within(src..src + len_i, write);
163 out_coeff.copy_within(src..src + len_i, write);
164 }
165 write += len_i;
166 }
167 write
168}
169
170pub(crate) fn scan_phase<const W: usize, C: Channel<W> + ?Sized>(
172 input: &PauliSum<W>,
173 channel: &C,
174 out_x: &mut [[u64; W]],
175 out_z: &mut [[u64; W]],
176 out_coeff: &mut [Complex64],
177) -> usize {
178 scan_phase_with(
179 input,
180 channel,
181 out_x,
182 out_z,
183 out_coeff,
184 |c, x, z, co, out| c.apply(x, z, co, out),
185 )
186}
187
188pub(crate) fn scan_phase_adjoint<const W: usize, C: Channel<W> + ?Sized>(
190 input: &PauliSum<W>,
191 channel: &C,
192 out_x: &mut [[u64; W]],
193 out_z: &mut [[u64; W]],
194 out_coeff: &mut [Complex64],
195) -> usize {
196 scan_phase_with(
197 input,
198 channel,
199 out_x,
200 out_z,
201 out_coeff,
202 |c, x, z, co, out| c.apply_adjoint(x, z, co, out),
203 )
204}
205
206pub(crate) fn sort_phase<const W: usize>(
216 out_x: &mut [[u64; W]],
217 out_z: &mut [[u64; W]],
218 out_coeff: &mut [Complex64],
219 len: usize,
220) {
221 debug_assert!(out_x.len() >= len);
222 debug_assert_eq!(out_x.len(), out_z.len());
223 debug_assert_eq!(out_x.len(), out_coeff.len());
224 if len < 2 {
225 return;
226 }
227 let mut perm: Vec<usize> = (0..len).collect();
228 perm.sort_by(|&a, &b| {
231 out_x[a]
232 .cmp(&out_x[b])
233 .then_with(|| out_z[a].cmp(&out_z[b]))
234 });
235 let new_x: Vec<[u64; W]> = perm.iter().map(|&i| out_x[i]).collect();
236 let new_z: Vec<[u64; W]> = perm.iter().map(|&i| out_z[i]).collect();
237 let new_c: Vec<Complex64> = perm.iter().map(|&i| out_coeff[i]).collect();
238 out_x[..len].copy_from_slice(&new_x);
239 out_z[..len].copy_from_slice(&new_z);
240 out_coeff[..len].copy_from_slice(&new_c);
241}
242
243const SMALL_MERGE_THRESHOLD: usize = 1024;
246
247type ChunkOutput<const W: usize> = (Vec<[u64; W]>, Vec<[u64; W]>, Vec<Complex64>);
250
251pub(crate) fn merge_phase<const W: usize, T: TruncationPolicy<W> + ?Sized>(
269 sorted_x: &[[u64; W]],
270 sorted_z: &[[u64; W]],
271 sorted_coeff: &[Complex64],
272 len: usize,
273 num_qubits: usize,
274 policy: &T,
275) -> PauliSum<W> {
276 let nchunks = if len < SMALL_MERGE_THRESHOLD {
277 1
278 } else {
279 rayon::current_num_threads().max(1)
280 };
281 merge_phase_with_nchunks::<W, T>(
282 sorted_x,
283 sorted_z,
284 sorted_coeff,
285 len,
286 num_qubits,
287 policy,
288 nchunks,
289 )
290}
291
292pub(crate) fn merge_phase_with_nchunks<const W: usize, T: TruncationPolicy<W> + ?Sized>(
296 sorted_x: &[[u64; W]],
297 sorted_z: &[[u64; W]],
298 sorted_coeff: &[Complex64],
299 len: usize,
300 num_qubits: usize,
301 policy: &T,
302 nchunks: usize,
303) -> PauliSum<W> {
304 debug_assert!(sorted_x.len() >= len);
305 debug_assert_eq!(sorted_x.len(), sorted_z.len());
306 debug_assert_eq!(sorted_x.len(), sorted_coeff.len());
307 if len == 0 {
308 return PauliSum::<W> {
309 x: Vec::new(),
310 z: Vec::new(),
311 coeff: Vec::new(),
312 num_qubits,
313 };
314 }
315 let bounds = align_chunk_boundaries(sorted_x, sorted_z, len, nchunks.max(1));
316 let chunk_results: Vec<ChunkOutput<W>> = bounds
317 .par_iter()
318 .map(|&(start, end)| {
319 merge_chunk::<W, T>(sorted_x, sorted_z, sorted_coeff, start, end, policy)
320 })
321 .collect();
322 let total: usize = chunk_results.iter().map(|(cx, _, _)| cx.len()).sum();
323 let mut x = Vec::with_capacity(total);
324 let mut z = Vec::with_capacity(total);
325 let mut coeff = Vec::with_capacity(total);
326 for (cx, cz, cc) in chunk_results {
327 x.extend(cx);
328 z.extend(cz);
329 coeff.extend(cc);
330 }
331 PauliSum::<W> {
332 x,
333 z,
334 coeff,
335 num_qubits,
336 }
337}
338
339fn merge_chunk<const W: usize, T: TruncationPolicy<W> + ?Sized>(
346 sorted_x: &[[u64; W]],
347 sorted_z: &[[u64; W]],
348 sorted_coeff: &[Complex64],
349 start: usize,
350 end: usize,
351 policy: &T,
352) -> ChunkOutput<W> {
353 let zero = Complex64::new(0.0, 0.0);
354 let mut x: Vec<[u64; W]> = Vec::new();
355 let mut z: Vec<[u64; W]> = Vec::new();
356 let mut coeff: Vec<Complex64> = Vec::new();
357 let mut i = start;
358 while i < end {
359 let key_x = sorted_x[i];
360 let key_z = sorted_z[i];
361 let mut acc = sorted_coeff[i];
362 let mut j = i + 1;
363 while j < end && sorted_x[j] == key_x && sorted_z[j] == key_z {
364 acc += sorted_coeff[j];
365 j += 1;
366 }
367 debug_assert!(
368 i == 0 || (sorted_x[i - 1], sorted_z[i - 1]) <= (key_x, key_z),
369 "merge_chunk: scratch is not sorted at index {}",
370 i,
371 );
372 if acc != zero && policy.keep_term(&key_x, &key_z, acc) {
373 x.push(key_x);
374 z.push(key_z);
375 coeff.push(acc);
376 }
377 i = j;
378 }
379 (x, z, coeff)
380}
381
382fn align_chunk_boundaries<const W: usize>(
388 sorted_x: &[[u64; W]],
389 sorted_z: &[[u64; W]],
390 len: usize,
391 nchunks: usize,
392) -> Vec<(usize, usize)> {
393 if len == 0 {
394 return Vec::new();
395 }
396 if nchunks <= 1 {
397 return vec![(0, len)];
398 }
399 let mut bounds: Vec<usize> = Vec::with_capacity(nchunks + 1);
400 bounds.push(0);
401 for k in 1..nchunks {
402 let mut t = (len * k) / nchunks;
403 while t > 0 && t < len && sorted_x[t] == sorted_x[t - 1] && sorted_z[t] == sorted_z[t - 1] {
405 t += 1;
406 }
407 bounds.push(t);
408 }
409 bounds.push(len);
410 bounds.dedup();
411 bounds.windows(2).map(|w| (w[0], w[1])).collect()
412}
413
414#[cfg(all(test, debug_assertions))]
415mod tests {
416 use super::*;
417 use crate::channel::{Clifford1Q, IdentityChannel, PauliRotation};
418 use crate::pauli_string::PauliString;
419 use crate::truncation::CoefficientThreshold;
420
421 const TOL: f64 = 1e-12;
422
423 #[allow(clippy::type_complexity)]
424 fn alloc_bufs<const W: usize>(n: usize) -> (Vec<[u64; W]>, Vec<[u64; W]>, Vec<Complex64>) {
425 (
426 vec![[0u64; W]; n],
427 vec![[0u64; W]; n],
428 vec![Complex64::new(0.0, 0.0); n],
429 )
430 }
431
432 fn approx_eq(a: Complex64, b: Complex64, tol: f64) -> bool {
433 (a - b).norm() <= tol
434 }
435
436 #[test]
438 fn scan_identity_passes_through_w1() {
439 let input = PauliSum::<1>::from_strings(&[
440 ("X", Complex64::new(1.0, 0.0)),
441 ("Z", Complex64::new(2.0, 0.0)),
442 ]);
443 let id = IdentityChannel::new();
444 let cap = input.len() * <IdentityChannel as Channel<1>>::max_fanout(&id);
445 let (mut bx, mut bz, mut bc) = alloc_bufs::<1>(cap);
446 let total = scan_phase(&input, &id, &mut bx, &mut bz, &mut bc);
447 assert_eq!(total, 2);
448 for i in 0..total {
449 assert_eq!(bx[i], input.x()[i]);
450 assert_eq!(bz[i], input.z()[i]);
451 assert_eq!(bc[i], input.coeff()[i]);
452 }
453 }
454
455 #[test]
458 fn scan_clifford1q_h_w1() {
459 let input = PauliSum::<1>::from_strings(&[
460 ("Z", Complex64::new(3.0, 0.0)),
461 ("X", Complex64::new(5.0, 0.0)),
462 ]);
463 let h = Clifford1Q::h(0);
464 let cap = input.len() * <Clifford1Q as Channel<1>>::max_fanout(&h);
465 let (mut bx, mut bz, mut bc) = alloc_bufs::<1>(cap);
466 let total = scan_phase(&input, &h, &mut bx, &mut bz, &mut bc);
467 assert_eq!(total, 2);
468 assert_eq!(bx[0], PauliString::<1>::x(0).x);
471 assert_eq!(bz[0], PauliString::<1>::x(0).z);
472 assert_eq!(bc[0], Complex64::new(3.0, 0.0));
473 assert_eq!(bx[1], PauliString::<1>::z(0).x);
474 assert_eq!(bz[1], PauliString::<1>::z(0).z);
475 assert_eq!(bc[1], Complex64::new(5.0, 0.0));
476 }
477
478 #[test]
482 fn scan_pauli_rotation_packs_variable_fanout() {
483 let p = PauliString::<1>::z(0);
486 let theta = std::f64::consts::FRAC_PI_3;
487 let rot = PauliRotation::<1> {
488 support: vec![0],
489 gen_x: p.x,
490 gen_z: p.z,
491 theta,
492 };
493 let input = PauliSum::<1>::from_strings(&[
494 ("X", Complex64::new(1.0, 0.0)),
495 ("Z", Complex64::new(2.0, 0.0)),
496 ]);
497 let cap = input.len() * <PauliRotation<1> as Channel<1>>::max_fanout(&rot);
498 let (mut bx, mut bz, mut bc) = alloc_bufs::<1>(cap);
499 let total = scan_phase(&input, &rot, &mut bx, &mut bz, &mut bc);
500 assert_eq!(total, 3);
501 assert_eq!(bx[0], PauliString::<1>::z(0).x);
505 assert_eq!(bz[0], PauliString::<1>::z(0).z);
506 assert!(approx_eq(bc[0], Complex64::new(2.0, 0.0), TOL));
507 assert_eq!(bx[1], PauliString::<1>::x(0).x);
509 assert_eq!(bz[1], PauliString::<1>::x(0).z);
510 assert!(approx_eq(bc[1], Complex64::new(theta.cos(), 0.0), TOL));
511 assert_eq!(bx[2], PauliString::<1>::y(0).x);
515 assert_eq!(bz[2], PauliString::<1>::y(0).z);
516 assert!(approx_eq(bc[2], Complex64::new(theta.sin(), 0.0), TOL));
517 }
518
519 #[test]
521 fn scan_w2_word_boundary() {
522 let input = PauliSum::<2> {
523 x: vec![[0u64, 1u64]], z: vec![[0u64; 2]],
525 coeff: vec![Complex64::new(1.5, 0.0)],
526 num_qubits: 65,
527 };
528 let h = Clifford1Q::h(64);
529 let cap = input.len() * <Clifford1Q as Channel<2>>::max_fanout(&h);
530 let (mut bx, mut bz, mut bc) = alloc_bufs::<2>(cap);
531 let total = scan_phase(&input, &h, &mut bx, &mut bz, &mut bc);
532 assert_eq!(total, 1);
533 assert_eq!(bx[0], [0u64, 0u64]);
535 assert_eq!(bz[0], [0u64, 1u64]);
536 assert_eq!(bc[0], Complex64::new(1.5, 0.0));
537 }
538
539 #[test]
541 fn scan_empty_input() {
542 let input = PauliSum::<1>::empty(4);
543 let id = IdentityChannel::new();
544 let mut bx: Vec<[u64; 1]> = vec![];
545 let mut bz: Vec<[u64; 1]> = vec![];
546 let mut bc: Vec<Complex64> = vec![];
547 let total = scan_phase(&input, &id, &mut bx, &mut bz, &mut bc);
548 assert_eq!(total, 0);
549 }
550
551 #[test]
557 fn scan_determinism_across_thread_counts() {
558 let labels = ["I", "X", "Y", "Z"];
560 let mut owned: Vec<(String, Complex64)> = Vec::new();
561 let mut k: i64 = 0;
562 for a in &labels {
563 for b in &labels {
564 k += 1;
565 owned.push((
566 format!("{}{}", a, b),
567 Complex64::new(k as f64 * 0.13, k as f64 * 0.07),
568 ));
569 }
570 }
571 let strings: Vec<(&str, Complex64)> = owned.iter().map(|(s, c)| (s.as_str(), *c)).collect();
572 let input = PauliSum::<1>::from_strings(&strings);
573 let h = Clifford1Q::h(0);
574 let policy = AlwaysKeep;
575
576 let run = |n: usize| -> PauliSum<1> {
577 let pool = rayon::ThreadPoolBuilder::new()
578 .num_threads(n)
579 .build()
580 .unwrap();
581 pool.install(|| apply_layer(&input, &h, &policy))
582 };
583 let r1 = run(1);
584 let r2 = run(2);
585 let r4 = run(4);
586
587 assert_eq!(r1.x(), r2.x());
588 assert_eq!(r1.z(), r2.z());
589 assert_eq!(r1.coeff(), r2.coeff());
590 assert_eq!(r1.x(), r4.x());
591 assert_eq!(r1.z(), r4.z());
592 assert_eq!(r1.coeff(), r4.coeff());
593 }
594
595 #[test]
599 fn sort_phase_orders_by_lex_key() {
600 let mut x: Vec<[u64; 1]> = vec![[1], [0], [0]];
602 let mut z: Vec<[u64; 1]> = vec![[0], [0], [1]];
603 let mut c: Vec<Complex64> = vec![
604 Complex64::new(7.0, 0.0), Complex64::new(8.0, 0.0), Complex64::new(9.0, 0.0), ];
608 sort_phase(&mut x, &mut z, &mut c, 3);
609 assert_eq!(x[0], [0]);
610 assert_eq!(z[0], [0]);
611 assert_eq!(c[0], Complex64::new(8.0, 0.0));
612 assert_eq!(x[1], [0]);
613 assert_eq!(z[1], [1]);
614 assert_eq!(c[1], Complex64::new(9.0, 0.0));
615 assert_eq!(x[2], [1]);
616 assert_eq!(z[2], [0]);
617 assert_eq!(c[2], Complex64::new(7.0, 0.0));
618 }
619
620 #[test]
622 fn sort_phase_preserves_already_sorted() {
623 let mut x: Vec<[u64; 1]> = vec![[0], [0], [1]];
624 let mut z: Vec<[u64; 1]> = vec![[0], [1], [0]];
625 let mut c: Vec<Complex64> = vec![
626 Complex64::new(1.0, 0.0),
627 Complex64::new(2.0, 0.0),
628 Complex64::new(3.0, 0.0),
629 ];
630 sort_phase(&mut x, &mut z, &mut c, 3);
631 assert_eq!(x, vec![[0u64], [0u64], [1u64]]);
632 assert_eq!(z, vec![[0u64], [1u64], [0u64]]);
633 assert_eq!(
634 c,
635 vec![
636 Complex64::new(1.0, 0.0),
637 Complex64::new(2.0, 0.0),
638 Complex64::new(3.0, 0.0),
639 ]
640 );
641 }
642
643 #[test]
646 fn sort_phase_is_stable_on_equal_keys() {
647 let mut x: Vec<[u64; 1]> = vec![[0], [1], [0], [0]];
650 let mut z: Vec<[u64; 1]> = vec![[1], [0], [1], [1]];
651 let mut c: Vec<Complex64> = vec![
652 Complex64::new(1.0, 0.0), Complex64::new(99.0, 0.0), Complex64::new(2.0, 0.0), Complex64::new(3.0, 0.0), ];
657 sort_phase(&mut x, &mut z, &mut c, 4);
658 assert_eq!(x[0], [0]);
661 assert_eq!(z[0], [1]);
662 assert_eq!(c[0], Complex64::new(1.0, 0.0));
663 assert_eq!(x[1], [0]);
664 assert_eq!(z[1], [1]);
665 assert_eq!(c[1], Complex64::new(2.0, 0.0));
666 assert_eq!(x[2], [0]);
667 assert_eq!(z[2], [1]);
668 assert_eq!(c[2], Complex64::new(3.0, 0.0));
669 assert_eq!(x[3], [1]);
670 assert_eq!(z[3], [0]);
671 assert_eq!(c[3], Complex64::new(99.0, 0.0));
672 }
673
674 #[test]
677 fn sort_phase_w2_cross_word_priority() {
678 let mut x: Vec<[u64; 2]> = vec![[1, 0], [0, 99]];
682 let mut z: Vec<[u64; 2]> = vec![[0, 0], [0, 0]];
683 let mut c: Vec<Complex64> = vec![
684 Complex64::new(11.0, 0.0), Complex64::new(22.0, 0.0), ];
687 sort_phase(&mut x, &mut z, &mut c, 2);
688 assert_eq!(x[0], [0, 99]);
689 assert_eq!(c[0], Complex64::new(22.0, 0.0));
690 assert_eq!(x[1], [1, 0]);
691 assert_eq!(c[1], Complex64::new(11.0, 0.0));
692 }
693
694 #[test]
697 fn sort_phase_len_lt_2_is_noop() {
698 let mut x: Vec<[u64; 1]> = vec![[5]];
699 let mut z: Vec<[u64; 1]> = vec![[7]];
700 let mut c: Vec<Complex64> = vec![Complex64::new(1.0, 2.0)];
701 sort_phase(&mut x, &mut z, &mut c, 1);
702 assert_eq!(x[0], [5]);
703 assert_eq!(z[0], [7]);
704 assert_eq!(c[0], Complex64::new(1.0, 2.0));
705
706 let mut empty_x: Vec<[u64; 1]> = vec![];
707 let mut empty_z: Vec<[u64; 1]> = vec![];
708 let mut empty_c: Vec<Complex64> = vec![];
709 sort_phase(&mut empty_x, &mut empty_z, &mut empty_c, 0);
710 assert!(empty_x.is_empty());
711 }
712
713 struct AlwaysKeep;
717 impl<const W: usize> TruncationPolicy<W> for AlwaysKeep {}
718
719 #[test]
720 fn merge_phase_empty_input() {
721 let x: Vec<[u64; 1]> = vec![];
722 let z: Vec<[u64; 1]> = vec![];
723 let c: Vec<Complex64> = vec![];
724 let out = merge_phase::<1, _>(&x, &z, &c, 0, 4, &AlwaysKeep);
725 assert!(out.is_empty());
726 assert_eq!(out.num_qubits(), 4);
727 out.assert_invariants();
728 }
729
730 #[test]
731 fn merge_phase_distinct_keys_pass_through() {
732 let x: Vec<[u64; 1]> = vec![[0], [0], [1]];
734 let z: Vec<[u64; 1]> = vec![[0], [1], [0]];
735 let c: Vec<Complex64> = vec![
736 Complex64::new(1.0, 0.0),
737 Complex64::new(2.0, 0.0),
738 Complex64::new(3.0, 0.0),
739 ];
740 let out = merge_phase::<1, _>(&x, &z, &c, 3, 1, &AlwaysKeep);
741 assert_eq!(out.len(), 3);
742 assert_eq!(out.x(), &[[0u64], [0u64], [1u64]]);
743 assert_eq!(out.z(), &[[0u64], [1u64], [0u64]]);
744 assert_eq!(
745 out.coeff(),
746 &[
747 Complex64::new(1.0, 0.0),
748 Complex64::new(2.0, 0.0),
749 Complex64::new(3.0, 0.0),
750 ]
751 );
752 out.assert_invariants();
753 }
754
755 #[test]
756 fn merge_phase_combines_adjacent_duplicates() {
757 let x: Vec<[u64; 1]> = vec![[0], [0], [0], [1]];
759 let z: Vec<[u64; 1]> = vec![[1], [1], [1], [0]];
760 let c: Vec<Complex64> = vec![
761 Complex64::new(1.0, 0.0),
762 Complex64::new(2.0, 0.0),
763 Complex64::new(3.0, 0.0),
764 Complex64::new(7.0, 0.0),
765 ];
766 let out = merge_phase::<1, _>(&x, &z, &c, 4, 1, &AlwaysKeep);
767 assert_eq!(out.len(), 2);
768 assert_eq!(out.x()[0], [0]);
770 assert_eq!(out.z()[0], [1]);
771 assert_eq!(out.coeff()[0], Complex64::new(6.0, 0.0));
772 assert_eq!(out.x()[1], [1]);
773 assert_eq!(out.z()[1], [0]);
774 assert_eq!(out.coeff()[1], Complex64::new(7.0, 0.0));
775 out.assert_invariants();
776 }
777
778 #[test]
779 fn merge_phase_drops_exact_zero_runs() {
780 let x: Vec<[u64; 1]> = vec![[0], [0], [1]];
783 let z: Vec<[u64; 1]> = vec![[1], [1], [0]];
784 let c: Vec<Complex64> = vec![
785 Complex64::new(1.0, 0.0),
786 Complex64::new(-1.0, 0.0),
787 Complex64::new(5.0, 0.0),
788 ];
789 let out = merge_phase::<1, _>(&x, &z, &c, 3, 1, &AlwaysKeep);
790 assert_eq!(out.len(), 1);
791 assert_eq!(out.x()[0], [1]);
792 assert_eq!(out.z()[0], [0]);
793 assert_eq!(out.coeff()[0], Complex64::new(5.0, 0.0));
794 out.assert_invariants();
795 }
796
797 #[test]
801 fn merge_phase_drops_below_threshold() {
802 let x: Vec<[u64; 1]> = vec![[0], [1]]; let z: Vec<[u64; 1]> = vec![[1], [0]];
804 let c: Vec<Complex64> = vec![Complex64::new(0.5, 0.0), Complex64::new(1e-9, 0.0)];
805 let out = merge_phase::<1, _>(&x, &z, &c, 2, 1, &CoefficientThreshold(1e-6));
806 assert_eq!(out.len(), 1);
807 assert_eq!(out.x()[0], [0]);
808 assert_eq!(out.z()[0], [1]);
809 assert!(approx_eq(out.coeff()[0], Complex64::new(0.5, 0.0), TOL));
810 out.assert_invariants();
811 }
812
813 #[test]
818 fn merge_phase_threshold_applied_after_summation() {
819 let x: Vec<[u64; 1]> = vec![[0], [0]];
820 let z: Vec<[u64; 1]> = vec![[1], [1]];
821 let c: Vec<Complex64> = vec![Complex64::new(0.5, 0.0), Complex64::new(-0.4999999, 0.0)];
822 let out = merge_phase::<1, _>(&x, &z, &c, 2, 1, &CoefficientThreshold(1e-6));
823 assert_eq!(out.len(), 0);
824 out.assert_invariants();
825 }
826
827 #[test]
830 fn merge_phase_zero_threshold_keeps_everything() {
831 let x: Vec<[u64; 1]> = vec![[0], [0], [1]];
832 let z: Vec<[u64; 1]> = vec![[0], [1], [0]];
833 let c: Vec<Complex64> = vec![
834 Complex64::new(1.0, 0.0),
835 Complex64::new(2.0, 0.0),
836 Complex64::new(3.0, 0.0),
837 ];
838 let out = merge_phase::<1, _>(&x, &z, &c, 3, 1, &CoefficientThreshold(0.0));
839 assert_eq!(out.len(), 3);
840 out.assert_invariants();
841 }
842
843 #[test]
846 fn merge_phase_ignores_trailing_junk() {
847 let x: Vec<[u64; 1]> = vec![[0], [99], [99]]; let z: Vec<[u64; 1]> = vec![[1], [99], [99]];
849 let c: Vec<Complex64> = vec![
850 Complex64::new(2.0, 0.0),
851 Complex64::new(99.0, 0.0),
852 Complex64::new(99.0, 0.0),
853 ];
854 let out = merge_phase::<1, _>(&x, &z, &c, 1, 1, &AlwaysKeep);
855 assert_eq!(out.len(), 1);
856 assert_eq!(out.x()[0], [0]);
857 assert_eq!(out.z()[0], [1]);
858 assert_eq!(out.coeff()[0], Complex64::new(2.0, 0.0));
859 }
860
861 #[test]
867 fn merge_phase_run_spans_chunk_boundary() {
868 let x: Vec<[u64; 1]> = vec![[0], [0], [0], [0], [0], [0], [0], [1], [1]];
870 let z: Vec<[u64; 1]> = vec![[0], [0], [1], [1], [1], [1], [1], [0], [0]];
871 let c: Vec<Complex64> = vec![
872 Complex64::new(1.0, 0.0),
873 Complex64::new(1.0, 0.0),
874 Complex64::new(0.5, 0.0),
875 Complex64::new(0.5, 0.0),
876 Complex64::new(0.5, 0.0),
877 Complex64::new(0.5, 0.0),
878 Complex64::new(0.5, 0.0),
879 Complex64::new(2.0, 0.0),
880 Complex64::new(2.0, 0.0),
881 ];
882 let out = merge_phase_with_nchunks::<1, _>(&x, &z, &c, 9, 1, &AlwaysKeep, 2);
883 assert_eq!(out.len(), 3);
884 assert_eq!(out.x()[0], [0]);
885 assert_eq!(out.z()[0], [0]);
886 assert_eq!(out.coeff()[0], Complex64::new(2.0, 0.0));
887 assert_eq!(out.x()[1], [0]);
888 assert_eq!(out.z()[1], [1]);
889 assert_eq!(out.coeff()[1], Complex64::new(2.5, 0.0));
890 assert_eq!(out.x()[2], [1]);
891 assert_eq!(out.z()[2], [0]);
892 assert_eq!(out.coeff()[2], Complex64::new(4.0, 0.0));
893 out.assert_invariants();
894 }
895
896 #[test]
899 fn merge_phase_aligned_boundary_no_shift() {
900 let x: Vec<[u64; 1]> = vec![[0], [0], [0], [1], [1], [1]];
902 let z: Vec<[u64; 1]> = vec![[0], [0], [0], [0], [0], [0]];
903 let c: Vec<Complex64> = vec![
904 Complex64::new(1.0, 0.0),
905 Complex64::new(2.0, 0.0),
906 Complex64::new(3.0, 0.0),
907 Complex64::new(4.0, 0.0),
908 Complex64::new(5.0, 0.0),
909 Complex64::new(6.0, 0.0),
910 ];
911 let out = merge_phase_with_nchunks::<1, _>(&x, &z, &c, 6, 1, &AlwaysKeep, 2);
912 assert_eq!(out.len(), 2);
913 assert_eq!(out.x()[0], [0]);
914 assert_eq!(out.z()[0], [0]);
915 assert_eq!(out.coeff()[0], Complex64::new(6.0, 0.0));
916 assert_eq!(out.x()[1], [1]);
917 assert_eq!(out.z()[1], [0]);
918 assert_eq!(out.coeff()[1], Complex64::new(15.0, 0.0));
919 out.assert_invariants();
920 }
921
922 #[test]
925 fn merge_phase_all_same_key_with_nchunks() {
926 let x: Vec<[u64; 1]> = vec![[0], [0], [0], [0], [0]];
927 let z: Vec<[u64; 1]> = vec![[1], [1], [1], [1], [1]];
928 let c: Vec<Complex64> = vec![
929 Complex64::new(1.0, 0.0),
930 Complex64::new(1.0, 0.0),
931 Complex64::new(1.0, 0.0),
932 Complex64::new(1.0, 0.0),
933 Complex64::new(1.0, 0.0),
934 ];
935 let out = merge_phase_with_nchunks::<1, _>(&x, &z, &c, 5, 1, &AlwaysKeep, 4);
936 assert_eq!(out.len(), 1);
937 assert_eq!(out.x()[0], [0]);
938 assert_eq!(out.z()[0], [1]);
939 assert_eq!(out.coeff()[0], Complex64::new(5.0, 0.0));
940 out.assert_invariants();
941 }
942}