paulistrings/engine/
sort_merge.rs

1//! Sort-merge propagation pipeline: scan → bucket → merge. See §5.
2
3#![allow(unused)]
4
5use num_complex::Complex64;
6use rayon::prelude::*;
7
8use crate::channel::{Channel, OutputBuffer};
9use crate::pauli_sum::PauliSum;
10use crate::truncation::TruncationPolicy;
11
12/// Empirical threshold below which a hashmap-based fast path beats sort-merge
13/// (§8.3). Subject to benchmarking.
14pub const SMALL_SUM_THRESHOLD: usize = 4096;
15
16/// Apply a single channel to a `PauliSum`, producing the next layer.
17///
18/// Implements the three-phase pipeline:
19///   1. **Scan** — `n_in × MAX_FANOUT` data-parallel channel applications.
20///   2. **Sort** — stable lex sort of the populated prefix (Phase 6
21///      placeholder for the bucket-scatter optimization deferred to
22///      Phase 11; see slice plan §6.2).
23///   3. **Merge** — segmented reduction; `keep_term` integration arrives
24///      in Phase 7.
25pub fn apply_layer<const W: usize, C, T>(
26    input: &PauliSum<W>,
27    channel: &C,
28    policy: &T,
29) -> PauliSum<W>
30where
31    C: Channel<W> + ?Sized,
32    T: TruncationPolicy<W> + ?Sized,
33{
34    apply_layer_inner(input, channel, policy, /* adjoint = */ false)
35}
36
37/// Apply a channel's adjoint to a `PauliSum`. Used by `propagate` in
38/// `Direction::Heisenberg` mode; structurally identical to `apply_layer`
39/// but routes through `Channel::apply_adjoint`.
40pub fn apply_layer_adjoint<const W: usize, C, T>(
41    input: &PauliSum<W>,
42    channel: &C,
43    policy: &T,
44) -> PauliSum<W>
45where
46    C: Channel<W> + ?Sized,
47    T: TruncationPolicy<W> + ?Sized,
48{
49    apply_layer_inner(input, channel, policy, /* adjoint = */ true)
50}
51
52fn apply_layer_inner<const W: usize, C, T>(
53    input: &PauliSum<W>,
54    channel: &C,
55    policy: &T,
56    adjoint: bool,
57) -> PauliSum<W>
58where
59    C: Channel<W> + ?Sized,
60    T: TruncationPolicy<W> + ?Sized,
61{
62    let n_in = input.len();
63    let mf = channel.max_fanout();
64    let cap = n_in * mf;
65    let mut out_x: Vec<[u64; W]> = vec![[0u64; W]; cap];
66    let mut out_z: Vec<[u64; W]> = vec![[0u64; W]; cap];
67    let mut out_coeff: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); cap];
68    let len = if adjoint {
69        scan_phase_adjoint(input, channel, &mut out_x, &mut out_z, &mut out_coeff)
70    } else {
71        scan_phase(input, channel, &mut out_x, &mut out_z, &mut out_coeff)
72    };
73    sort_phase(&mut out_x, &mut out_z, &mut out_coeff, len);
74    let result = merge_phase::<W, T>(&out_x, &out_z, &out_coeff, len, input.num_qubits(), policy);
75    #[cfg(debug_assertions)]
76    result.assert_invariants();
77    result
78}
79
80/// Phase 1 of `apply_layer`: walk the input `PauliSum` and write each input
81/// term's outputs into a flat scratch buffer.
82///
83/// The caller pre-sizes the three slices to at least
84/// `input.len() * channel.max_fanout()`. Each input `i` is assigned the
85/// disjoint slot `[i*mf, (i+1)*mf)` in the buffer; per-input fills run in
86/// parallel via `rayon::par_chunks_mut`, after which a sequential compaction
87/// pass packs the populated prefixes contiguously. The slot layout depends
88/// only on `(i, mf)` — not on the rayon thread count or scheduling — so the
89/// final byte layout is deterministic across thread pool sizes (slice 8.1).
90///
91/// The function returns the actual number of output terms written (which may
92/// be less than `n_in * max_fanout` when the channel produces variable
93/// per-input fanout, e.g. `PauliRotation`).
94///
95/// `apply_fn` selects between forward (`Channel::apply`) and Heisenberg
96/// (`Channel::apply_adjoint`); the two scan-phase callers differ only in
97/// that one method call. It is `Fn + Sync` so the rayon worker threads can
98/// share it.
99///
100/// Output is *not* sorted — that's slice 6.2's job.
101fn scan_phase_with<const W: usize, C, F>(
102    input: &PauliSum<W>,
103    channel: &C,
104    out_x: &mut [[u64; W]],
105    out_z: &mut [[u64; W]],
106    out_coeff: &mut [Complex64],
107    apply_fn: F,
108) -> usize
109where
110    C: Channel<W> + ?Sized,
111    F: Fn(&C, &[u64; W], &[u64; W], Complex64, &mut OutputBuffer<'_, W>) + Sync,
112{
113    let mf = channel.max_fanout();
114    let n_in = input.len();
115    debug_assert!(out_x.len() >= n_in * mf);
116    debug_assert_eq!(out_x.len(), out_z.len());
117    debug_assert_eq!(out_x.len(), out_coeff.len());
118    if n_in == 0 || mf == 0 {
119        return 0;
120    }
121    let cap = n_in * mf;
122    let lens: Vec<usize> = out_x[..cap]
123        .par_chunks_mut(mf)
124        .zip(out_z[..cap].par_chunks_mut(mf))
125        .zip(out_coeff[..cap].par_chunks_mut(mf))
126        .enumerate()
127        .map(|(i, ((sx, sz), sc))| {
128            let mut local_len = 0usize;
129            let mut buf = OutputBuffer::<W> {
130                x: sx,
131                z: sz,
132                coeff: sc,
133                len: &mut local_len,
134            };
135            apply_fn(channel, &input.x[i], &input.z[i], input.coeff[i], &mut buf);
136            local_len
137        })
138        .collect();
139    compact_in_place(out_x, out_z, out_coeff, &lens, mf)
140}
141
142/// Pack the per-input populated prefixes `[i*mf, i*mf + lens[i])` into a
143/// contiguous prefix `[0..total)` of the three SoA buffers. Source and
144/// destination overlap, but the source index is always `>=` the destination,
145/// so `slice::copy_within` (memmove semantics) is correct.
146fn compact_in_place<const W: usize>(
147    out_x: &mut [[u64; W]],
148    out_z: &mut [[u64; W]],
149    out_coeff: &mut [Complex64],
150    lens: &[usize],
151    mf: usize,
152) -> usize {
153    let mut write = 0usize;
154    for (i, &len_i) in lens.iter().enumerate() {
155        if len_i == 0 {
156            continue;
157        }
158        let src = i * mf;
159        debug_assert!(write <= src);
160        if write != src {
161            out_x.copy_within(src..src + len_i, write);
162            out_z.copy_within(src..src + len_i, write);
163            out_coeff.copy_within(src..src + len_i, write);
164        }
165        write += len_i;
166    }
167    write
168}
169
170/// Forward scan: dispatches each input through `Channel::apply`.
171pub(crate) fn scan_phase<const W: usize, C: Channel<W> + ?Sized>(
172    input: &PauliSum<W>,
173    channel: &C,
174    out_x: &mut [[u64; W]],
175    out_z: &mut [[u64; W]],
176    out_coeff: &mut [Complex64],
177) -> usize {
178    scan_phase_with(
179        input,
180        channel,
181        out_x,
182        out_z,
183        out_coeff,
184        |c, x, z, co, out| c.apply(x, z, co, out),
185    )
186}
187
188/// Heisenberg scan: dispatches each input through `Channel::apply_adjoint`.
189pub(crate) fn scan_phase_adjoint<const W: usize, C: Channel<W> + ?Sized>(
190    input: &PauliSum<W>,
191    channel: &C,
192    out_x: &mut [[u64; W]],
193    out_z: &mut [[u64; W]],
194    out_coeff: &mut [Complex64],
195) -> usize {
196    scan_phase_with(
197        input,
198        channel,
199        out_x,
200        out_z,
201        out_coeff,
202        |c, x, z, co, out| c.apply_adjoint(x, z, co, out),
203    )
204}
205
206/// Phase 2: stably sort the populated prefix `[0..len)` of the scratch
207/// buffers by `(x, z)` lex order — the same key the `PauliString` `Ord`
208/// impl uses (`x[0]..x[W-1]` then `z[0]..z[W-1]`).
209///
210/// Stability matters once truncation policies depend on insertion order
211/// (Phase 7), and matches the design doc's "within-bucket relative order
212/// inherited from input" contract (§5). The implementation is `O(n log n)`
213/// instead of the design-doc's `O(n)` bucket scatter; the bucket
214/// optimization is deferred to Phase 11 with profile data.
215pub(crate) fn sort_phase<const W: usize>(
216    out_x: &mut [[u64; W]],
217    out_z: &mut [[u64; W]],
218    out_coeff: &mut [Complex64],
219    len: usize,
220) {
221    debug_assert!(out_x.len() >= len);
222    debug_assert_eq!(out_x.len(), out_z.len());
223    debug_assert_eq!(out_x.len(), out_coeff.len());
224    if len < 2 {
225        return;
226    }
227    let mut perm: Vec<usize> = (0..len).collect();
228    // `[u64; W]`'s built-in `Ord` is lex over array elements, identical to
229    // `PauliString::cmp`'s loop body.
230    perm.sort_by(|&a, &b| {
231        out_x[a]
232            .cmp(&out_x[b])
233            .then_with(|| out_z[a].cmp(&out_z[b]))
234    });
235    let new_x: Vec<[u64; W]> = perm.iter().map(|&i| out_x[i]).collect();
236    let new_z: Vec<[u64; W]> = perm.iter().map(|&i| out_z[i]).collect();
237    let new_c: Vec<Complex64> = perm.iter().map(|&i| out_coeff[i]).collect();
238    out_x[..len].copy_from_slice(&new_x);
239    out_z[..len].copy_from_slice(&new_z);
240    out_coeff[..len].copy_from_slice(&new_c);
241}
242
243/// Empirical threshold below which the parallel merge's overhead dominates.
244/// Below this, `merge_phase` collapses to a single sequential chunk.
245const SMALL_MERGE_THRESHOLD: usize = 1024;
246
247/// SoA triple emitted by a per-chunk merge — the same shape as `PauliSum`'s
248/// internal storage, deferred into a `Vec` until concatenation.
249type ChunkOutput<const W: usize> = (Vec<[u64; W]>, Vec<[u64; W]>, Vec<Complex64>);
250
251/// Phase 3: segmented reduction over the sorted scratch into a fresh
252/// `PauliSum`.
253///
254/// The input slices are the populated prefix `[0..len)` of the sort_phase
255/// output; they must be sorted by `(x, z)` (`debug_assert`-checked). Adjacent
256/// runs of equal keys have their coefficients summed; runs whose summed
257/// coefficient is exactly `0+0i` are dropped, and `policy.keep_term` is
258/// consulted on the *summed* coefficient — terms it rejects are dropped here
259/// rather than in a post-pass (slice 7.1).
260///
261/// The reduction is parallelized via chunked segment-aware merging
262/// (slice 8.2 / design doc §9): the populated prefix is partitioned into
263/// `rayon::current_num_threads()` chunks whose boundaries are advanced
264/// forward to the next run break, so every run is fully contained in exactly
265/// one chunk. Each chunk runs the same per-run reduction independently;
266/// the boundary "reconciliation pass" is the alignment step itself, so the
267/// chunk results are concatenated without further merging.
268pub(crate) fn merge_phase<const W: usize, T: TruncationPolicy<W> + ?Sized>(
269    sorted_x: &[[u64; W]],
270    sorted_z: &[[u64; W]],
271    sorted_coeff: &[Complex64],
272    len: usize,
273    num_qubits: usize,
274    policy: &T,
275) -> PauliSum<W> {
276    let nchunks = if len < SMALL_MERGE_THRESHOLD {
277        1
278    } else {
279        rayon::current_num_threads().max(1)
280    };
281    merge_phase_with_nchunks::<W, T>(
282        sorted_x,
283        sorted_z,
284        sorted_coeff,
285        len,
286        num_qubits,
287        policy,
288        nchunks,
289    )
290}
291
292/// `merge_phase` with an explicit chunk count. Public to the crate so tests
293/// can pin `nchunks` and force runs to straddle boundaries; `merge_phase`
294/// itself derives `nchunks` from `rayon::current_num_threads()`.
295pub(crate) fn merge_phase_with_nchunks<const W: usize, T: TruncationPolicy<W> + ?Sized>(
296    sorted_x: &[[u64; W]],
297    sorted_z: &[[u64; W]],
298    sorted_coeff: &[Complex64],
299    len: usize,
300    num_qubits: usize,
301    policy: &T,
302    nchunks: usize,
303) -> PauliSum<W> {
304    debug_assert!(sorted_x.len() >= len);
305    debug_assert_eq!(sorted_x.len(), sorted_z.len());
306    debug_assert_eq!(sorted_x.len(), sorted_coeff.len());
307    if len == 0 {
308        return PauliSum::<W> {
309            x: Vec::new(),
310            z: Vec::new(),
311            coeff: Vec::new(),
312            num_qubits,
313        };
314    }
315    let bounds = align_chunk_boundaries(sorted_x, sorted_z, len, nchunks.max(1));
316    let chunk_results: Vec<ChunkOutput<W>> = bounds
317        .par_iter()
318        .map(|&(start, end)| {
319            merge_chunk::<W, T>(sorted_x, sorted_z, sorted_coeff, start, end, policy)
320        })
321        .collect();
322    let total: usize = chunk_results.iter().map(|(cx, _, _)| cx.len()).sum();
323    let mut x = Vec::with_capacity(total);
324    let mut z = Vec::with_capacity(total);
325    let mut coeff = Vec::with_capacity(total);
326    for (cx, cz, cc) in chunk_results {
327        x.extend(cx);
328        z.extend(cz);
329        coeff.extend(cc);
330    }
331    PauliSum::<W> {
332        x,
333        z,
334        coeff,
335        num_qubits,
336    }
337}
338
339/// Run the segmented reduction on the sub-range `[start..end)` of the sorted
340/// scratch. Caller must ensure runs are not split: `(sorted_x[start-1],
341/// sorted_z[start-1]) != (sorted_x[start], sorted_z[start])` whenever
342/// `start > 0`, and similarly at `end`. With that invariant, the chunk's
343/// reduction is identical to the sequential merge on the same sub-range —
344/// `keep_term` and the zero-drop check operate on fully-summed coefficients.
345fn merge_chunk<const W: usize, T: TruncationPolicy<W> + ?Sized>(
346    sorted_x: &[[u64; W]],
347    sorted_z: &[[u64; W]],
348    sorted_coeff: &[Complex64],
349    start: usize,
350    end: usize,
351    policy: &T,
352) -> ChunkOutput<W> {
353    let zero = Complex64::new(0.0, 0.0);
354    let mut x: Vec<[u64; W]> = Vec::new();
355    let mut z: Vec<[u64; W]> = Vec::new();
356    let mut coeff: Vec<Complex64> = Vec::new();
357    let mut i = start;
358    while i < end {
359        let key_x = sorted_x[i];
360        let key_z = sorted_z[i];
361        let mut acc = sorted_coeff[i];
362        let mut j = i + 1;
363        while j < end && sorted_x[j] == key_x && sorted_z[j] == key_z {
364            acc += sorted_coeff[j];
365            j += 1;
366        }
367        debug_assert!(
368            i == 0 || (sorted_x[i - 1], sorted_z[i - 1]) <= (key_x, key_z),
369            "merge_chunk: scratch is not sorted at index {}",
370            i,
371        );
372        if acc != zero && policy.keep_term(&key_x, &key_z, acc) {
373            x.push(key_x);
374            z.push(key_z);
375            coeff.push(acc);
376        }
377        i = j;
378    }
379    (x, z, coeff)
380}
381
382/// Partition `[0..len)` into `nchunks` non-empty sub-ranges whose interior
383/// boundaries land at run breaks. The "natural" boundary `len * k / nchunks`
384/// is advanced forward (or to `len`) until the keys at `t-1` and `t` differ.
385/// Boundaries that collapse onto each other are deduped, so the returned
386/// vector may contain fewer than `nchunks` chunks.
387fn align_chunk_boundaries<const W: usize>(
388    sorted_x: &[[u64; W]],
389    sorted_z: &[[u64; W]],
390    len: usize,
391    nchunks: usize,
392) -> Vec<(usize, usize)> {
393    if len == 0 {
394        return Vec::new();
395    }
396    if nchunks <= 1 {
397        return vec![(0, len)];
398    }
399    let mut bounds: Vec<usize> = Vec::with_capacity(nchunks + 1);
400    bounds.push(0);
401    for k in 1..nchunks {
402        let mut t = (len * k) / nchunks;
403        // Advance to the start of a fresh run (or to `len`).
404        while t > 0 && t < len && sorted_x[t] == sorted_x[t - 1] && sorted_z[t] == sorted_z[t - 1] {
405            t += 1;
406        }
407        bounds.push(t);
408    }
409    bounds.push(len);
410    bounds.dedup();
411    bounds.windows(2).map(|w| (w[0], w[1])).collect()
412}
413
414#[cfg(all(test, debug_assertions))]
415mod tests {
416    use super::*;
417    use crate::channel::{Clifford1Q, IdentityChannel, PauliRotation};
418    use crate::pauli_string::PauliString;
419    use crate::truncation::CoefficientThreshold;
420
421    const TOL: f64 = 1e-12;
422
423    #[allow(clippy::type_complexity)]
424    fn alloc_bufs<const W: usize>(n: usize) -> (Vec<[u64; W]>, Vec<[u64; W]>, Vec<Complex64>) {
425        (
426            vec![[0u64; W]; n],
427            vec![[0u64; W]; n],
428            vec![Complex64::new(0.0, 0.0); n],
429        )
430    }
431
432    fn approx_eq(a: Complex64, b: Complex64, tol: f64) -> bool {
433        (a - b).norm() <= tol
434    }
435
436    /// `IdentityChannel` writes input through unchanged, in input order.
437    #[test]
438    fn scan_identity_passes_through_w1() {
439        let input = PauliSum::<1>::from_strings(&[
440            ("X", Complex64::new(1.0, 0.0)),
441            ("Z", Complex64::new(2.0, 0.0)),
442        ]);
443        let id = IdentityChannel::new();
444        let cap = input.len() * <IdentityChannel as Channel<1>>::max_fanout(&id);
445        let (mut bx, mut bz, mut bc) = alloc_bufs::<1>(cap);
446        let total = scan_phase(&input, &id, &mut bx, &mut bz, &mut bc);
447        assert_eq!(total, 2);
448        for i in 0..total {
449            assert_eq!(bx[i], input.x()[i]);
450            assert_eq!(bz[i], input.z()[i]);
451            assert_eq!(bc[i], input.coeff()[i]);
452        }
453    }
454
455    /// `H` conjugates `Z → X` and `X → Z` (with phase +1), preserving coeffs.
456    /// Output is in input order; sort happens in slice 6.2.
457    #[test]
458    fn scan_clifford1q_h_w1() {
459        let input = PauliSum::<1>::from_strings(&[
460            ("Z", Complex64::new(3.0, 0.0)),
461            ("X", Complex64::new(5.0, 0.0)),
462        ]);
463        let h = Clifford1Q::h(0);
464        let cap = input.len() * <Clifford1Q as Channel<1>>::max_fanout(&h);
465        let (mut bx, mut bz, mut bc) = alloc_bufs::<1>(cap);
466        let total = scan_phase(&input, &h, &mut bx, &mut bz, &mut bc);
467        assert_eq!(total, 2);
468        // Input order: from_strings sorts lex by (x, z) — Z (x=0,z=1) sorts
469        // before X (x=1,z=0). So scan output[0] = H·Z = X, output[1] = H·X = Z.
470        assert_eq!(bx[0], PauliString::<1>::x(0).x);
471        assert_eq!(bz[0], PauliString::<1>::x(0).z);
472        assert_eq!(bc[0], Complex64::new(3.0, 0.0));
473        assert_eq!(bx[1], PauliString::<1>::z(0).x);
474        assert_eq!(bz[1], PauliString::<1>::z(0).z);
475        assert_eq!(bc[1], Complex64::new(5.0, 0.0));
476    }
477
478    /// `PauliRotation` has `MAX_FANOUT = 2` but emits a single term when the
479    /// input commutes with the generator. The scan packs outputs contiguously
480    /// — no gaps left by a 1-output input.
481    #[test]
482    fn scan_pauli_rotation_packs_variable_fanout() {
483        // Rotation around Z. Input "X" anticommutes (fanout 2); "Z" commutes
484        // (fanout 1). Total = 3 outputs from 2 inputs.
485        let p = PauliString::<1>::z(0);
486        let theta = std::f64::consts::FRAC_PI_3;
487        let rot = PauliRotation::<1> {
488            support: vec![0],
489            gen_x: p.x,
490            gen_z: p.z,
491            theta,
492        };
493        let input = PauliSum::<1>::from_strings(&[
494            ("X", Complex64::new(1.0, 0.0)),
495            ("Z", Complex64::new(2.0, 0.0)),
496        ]);
497        let cap = input.len() * <PauliRotation<1> as Channel<1>>::max_fanout(&rot);
498        let (mut bx, mut bz, mut bc) = alloc_bufs::<1>(cap);
499        let total = scan_phase(&input, &rot, &mut bx, &mut bz, &mut bc);
500        assert_eq!(total, 3);
501        // from_strings sorts: Z (x=0,z=1) < X (x=1,z=0). So input[0] = Z,
502        // input[1] = X.
503        // Output[0]: rot(Z) = Z (commutes, fanout 1) with coeff 2.
504        assert_eq!(bx[0], PauliString::<1>::z(0).x);
505        assert_eq!(bz[0], PauliString::<1>::z(0).z);
506        assert!(approx_eq(bc[0], Complex64::new(2.0, 0.0), TOL));
507        // Output[1]: cos·X with coeff 1.
508        assert_eq!(bx[1], PauliString::<1>::x(0).x);
509        assert_eq!(bz[1], PauliString::<1>::x(0).z);
510        assert!(approx_eq(bc[1], Complex64::new(theta.cos(), 0.0), TOL));
511        // Output[2]: sin·Y (X·Z = -iY, Phase::I + 3 = 0, so coeff = sin·1 = sin·1).
512        // Working: input.coeff=1, Phase::I + mul_phase=Phase::I + 3 = 0 ⇒ Phase::ONE.
513        // So bc[2] = 1·sin(θ).
514        assert_eq!(bx[2], PauliString::<1>::y(0).x);
515        assert_eq!(bz[2], PauliString::<1>::y(0).z);
516        assert!(approx_eq(bc[2], Complex64::new(theta.sin(), 0.0), TOL));
517    }
518
519    /// Multi-word: input on qubit 64 (word 1), `H` flips X↔Z within word 1.
520    #[test]
521    fn scan_w2_word_boundary() {
522        let input = PauliSum::<2> {
523            x: vec![[0u64, 1u64]], // X on qubit 64
524            z: vec![[0u64; 2]],
525            coeff: vec![Complex64::new(1.5, 0.0)],
526            num_qubits: 65,
527        };
528        let h = Clifford1Q::h(64);
529        let cap = input.len() * <Clifford1Q as Channel<2>>::max_fanout(&h);
530        let (mut bx, mut bz, mut bc) = alloc_bufs::<2>(cap);
531        let total = scan_phase(&input, &h, &mut bx, &mut bz, &mut bc);
532        assert_eq!(total, 1);
533        // H·X = Z, on qubit 64 (word 1, bit 0).
534        assert_eq!(bx[0], [0u64, 0u64]);
535        assert_eq!(bz[0], [0u64, 1u64]);
536        assert_eq!(bc[0], Complex64::new(1.5, 0.0));
537    }
538
539    /// Empty input → zero outputs; doesn't read the buffer.
540    #[test]
541    fn scan_empty_input() {
542        let input = PauliSum::<1>::empty(4);
543        let id = IdentityChannel::new();
544        let mut bx: Vec<[u64; 1]> = vec![];
545        let mut bz: Vec<[u64; 1]> = vec![];
546        let mut bc: Vec<Complex64> = vec![];
547        let total = scan_phase(&input, &id, &mut bx, &mut bz, &mut bc);
548        assert_eq!(total, 0);
549    }
550
551    /// Slice 8.1: same input under different rayon thread-pool sizes must
552    /// produce byte-identical output through `apply_layer`. Per-input slot
553    /// bounds depend only on `(i, mf)`, so the scan output is deterministic;
554    /// the merge phase processes runs in input order regardless of thread
555    /// count, so floating-point summation is bit-stable too.
556    #[test]
557    fn scan_determinism_across_thread_counts() {
558        // 16 distinct 2-qubit terms, fanout-1 channel (`H` on qubit 0).
559        let labels = ["I", "X", "Y", "Z"];
560        let mut owned: Vec<(String, Complex64)> = Vec::new();
561        let mut k: i64 = 0;
562        for a in &labels {
563            for b in &labels {
564                k += 1;
565                owned.push((
566                    format!("{}{}", a, b),
567                    Complex64::new(k as f64 * 0.13, k as f64 * 0.07),
568                ));
569            }
570        }
571        let strings: Vec<(&str, Complex64)> = owned.iter().map(|(s, c)| (s.as_str(), *c)).collect();
572        let input = PauliSum::<1>::from_strings(&strings);
573        let h = Clifford1Q::h(0);
574        let policy = AlwaysKeep;
575
576        let run = |n: usize| -> PauliSum<1> {
577            let pool = rayon::ThreadPoolBuilder::new()
578                .num_threads(n)
579                .build()
580                .unwrap();
581            pool.install(|| apply_layer(&input, &h, &policy))
582        };
583        let r1 = run(1);
584        let r2 = run(2);
585        let r4 = run(4);
586
587        assert_eq!(r1.x(), r2.x());
588        assert_eq!(r1.z(), r2.z());
589        assert_eq!(r1.coeff(), r2.coeff());
590        assert_eq!(r1.x(), r4.x());
591        assert_eq!(r1.z(), r4.z());
592        assert_eq!(r1.coeff(), r4.coeff());
593    }
594
595    /// Hand-built unsorted scratch becomes sorted; coeffs follow their keys.
596    /// Lex on `(x, z)`: I < Z < X < Y per word (since X has x=1>0 and Z has z=1
597    /// only, x[0] dominates).
598    #[test]
599    fn sort_phase_orders_by_lex_key() {
600        // Three terms in non-sorted order: X, I, Z. Expected sort: I, Z, X.
601        let mut x: Vec<[u64; 1]> = vec![[1], [0], [0]];
602        let mut z: Vec<[u64; 1]> = vec![[0], [0], [1]];
603        let mut c: Vec<Complex64> = vec![
604            Complex64::new(7.0, 0.0), // X tag
605            Complex64::new(8.0, 0.0), // I tag
606            Complex64::new(9.0, 0.0), // Z tag
607        ];
608        sort_phase(&mut x, &mut z, &mut c, 3);
609        assert_eq!(x[0], [0]);
610        assert_eq!(z[0], [0]);
611        assert_eq!(c[0], Complex64::new(8.0, 0.0));
612        assert_eq!(x[1], [0]);
613        assert_eq!(z[1], [1]);
614        assert_eq!(c[1], Complex64::new(9.0, 0.0));
615        assert_eq!(x[2], [1]);
616        assert_eq!(z[2], [0]);
617        assert_eq!(c[2], Complex64::new(7.0, 0.0));
618    }
619
620    /// A pre-sorted scratch survives sort_phase intact.
621    #[test]
622    fn sort_phase_preserves_already_sorted() {
623        let mut x: Vec<[u64; 1]> = vec![[0], [0], [1]];
624        let mut z: Vec<[u64; 1]> = vec![[0], [1], [0]];
625        let mut c: Vec<Complex64> = vec![
626            Complex64::new(1.0, 0.0),
627            Complex64::new(2.0, 0.0),
628            Complex64::new(3.0, 0.0),
629        ];
630        sort_phase(&mut x, &mut z, &mut c, 3);
631        assert_eq!(x, vec![[0u64], [0u64], [1u64]]);
632        assert_eq!(z, vec![[0u64], [1u64], [0u64]]);
633        assert_eq!(
634            c,
635            vec![
636                Complex64::new(1.0, 0.0),
637                Complex64::new(2.0, 0.0),
638                Complex64::new(3.0, 0.0),
639            ]
640        );
641    }
642
643    /// Stable sort: two outputs with identical (x, z) keep their input
644    /// relative order. Distinguishable via their coefficients.
645    #[test]
646    fn sort_phase_is_stable_on_equal_keys() {
647        // Three Z terms with coeffs 1, 2, 3 in input order, plus one X
648        // separating coefficient 1 and 2 to force a non-trivial permutation.
649        let mut x: Vec<[u64; 1]> = vec![[0], [1], [0], [0]];
650        let mut z: Vec<[u64; 1]> = vec![[1], [0], [1], [1]];
651        let mut c: Vec<Complex64> = vec![
652            Complex64::new(1.0, 0.0),  // Z (first)
653            Complex64::new(99.0, 0.0), // X
654            Complex64::new(2.0, 0.0),  // Z (second)
655            Complex64::new(3.0, 0.0),  // Z (third)
656        ];
657        sort_phase(&mut x, &mut z, &mut c, 4);
658        // Sorted order: Z, Z, Z, X. Z coeffs must come out in input order
659        // 1, 2, 3 — the stability check.
660        assert_eq!(x[0], [0]);
661        assert_eq!(z[0], [1]);
662        assert_eq!(c[0], Complex64::new(1.0, 0.0));
663        assert_eq!(x[1], [0]);
664        assert_eq!(z[1], [1]);
665        assert_eq!(c[1], Complex64::new(2.0, 0.0));
666        assert_eq!(x[2], [0]);
667        assert_eq!(z[2], [1]);
668        assert_eq!(c[2], Complex64::new(3.0, 0.0));
669        assert_eq!(x[3], [1]);
670        assert_eq!(z[3], [0]);
671        assert_eq!(c[3], Complex64::new(99.0, 0.0));
672    }
673
674    /// W=2: lex priority is `x[0]` first (low word), then `x[1]`, then `z[0]`,
675    /// `z[1]`. Word 0 dominates word 1 in cmp.
676    #[test]
677    fn sort_phase_w2_cross_word_priority() {
678        // Term A: x=[0, 99], z=[0, 0]   (X-bits in word 1)
679        // Term B: x=[1, 0],  z=[0, 0]   (X-bit in word 0, low value)
680        // Lex cmp: A.x[0]=0 < B.x[0]=1, so A < B.
681        let mut x: Vec<[u64; 2]> = vec![[1, 0], [0, 99]];
682        let mut z: Vec<[u64; 2]> = vec![[0, 0], [0, 0]];
683        let mut c: Vec<Complex64> = vec![
684            Complex64::new(11.0, 0.0), // B
685            Complex64::new(22.0, 0.0), // A
686        ];
687        sort_phase(&mut x, &mut z, &mut c, 2);
688        assert_eq!(x[0], [0, 99]);
689        assert_eq!(c[0], Complex64::new(22.0, 0.0));
690        assert_eq!(x[1], [1, 0]);
691        assert_eq!(c[1], Complex64::new(11.0, 0.0));
692    }
693
694    /// `len < 2` is a no-op short-circuit. Verify behavior on empty and
695    /// single-element prefixes.
696    #[test]
697    fn sort_phase_len_lt_2_is_noop() {
698        let mut x: Vec<[u64; 1]> = vec![[5]];
699        let mut z: Vec<[u64; 1]> = vec![[7]];
700        let mut c: Vec<Complex64> = vec![Complex64::new(1.0, 2.0)];
701        sort_phase(&mut x, &mut z, &mut c, 1);
702        assert_eq!(x[0], [5]);
703        assert_eq!(z[0], [7]);
704        assert_eq!(c[0], Complex64::new(1.0, 2.0));
705
706        let mut empty_x: Vec<[u64; 1]> = vec![];
707        let mut empty_z: Vec<[u64; 1]> = vec![];
708        let mut empty_c: Vec<Complex64> = vec![];
709        sort_phase(&mut empty_x, &mut empty_z, &mut empty_c, 0);
710        assert!(empty_x.is_empty());
711    }
712
713    /// Truncation policy that always keeps terms — exercises the trait bound
714    /// without affecting merge_phase output (Phase 6 doesn't fold keep_term
715    /// into the merge yet).
716    struct AlwaysKeep;
717    impl<const W: usize> TruncationPolicy<W> for AlwaysKeep {}
718
719    #[test]
720    fn merge_phase_empty_input() {
721        let x: Vec<[u64; 1]> = vec![];
722        let z: Vec<[u64; 1]> = vec![];
723        let c: Vec<Complex64> = vec![];
724        let out = merge_phase::<1, _>(&x, &z, &c, 0, 4, &AlwaysKeep);
725        assert!(out.is_empty());
726        assert_eq!(out.num_qubits(), 4);
727        out.assert_invariants();
728    }
729
730    #[test]
731    fn merge_phase_distinct_keys_pass_through() {
732        // Sorted: I, Z, X.
733        let x: Vec<[u64; 1]> = vec![[0], [0], [1]];
734        let z: Vec<[u64; 1]> = vec![[0], [1], [0]];
735        let c: Vec<Complex64> = vec![
736            Complex64::new(1.0, 0.0),
737            Complex64::new(2.0, 0.0),
738            Complex64::new(3.0, 0.0),
739        ];
740        let out = merge_phase::<1, _>(&x, &z, &c, 3, 1, &AlwaysKeep);
741        assert_eq!(out.len(), 3);
742        assert_eq!(out.x(), &[[0u64], [0u64], [1u64]]);
743        assert_eq!(out.z(), &[[0u64], [1u64], [0u64]]);
744        assert_eq!(
745            out.coeff(),
746            &[
747                Complex64::new(1.0, 0.0),
748                Complex64::new(2.0, 0.0),
749                Complex64::new(3.0, 0.0),
750            ]
751        );
752        out.assert_invariants();
753    }
754
755    #[test]
756    fn merge_phase_combines_adjacent_duplicates() {
757        // Three Z entries (coeffs 1, 2, 3) followed by one X (coeff 7).
758        let x: Vec<[u64; 1]> = vec![[0], [0], [0], [1]];
759        let z: Vec<[u64; 1]> = vec![[1], [1], [1], [0]];
760        let c: Vec<Complex64> = vec![
761            Complex64::new(1.0, 0.0),
762            Complex64::new(2.0, 0.0),
763            Complex64::new(3.0, 0.0),
764            Complex64::new(7.0, 0.0),
765        ];
766        let out = merge_phase::<1, _>(&x, &z, &c, 4, 1, &AlwaysKeep);
767        assert_eq!(out.len(), 2);
768        // Z with summed coeff 6, then X with 7.
769        assert_eq!(out.x()[0], [0]);
770        assert_eq!(out.z()[0], [1]);
771        assert_eq!(out.coeff()[0], Complex64::new(6.0, 0.0));
772        assert_eq!(out.x()[1], [1]);
773        assert_eq!(out.z()[1], [0]);
774        assert_eq!(out.coeff()[1], Complex64::new(7.0, 0.0));
775        out.assert_invariants();
776    }
777
778    #[test]
779    fn merge_phase_drops_exact_zero_runs() {
780        // Z with coeffs +1 and -1 → cancels to zero, dropped. X with 5
781        // survives.
782        let x: Vec<[u64; 1]> = vec![[0], [0], [1]];
783        let z: Vec<[u64; 1]> = vec![[1], [1], [0]];
784        let c: Vec<Complex64> = vec![
785            Complex64::new(1.0, 0.0),
786            Complex64::new(-1.0, 0.0),
787            Complex64::new(5.0, 0.0),
788        ];
789        let out = merge_phase::<1, _>(&x, &z, &c, 3, 1, &AlwaysKeep);
790        assert_eq!(out.len(), 1);
791        assert_eq!(out.x()[0], [1]);
792        assert_eq!(out.z()[0], [0]);
793        assert_eq!(out.coeff()[0], Complex64::new(5.0, 0.0));
794        out.assert_invariants();
795    }
796
797    /// Slice 7.1: the policy's `keep_term` runs inside the merge loop.
798    /// `CoefficientThreshold(1e-6)` drops the X term (coeff 1e-9) but keeps
799    /// the Z term (coeff 0.5).
800    #[test]
801    fn merge_phase_drops_below_threshold() {
802        let x: Vec<[u64; 1]> = vec![[0], [1]]; // Z, then X
803        let z: Vec<[u64; 1]> = vec![[1], [0]];
804        let c: Vec<Complex64> = vec![Complex64::new(0.5, 0.0), Complex64::new(1e-9, 0.0)];
805        let out = merge_phase::<1, _>(&x, &z, &c, 2, 1, &CoefficientThreshold(1e-6));
806        assert_eq!(out.len(), 1);
807        assert_eq!(out.x()[0], [0]);
808        assert_eq!(out.z()[0], [1]);
809        assert!(approx_eq(out.coeff()[0], Complex64::new(0.5, 0.0), TOL));
810        out.assert_invariants();
811    }
812
813    /// Slice 7.1: the threshold is checked *after* coefficients are summed,
814    /// not on individual scratch entries. Two Z terms with coeffs 0.5 and
815    /// -0.4999999 sum to 1e-7, which is below threshold 1e-6 — drop the
816    /// merged term, even though each summand individually exceeds 1e-6.
817    #[test]
818    fn merge_phase_threshold_applied_after_summation() {
819        let x: Vec<[u64; 1]> = vec![[0], [0]];
820        let z: Vec<[u64; 1]> = vec![[1], [1]];
821        let c: Vec<Complex64> = vec![Complex64::new(0.5, 0.0), Complex64::new(-0.4999999, 0.0)];
822        let out = merge_phase::<1, _>(&x, &z, &c, 2, 1, &CoefficientThreshold(1e-6));
823        assert_eq!(out.len(), 0);
824        out.assert_invariants();
825    }
826
827    /// `CoefficientThreshold(0.0)` keeps every (non-zero) summed term — the
828    /// no-op case for the new keep_term path.
829    #[test]
830    fn merge_phase_zero_threshold_keeps_everything() {
831        let x: Vec<[u64; 1]> = vec![[0], [0], [1]];
832        let z: Vec<[u64; 1]> = vec![[0], [1], [0]];
833        let c: Vec<Complex64> = vec![
834            Complex64::new(1.0, 0.0),
835            Complex64::new(2.0, 0.0),
836            Complex64::new(3.0, 0.0),
837        ];
838        let out = merge_phase::<1, _>(&x, &z, &c, 3, 1, &CoefficientThreshold(0.0));
839        assert_eq!(out.len(), 3);
840        out.assert_invariants();
841    }
842
843    /// Buffer with a populated prefix `[0..len)` and trailing junk: the
844    /// junk must not be merged.
845    #[test]
846    fn merge_phase_ignores_trailing_junk() {
847        let x: Vec<[u64; 1]> = vec![[0], [99], [99]]; // junk at idx 1, 2
848        let z: Vec<[u64; 1]> = vec![[1], [99], [99]];
849        let c: Vec<Complex64> = vec![
850            Complex64::new(2.0, 0.0),
851            Complex64::new(99.0, 0.0),
852            Complex64::new(99.0, 0.0),
853        ];
854        let out = merge_phase::<1, _>(&x, &z, &c, 1, 1, &AlwaysKeep);
855        assert_eq!(out.len(), 1);
856        assert_eq!(out.x()[0], [0]);
857        assert_eq!(out.z()[0], [1]);
858        assert_eq!(out.coeff()[0], Complex64::new(2.0, 0.0));
859    }
860
861    /// Slice 8.2 stress: a run of identical keys straddles the `len/2`
862    /// boundary. With `nchunks = 2`, the natural boundary at index 4 lands
863    /// inside the Z-run; alignment must advance it forward to index 7 so
864    /// the Z-run merges into a single term. Lex `(x, z)` sort puts
865    /// `I < Z < X`.
866    #[test]
867    fn merge_phase_run_spans_chunk_boundary() {
868        // Sorted: I, I, Z, Z, Z, Z, Z, X, X   (len = 9, mid = 4 inside Z-run)
869        let x: Vec<[u64; 1]> = vec![[0], [0], [0], [0], [0], [0], [0], [1], [1]];
870        let z: Vec<[u64; 1]> = vec![[0], [0], [1], [1], [1], [1], [1], [0], [0]];
871        let c: Vec<Complex64> = vec![
872            Complex64::new(1.0, 0.0),
873            Complex64::new(1.0, 0.0),
874            Complex64::new(0.5, 0.0),
875            Complex64::new(0.5, 0.0),
876            Complex64::new(0.5, 0.0),
877            Complex64::new(0.5, 0.0),
878            Complex64::new(0.5, 0.0),
879            Complex64::new(2.0, 0.0),
880            Complex64::new(2.0, 0.0),
881        ];
882        let out = merge_phase_with_nchunks::<1, _>(&x, &z, &c, 9, 1, &AlwaysKeep, 2);
883        assert_eq!(out.len(), 3);
884        assert_eq!(out.x()[0], [0]);
885        assert_eq!(out.z()[0], [0]);
886        assert_eq!(out.coeff()[0], Complex64::new(2.0, 0.0));
887        assert_eq!(out.x()[1], [0]);
888        assert_eq!(out.z()[1], [1]);
889        assert_eq!(out.coeff()[1], Complex64::new(2.5, 0.0));
890        assert_eq!(out.x()[2], [1]);
891        assert_eq!(out.z()[2], [0]);
892        assert_eq!(out.coeff()[2], Complex64::new(4.0, 0.0));
893        out.assert_invariants();
894    }
895
896    /// Slice 8.2: when the natural midpoint already falls at a run break,
897    /// alignment leaves it alone — chunks split exactly between runs.
898    #[test]
899    fn merge_phase_aligned_boundary_no_shift() {
900        // Sorted: I, I, I, X, X, X   (len = 6, mid = 3 lands on run break)
901        let x: Vec<[u64; 1]> = vec![[0], [0], [0], [1], [1], [1]];
902        let z: Vec<[u64; 1]> = vec![[0], [0], [0], [0], [0], [0]];
903        let c: Vec<Complex64> = vec![
904            Complex64::new(1.0, 0.0),
905            Complex64::new(2.0, 0.0),
906            Complex64::new(3.0, 0.0),
907            Complex64::new(4.0, 0.0),
908            Complex64::new(5.0, 0.0),
909            Complex64::new(6.0, 0.0),
910        ];
911        let out = merge_phase_with_nchunks::<1, _>(&x, &z, &c, 6, 1, &AlwaysKeep, 2);
912        assert_eq!(out.len(), 2);
913        assert_eq!(out.x()[0], [0]);
914        assert_eq!(out.z()[0], [0]);
915        assert_eq!(out.coeff()[0], Complex64::new(6.0, 0.0));
916        assert_eq!(out.x()[1], [1]);
917        assert_eq!(out.z()[1], [0]);
918        assert_eq!(out.coeff()[1], Complex64::new(15.0, 0.0));
919        out.assert_invariants();
920    }
921
922    /// Slice 8.2: degenerate input where every key collapses to one run.
923    /// All chunk boundaries advance to `len`; only one effective chunk.
924    #[test]
925    fn merge_phase_all_same_key_with_nchunks() {
926        let x: Vec<[u64; 1]> = vec![[0], [0], [0], [0], [0]];
927        let z: Vec<[u64; 1]> = vec![[1], [1], [1], [1], [1]];
928        let c: Vec<Complex64> = vec![
929            Complex64::new(1.0, 0.0),
930            Complex64::new(1.0, 0.0),
931            Complex64::new(1.0, 0.0),
932            Complex64::new(1.0, 0.0),
933            Complex64::new(1.0, 0.0),
934        ];
935        let out = merge_phase_with_nchunks::<1, _>(&x, &z, &c, 5, 1, &AlwaysKeep, 4);
936        assert_eq!(out.len(), 1);
937        assert_eq!(out.x()[0], [0]);
938        assert_eq!(out.z()[0], [1]);
939        assert_eq!(out.coeff()[0], Complex64::new(5.0, 0.0));
940        out.assert_invariants();
941    }
942}