1use num_traits::{One, Pow, Zero};
4use std::collections::BTreeMap;
5use std::fmt::Display;
6use std::iter::{Product, Sum};
7use std::ops::{Add, AddAssign, Mul, Neg};
8
9use derivative::Derivative;
10
11use super::rig::*;
12
13pub trait CommAlg: CommRing + Module<Ring = Self::R> {
15 type R: CommRing;
17
18 fn from_scalar(r: Self::R) -> Self {
23 Self::one() * r
24 }
25}
26
27#[derive(Clone, PartialEq, Eq, Derivative)]
39#[derivative(Default(bound = ""))]
40pub struct Polynomial<Var, Coef, Exp>(Combination<Monomial<Var, Exp>, Coef>);
41
42impl<Var, Coef, Exp> Polynomial<Var, Coef, Exp>
43where
44 Var: Ord,
45 Exp: Ord,
46{
47 pub fn generator(var: Var) -> Self
49 where
50 Coef: One,
51 Exp: One,
52 {
53 Polynomial::from_monomial(Monomial::generator(var))
54 }
55
56 pub fn from_monomial(m: Monomial<Var, Exp>) -> Self
58 where
59 Coef: One,
60 {
61 Polynomial(Combination::generator(m))
62 }
63
64 pub fn monomials(&self) -> impl ExactSizeIterator<Item = &Monomial<Var, Exp>> {
66 self.0.variables()
67 }
68
69 pub fn extend_scalars<NewCoef, F>(self, f: F) -> Polynomial<Var, NewCoef, Exp>
75 where
76 F: FnMut(Coef) -> NewCoef,
77 {
78 Polynomial(self.0.extend_scalars(f))
79 }
80
81 pub fn eval<A, F>(&self, f: F) -> A
83 where
84 A: Clone + Mul<Coef, Output = A> + Pow<Exp, Output = A> + Sum + Product,
85 F: Clone + FnMut(&Var) -> A,
86 Coef: Clone,
87 Exp: Clone,
88 {
89 self.0.eval_with_order(self.monomials().map(|m| m.eval(f.clone())))
90 }
91
92 pub fn eval_pairs<A>(&self, pairs: impl IntoIterator<Item = (Var, A)>) -> A
97 where
98 A: Clone + Mul<Coef, Output = A> + Pow<Exp, Output = A> + Sum + Product,
99 Coef: Clone,
100 Exp: Clone,
101 {
102 let map: BTreeMap<Var, A> = pairs.into_iter().collect();
103 self.eval(|var| map.get(var).cloned().unwrap())
104 }
105
106 pub fn map_variables<NewVar, F>(&self, mut f: F) -> Polynomial<NewVar, Coef, Exp>
114 where
115 Coef: Clone + Add<Output = Coef>,
116 Exp: Clone + Add<Output = Exp>,
117 NewVar: Clone + Ord,
118 F: FnMut(&Var) -> NewVar,
119 {
120 (&self.0)
121 .into_iter()
122 .map(|(coef, m)| (coef.clone(), m.map_variables(|var| f(var))))
123 .collect()
124 }
125
126 pub fn normalize(self) -> Self
132 where
133 Coef: Zero,
134 Exp: Zero,
135 {
136 self.0
137 .into_iter()
138 .filter_map(|(coef, m)| {
139 if coef.is_zero() {
140 None
141 } else {
142 Some((coef, m.normalize()))
143 }
144 })
145 .collect()
146 }
147}
148
149impl<Var, Coef, Exp> FromIterator<(Coef, Monomial<Var, Exp>)> for Polynomial<Var, Coef, Exp>
150where
151 Var: Ord,
152 Coef: Add<Output = Coef>,
153 Exp: Ord,
154{
155 fn from_iter<T: IntoIterator<Item = (Coef, Monomial<Var, Exp>)>>(iter: T) -> Self {
156 Polynomial(iter.into_iter().collect())
157 }
158}
159
160impl<Var, Coef, Exp> Display for Polynomial<Var, Coef, Exp>
161where
162 Var: Display,
163 Coef: Display + PartialEq + One,
164 Exp: Display + PartialEq + One,
165{
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 self.0.fmt(f)
168 }
169}
170
171impl<Var, Coef, Exp> AddAssign<(Coef, Monomial<Var, Exp>)> for Polynomial<Var, Coef, Exp>
175where
176 Var: Ord,
177 Coef: Add<Output = Coef>,
178 Exp: Ord,
179{
180 fn add_assign(&mut self, rhs: (Coef, Monomial<Var, Exp>)) {
181 self.0 += rhs;
182 }
183}
184
185impl<Var, Coef, Exp> Add for Polynomial<Var, Coef, Exp>
186where
187 Var: Ord,
188 Coef: Add<Output = Coef>,
189 Exp: Ord,
190{
191 type Output = Self;
192
193 fn add(self, rhs: Self) -> Self::Output {
194 Polynomial(self.0 + rhs.0)
195 }
196}
197
198impl<Var, Coef, Exp> Zero for Polynomial<Var, Coef, Exp>
199where
200 Var: Ord,
201 Coef: Add<Output = Coef> + Zero,
202 Exp: Ord,
203{
204 fn zero() -> Self {
205 Polynomial(Combination::default())
206 }
207
208 fn is_zero(&self) -> bool {
209 self.0.is_zero()
210 }
211}
212
213impl<Var, Coef, Exp> AdditiveMonoid for Polynomial<Var, Coef, Exp>
214where
215 Var: Ord,
216 Coef: AdditiveMonoid,
217 Exp: Ord,
218{
219}
220
221impl<Var, Coef, Exp> Mul<Coef> for Polynomial<Var, Coef, Exp>
222where
223 Var: Ord,
224 Coef: Clone + Default + Mul<Output = Coef>,
225 Exp: Ord,
226{
227 type Output = Self;
228
229 fn mul(self, a: Coef) -> Self::Output {
230 Polynomial(self.0 * a)
231 }
232}
233
234impl<Var, Coef, Exp> RigModule for Polynomial<Var, Coef, Exp>
235where
236 Var: Ord,
237 Coef: Clone + Default + CommRig,
238 Exp: Ord,
239{
240 type Rig = Coef;
241}
242
243impl<Var, Coef, Exp> Neg for Polynomial<Var, Coef, Exp>
244where
245 Var: Ord,
246 Coef: Default + Neg<Output = Coef>,
247 Exp: Ord,
248{
249 type Output = Self;
250
251 fn neg(self) -> Self::Output {
252 Polynomial(self.0.neg())
253 }
254}
255
256impl<Var, Coef, Exp> AbGroup for Polynomial<Var, Coef, Exp>
257where
258 Var: Ord,
259 Coef: Default + AbGroup,
260 Exp: Ord,
261{
262}
263
264impl<Var, Coef, Exp> Module for Polynomial<Var, Coef, Exp>
265where
266 Var: Ord,
267 Coef: Clone + Default + CommRing,
268 Exp: Ord,
269{
270 type Ring = Coef;
271}
272
273impl<Var, Coef, Exp> Mul for Polynomial<Var, Coef, Exp>
275where
276 Var: Clone + Ord,
277 Coef: Clone + Add<Output = Coef> + Mul<Output = Coef>,
278 Exp: Clone + Ord + Add<Output = Exp>,
279{
280 type Output = Self;
281
282 fn mul(self, rhs: Self) -> Self::Output {
283 let mut result = Polynomial::default();
286 let (outer, inner) = (self.0, rhs.0);
287 let mut outer_iter = outer.into_iter();
288 while let Some((a, m)) = outer_iter.next() {
289 if outer_iter.len() == 0 {
290 let mut inner_iter = inner.into_iter();
291 while let Some((b, n)) = inner_iter.next() {
292 if inner_iter.len() == 0 {
293 result += (a * b, m * n);
294 break;
295 } else {
296 result += (a.clone() * b, m.clone() * n);
297 }
298 }
299 break;
300 } else {
301 for (b, n) in &inner {
302 result += (a.clone() * b.clone(), m.clone() * n.clone());
303 }
304 }
305 }
306 result
307 }
308}
309
310impl<Var, Coef, Exp> One for Polynomial<Var, Coef, Exp>
311where
312 Var: Clone + Ord,
313 Coef: Clone + Add<Output = Coef> + One,
314 Exp: Clone + Ord + Add<Output = Exp>,
315{
316 fn one() -> Self {
317 Polynomial::from_monomial(Default::default())
318 }
319}
320
321impl<Var, Coef, Exp> Monoid for Polynomial<Var, Coef, Exp>
322where
323 Var: Clone + Ord,
324 Coef: Clone + Rig,
325 Exp: Clone + Ord + AdditiveMonoid,
326{
327}
328
329impl<Var, Coef, Exp> Rig for Polynomial<Var, Coef, Exp>
330where
331 Var: Clone + Ord,
332 Coef: Clone + Rig,
333 Exp: Clone + Ord + AdditiveMonoid,
334{
335}
336
337impl<Var, Coef, Exp> Ring for Polynomial<Var, Coef, Exp>
338where
339 Var: Clone + Ord,
340 Coef: Clone + Default + Ring,
341 Exp: Clone + Ord + AdditiveMonoid,
342{
343}
344
345impl<Var, Coef, Exp> CommMonoid for Polynomial<Var, Coef, Exp>
346where
347 Var: Clone + Ord,
348 Coef: Clone + CommRig,
349 Exp: Clone + Ord + AdditiveMonoid,
350{
351}
352
353impl<Var, Coef, Exp> CommRig for Polynomial<Var, Coef, Exp>
354where
355 Var: Clone + Ord,
356 Coef: Clone + CommRig,
357 Exp: Clone + Ord + AdditiveMonoid,
358{
359}
360
361impl<Var, Coef, Exp> CommRing for Polynomial<Var, Coef, Exp>
362where
363 Var: Clone + Ord,
364 Coef: Clone + Default + CommRing,
365 Exp: Clone + Ord + AdditiveMonoid,
366{
367}
368
369impl<Var, Coef, Exp> CommAlg for Polynomial<Var, Coef, Exp>
370where
371 Var: Clone + Ord,
372 Coef: Clone + Default + CommRing,
373 Exp: Clone + Ord + AdditiveMonoid,
374{
375 type R = Coef;
376
377 fn from_scalar(r: Self::R) -> Self {
378 [(r, Monomial::one())].into_iter().collect()
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn polynomials() {
388 let x = || Polynomial::<_, i32, u32>::generator('x');
389 let y = || Polynomial::<_, i32, u32>::generator('y');
390 assert_eq!(x().to_string(), "x");
391
392 let p = Polynomial::<char, i32, u32>::from_scalar(-5);
393 assert_eq!(p.eval_pairs::<i32>([]), -5);
394
395 let p = x() * y() * x() * 2 + y() * x() * y() * 3;
396 assert_eq!(p.to_string(), "3 x y^2 + 2 x^2 y");
397 assert_eq!(p.map_variables(|_| 'x').to_string(), "5 x^3");
398 assert_eq!(p.eval_pairs([('x', 1), ('y', 1)]), 5);
399 assert_eq!(p.eval_pairs([('x', 1), ('y', 2)]), 16);
400 assert_eq!(p.eval_pairs([('y', 1), ('x', 2)]), 14);
401
402 let p = (x() + y()) * (x() + y());
403 assert_eq!(p.to_string(), "2 x y + x^2 + y^2");
404
405 let p = (x() + y()) * (x() + y().neg());
406 assert_eq!(p.normalize().to_string(), "x^2 + (-1) y^2");
407 }
408}