1use itertools::Itertools;
4use num_traits::{One, Pow, Zero, one, zero};
5use std::collections::BTreeMap;
6use std::fmt::Display;
7use std::iter::{Product, Sum};
8use std::ops::{Add, AddAssign, Mul, Neg};
9
10use derivative::Derivative;
11
12use super::rig::*;
13
14pub trait CommAlg: CommRing + Module<Ring = Self::R> {
16 type R: CommRing;
18
19 fn from_scalar(r: Self::R) -> Self {
24 Self::one() * r
25 }
26}
27
28#[derive(Clone, PartialEq, Eq, Derivative, Debug)]
40#[derivative(Default(bound = ""))]
41pub struct Polynomial<Var, Coef, Exp>(Combination<Monomial<Var, Exp>, Coef>);
42
43impl<Var, Coef, Exp> Polynomial<Var, Coef, Exp>
44where
45 Var: Ord,
46 Exp: Ord,
47{
48 pub fn generator(var: Var) -> Self
50 where
51 Coef: One,
52 Exp: One,
53 {
54 Polynomial::from_monomial(Monomial::generator(var))
55 }
56
57 pub fn from_monomial(m: Monomial<Var, Exp>) -> Self
59 where
60 Coef: One,
61 {
62 Polynomial(Combination::generator(m))
63 }
64
65 pub fn monomials(&self) -> impl ExactSizeIterator<Item = &Monomial<Var, Exp>> {
67 self.0.variables()
68 }
69
70 pub fn extend_scalars<NewCoef, F>(self, f: F) -> Polynomial<Var, NewCoef, Exp>
76 where
77 F: FnMut(Coef) -> NewCoef,
78 {
79 Polynomial(self.0.extend_scalars(f))
80 }
81
82 pub fn eval<A, F>(&self, f: F) -> A
84 where
85 A: Clone + Mul<Coef, Output = A> + Pow<Exp, Output = A> + Sum + Product,
86 F: Clone + FnMut(&Var) -> A,
87 Coef: Clone,
88 Exp: Clone,
89 {
90 self.0.eval_with_order(self.monomials().map(|m| m.eval(f.clone())))
91 }
92
93 pub fn eval_pairs<A>(&self, pairs: impl IntoIterator<Item = (Var, A)>) -> A
98 where
99 A: Clone + Mul<Coef, Output = A> + Pow<Exp, Output = A> + Sum + Product,
100 Coef: Clone,
101 Exp: Clone,
102 {
103 let map: BTreeMap<Var, A> = pairs.into_iter().collect();
104 self.eval(|var| map.get(var).cloned().unwrap())
105 }
106
107 pub fn map_variables<NewVar, F>(&self, mut f: F) -> Polynomial<NewVar, Coef, Exp>
115 where
116 Coef: Clone + Add<Output = Coef>,
117 Exp: Clone + Add<Output = Exp>,
118 NewVar: Clone + Ord,
119 F: FnMut(&Var) -> NewVar,
120 {
121 (&self.0)
122 .into_iter()
123 .map(|(coef, m)| (coef.clone(), m.map_variables(|var| f(var))))
124 .collect()
125 }
126
127 pub fn normalize(self) -> Self
133 where
134 Coef: Zero,
135 Exp: Zero,
136 {
137 self.0
138 .into_iter()
139 .filter_map(|(coef, m)| {
140 if coef.is_zero() {
141 None
142 } else {
143 Some((coef, m.normalize()))
144 }
145 })
146 .collect()
147 }
148}
149
150impl<Var, Coef, Exp> Polynomial<Var, Coef, Exp>
151where
152 Var: Display,
153 Coef: Display + DisplayCoef + Clone + PartialEq + One + Neg<Output = Coef>,
154 Exp: Display + PartialEq + One,
155{
156 pub fn to_latex(&self) -> String {
158 let fmt_term = |coef: &Coef, monomial: &Monomial<Var, Exp>| -> String {
159 let monomial = monomial.to_latex();
160 if coef.is_one() {
161 monomial
162 } else if *coef == Coef::one().neg() {
163 format!("-{monomial}")
164 } else if coef.needs_parentheses() {
165 format!("({coef}) \\cdot {monomial}")
166 } else {
167 format!("{coef} \\cdot {monomial}")
168 }
169 };
170
171 let mut terms = (&self.0).into_iter();
172 let Some((coef, monomial)) = terms.next() else {
173 return "0".to_string();
174 };
175 let mut output = fmt_term(coef, monomial);
176 for (coef, monomial) in terms {
177 if coef.has_negative_sign() {
178 output.push_str(" - ");
179 output.push_str(&fmt_term(&coef.clone().neg(), monomial));
180 } else {
181 output.push_str(" + ");
182 output.push_str(&fmt_term(coef, monomial));
183 }
184 }
185 output
186 }
187}
188
189impl<Var, Coef, Exp> FromIterator<(Coef, Monomial<Var, Exp>)> for Polynomial<Var, Coef, Exp>
190where
191 Var: Ord,
192 Coef: Add<Output = Coef>,
193 Exp: Ord,
194{
195 fn from_iter<T: IntoIterator<Item = (Coef, Monomial<Var, Exp>)>>(iter: T) -> Self {
196 Polynomial(iter.into_iter().collect())
197 }
198}
199
200impl<Var, Coef, Exp> Display for Polynomial<Var, Coef, Exp>
201where
202 Combination<Monomial<Var, Exp>, Coef>: Display,
203{
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 write!(f, "{}", self.0)
206 }
207}
208
209impl<Var, Coef, Exp> DisplayCoef for Polynomial<Var, Coef, Exp>
210where
211 Var: Ord,
212 Coef: DisplayCoef,
213 Exp: Ord,
214{
215 fn has_negative_sign(&self) -> bool {
216 if let Some(((coef, _),)) = (&self.0).into_iter().collect_tuple() {
217 coef.has_negative_sign()
218 } else {
219 false
220 }
221 }
222
223 fn needs_parentheses(&self) -> bool {
224 self.0.len() != 1
225 }
226}
227
228impl<Var, Coef, Exp> Sum for Polynomial<Var, Coef, Exp>
232where
233 Var: Ord,
234 Coef: AdditiveMonoid,
235 Exp: Ord,
236{
237 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
238 iter.fold(zero(), |acc, x| acc + x)
239 }
240}
241
242impl<Var, Coef, Exp> AddAssign<(Coef, Monomial<Var, Exp>)> for Polynomial<Var, Coef, Exp>
243where
244 Var: Ord,
245 Coef: Add<Output = Coef>,
246 Exp: Ord,
247{
248 fn add_assign(&mut self, rhs: (Coef, Monomial<Var, Exp>)) {
249 self.0 += rhs;
250 }
251}
252
253impl<Var, Coef, Exp> Add for Polynomial<Var, Coef, Exp>
254where
255 Var: Ord,
256 Coef: Add<Output = Coef>,
257 Exp: Ord,
258{
259 type Output = Self;
260
261 fn add(self, rhs: Self) -> Self::Output {
262 Polynomial(self.0 + rhs.0)
263 }
264}
265
266impl<Var, Coef, Exp> Add<Coef> for Polynomial<Var, Coef, Exp>
267where
268 Var: Ord,
269 Coef: Add<Output = Coef>,
270 Exp: Ord + Zero,
271{
272 type Output = Polynomial<Var, Coef, Exp>;
273 fn add(mut self, a: Coef) -> Self::Output {
274 self += (a, one());
275 self
276 }
277}
278
279impl<Var, Coef, Exp> Zero for Polynomial<Var, Coef, Exp>
280where
281 Var: Ord,
282 Coef: Zero,
283 Exp: Ord,
284{
285 fn zero() -> Self {
286 Polynomial(Combination::default())
287 }
288
289 fn is_zero(&self) -> bool {
290 self.0.is_zero()
291 }
292}
293
294impl<Var, Coef, Exp> AdditiveMonoid for Polynomial<Var, Coef, Exp>
295where
296 Var: Ord,
297 Coef: AdditiveMonoid,
298 Exp: Ord,
299{
300}
301
302impl<Var, Coef, Exp> Mul<Coef> for Polynomial<Var, Coef, Exp>
303where
304 Var: Ord,
305 Coef: Clone + Default + Mul<Output = Coef>,
306 Exp: Ord,
307{
308 type Output = Self;
309
310 fn mul(self, a: Coef) -> Self::Output {
311 Polynomial(self.0 * a)
312 }
313}
314
315impl<Var, Coef, Exp> RigModule for Polynomial<Var, Coef, Exp>
316where
317 Var: Ord,
318 Coef: Clone + Default + CommRig,
319 Exp: Ord,
320{
321 type Rig = Coef;
322}
323
324impl<Var, Coef, Exp> Neg for Polynomial<Var, Coef, Exp>
325where
326 Var: Ord,
327 Coef: Default + Neg<Output = Coef>,
328 Exp: Ord,
329{
330 type Output = Self;
331
332 fn neg(self) -> Self::Output {
333 Polynomial(self.0.neg())
334 }
335}
336
337impl<Var, Coef, Exp> AbGroup for Polynomial<Var, Coef, Exp>
338where
339 Var: Ord,
340 Coef: Default + AbGroup,
341 Exp: Ord,
342{
343}
344
345impl<Var, Coef, Exp> Module for Polynomial<Var, Coef, Exp>
346where
347 Var: Ord,
348 Coef: Clone + Default + CommRing,
349 Exp: Ord,
350{
351 type Ring = Coef;
352}
353
354impl<Var, Coef, Exp> Mul for Polynomial<Var, Coef, Exp>
356where
357 Var: Clone + Ord,
358 Coef: Clone + Add<Output = Coef> + Mul<Output = Coef>,
359 Exp: Clone + Ord + Add<Output = Exp>,
360{
361 type Output = Self;
362
363 fn mul(self, rhs: Self) -> Self::Output {
364 let mut result = Polynomial::default();
367 let (outer, inner) = (self.0, rhs.0);
368 let mut outer_iter = outer.into_iter();
369 while let Some((a, m)) = outer_iter.next() {
370 if outer_iter.len() == 0 {
371 let mut inner_iter = inner.into_iter();
372 while let Some((b, n)) = inner_iter.next() {
373 if inner_iter.len() == 0 {
374 result += (a * b, m * n);
375 break;
376 } else {
377 result += (a.clone() * b, m.clone() * n);
378 }
379 }
380 break;
381 } else {
382 for (b, n) in &inner {
383 result += (a.clone() * b.clone(), m.clone() * n.clone());
384 }
385 }
386 }
387 result
388 }
389}
390
391impl<Var, Coef, Exp> One for Polynomial<Var, Coef, Exp>
392where
393 Var: Clone + Ord,
394 Coef: Clone + Add<Output = Coef> + One,
395 Exp: Clone + Ord + Add<Output = Exp>,
396{
397 fn one() -> Self {
398 Polynomial::from_monomial(Default::default())
399 }
400}
401
402impl<Var, Coef, Exp> Monoid for Polynomial<Var, Coef, Exp>
403where
404 Var: Clone + Ord,
405 Coef: Clone + Rig,
406 Exp: Clone + Ord + AdditiveMonoid,
407{
408}
409
410impl<Var, Coef, Exp> Rig for Polynomial<Var, Coef, Exp>
411where
412 Var: Clone + Ord,
413 Coef: Clone + Rig,
414 Exp: Clone + Ord + AdditiveMonoid,
415{
416}
417
418impl<Var, Coef, Exp> Ring for Polynomial<Var, Coef, Exp>
419where
420 Var: Clone + Ord,
421 Coef: Clone + Default + Ring,
422 Exp: Clone + Ord + AdditiveMonoid,
423{
424}
425
426impl<Var, Coef, Exp> CommMonoid for Polynomial<Var, Coef, Exp>
427where
428 Var: Clone + Ord,
429 Coef: Clone + CommRig,
430 Exp: Clone + Ord + AdditiveMonoid,
431{
432}
433
434impl<Var, Coef, Exp> CommRig for Polynomial<Var, Coef, Exp>
435where
436 Var: Clone + Ord,
437 Coef: Clone + CommRig,
438 Exp: Clone + Ord + AdditiveMonoid,
439{
440}
441
442impl<Var, Coef, Exp> CommRing for Polynomial<Var, Coef, Exp>
443where
444 Var: Clone + Ord,
445 Coef: Clone + Default + CommRing,
446 Exp: Clone + Ord + AdditiveMonoid,
447{
448}
449
450impl<Var, Coef, Exp> CommAlg for Polynomial<Var, Coef, Exp>
451where
452 Var: Clone + Ord,
453 Coef: Clone + Default + CommRing,
454 Exp: Clone + Ord + AdditiveMonoid,
455{
456 type R = Coef;
457
458 fn from_scalar(r: Self::R) -> Self {
459 [(r, Monomial::one())].into_iter().collect()
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn polynomials() {
469 let x = || Polynomial::<_, i32, u32>::generator('x');
470 let y = || Polynomial::<_, i32, u32>::generator('y');
471 assert_eq!(x().to_string(), "x");
472
473 let p = Polynomial::<char, i32, u32>::from_scalar(-5);
474 assert_eq!(p.eval_pairs::<i32>([]), -5);
475
476 let p = x() * y() * x() * 2 + y() * x() * y() * 3;
477 assert_eq!(p.to_string(), "3 x y^2 + 2 x^2 y");
478 assert_eq!(p.map_variables(|_| 'x').to_string(), "5 x^3");
479 assert_eq!(p.eval_pairs([('x', 1), ('y', 1)]), 5);
480 assert_eq!(p.eval_pairs([('x', 1), ('y', 2)]), 16);
481 assert_eq!(p.eval_pairs([('y', 1), ('x', 2)]), 14);
482
483 let p = (x() + y()) * (x() + y());
484 assert_eq!(p.to_string(), "2 x y + x^2 + y^2");
485
486 let p = (x() + y()) * (x() + y().neg());
487 assert_eq!(p.normalize().to_string(), "x^2 - y^2");
488 }
489}