paulistrings/engine/
mod.rs

1//! The propagation engine: front door [`propagate`], pipeline in
2//! [`sort_merge`].
3//!
4//! See design doc §8 (loop) and §5 (sort-merge pipeline).
5
6pub mod sort_merge;
7
8use crate::channel::Channel;
9use crate::circuit::Circuit;
10use crate::pauli_sum::PauliSum;
11use crate::truncation::TruncationPolicy;
12
13/// Propagation direction.
14///
15/// [`Direction::Forward`] applies channels in order; [`Direction::Heisenberg`]
16/// iterates in reverse and applies adjoints (for backpropagating
17/// observables).
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub enum Direction {
20    /// Apply channels in the order they were pushed onto the [`Circuit`].
21    Forward,
22    /// Apply channels in reverse order, using each channel's
23    /// [`Channel::apply_adjoint`].
24    Heisenberg,
25}
26
27/// Propagate `initial` through `circuit` under `policy`.
28///
29/// Iterates the circuit's channels — in order for [`Direction::Forward`], in
30/// reverse for [`Direction::Heisenberg`], calling [`Channel::apply_adjoint`]
31/// in the latter case (default = self-adjoint; overridden on
32/// [`PauliRotation`](crate::channel::PauliRotation) and
33/// [`Clifford1Q`](crate::channel::Clifford1Q)).
34///
35/// # Examples
36///
37/// ```
38/// use paulistrings::{
39///     BuildAccumulator, Circuit, Direction, PauliString, Phase, TruncationPolicy,
40///     channel::Clifford1Q, propagate,
41/// };
42/// use num_complex::Complex64;
43///
44/// let mut acc = BuildAccumulator::<1>::new(1);
45/// acc.add_term(PauliString::<1>::z(0), Phase::ONE, Complex64::new(1.0, 0.0));
46/// let observable = acc.finalize();
47///
48/// let mut circuit = Circuit::<1>::new(1);
49/// circuit.push(Clifford1Q::h(0));
50///
51/// struct KeepAll;
52/// impl<const W: usize> TruncationPolicy<W> for KeepAll {}
53///
54/// // H conjugates Z to X, so propagating Z₀ through H gives X₀.
55/// let evolved = propagate(&circuit, observable, &KeepAll, Direction::Heisenberg);
56/// assert_eq!(evolved.len(), 1);
57/// assert_eq!(evolved.x()[0], [1]);
58/// assert_eq!(evolved.z()[0], [0]);
59/// ```
60///
61/// See design doc §8.1.
62pub fn propagate<const W: usize, T>(
63    circuit: &Circuit<W>,
64    initial: PauliSum<W>,
65    policy: &T,
66    direction: Direction,
67) -> PauliSum<W>
68where
69    T: TruncationPolicy<W>,
70{
71    let mut sum = initial;
72    let n = circuit.channels.len();
73    for k in 0..n {
74        let idx = match direction {
75            Direction::Forward => k,
76            Direction::Heisenberg => n - 1 - k,
77        };
78        let ch: &dyn Channel<W> = circuit.channels[idx].as_ref();
79        sum = match direction {
80            Direction::Forward => sort_merge::apply_layer(&sum, ch, policy),
81            Direction::Heisenberg => sort_merge::apply_layer_adjoint(&sum, ch, policy),
82        };
83        policy.finalize_layer(&mut sum);
84    }
85    sum
86}