1use std::collections::BTreeMap;
4use std::fmt::Display;
5use std::ops::Add;
6
7use derivative::Derivative;
8use nalgebra::DVector;
9use num_traits::{One, Pow};
10
11#[cfg(test)]
12use super::ODEProblem;
13use super::ODESystem;
14use crate::zero::alg::Polynomial;
15
16#[derive(Clone, Derivative)]
18#[derivative(Default(bound = ""))]
19pub struct PolynomialSystem<Var, Coef, Exp> {
20 pub components: BTreeMap<Var, Polynomial<Var, Coef, Exp>>,
22}
23
24impl<Var, Coef, Exp> PolynomialSystem<Var, Coef, Exp>
25where
26 Var: Ord,
27 Exp: Ord,
28{
29 pub fn new() -> Self {
31 Default::default()
32 }
33
34 pub fn add_term(&mut self, var: Var, term: Polynomial<Var, Coef, Exp>)
36 where
37 Coef: Add<Output = Coef>,
38 {
39 if let Some(component) = self.components.get_mut(&var) {
40 *component = std::mem::take(component) + term;
41 } else {
42 self.components.insert(var, term);
43 }
44 }
45
46 pub fn extend_scalars<NewCoef, F>(self, f: F) -> PolynomialSystem<Var, NewCoef, Exp>
48 where
49 F: Clone + FnMut(Coef) -> NewCoef,
50 {
51 let components = self
52 .components
53 .into_iter()
54 .map(|(var, poly)| (var, poly.extend_scalars(f.clone())))
55 .collect();
56 PolynomialSystem { components }
57 }
58}
59
60impl<Var, Exp> PolynomialSystem<Var, f32, Exp>
61where
62 Var: Clone + Ord,
63 Exp: Clone + Ord + Add<Output = Exp>,
64{
65 pub fn to_numerical(&self) -> NumericalPolynomialSystem<Exp> {
71 let indices: BTreeMap<Var, usize> =
72 self.components.keys().enumerate().map(|(i, var)| (var.clone(), i)).collect();
73 let components = self
74 .components
75 .values()
76 .map(|poly| poly.map_variables(|var| *indices.get(var).unwrap()))
77 .collect();
78 NumericalPolynomialSystem { components }
79 }
80}
81
82impl<Var, Coef, Exp> Display for PolynomialSystem<Var, Coef, Exp>
83where
84 Var: Display,
85 Coef: Display + PartialEq + One,
86 Exp: Display + PartialEq + One,
87{
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 for (var, component) in self.components.iter() {
90 writeln!(f, "d{} = {}", var, component)?;
91 }
92 Ok(())
93 }
94}
95
96impl<Var, Coef, Exp> FromIterator<(Var, Polynomial<Var, Coef, Exp>)>
97 for PolynomialSystem<Var, Coef, Exp>
98where
99 Var: Ord,
100 Coef: Add<Output = Coef>,
101 Exp: Ord,
102{
103 fn from_iter<T: IntoIterator<Item = (Var, Polynomial<Var, Coef, Exp>)>>(iter: T) -> Self {
104 let mut system: Self = Default::default();
105 for (var, term) in iter {
106 system.add_term(var, term);
107 }
108 system
109 }
110}
111
112pub struct NumericalPolynomialSystem<Exp> {
118 pub components: Vec<Polynomial<usize, f32, Exp>>,
120}
121
122impl<Exp> ODESystem for NumericalPolynomialSystem<Exp>
123where
124 Exp: Clone + Ord,
125 f32: Pow<Exp, Output = f32>,
126{
127 fn vector_field(&self, dx: &mut DVector<f32>, x: &DVector<f32>, _t: f32) {
128 for i in 0..dx.len() {
129 dx[i] = self.components[i].eval(|var| x[*var])
130 }
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use expect_test::expect;
137
138 use super::super::textplot_ode_result;
139 use super::*;
140
141 type Parameter<Id> = Polynomial<Id, f32, u8>;
142
143 #[test]
144 fn sir() {
145 let param = |c: char| Parameter::<_>::generator(c);
146 let var = |c: char| Polynomial::<_, Parameter<_>, u8>::generator(c);
147 let terms = [
148 ('S', -var('S') * var('I') * param('β')),
149 ('I', var('S') * var('I') * param('β')),
150 ('I', -var('I') * param('γ')),
151 ('R', var('I') * param('γ')),
152 ];
153 let sys: PolynomialSystem<_, _, _> = terms.into_iter().collect();
154 let expected = expect![[r#"
155 dI = ((-1) γ) I + β I S
156 dR = γ I
157 dS = ((-1) β) I S
158 "#]];
159 expected.assert_eq(&sys.to_string());
160
161 let sys = sys.extend_scalars(|p| p.eval(|_| 1.0));
162 let expected = expect![[r#"
163 dI = (-1) I + I S
164 dR = I
165 dS = (-1) I S
166 "#]];
167 expected.assert_eq(&sys.to_string());
168
169 let initial = DVector::from_column_slice(&[1.0, 0.0, 4.0]);
170 let problem = ODEProblem::new(sys.to_numerical(), initial).end_time(5.0);
171 let result = problem.solve_rk4(0.1).unwrap();
172 let expected = expect![[r#"
173 ⡁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⠤⠤⠤⠒⠒⠒⠒⠒⠉⠉⠉⠉⠁ 4.9
174 ⠄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⠤⠒⠒⠉⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
175 ⠂⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
176 ⡁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠤⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
177 ⢇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠒⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
178 ⠚⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
179 ⡁⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
180 ⠄⠘⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
181 ⠂⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
182 ⡁⠀⠘⡄⠀⢀⠤⠒⠤⡀⠀⢠⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
183 ⠄⠀⠀⢣⡔⠁⠀⠀⠀⠈⢦⠃⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
184 ⠂⠀⠀⡜⡄⠀⠀⠀⠀⢠⠃⠑⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
185 ⡁⠀⡸⠀⢣⠀⠀⠀⢠⠃⠀⠀⠀⠣⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
186 ⠄⢠⠃⠀⠘⡄⠀⢠⠃⠀⠀⠀⠀⠀⠈⠢⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
187 ⢂⠇⠀⠀⠀⠱⣠⠃⠀⠀⠀⠀⠀⠀⠀⠀⠈⠢⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
188 ⡝⠀⠀⠀⠀⢠⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠢⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
189 ⠅⠀⠀⠀⢠⠃⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
190 ⠂⠀⠀⢠⠃⠀⠀⠀⠣⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
191 ⡁⠀⡠⠃⠀⠀⠀⠀⠀⠈⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⠒⠤⠤⣀⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
192 ⢄⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⠒⠤⠤⠤⠤⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣉⣉⣒⣒⣒⣒⣤⣤⣤⣤⠤⣀⣀⣀⣀⡀
193 ⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠁⠈⠀⠉⠉⠉⠉⠉⠁ 0.0
194 0.0 5.0
195 "#]];
196 expected.assert_eq(&textplot_ode_result(&problem, &result));
197 }
198}