1use num_complex::Complex64;
31use std::ops::{Add, AddAssign};
32
33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)]
39#[repr(transparent)]
40pub struct Phase(u8);
41
42impl Phase {
43 pub const ONE: Self = Self(0);
45 pub const I: Self = Self(1);
47 pub const MINUS_ONE: Self = Self(2);
49 pub const MINUS_I: Self = Self(3);
51
52 #[inline]
55 pub const fn new(k: u8) -> Self {
56 Self(k & 3)
57 }
58
59 #[inline]
61 pub const fn exponent(self) -> u8 {
62 self.0
63 }
64
65 #[inline]
67 pub fn to_complex(self) -> Complex64 {
68 match self.0 {
69 0 => Complex64::new(1.0, 0.0),
70 1 => Complex64::new(0.0, 1.0),
71 2 => Complex64::new(-1.0, 0.0),
72 3 => Complex64::new(0.0, -1.0),
73 _ => unreachable!(),
74 }
75 }
76
77 #[inline]
80 pub fn apply(self, c: Complex64) -> Complex64 {
81 match self.0 {
82 0 => c,
83 1 => Complex64::new(-c.im, c.re),
84 2 => Complex64::new(-c.re, -c.im),
85 3 => Complex64::new(c.im, -c.re),
86 _ => unreachable!(),
87 }
88 }
89}
90
91impl Add for Phase {
92 type Output = Phase;
93 #[inline]
94 fn add(self, other: Phase) -> Phase {
95 Phase((self.0 + other.0) & 3)
98 }
99}
100
101impl AddAssign for Phase {
102 #[inline]
103 fn add_assign(&mut self, other: Phase) {
104 *self = *self + other;
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn constants_have_expected_exponents() {
114 assert_eq!(Phase::ONE.exponent(), 0);
115 assert_eq!(Phase::I.exponent(), 1);
116 assert_eq!(Phase::MINUS_ONE.exponent(), 2);
117 assert_eq!(Phase::MINUS_I.exponent(), 3);
118 }
119
120 #[test]
121 fn new_reduces_mod_4() {
122 assert_eq!(Phase::new(0), Phase::ONE);
123 assert_eq!(Phase::new(1), Phase::I);
124 assert_eq!(Phase::new(4), Phase::ONE);
125 assert_eq!(Phase::new(5), Phase::I);
126 assert_eq!(Phase::new(255), Phase::MINUS_I); }
128
129 #[test]
130 fn to_complex_matches_i_powers() {
131 assert_eq!(Phase::ONE.to_complex(), Complex64::new(1.0, 0.0));
132 assert_eq!(Phase::I.to_complex(), Complex64::new(0.0, 1.0));
133 assert_eq!(Phase::MINUS_ONE.to_complex(), Complex64::new(-1.0, 0.0));
134 assert_eq!(Phase::MINUS_I.to_complex(), Complex64::new(0.0, -1.0));
135 }
136
137 #[test]
138 fn apply_agrees_with_to_complex_times_c() {
139 let c = Complex64::new(2.0, 3.0);
140 for p in [Phase::ONE, Phase::I, Phase::MINUS_ONE, Phase::MINUS_I] {
141 assert_eq!(p.apply(c), p.to_complex() * c);
142 }
143 }
144
145 #[test]
146 fn add_wraps_mod_4() {
147 assert_eq!(Phase::I + Phase::I, Phase::MINUS_ONE);
148 assert_eq!(Phase::MINUS_ONE + Phase::I, Phase::MINUS_I);
149 assert_eq!(Phase::MINUS_I + Phase::I, Phase::ONE);
150 assert_eq!(Phase::MINUS_I + Phase::MINUS_I, Phase::MINUS_ONE);
151 }
152
153 #[test]
154 fn add_assign_wraps_mod_4() {
155 let mut p = Phase::I;
156 p += Phase::I;
157 assert_eq!(p, Phase::MINUS_ONE);
158 p += Phase::MINUS_I;
159 assert_eq!(p, Phase::I);
160 }
161
162 #[test]
163 fn repr_transparent_size_is_one_byte() {
164 use std::mem::size_of;
165 assert_eq!(size_of::<Phase>(), 1);
166 }
167}