paulistrings/truncation/
builtin.rs

1//! Built-in truncation policies and combinators. See design doc §7.
2
3#![allow(unused)]
4
5use super::TruncationPolicy;
6use crate::pauli_sum::PauliSum;
7use num_complex::Complex64;
8
9/// Drop terms whose coefficient magnitude is at most `epsilon`.
10///
11/// # Examples
12///
13/// ```
14/// use paulistrings::truncation::CoefficientThreshold;
15/// let policy = CoefficientThreshold(1e-9);
16/// # let _ = policy;
17/// ```
18pub struct CoefficientThreshold(
19    /// Magnitude threshold. Terms with `|coeff| <= epsilon` are dropped.
20    pub f64,
21);
22
23impl<const W: usize> TruncationPolicy<W> for CoefficientThreshold {
24    #[inline]
25    fn keep_term(&self, _x: &[u64; W], _z: &[u64; W], c: Complex64) -> bool {
26        c.norm() > self.0
27    }
28}
29
30/// Drop terms whose Pauli weight (number of non-identity qubits) exceeds `k`.
31///
32/// # Examples
33///
34/// ```
35/// use paulistrings::truncation::WeightCutoff;
36/// let policy = WeightCutoff(4);
37/// # let _ = policy;
38/// ```
39pub struct WeightCutoff(
40    /// Maximum allowed Pauli weight. Terms with weight `> k` are dropped.
41    pub u32,
42);
43
44impl<const W: usize> TruncationPolicy<W> for WeightCutoff {
45    #[inline]
46    fn keep_term(&self, x: &[u64; W], z: &[u64; W], _c: Complex64) -> bool {
47        let weight: u32 = (0..W).map(|i| (x[i] | z[i]).count_ones()).sum();
48        weight <= self.0
49    }
50}
51
52/// Retain only the `n` terms with largest coefficient magnitude. Implemented
53/// as a `finalize_layer` partial sort (no per-term filter).
54///
55/// # Examples
56///
57/// ```
58/// use paulistrings::truncation::TopN;
59/// let policy = TopN(1_000_000);
60/// # let _ = policy;
61/// ```
62pub struct TopN(
63    /// Number of terms to retain. Terms outside the top-`n` by magnitude
64    /// are dropped at layer finalization.
65    pub usize,
66);
67
68impl<const W: usize> TruncationPolicy<W> for TopN {
69    fn finalize_layer(&self, sum: &mut PauliSum<W>) {
70        let n = self.0;
71        let len = sum.coeff.len();
72        if len <= n {
73            return;
74        }
75        if n == 0 {
76            sum.x.clear();
77            sum.z.clear();
78            sum.coeff.clear();
79            return;
80        }
81        let mut perm: Vec<usize> = (0..len).collect();
82        // Partition descending by |coeff|: indices [0..n) hold the n largest.
83        perm.select_nth_unstable_by(n - 1, |&a, &b| {
84            sum.coeff[b]
85                .norm()
86                .partial_cmp(&sum.coeff[a].norm())
87                .unwrap()
88        });
89        perm.truncate(n);
90        // The survivors of an already-sorted sum are still sorted once we
91        // restore their original index order.
92        perm.sort_unstable();
93        let new_x: Vec<[u64; W]> = perm.iter().map(|&i| sum.x[i]).collect();
94        let new_z: Vec<[u64; W]> = perm.iter().map(|&i| sum.z[i]).collect();
95        let new_c: Vec<Complex64> = perm.iter().map(|&i| sum.coeff[i]).collect();
96        sum.x = new_x;
97        sum.z = new_z;
98        sum.coeff = new_c;
99    }
100}
101
102/// Logical AND of two policies — both must accept.
103///
104/// # Examples
105///
106/// ```
107/// use paulistrings::truncation::{And, CoefficientThreshold, WeightCutoff};
108/// let policy = And(CoefficientThreshold(1e-6), WeightCutoff(4));
109/// # let _ = policy;
110/// ```
111pub struct And<A, B>(
112    /// First policy. `keep_term` and `finalize_layer` both consult this first.
113    pub A,
114    /// Second policy.
115    pub B,
116);
117
118impl<const W: usize, A, B> TruncationPolicy<W> for And<A, B>
119where
120    A: TruncationPolicy<W>,
121    B: TruncationPolicy<W>,
122{
123    #[inline]
124    fn keep_term(&self, x: &[u64; W], z: &[u64; W], c: Complex64) -> bool {
125        self.0.keep_term(x, z, c) && self.1.keep_term(x, z, c)
126    }
127
128    fn finalize_layer(&self, sum: &mut PauliSum<W>) {
129        self.0.finalize_layer(sum);
130        self.1.finalize_layer(sum);
131    }
132}
133
134/// Logical OR of two policies — either accepting is enough.
135///
136/// Only `keep_term` is combined disjunctively; `finalize_layer` falls through
137/// to the trait default (no-op) because the layer-finalization semantics of
138/// "either policy's finalize pass" are not well-defined.
139///
140/// # Examples
141///
142/// ```
143/// use paulistrings::truncation::{Or, CoefficientThreshold, WeightCutoff};
144/// // Keep a term if |coeff| > 0.1 OR weight == 0 (identity).
145/// let policy = Or(CoefficientThreshold(0.1), WeightCutoff(0));
146/// # let _ = policy;
147/// ```
148pub struct Or<A, B>(
149    /// First policy.
150    pub A,
151    /// Second policy.
152    pub B,
153);
154
155impl<const W: usize, A, B> TruncationPolicy<W> for Or<A, B>
156where
157    A: TruncationPolicy<W>,
158    B: TruncationPolicy<W>,
159{
160    #[inline]
161    fn keep_term(&self, x: &[u64; W], z: &[u64; W], c: Complex64) -> bool {
162        self.0.keep_term(x, z, c) || self.1.keep_term(x, z, c)
163    }
164}
165
166#[cfg(all(test, debug_assertions))]
167mod tests {
168    use super::*;
169
170    /// Slice 7.2: `WeightCutoff(2)` keeps weights 0, 1, 2 and drops 3.
171    /// Identity I (weight 0), single X (1), XZ on qubits 0+1 (2) all kept;
172    /// X on q0 + Y on q1 + Z on q2 (3) dropped.
173    #[test]
174    fn weight_cutoff_keeps_below_or_equal() {
175        let cut = WeightCutoff(2);
176        // Identity: weight 0.
177        assert!(<WeightCutoff as TruncationPolicy<1>>::keep_term(
178            &cut,
179            &[0],
180            &[0],
181            Complex64::new(1.0, 0.0)
182        ));
183        // X on q0: weight 1 (x bit set).
184        assert!(<WeightCutoff as TruncationPolicy<1>>::keep_term(
185            &cut,
186            &[1],
187            &[0],
188            Complex64::new(1.0, 0.0)
189        ));
190        // X on q0, Z on q1: weight 2.
191        assert!(<WeightCutoff as TruncationPolicy<1>>::keep_term(
192            &cut,
193            &[0b01],
194            &[0b10],
195            Complex64::new(1.0, 0.0)
196        ));
197        // X on q0, Y on q1 (x+z), Z on q2: weight 3, dropped.
198        assert!(!<WeightCutoff as TruncationPolicy<1>>::keep_term(
199            &cut,
200            &[0b011],
201            &[0b110],
202            Complex64::new(1.0, 0.0)
203        ));
204    }
205
206    /// Slice 7.2: `WeightCutoff(0)` keeps only the identity.
207    #[test]
208    fn weight_cutoff_zero_keeps_only_identity() {
209        let cut = WeightCutoff(0);
210        assert!(<WeightCutoff as TruncationPolicy<1>>::keep_term(
211            &cut,
212            &[0],
213            &[0],
214            Complex64::new(1.0, 0.0)
215        ));
216        // Any non-identity Pauli is dropped.
217        assert!(!<WeightCutoff as TruncationPolicy<1>>::keep_term(
218            &cut,
219            &[1],
220            &[0],
221            Complex64::new(1.0, 0.0)
222        ));
223        assert!(!<WeightCutoff as TruncationPolicy<1>>::keep_term(
224            &cut,
225            &[0],
226            &[1],
227            Complex64::new(1.0, 0.0)
228        ));
229        assert!(!<WeightCutoff as TruncationPolicy<1>>::keep_term(
230            &cut,
231            &[1],
232            &[1],
233            Complex64::new(1.0, 0.0)
234        ));
235    }
236
237    /// Slice 7.2: multi-word popcount. Qubit 64 lives in word 1, bit 0.
238    #[test]
239    fn weight_cutoff_w2_word_boundary() {
240        let cut = WeightCutoff(1);
241        // X on qubit 64 alone: weight 1, kept.
242        assert!(<WeightCutoff as TruncationPolicy<2>>::keep_term(
243            &cut,
244            &[0u64, 1u64],
245            &[0u64, 0u64],
246            Complex64::new(1.0, 0.0)
247        ));
248        // X on qubit 0 AND X on qubit 64: weight 2, dropped.
249        assert!(!<WeightCutoff as TruncationPolicy<2>>::keep_term(
250            &cut,
251            &[1u64, 1u64],
252            &[0u64, 0u64],
253            Complex64::new(1.0, 0.0)
254        ));
255    }
256
257    /// Slice 7.3: ten distinct keys with decreasing |coeff| (10, 9, …, 1);
258    /// `TopN(3)` keeps the three with magnitudes 10, 9, 8.
259    #[test]
260    fn top_n_keeps_largest_three_of_ten() {
261        // Ten distinct (x, z) keys: x ∈ {0..10}, z=0. Sorted by x ascending.
262        let mut sum = PauliSum::<1> {
263            x: (1u64..=10).map(|i| [i]).collect(),
264            z: vec![[0u64]; 10],
265            // Magnitudes 10, 9, 8, ... in same order. So largest sit at the
266            // *front* of the sort order. We'll exercise back-loaded magnitudes
267            // in a separate test.
268            coeff: (1u64..=10)
269                .rev()
270                .map(|m| Complex64::new(m as f64, 0.0))
271                .collect(),
272            num_qubits: 4,
273        };
274        sum.assert_invariants();
275        TopN(3).finalize_layer(&mut sum);
276        assert_eq!(sum.len(), 3);
277        // Survivors are the keys whose original magnitudes were 10, 9, 8 →
278        // they sat at indices 0, 1, 2 → x = [1], [2], [3].
279        assert_eq!(sum.x(), &[[1u64], [2u64], [3u64]]);
280        let mags: Vec<f64> = sum.coeff().iter().map(|c| c.norm()).collect();
281        assert_eq!(mags, vec![10.0, 9.0, 8.0]);
282        sum.assert_invariants();
283    }
284
285    /// Slice 7.3: `TopN(N) where N >= len` is a no-op.
286    #[test]
287    fn top_n_no_op_when_n_ge_len() {
288        let mut sum = PauliSum::<1> {
289            x: vec![[0], [0], [1]],
290            z: vec![[0], [1], [0]],
291            coeff: vec![
292                Complex64::new(1.0, 0.0),
293                Complex64::new(2.0, 0.0),
294                Complex64::new(3.0, 0.0),
295            ],
296            num_qubits: 1,
297        };
298        let snapshot_x = sum.x().to_vec();
299        let snapshot_z = sum.z().to_vec();
300        let snapshot_c = sum.coeff().to_vec();
301        TopN(5).finalize_layer(&mut sum);
302        assert_eq!(sum.x(), snapshot_x.as_slice());
303        assert_eq!(sum.z(), snapshot_z.as_slice());
304        assert_eq!(sum.coeff(), snapshot_c.as_slice());
305    }
306
307    /// Slice 7.3: `TopN(0)` empties the sum.
308    #[test]
309    fn top_n_zero_empties_sum() {
310        let mut sum = PauliSum::<1> {
311            x: vec![[0], [1]],
312            z: vec![[1], [0]],
313            coeff: vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
314            num_qubits: 1,
315        };
316        TopN(0).finalize_layer(&mut sum);
317        assert!(sum.is_empty());
318        sum.assert_invariants();
319    }
320
321    /// Slice 7.3: largest coefficients sit at the *end* of the sort order;
322    /// the survivors must still be in (x, z) sort order, not magnitude order.
323    #[test]
324    fn top_n_preserves_sort_order() {
325        // Five keys, magnitudes 1, 2, 3, 4, 5 (back-loaded).
326        let mut sum = PauliSum::<1> {
327            x: vec![[1], [2], [3], [4], [5]],
328            z: vec![[0]; 5],
329            coeff: vec![
330                Complex64::new(1.0, 0.0),
331                Complex64::new(2.0, 0.0),
332                Complex64::new(3.0, 0.0),
333                Complex64::new(4.0, 0.0),
334                Complex64::new(5.0, 0.0),
335            ],
336            num_qubits: 4,
337        };
338        sum.assert_invariants();
339        TopN(3).finalize_layer(&mut sum);
340        assert_eq!(sum.len(), 3);
341        // Survivors: magnitudes 5, 4, 3 — keys x=[5], [4], [3] in the
342        // original. Sort-order preservation means [3], [4], [5].
343        assert_eq!(sum.x(), &[[3u64], [4u64], [5u64]]);
344        assert_eq!(
345            sum.coeff(),
346            &[
347                Complex64::new(3.0, 0.0),
348                Complex64::new(4.0, 0.0),
349                Complex64::new(5.0, 0.0),
350            ]
351        );
352        sum.assert_invariants();
353    }
354
355    /// `And` requires both policies to accept. Pair a coeff threshold with
356    /// a weight cutoff; only terms passing *both* survive.
357    #[test]
358    fn and_requires_both_keep() {
359        let policy = And(CoefficientThreshold(0.5), WeightCutoff(1));
360        // (X, 1.0): |c|=1.0 > 0.5 ✓, weight=1 ≤ 1 ✓ → kept.
361        assert!(<And<_, _> as TruncationPolicy<1>>::keep_term(
362            &policy,
363            &[1],
364            &[0],
365            Complex64::new(1.0, 0.0)
366        ));
367        // (X, 0.1): |c|=0.1 ≤ 0.5 ✗ → dropped.
368        assert!(!<And<_, _> as TruncationPolicy<1>>::keep_term(
369            &policy,
370            &[1],
371            &[0],
372            Complex64::new(0.1, 0.0)
373        ));
374        // (XZ, 1.0): weight 2 > 1 ✗ → dropped.
375        assert!(!<And<_, _> as TruncationPolicy<1>>::keep_term(
376            &policy,
377            &[0b01],
378            &[0b10],
379            Complex64::new(1.0, 0.0)
380        ));
381    }
382
383    /// `Or` accepts if *either* policy accepts.
384    #[test]
385    fn or_keeps_if_either() {
386        let policy = Or(CoefficientThreshold(0.5), WeightCutoff(0));
387        // (I, 0.1): |c| fails (0.1 ≤ 0.5), but weight=0 passes → kept.
388        assert!(<Or<_, _> as TruncationPolicy<1>>::keep_term(
389            &policy,
390            &[0],
391            &[0],
392            Complex64::new(0.1, 0.0)
393        ));
394        // (X, 1.0): weight fails, but |c|=1.0 > 0.5 → kept.
395        assert!(<Or<_, _> as TruncationPolicy<1>>::keep_term(
396            &policy,
397            &[1],
398            &[0],
399            Complex64::new(1.0, 0.0)
400        ));
401        // (X, 0.1): both fail → dropped.
402        assert!(!<Or<_, _> as TruncationPolicy<1>>::keep_term(
403            &policy,
404            &[1],
405            &[0],
406            Complex64::new(0.1, 0.0)
407        ));
408    }
409}