catlog/simulate/ode/
mod.rs1use nalgebra::DVector;
4use ode_solvers::{
5 self,
6 dop_shared::{IntegrationError, SolverResult},
7};
8
9#[cfg(test)]
10use textplots::{Chart, Plot, Shape};
11
12pub trait ODESystem {
17 fn vector_field(&self, dx: &mut DVector<f32>, x: &DVector<f32>, t: f32);
19
20 fn eval_vector_field(&self, x: &DVector<f32>, t: f32) -> DVector<f32> {
22 let mut dx = DVector::from_element(x.len(), 0.0f32);
23 self.vector_field(&mut dx, x, t);
24 dx
25 }
26}
27
28#[derive(Clone, Debug, PartialEq)]
34pub struct ODEProblem<Sys> {
35 pub(crate) system: Sys,
36 pub(crate) initial_values: DVector<f32>,
37 pub(crate) start_time: f32,
38 pub(crate) end_time: f32,
39 rtol: f32,
40 atol: f32,
41}
42
43impl<Sys> ODEProblem<Sys> {
44 pub fn new(system: Sys, initial_values: DVector<f32>) -> Self {
46 ODEProblem {
47 system,
48 initial_values,
49 start_time: 0.0,
50 end_time: 0.0,
51 rtol: 0.001,
53 atol: 1e-6,
54 }
55 }
56
57 pub fn start_time(mut self, t: f32) -> Self {
59 self.start_time = t;
60 self
61 }
62
63 pub fn end_time(mut self, t: f32) -> Self {
65 self.end_time = t;
66 self
67 }
68
69 pub fn time_span(mut self, tspan: (f32, f32)) -> Self {
71 (self.start_time, self.end_time) = tspan;
72 self
73 }
74}
75
76impl<Sys> ODEProblem<Sys>
77where
78 Sys: ODESystem,
79{
80 pub fn solve_rk4(
85 &self,
86 step_size: f32,
87 ) -> Result<SolverResult<f32, DVector<f32>>, IntegrationError> {
88 let mut stepper = ode_solvers::Rk4::new(
89 self,
90 self.start_time,
91 self.initial_values.clone(),
92 self.end_time,
93 step_size,
94 );
95 stepper.integrate()?;
96 Ok(stepper.into())
97 }
98
99 pub fn solve_dopri5(
105 &self,
106 output_step_size: f32,
107 ) -> Result<SolverResult<f32, DVector<f32>>, IntegrationError> {
108 let mut stepper = ode_solvers::Dopri5::new(
109 self,
110 self.start_time,
111 self.end_time,
112 output_step_size,
113 self.initial_values.clone(),
114 self.rtol,
115 self.atol,
116 );
117 stepper.integrate()?;
118 Ok(stepper.into())
119 }
120}
121
122impl<Sys> ode_solvers::dop_shared::System<f32, DVector<f32>> for &ODEProblem<Sys>
123where
124 Sys: ODESystem,
125{
126 fn system(&self, x: f32, y: &DVector<f32>, dy: &mut DVector<f32>) {
127 self.system.vector_field(dy, y, x);
128 }
129}
130
131#[cfg(test)]
132pub(crate) fn textplot_ode_result<Sys>(
133 problem: &ODEProblem<Sys>,
134 result: &SolverResult<f32, DVector<f32>>,
135) -> String {
136 let mut chart = Chart::new(100, 80, 0.0, problem.end_time);
137 let (t_out, x_out) = result.get();
138
139 let dim = problem.initial_values.len();
140 let line_data: Vec<_> = (0..dim)
141 .into_iter()
142 .map(|i| t_out.iter().copied().zip(x_out.iter().map(|x| x[i])).collect::<Vec<_>>())
143 .collect();
144
145 let lines: Vec<_> = line_data.iter().map(|data| Shape::Lines(data)).into_iter().collect();
146
147 let chart = lines.iter().fold(&mut chart, |chart, line| chart.lineplot(line));
148 chart.axis();
149 chart.figures();
150 chart.to_string()
151}
152
153#[allow(non_snake_case)]
154pub mod lotka_volterra;
155pub mod polynomial;
156
157pub use lotka_volterra::*;
158pub use polynomial::*;