paulistrings/
pauli_sum.rs

1//! [`PauliSum<W>`] — weighted sum of Pauli strings in structure-of-arrays form.
2//!
3//! Storage is parallel `Vec<[u64; W]>` columns for the `x` and `z` parts plus
4//! a `Vec<Complex64>` for coefficients. SoA is chosen so coefficient-only and
5//! key-only scans get full cache utilization, and so each `Vec` maps directly
6//! to a GPU device buffer.
7//!
8//! **Invariant:** the `x` and `z` columns are sorted in lexicographic order as
9//! a single key, and no two entries share a key. Every public operation
10//! either preserves this invariant or returns a fresh [`PauliSum`] that does.
11//!
12//! Build a [`PauliSum`] from unsorted inputs via [`BuildAccumulator`]; once
13//! built, combine sums with [`PauliSum::add`] or scale coefficients with
14//! [`PauliSum::scale`].
15//!
16//! See design doc §3.2.
17//!
18//! # Examples
19//!
20//! Construct the observable `Z₀ + 0.5·X₁` on two qubits via
21//! [`BuildAccumulator`], then merge in a second sum.
22//!
23//! ```
24//! use paulistrings::{BuildAccumulator, PauliString, PauliSum, Phase};
25//! use num_complex::Complex64;
26//!
27//! let mut acc = BuildAccumulator::<1>::new(2);
28//! acc.add_term(PauliString::<1>::z(0), Phase::ONE, Complex64::new(1.0, 0.0));
29//! acc.add_term(PauliString::<1>::x(1), Phase::ONE, Complex64::new(0.5, 0.0));
30//! let a = acc.finalize();
31//! assert_eq!(a.len(), 2);
32//!
33//! let mut acc2 = BuildAccumulator::<1>::new(2);
34//! acc2.add_term(PauliString::<1>::x(1), Phase::ONE, Complex64::new(-0.25, 0.0));
35//! let b = acc2.finalize();
36//!
37//! let merged = a.add(&b);
38//! assert_eq!(merged.len(), 2); // Z₀ + 0.25·X₁
39//! ```
40//!
41//! [`BuildAccumulator`]: crate::BuildAccumulator
42
43#![allow(unused)]
44
45use num_complex::Complex64;
46
47use crate::pauli_string::PauliString;
48
49/// Weighted sum of Pauli operators, stored SoA, sorted and deduplicated.
50#[derive(Clone, Debug, Default)]
51pub struct PauliSum<const W: usize> {
52    pub(crate) x: Vec<[u64; W]>,
53    pub(crate) z: Vec<[u64; W]>,
54    pub(crate) coeff: Vec<Complex64>,
55    pub(crate) num_qubits: usize,
56}
57
58impl<const W: usize> PauliSum<W> {
59    /// Empty sum on `num_qubits` qubits.
60    ///
61    /// # Panics
62    ///
63    /// Panics in debug builds if `num_qubits > 64 · W`. Caller is responsible
64    /// for ensuring `num_qubits <= 64 · W`.
65    pub fn empty(num_qubits: usize) -> Self {
66        debug_assert!(num_qubits <= 64 * W);
67        Self {
68            x: Vec::new(),
69            z: Vec::new(),
70            coeff: Vec::new(),
71            num_qubits,
72        }
73    }
74
75    /// Number of qubits this sum is defined over.
76    #[inline]
77    pub fn num_qubits(&self) -> usize {
78        self.num_qubits
79    }
80
81    /// Number of non-identity terms after deduplication.
82    #[inline]
83    pub fn len(&self) -> usize {
84        self.coeff.len()
85    }
86
87    /// `true` iff the sum has no terms.
88    #[inline]
89    pub fn is_empty(&self) -> bool {
90        self.coeff.is_empty()
91    }
92
93    /// Read-only view of the X-part column.
94    #[inline]
95    pub fn x(&self) -> &[[u64; W]] {
96        &self.x
97    }
98
99    /// Read-only view of the Z-part column.
100    #[inline]
101    pub fn z(&self) -> &[[u64; W]] {
102        &self.z
103    }
104
105    /// Read-only view of the coefficient column.
106    #[inline]
107    pub fn coeff(&self) -> &[Complex64] {
108        &self.coeff
109    }
110
111    /// Sum of two [`PauliSum`]s. Linear-time merge; preserves the sorted
112    /// invariant. Terms whose coefficients sum to exactly `0+0i` are dropped.
113    ///
114    /// # Examples
115    ///
116    /// Disjoint keys interleave in sort order; equal keys sum, and an
117    /// exact-zero combined coefficient drops the term.
118    ///
119    /// ```
120    /// use paulistrings::{BuildAccumulator, PauliString, Phase};
121    /// use num_complex::Complex64;
122    ///
123    /// let mut a = BuildAccumulator::<1>::new(2);
124    /// a.add_term(PauliString::<1>::z(0), Phase::ONE, Complex64::new(1.0, 0.0));
125    /// a.add_term(PauliString::<1>::x(1), Phase::ONE, Complex64::new(0.5, 0.0));
126    /// let a = a.finalize();
127    ///
128    /// let mut b = BuildAccumulator::<1>::new(2);
129    /// b.add_term(PauliString::<1>::z(0), Phase::ONE, Complex64::new(-1.0, 0.0));
130    /// let b = b.finalize();
131    ///
132    /// // Z₀ cancels exactly; only X₁ survives.
133    /// let r = a.add(&b);
134    /// assert_eq!(r.len(), 1);
135    /// assert_eq!(r.coeff()[0], Complex64::new(0.5, 0.0));
136    /// ```
137    pub fn add(&self, other: &Self) -> Self {
138        debug_assert_eq!(self.num_qubits, other.num_qubits);
139        let n_a = self.x.len();
140        let n_b = other.x.len();
141        let cap = n_a + n_b;
142        let mut x = Vec::with_capacity(cap);
143        let mut z = Vec::with_capacity(cap);
144        let mut coeff = Vec::with_capacity(cap);
145        let zero = Complex64::new(0.0, 0.0);
146        let (mut i, mut j) = (0usize, 0usize);
147        while i < n_a && j < n_b {
148            match (&self.x[i], &self.z[i]).cmp(&(&other.x[j], &other.z[j])) {
149                std::cmp::Ordering::Less => {
150                    x.push(self.x[i]);
151                    z.push(self.z[i]);
152                    coeff.push(self.coeff[i]);
153                    i += 1;
154                }
155                std::cmp::Ordering::Greater => {
156                    x.push(other.x[j]);
157                    z.push(other.z[j]);
158                    coeff.push(other.coeff[j]);
159                    j += 1;
160                }
161                std::cmp::Ordering::Equal => {
162                    let c = self.coeff[i] + other.coeff[j];
163                    if c != zero {
164                        x.push(self.x[i]);
165                        z.push(self.z[i]);
166                        coeff.push(c);
167                    }
168                    i += 1;
169                    j += 1;
170                }
171            }
172        }
173        while i < n_a {
174            x.push(self.x[i]);
175            z.push(self.z[i]);
176            coeff.push(self.coeff[i]);
177            i += 1;
178        }
179        while j < n_b {
180            x.push(other.x[j]);
181            z.push(other.z[j]);
182            coeff.push(other.coeff[j]);
183            j += 1;
184        }
185        Self {
186            x,
187            z,
188            coeff,
189            num_qubits: self.num_qubits,
190        }
191    }
192
193    /// Multiply every coefficient by `c` in place.
194    pub fn scale(&mut self, c: Complex64) {
195        for coeff in self.coeff.iter_mut() {
196            *coeff *= c;
197        }
198    }
199
200    /// Locate a Pauli key by binary search; returns `Ok(idx)` if present,
201    /// `Err(idx)` for the insertion point otherwise.
202    pub fn find(&self, x: &[u64; W], z: &[u64; W]) -> Result<usize, usize> {
203        let mut lo = 0;
204        let mut hi = self.x.len();
205        while lo < hi {
206            let mid = lo + (hi - lo) / 2;
207            match (&self.x[mid], &self.z[mid]).cmp(&(x, z)) {
208                std::cmp::Ordering::Less => lo = mid + 1,
209                std::cmp::Ordering::Greater => hi = mid,
210                std::cmp::Ordering::Equal => return Ok(mid),
211            }
212        }
213        Err(lo)
214    }
215
216    /// Drop terms whose coefficient magnitude is `<= eps`. Preserves sort.
217    pub fn truncate_by_magnitude(&mut self, eps: f64) {
218        let n = self.coeff.len();
219        let mut w = 0;
220        for r in 0..n {
221            if self.coeff[r].norm() > eps {
222                if w != r {
223                    self.x[w] = self.x[r];
224                    self.z[w] = self.z[r];
225                    self.coeff[w] = self.coeff[r];
226                }
227                w += 1;
228            }
229        }
230        self.x.truncate(w);
231        self.z.truncate(w);
232        self.coeff.truncate(w);
233    }
234
235    /// Debug-only invariant check. No-op in release builds.
236    #[cfg(debug_assertions)]
237    pub fn assert_invariants(&self) {
238        assert_eq!(self.x.len(), self.z.len());
239        assert_eq!(self.x.len(), self.coeff.len());
240        for i in 0..self.x.len() {
241            let term = PauliString::<W> {
242                x: self.x[i],
243                z: self.z[i],
244            };
245            assert!(
246                term.is_within(self.num_qubits),
247                "PauliSum term {} has bits beyond num_qubits={}",
248                i,
249                self.num_qubits,
250            );
251        }
252        for i in 1..self.x.len() {
253            let prev = (&self.x[i - 1], &self.z[i - 1]);
254            let cur = (&self.x[i], &self.z[i]);
255            assert!(prev < cur, "PauliSum out of order at {}", i);
256        }
257    }
258}
259
260#[cfg(test)]
261impl<const W: usize> PauliSum<W> {
262    /// Test-only helper: build a `PauliSum<W>` from `(pauli_str, coeff)`
263    /// pairs. Each `pauli_str` is a sequence of `I/X/Y/Z` characters where
264    /// index `i` of the string corresponds to qubit `i`. `Y` characters
265    /// fold one factor of `i` into the coefficient — the bitstring image
266    /// of `Y_canonical` is `(x=1, z=1)` with an implicit `i` factor, so
267    /// `Y_canonical = i · (x=1, z=1)`.
268    ///
269    /// `num_qubits` is taken from the length of the first string; all
270    /// other strings must match. Routes through `BuildAccumulator`, so
271    /// duplicate keys sum and exact-zero coefficients are dropped.
272    pub(crate) fn from_strings(terms: &[(&str, Complex64)]) -> Self {
273        use crate::phase::Phase;
274        assert!(!terms.is_empty(), "from_strings requires at least one term");
275        let num_qubits = terms[0].0.len();
276        assert!(num_qubits <= 64 * W, "num_qubits must fit in W*64 bits");
277        let mut acc = crate::accumulator::BuildAccumulator::<W>::new(num_qubits);
278        for (s, c) in terms {
279            assert_eq!(
280                s.len(),
281                num_qubits,
282                "all pauli strings must have the same length",
283            );
284            let mut x = [0u64; W];
285            let mut z = [0u64; W];
286            let mut phase = Phase::ONE;
287            for (i, ch) in s.chars().enumerate() {
288                let word = i / 64;
289                let bit = 1u64 << (i % 64);
290                match ch {
291                    'I' => {}
292                    'X' => x[word] |= bit,
293                    'Z' => z[word] |= bit,
294                    'Y' => {
295                        x[word] |= bit;
296                        z[word] |= bit;
297                        phase += Phase::I;
298                    }
299                    other => panic!("unexpected Pauli char {:?} (expected I/X/Y/Z)", other),
300                }
301            }
302            let p = PauliString::<W> { x, z };
303            acc.add_term(p, phase, *c);
304        }
305        acc.finalize()
306    }
307}
308
309#[cfg(all(test, debug_assertions))]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn assert_invariants_accepts_bits_within_num_qubits() {
315        // num_qubits=50, single term with X on qubit 49 (in range).
316        let sum = PauliSum::<1> {
317            x: vec![[1u64 << 49]],
318            z: vec![[0u64; 1]],
319            coeff: vec![Complex64::new(1.0, 0.0)],
320            num_qubits: 50,
321        };
322        sum.assert_invariants();
323    }
324
325    #[test]
326    #[should_panic(expected = "beyond num_qubits")]
327    fn assert_invariants_rejects_bit_beyond_num_qubits() {
328        // num_qubits=50, but X bit set at qubit 50 — must panic.
329        let sum = PauliSum::<1> {
330            x: vec![[1u64 << 50]],
331            z: vec![[0u64; 1]],
332            coeff: vec![Complex64::new(1.0, 0.0)],
333            num_qubits: 50,
334        };
335        sum.assert_invariants();
336    }
337
338    #[test]
339    #[should_panic(expected = "beyond num_qubits")]
340    fn assert_invariants_rejects_z_bit_beyond_num_qubits() {
341        // Same as above but on the Z-part: invariant must check both parts.
342        let sum = PauliSum::<1> {
343            x: vec![[0u64; 1]],
344            z: vec![[1u64 << 60]],
345            coeff: vec![Complex64::new(1.0, 0.0)],
346            num_qubits: 50,
347        };
348        sum.assert_invariants();
349    }
350
351    #[test]
352    #[should_panic(expected = "beyond num_qubits")]
353    fn assert_invariants_rejects_bit_in_unused_word() {
354        // num_qubits=64 (one full word), W=2. Bit on qubit 64 lives in word 1
355        // and is therefore out of range.
356        let sum = PauliSum::<2> {
357            x: vec![[0u64, 1u64]],
358            z: vec![[0u64; 2]],
359            coeff: vec![Complex64::new(1.0, 0.0)],
360            num_qubits: 64,
361        };
362        sum.assert_invariants();
363    }
364
365    // --- Slice 2.1: find() -----------------------------------------------
366
367    /// Three-term `PauliSum<1>` with sorted, distinct keys `K0 < K1 < K2`.
368    /// Used as the fixture for `find` hit/miss tests.
369    fn three_term_sum_w1() -> PauliSum<1> {
370        // K0 = (x=0, z=1), K1 = (x=1, z=0), K2 = (x=1, z=2). Sorted by lex
371        // on (x, z): K0 has smallest x; K1, K2 share x but K1 has smaller z.
372        PauliSum::<1> {
373            x: vec![[0u64], [1u64], [1u64]],
374            z: vec![[1u64], [0u64], [2u64]],
375            coeff: vec![
376                Complex64::new(1.0, 0.0),
377                Complex64::new(2.0, 0.0),
378                Complex64::new(3.0, 0.0),
379            ],
380            num_qubits: 4,
381        }
382    }
383
384    #[test]
385    fn find_on_empty_returns_err_zero() {
386        let s = PauliSum::<1>::empty(4);
387        assert_eq!(s.find(&[0u64], &[0u64]), Err(0));
388    }
389
390    #[test]
391    fn find_hit_at_index_zero() {
392        let s = three_term_sum_w1();
393        assert_eq!(s.find(&[0u64], &[1u64]), Ok(0));
394    }
395
396    #[test]
397    fn find_hit_in_middle() {
398        let s = three_term_sum_w1();
399        assert_eq!(s.find(&[1u64], &[0u64]), Ok(1));
400    }
401
402    #[test]
403    fn find_hit_at_last() {
404        let s = three_term_sum_w1();
405        assert_eq!(s.find(&[1u64], &[2u64]), Ok(2));
406    }
407
408    #[test]
409    fn find_miss_below_min_returns_err_zero() {
410        let s = three_term_sum_w1();
411        // Identity (0, 0) is below K0=(0, 1) under lex.
412        assert_eq!(s.find(&[0u64], &[0u64]), Err(0));
413    }
414
415    #[test]
416    fn find_miss_in_gap_returns_insertion_point() {
417        let s = three_term_sum_w1();
418        // (1, 1) sits between K1=(1,0) and K2=(1,2): insertion point is 2.
419        assert_eq!(s.find(&[1u64], &[1u64]), Err(2));
420    }
421
422    #[test]
423    fn find_miss_above_max_returns_err_len() {
424        let s = three_term_sum_w1();
425        // (2, 0) is above all keys (largest x).
426        assert_eq!(s.find(&[2u64], &[0u64]), Err(3));
427    }
428
429    #[test]
430    fn find_lex_orders_x_before_z() {
431        // Two terms with K_a=(x=0, z=5) and K_b=(x=1, z=0). Despite z_a > z_b,
432        // x_a < x_b, so K_a < K_b. A lex-on-x-only impl would invert this.
433        let s = PauliSum::<1> {
434            x: vec![[0u64], [1u64]],
435            z: vec![[5u64], [0u64]],
436            coeff: vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
437            num_qubits: 4,
438        };
439        assert_eq!(s.find(&[0u64], &[5u64]), Ok(0));
440        assert_eq!(s.find(&[1u64], &[0u64]), Ok(1));
441        // (0, 6) is between them under lex(x, z): same x as K_a but larger z.
442        assert_eq!(s.find(&[0u64], &[6u64]), Err(1));
443    }
444
445    #[test]
446    fn find_w2_hit_across_word_boundary() {
447        // W=2 sum, key bits live in word 1.
448        let s = PauliSum::<2> {
449            x: vec![[0u64, 1u64], [0u64, 1u64], [0u64, 2u64]],
450            z: vec![[0u64, 0u64], [0u64, 4u64], [0u64, 0u64]],
451            coeff: vec![
452                Complex64::new(1.0, 0.0),
453                Complex64::new(2.0, 0.0),
454                Complex64::new(3.0, 0.0),
455            ],
456            num_qubits: 128,
457        };
458        s.assert_invariants();
459        assert_eq!(s.find(&[0u64, 1u64], &[0u64, 4u64]), Ok(1));
460        assert_eq!(s.find(&[0u64, 2u64], &[0u64, 0u64]), Ok(2));
461        // Miss between idx 0 and 1: same x, z between 0 and 4.
462        assert_eq!(s.find(&[0u64, 1u64], &[0u64, 1u64]), Err(1));
463    }
464
465    // --- Slice 2.2: scale() ----------------------------------------------
466
467    #[test]
468    fn scale_by_zero_zeros_all_coeffs() {
469        let mut s = three_term_sum_w1();
470        s.scale(Complex64::new(0.0, 0.0));
471        assert_eq!(s.len(), 3);
472        for c in s.coeff() {
473            assert_eq!(*c, Complex64::new(0.0, 0.0));
474        }
475        s.assert_invariants();
476    }
477
478    #[test]
479    fn scale_by_one_is_identity() {
480        let mut s = three_term_sum_w1();
481        let before: Vec<Complex64> = s.coeff().to_vec();
482        s.scale(Complex64::new(1.0, 0.0));
483        assert_eq!(s.coeff(), before.as_slice());
484    }
485
486    #[test]
487    fn scale_by_i_rotates_phases() {
488        let mut s = PauliSum::<1> {
489            x: vec![[0u64], [1u64]],
490            z: vec![[1u64], [0u64]],
491            coeff: vec![Complex64::new(2.0, 0.0), Complex64::new(0.0, -3.0)],
492            num_qubits: 4,
493        };
494        s.scale(Complex64::new(0.0, 1.0));
495        // (2 + 0i) * i = 0 + 2i; (0 - 3i) * i = 3 + 0i.
496        assert_eq!(s.coeff()[0], Complex64::new(0.0, 2.0));
497        assert_eq!(s.coeff()[1], Complex64::new(3.0, 0.0));
498    }
499
500    #[test]
501    fn scale_preserves_sort_invariant() {
502        let mut s = three_term_sum_w1();
503        s.scale(Complex64::new(2.5, -0.5));
504        s.assert_invariants();
505    }
506
507    // --- Slice 2.3: truncate_by_magnitude() ------------------------------
508
509    #[test]
510    fn truncate_eps_zero_is_noop_on_nonzero_terms() {
511        let mut s = three_term_sum_w1();
512        let before_len = s.len();
513        s.truncate_by_magnitude(0.0);
514        assert_eq!(s.len(), before_len);
515        s.assert_invariants();
516    }
517
518    #[test]
519    fn truncate_eps_above_max_empties() {
520        let mut s = three_term_sum_w1();
521        s.truncate_by_magnitude(10.0);
522        assert!(s.is_empty());
523        s.assert_invariants();
524    }
525
526    #[test]
527    fn truncate_mixed_drops_only_below_threshold() {
528        // Four sorted terms with magnitudes [0.1, 0.5, 1.0, 0.05]; eps=0.2
529        // keeps 0.5 and 1.0 (originally at indices 1 and 2).
530        let mut s = PauliSum::<1> {
531            x: vec![[0u64], [0u64], [1u64], [1u64]],
532            z: vec![[1u64], [2u64], [0u64], [1u64]],
533            coeff: vec![
534                Complex64::new(0.1, 0.0),
535                Complex64::new(0.5, 0.0),
536                Complex64::new(1.0, 0.0),
537                Complex64::new(0.05, 0.0),
538            ],
539            num_qubits: 4,
540        };
541        s.assert_invariants();
542        s.truncate_by_magnitude(0.2);
543        assert_eq!(s.len(), 2);
544        assert_eq!(s.x()[0], [0u64]);
545        assert_eq!(s.z()[0], [2u64]);
546        assert_eq!(s.coeff()[0], Complex64::new(0.5, 0.0));
547        assert_eq!(s.x()[1], [1u64]);
548        assert_eq!(s.z()[1], [0u64]);
549        assert_eq!(s.coeff()[1], Complex64::new(1.0, 0.0));
550        s.assert_invariants();
551    }
552
553    #[test]
554    fn truncate_drops_exact_zero_at_eps_zero() {
555        // Include an exact (0+0i) term; eps=0 should drop only that one.
556        let mut s = PauliSum::<1> {
557            x: vec![[0u64], [1u64], [1u64]],
558            z: vec![[1u64], [0u64], [2u64]],
559            coeff: vec![
560                Complex64::new(1.0, 0.0),
561                Complex64::new(0.0, 0.0),
562                Complex64::new(2.0, 0.0),
563            ],
564            num_qubits: 4,
565        };
566        s.truncate_by_magnitude(0.0);
567        assert_eq!(s.len(), 2);
568        assert_eq!(s.x()[0], [0u64]);
569        assert_eq!(s.x()[1], [1u64]);
570        assert_eq!(s.z()[1], [2u64]);
571        s.assert_invariants();
572    }
573
574    #[test]
575    fn truncate_w2_preserves_sort() {
576        let mut s = PauliSum::<2> {
577            x: vec![[0u64, 0u64], [0u64, 1u64], [1u64, 0u64]],
578            z: vec![[0u64, 1u64], [0u64, 0u64], [0u64, 0u64]],
579            coeff: vec![
580                Complex64::new(0.01, 0.0),
581                Complex64::new(2.0, 0.0),
582                Complex64::new(0.005, 0.0),
583            ],
584            num_qubits: 128,
585        };
586        s.assert_invariants();
587        s.truncate_by_magnitude(0.1);
588        assert_eq!(s.len(), 1);
589        assert_eq!(s.x()[0], [0u64, 1u64]);
590        s.assert_invariants();
591    }
592
593    // --- Slice 2.4: add() ------------------------------------------------
594
595    #[test]
596    fn add_empty_left_is_other() {
597        let a = PauliSum::<1>::empty(4);
598        let b = three_term_sum_w1();
599        let r = a.add(&b);
600        assert_eq!(r.len(), 3);
601        assert_eq!(r.x(), b.x());
602        assert_eq!(r.z(), b.z());
603        assert_eq!(r.coeff(), b.coeff());
604        r.assert_invariants();
605    }
606
607    #[test]
608    fn add_empty_right_is_self() {
609        let a = three_term_sum_w1();
610        let b = PauliSum::<1>::empty(4);
611        let r = a.add(&b);
612        assert_eq!(r.len(), 3);
613        assert_eq!(r.x(), a.x());
614        assert_eq!(r.z(), a.z());
615        assert_eq!(r.coeff(), a.coeff());
616        r.assert_invariants();
617    }
618
619    #[test]
620    fn add_disjoint_keys_interleaves_in_sort_order() {
621        // a has K0=(0,1), K2=(1,2); b has K1=(1,0), K3=(2,0).
622        // Lex sort across the union: (0,1) < (1,0) < (1,2) < (2,0).
623        let a = PauliSum::<1> {
624            x: vec![[0u64], [1u64]],
625            z: vec![[1u64], [2u64]],
626            coeff: vec![Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
627            num_qubits: 4,
628        };
629        let b = PauliSum::<1> {
630            x: vec![[1u64], [2u64]],
631            z: vec![[0u64], [0u64]],
632            coeff: vec![Complex64::new(2.0, 0.0), Complex64::new(4.0, 0.0)],
633            num_qubits: 4,
634        };
635        let r = a.add(&b);
636        assert_eq!(r.len(), 4);
637        assert_eq!(r.x(), &[[0u64], [1u64], [1u64], [2u64]][..]);
638        assert_eq!(r.z(), &[[1u64], [0u64], [2u64], [0u64]][..]);
639        assert_eq!(
640            r.coeff(),
641            &[
642                Complex64::new(1.0, 0.0),
643                Complex64::new(2.0, 0.0),
644                Complex64::new(3.0, 0.0),
645                Complex64::new(4.0, 0.0),
646            ][..]
647        );
648        r.assert_invariants();
649    }
650
651    #[test]
652    fn add_equal_keys_sum_coeffs() {
653        let a = three_term_sum_w1();
654        let r = a.add(&a);
655        assert_eq!(r.len(), 3);
656        assert_eq!(r.x(), a.x());
657        assert_eq!(r.z(), a.z());
658        for k in 0..3 {
659            assert_eq!(r.coeff()[k], a.coeff()[k] * Complex64::new(2.0, 0.0));
660        }
661        r.assert_invariants();
662    }
663
664    #[test]
665    fn add_cancellation_drops_term() {
666        let a = PauliSum::<1> {
667            x: vec![[1u64]],
668            z: vec![[0u64]],
669            coeff: vec![Complex64::new(1.0, 0.0)],
670            num_qubits: 4,
671        };
672        let b = PauliSum::<1> {
673            x: vec![[1u64]],
674            z: vec![[0u64]],
675            coeff: vec![Complex64::new(-1.0, 0.0)],
676            num_qubits: 4,
677        };
678        let r = a.add(&b);
679        assert!(r.is_empty());
680        r.assert_invariants();
681    }
682
683    #[test]
684    fn add_mixed_cancellation_and_merge() {
685        // a = {K1: 1, K2: 2, K3: 3}, b = {K1: -1, K2: 0.5, K4: 4}
686        // K1 cancels, K2 sums to 2.5, K3 unique to a, K4 unique to b.
687        let a = PauliSum::<1> {
688            x: vec![[0u64], [1u64], [2u64]],
689            z: vec![[0u64], [0u64], [0u64]],
690            coeff: vec![
691                Complex64::new(1.0, 0.0),
692                Complex64::new(2.0, 0.0),
693                Complex64::new(3.0, 0.0),
694            ],
695            num_qubits: 4,
696        };
697        let b = PauliSum::<1> {
698            x: vec![[0u64], [1u64], [3u64]],
699            z: vec![[0u64], [0u64], [0u64]],
700            coeff: vec![
701                Complex64::new(-1.0, 0.0),
702                Complex64::new(0.5, 0.0),
703                Complex64::new(4.0, 0.0),
704            ],
705            num_qubits: 4,
706        };
707        let r = a.add(&b);
708        assert_eq!(r.len(), 3);
709        assert_eq!(r.x(), &[[1u64], [2u64], [3u64]][..]);
710        assert_eq!(r.z(), &[[0u64], [0u64], [0u64]][..]);
711        assert_eq!(
712            r.coeff(),
713            &[
714                Complex64::new(2.5, 0.0),
715                Complex64::new(3.0, 0.0),
716                Complex64::new(4.0, 0.0),
717            ][..]
718        );
719        r.assert_invariants();
720    }
721
722    #[test]
723    fn add_w2_across_word_boundary() {
724        let a = PauliSum::<2> {
725            x: vec![[0u64, 1u64], [0u64, 2u64]],
726            z: vec![[0u64, 0u64], [0u64, 0u64]],
727            coeff: vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
728            num_qubits: 128,
729        };
730        let b = PauliSum::<2> {
731            x: vec![[0u64, 1u64], [0u64, 4u64]],
732            z: vec![[0u64, 0u64], [0u64, 0u64]],
733            coeff: vec![Complex64::new(0.5, 0.0), Complex64::new(7.0, 0.0)],
734            num_qubits: 128,
735        };
736        let r = a.add(&b);
737        assert_eq!(r.len(), 3);
738        assert_eq!(r.x(), &[[0u64, 1u64], [0u64, 2u64], [0u64, 4u64]][..]);
739        assert_eq!(r.coeff()[0], Complex64::new(1.5, 0.0));
740        assert_eq!(r.coeff()[1], Complex64::new(2.0, 0.0));
741        assert_eq!(r.coeff()[2], Complex64::new(7.0, 0.0));
742        r.assert_invariants();
743    }
744
745    // --- Slice 3.2: PauliSum::from_strings test helper -------------------
746
747    #[test]
748    fn from_strings_single_x_term() {
749        let s = PauliSum::<1>::from_strings(&[("XII", Complex64::new(1.0, 0.0))]);
750        assert_eq!(s.len(), 1);
751        assert_eq!(s.num_qubits(), 3);
752        assert_eq!(s.x()[0], [0b001u64]);
753        assert_eq!(s.z()[0], [0u64]);
754        assert_eq!(s.coeff()[0], Complex64::new(1.0, 0.0));
755        s.assert_invariants();
756    }
757
758    #[test]
759    fn from_strings_x_z_combined() {
760        // "XZI": X on qubit 0, Z on qubit 1, I on qubit 2.
761        let s = PauliSum::<1>::from_strings(&[("XZI", Complex64::new(1.0, 0.0))]);
762        assert_eq!(s.x()[0], [0b001u64]);
763        assert_eq!(s.z()[0], [0b010u64]);
764        s.assert_invariants();
765    }
766
767    #[test]
768    fn from_strings_y_includes_i_phase() {
769        // Y_canonical = i · (x=1, z=1). Caller writes coeff=1, stored is i.
770        let s = PauliSum::<1>::from_strings(&[("Y", Complex64::new(1.0, 0.0))]);
771        assert_eq!(s.x()[0], [1u64]);
772        assert_eq!(s.z()[0], [1u64]);
773        assert_eq!(s.coeff()[0], Complex64::new(0.0, 1.0));
774    }
775
776    #[test]
777    fn from_strings_yy_phase_minus_one() {
778        // i^2 = -1.
779        let s = PauliSum::<1>::from_strings(&[("YY", Complex64::new(1.0, 0.0))]);
780        assert_eq!(s.coeff()[0], Complex64::new(-1.0, 0.0));
781    }
782
783    #[test]
784    fn from_strings_yyy_phase_minus_i() {
785        // i^3 = -i.
786        let s = PauliSum::<1>::from_strings(&[("YYY", Complex64::new(1.0, 0.0))]);
787        assert_eq!(s.coeff()[0], Complex64::new(0.0, -1.0));
788    }
789
790    #[test]
791    fn from_strings_yyyy_phase_one() {
792        // i^4 = 1.
793        let s = PauliSum::<1>::from_strings(&[("YYYY", Complex64::new(1.0, 0.0))]);
794        assert_eq!(s.coeff()[0], Complex64::new(1.0, 0.0));
795    }
796
797    #[test]
798    fn from_strings_dedup_sums_coeffs() {
799        let s = PauliSum::<1>::from_strings(&[
800            ("XI", Complex64::new(1.0, 0.0)),
801            ("XI", Complex64::new(0.5, -0.25)),
802        ]);
803        assert_eq!(s.len(), 1);
804        assert_eq!(s.coeff()[0], Complex64::new(1.5, -0.25));
805        s.assert_invariants();
806    }
807
808    #[test]
809    fn from_strings_cancellation_drops_term() {
810        let s = PauliSum::<1>::from_strings(&[
811            ("XI", Complex64::new(1.0, 0.0)),
812            ("XI", Complex64::new(-1.0, 0.0)),
813            ("ZI", Complex64::new(2.0, 0.0)),
814        ]);
815        assert_eq!(s.len(), 1);
816        assert_eq!(s.x()[0], [0u64]);
817        assert_eq!(s.z()[0], [1u64]);
818        assert_eq!(s.coeff()[0], Complex64::new(2.0, 0.0));
819        s.assert_invariants();
820    }
821
822    #[test]
823    fn from_strings_sorts_lex_keys() {
824        // Insert out of order: ZI=(0,1), XI=(1,0), YI=(1,1) — lex sorted is
825        // ZI < XI < YI.
826        let s = PauliSum::<1>::from_strings(&[
827            ("YI", Complex64::new(1.0, 0.0)),
828            ("ZI", Complex64::new(2.0, 0.0)),
829            ("XI", Complex64::new(3.0, 0.0)),
830        ]);
831        assert_eq!(s.len(), 3);
832        assert_eq!((s.x()[0], s.z()[0]), ([0u64], [1u64])); // ZI
833        assert_eq!((s.x()[1], s.z()[1]), ([1u64], [0u64])); // XI
834        assert_eq!((s.x()[2], s.z()[2]), ([1u64], [1u64])); // YI (with i factor)
835        assert_eq!(s.coeff()[0], Complex64::new(2.0, 0.0));
836        assert_eq!(s.coeff()[1], Complex64::new(3.0, 0.0));
837        assert_eq!(s.coeff()[2], Complex64::new(0.0, 1.0));
838        s.assert_invariants();
839    }
840
841    #[test]
842    fn from_strings_w2_qubit_64() {
843        // 65-character string: X at index 64 lands in word 1.
844        let mut s_chars: String = "I".repeat(65);
845        // Replace index 64 with 'X'.
846        unsafe {
847            let bytes = s_chars.as_bytes_mut();
848            bytes[64] = b'X';
849        }
850        let s = PauliSum::<2>::from_strings(&[(s_chars.as_str(), Complex64::new(1.0, 0.0))]);
851        assert_eq!(s.num_qubits(), 65);
852        assert_eq!(s.x()[0], [0u64, 1u64]);
853        assert_eq!(s.z()[0], [0u64, 0u64]);
854        s.assert_invariants();
855    }
856
857    #[test]
858    #[should_panic(expected = "unexpected Pauli char")]
859    fn from_strings_panics_on_invalid_char() {
860        let _ = PauliSum::<1>::from_strings(&[("AB", Complex64::new(1.0, 0.0))]);
861    }
862
863    #[test]
864    #[should_panic(expected = "all pauli strings must have the same length")]
865    fn from_strings_panics_on_length_mismatch() {
866        let _ = PauliSum::<1>::from_strings(&[
867            ("XI", Complex64::new(1.0, 0.0)),
868            ("XII", Complex64::new(1.0, 0.0)),
869        ]);
870    }
871}
872
873#[cfg(all(test, debug_assertions))]
874mod props {
875    use super::*;
876    use proptest::prelude::*;
877    use std::collections::BTreeMap;
878
879    /// Build a sorted, deduplicated `PauliSum<2>` from random `(x, z, coeff)`
880    /// triples. Uses `BTreeMap` keyed on `(x, z)` to enforce the sorted /
881    /// unique invariant before SoA materialization. Coefficients are kept
882    /// small (`re, im ∈ [-4, 4]`) so the property assertions don't run into
883    /// FP cancellation noise. Length capped at 8 — sufficient to exercise
884    /// merge interleaving without blowing up shrinking time.
885    fn arb_pauli_sum_w2() -> impl Strategy<Value = PauliSum<2>> {
886        prop::collection::vec(
887            (
888                any::<u64>(),
889                any::<u64>(),
890                any::<u64>(),
891                any::<u64>(),
892                -4.0f64..4.0,
893                -4.0f64..4.0,
894            ),
895            0..8,
896        )
897        .prop_map(|entries| {
898            let mut map: BTreeMap<([u64; 2], [u64; 2]), Complex64> = BTreeMap::new();
899            for (x0, x1, z0, z1, re, im) in entries {
900                map.insert(([x0, x1], [z0, z1]), Complex64::new(re, im));
901            }
902            let mut x = Vec::with_capacity(map.len());
903            let mut z = Vec::with_capacity(map.len());
904            let mut coeff = Vec::with_capacity(map.len());
905            for ((kx, kz), c) in map {
906                x.push(kx);
907                z.push(kz);
908                coeff.push(c);
909            }
910            PauliSum::<2> {
911                x,
912                z,
913                coeff,
914                num_qubits: 128,
915            }
916        })
917    }
918
919    proptest! {
920        #[test]
921        fn add_is_associative(
922            a in arb_pauli_sum_w2(),
923            b in arb_pauli_sum_w2(),
924            c in arb_pauli_sum_w2(),
925        ) {
926            let left = a.add(&b).add(&c);
927            let right = a.add(&b.add(&c));
928            left.assert_invariants();
929            right.assert_invariants();
930            prop_assert_eq!(left.x(), right.x());
931            prop_assert_eq!(left.z(), right.z());
932            prop_assert_eq!(left.coeff().len(), right.coeff().len());
933            for k in 0..left.coeff().len() {
934                let diff = left.coeff()[k] - right.coeff()[k];
935                prop_assert!(
936                    diff.norm() <= 1e-12,
937                    "coeff mismatch at idx {}: lhs={:?} rhs={:?}",
938                    k, left.coeff()[k], right.coeff()[k]
939                );
940            }
941        }
942    }
943}