1#![allow(unused)]
4
5use super::TruncationPolicy;
6use crate::pauli_sum::PauliSum;
7use num_complex::Complex64;
8
9pub struct CoefficientThreshold(
19 pub f64,
21);
22
23impl<const W: usize> TruncationPolicy<W> for CoefficientThreshold {
24 #[inline]
25 fn keep_term(&self, _x: &[u64; W], _z: &[u64; W], c: Complex64) -> bool {
26 c.norm() > self.0
27 }
28}
29
30pub struct WeightCutoff(
40 pub u32,
42);
43
44impl<const W: usize> TruncationPolicy<W> for WeightCutoff {
45 #[inline]
46 fn keep_term(&self, x: &[u64; W], z: &[u64; W], _c: Complex64) -> bool {
47 let weight: u32 = (0..W).map(|i| (x[i] | z[i]).count_ones()).sum();
48 weight <= self.0
49 }
50}
51
52pub struct TopN(
63 pub usize,
66);
67
68impl<const W: usize> TruncationPolicy<W> for TopN {
69 fn finalize_layer(&self, sum: &mut PauliSum<W>) {
70 let n = self.0;
71 let len = sum.coeff.len();
72 if len <= n {
73 return;
74 }
75 if n == 0 {
76 sum.x.clear();
77 sum.z.clear();
78 sum.coeff.clear();
79 return;
80 }
81 let mut perm: Vec<usize> = (0..len).collect();
82 perm.select_nth_unstable_by(n - 1, |&a, &b| {
84 sum.coeff[b]
85 .norm()
86 .partial_cmp(&sum.coeff[a].norm())
87 .unwrap()
88 });
89 perm.truncate(n);
90 perm.sort_unstable();
93 let new_x: Vec<[u64; W]> = perm.iter().map(|&i| sum.x[i]).collect();
94 let new_z: Vec<[u64; W]> = perm.iter().map(|&i| sum.z[i]).collect();
95 let new_c: Vec<Complex64> = perm.iter().map(|&i| sum.coeff[i]).collect();
96 sum.x = new_x;
97 sum.z = new_z;
98 sum.coeff = new_c;
99 }
100}
101
102pub struct And<A, B>(
112 pub A,
114 pub B,
116);
117
118impl<const W: usize, A, B> TruncationPolicy<W> for And<A, B>
119where
120 A: TruncationPolicy<W>,
121 B: TruncationPolicy<W>,
122{
123 #[inline]
124 fn keep_term(&self, x: &[u64; W], z: &[u64; W], c: Complex64) -> bool {
125 self.0.keep_term(x, z, c) && self.1.keep_term(x, z, c)
126 }
127
128 fn finalize_layer(&self, sum: &mut PauliSum<W>) {
129 self.0.finalize_layer(sum);
130 self.1.finalize_layer(sum);
131 }
132}
133
134pub struct Or<A, B>(
149 pub A,
151 pub B,
153);
154
155impl<const W: usize, A, B> TruncationPolicy<W> for Or<A, B>
156where
157 A: TruncationPolicy<W>,
158 B: TruncationPolicy<W>,
159{
160 #[inline]
161 fn keep_term(&self, x: &[u64; W], z: &[u64; W], c: Complex64) -> bool {
162 self.0.keep_term(x, z, c) || self.1.keep_term(x, z, c)
163 }
164}
165
166#[cfg(all(test, debug_assertions))]
167mod tests {
168 use super::*;
169
170 #[test]
174 fn weight_cutoff_keeps_below_or_equal() {
175 let cut = WeightCutoff(2);
176 assert!(<WeightCutoff as TruncationPolicy<1>>::keep_term(
178 &cut,
179 &[0],
180 &[0],
181 Complex64::new(1.0, 0.0)
182 ));
183 assert!(<WeightCutoff as TruncationPolicy<1>>::keep_term(
185 &cut,
186 &[1],
187 &[0],
188 Complex64::new(1.0, 0.0)
189 ));
190 assert!(<WeightCutoff as TruncationPolicy<1>>::keep_term(
192 &cut,
193 &[0b01],
194 &[0b10],
195 Complex64::new(1.0, 0.0)
196 ));
197 assert!(!<WeightCutoff as TruncationPolicy<1>>::keep_term(
199 &cut,
200 &[0b011],
201 &[0b110],
202 Complex64::new(1.0, 0.0)
203 ));
204 }
205
206 #[test]
208 fn weight_cutoff_zero_keeps_only_identity() {
209 let cut = WeightCutoff(0);
210 assert!(<WeightCutoff as TruncationPolicy<1>>::keep_term(
211 &cut,
212 &[0],
213 &[0],
214 Complex64::new(1.0, 0.0)
215 ));
216 assert!(!<WeightCutoff as TruncationPolicy<1>>::keep_term(
218 &cut,
219 &[1],
220 &[0],
221 Complex64::new(1.0, 0.0)
222 ));
223 assert!(!<WeightCutoff as TruncationPolicy<1>>::keep_term(
224 &cut,
225 &[0],
226 &[1],
227 Complex64::new(1.0, 0.0)
228 ));
229 assert!(!<WeightCutoff as TruncationPolicy<1>>::keep_term(
230 &cut,
231 &[1],
232 &[1],
233 Complex64::new(1.0, 0.0)
234 ));
235 }
236
237 #[test]
239 fn weight_cutoff_w2_word_boundary() {
240 let cut = WeightCutoff(1);
241 assert!(<WeightCutoff as TruncationPolicy<2>>::keep_term(
243 &cut,
244 &[0u64, 1u64],
245 &[0u64, 0u64],
246 Complex64::new(1.0, 0.0)
247 ));
248 assert!(!<WeightCutoff as TruncationPolicy<2>>::keep_term(
250 &cut,
251 &[1u64, 1u64],
252 &[0u64, 0u64],
253 Complex64::new(1.0, 0.0)
254 ));
255 }
256
257 #[test]
260 fn top_n_keeps_largest_three_of_ten() {
261 let mut sum = PauliSum::<1> {
263 x: (1u64..=10).map(|i| [i]).collect(),
264 z: vec![[0u64]; 10],
265 coeff: (1u64..=10)
269 .rev()
270 .map(|m| Complex64::new(m as f64, 0.0))
271 .collect(),
272 num_qubits: 4,
273 };
274 sum.assert_invariants();
275 TopN(3).finalize_layer(&mut sum);
276 assert_eq!(sum.len(), 3);
277 assert_eq!(sum.x(), &[[1u64], [2u64], [3u64]]);
280 let mags: Vec<f64> = sum.coeff().iter().map(|c| c.norm()).collect();
281 assert_eq!(mags, vec![10.0, 9.0, 8.0]);
282 sum.assert_invariants();
283 }
284
285 #[test]
287 fn top_n_no_op_when_n_ge_len() {
288 let mut sum = PauliSum::<1> {
289 x: vec![[0], [0], [1]],
290 z: vec![[0], [1], [0]],
291 coeff: vec![
292 Complex64::new(1.0, 0.0),
293 Complex64::new(2.0, 0.0),
294 Complex64::new(3.0, 0.0),
295 ],
296 num_qubits: 1,
297 };
298 let snapshot_x = sum.x().to_vec();
299 let snapshot_z = sum.z().to_vec();
300 let snapshot_c = sum.coeff().to_vec();
301 TopN(5).finalize_layer(&mut sum);
302 assert_eq!(sum.x(), snapshot_x.as_slice());
303 assert_eq!(sum.z(), snapshot_z.as_slice());
304 assert_eq!(sum.coeff(), snapshot_c.as_slice());
305 }
306
307 #[test]
309 fn top_n_zero_empties_sum() {
310 let mut sum = PauliSum::<1> {
311 x: vec![[0], [1]],
312 z: vec![[1], [0]],
313 coeff: vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
314 num_qubits: 1,
315 };
316 TopN(0).finalize_layer(&mut sum);
317 assert!(sum.is_empty());
318 sum.assert_invariants();
319 }
320
321 #[test]
324 fn top_n_preserves_sort_order() {
325 let mut sum = PauliSum::<1> {
327 x: vec![[1], [2], [3], [4], [5]],
328 z: vec![[0]; 5],
329 coeff: vec![
330 Complex64::new(1.0, 0.0),
331 Complex64::new(2.0, 0.0),
332 Complex64::new(3.0, 0.0),
333 Complex64::new(4.0, 0.0),
334 Complex64::new(5.0, 0.0),
335 ],
336 num_qubits: 4,
337 };
338 sum.assert_invariants();
339 TopN(3).finalize_layer(&mut sum);
340 assert_eq!(sum.len(), 3);
341 assert_eq!(sum.x(), &[[3u64], [4u64], [5u64]]);
344 assert_eq!(
345 sum.coeff(),
346 &[
347 Complex64::new(3.0, 0.0),
348 Complex64::new(4.0, 0.0),
349 Complex64::new(5.0, 0.0),
350 ]
351 );
352 sum.assert_invariants();
353 }
354
355 #[test]
358 fn and_requires_both_keep() {
359 let policy = And(CoefficientThreshold(0.5), WeightCutoff(1));
360 assert!(<And<_, _> as TruncationPolicy<1>>::keep_term(
362 &policy,
363 &[1],
364 &[0],
365 Complex64::new(1.0, 0.0)
366 ));
367 assert!(!<And<_, _> as TruncationPolicy<1>>::keep_term(
369 &policy,
370 &[1],
371 &[0],
372 Complex64::new(0.1, 0.0)
373 ));
374 assert!(!<And<_, _> as TruncationPolicy<1>>::keep_term(
376 &policy,
377 &[0b01],
378 &[0b10],
379 Complex64::new(1.0, 0.0)
380 ));
381 }
382
383 #[test]
385 fn or_keeps_if_either() {
386 let policy = Or(CoefficientThreshold(0.5), WeightCutoff(0));
387 assert!(<Or<_, _> as TruncationPolicy<1>>::keep_term(
389 &policy,
390 &[0],
391 &[0],
392 Complex64::new(0.1, 0.0)
393 ));
394 assert!(<Or<_, _> as TruncationPolicy<1>>::keep_term(
396 &policy,
397 &[1],
398 &[0],
399 Complex64::new(1.0, 0.0)
400 ));
401 assert!(!<Or<_, _> as TruncationPolicy<1>>::keep_term(
403 &policy,
404 &[1],
405 &[0],
406 Complex64::new(0.1, 0.0)
407 ));
408 }
409}