catlog/simulate/ode/
mod.rs

1//! Simulation of dynamical systems defined by ODEs.
2
3use nalgebra::DVector;
4use ode_solvers::{
5    self,
6    dop_shared::{IntegrationError, SolverResult},
7};
8
9#[cfg(test)]
10use textplots::{Chart, Plot, Shape};
11
12/** A system of ordinary differential equations (ODEs).
13
14An ODE system is anything that can compute a vector field.
15 */
16pub trait ODESystem {
17    /// Compute the vector field at the given time and state in place.
18    fn vector_field(&self, dx: &mut DVector<f32>, x: &DVector<f32>, t: f32);
19
20    /// Compute and return the vector field at the given time and state.
21    fn eval_vector_field(&self, x: &DVector<f32>, t: f32) -> DVector<f32> {
22        let mut dx = DVector::from_element(x.len(), 0.0f32);
23        self.vector_field(&mut dx, x, t);
24        dx
25    }
26}
27
28/** An ODE problem ready to be solved.
29
30An ODE problem comprises an [ODE system](ODESystem) plus the extra information
31needed to solve the system, namely the initial values and the time span.
32 */
33#[derive(Clone, Debug, PartialEq)]
34pub struct ODEProblem<Sys> {
35    pub(crate) system: Sys,
36    pub(crate) initial_values: DVector<f32>,
37    pub(crate) start_time: f32,
38    pub(crate) end_time: f32,
39    rtol: f32,
40    atol: f32,
41}
42
43impl<Sys> ODEProblem<Sys> {
44    /// Creates a new ODE problem.
45    pub fn new(system: Sys, initial_values: DVector<f32>) -> Self {
46        ODEProblem {
47            system,
48            initial_values,
49            start_time: 0.0,
50            end_time: 0.0,
51            // Same defaults as `scipy.integrate.RK45`.
52            rtol: 0.001,
53            atol: 1e-6,
54        }
55    }
56
57    /// Sets the start time for the problem.
58    pub fn start_time(mut self, t: f32) -> Self {
59        self.start_time = t;
60        self
61    }
62
63    /// Sets the end time for the problem.
64    pub fn end_time(mut self, t: f32) -> Self {
65        self.end_time = t;
66        self
67    }
68
69    /// Sets the time span (start and end time) for the problem.
70    pub fn time_span(mut self, tspan: (f32, f32)) -> Self {
71        (self.start_time, self.end_time) = tspan;
72        self
73    }
74}
75
76impl<Sys> ODEProblem<Sys>
77where
78    Sys: ODESystem,
79{
80    /** Solves the ODE system using the Runge-Kutta method.
81
82    Returns the solver results if successful and an integration error otherwise.
83     */
84    pub fn solve_rk4(
85        &self,
86        step_size: f32,
87    ) -> Result<SolverResult<f32, DVector<f32>>, IntegrationError> {
88        let mut stepper = ode_solvers::Rk4::new(
89            self,
90            self.start_time,
91            self.initial_values.clone(),
92            self.end_time,
93            step_size,
94        );
95        stepper.integrate()?;
96        Ok(stepper.into())
97    }
98
99    /** Solves the ODE system using the Dormand-Prince method.
100
101    A variant of Runge-Kutta with adaptive step size control and automatic
102    selection of initial step size.
103    */
104    pub fn solve_dopri5(
105        &self,
106        output_step_size: f32,
107    ) -> Result<SolverResult<f32, DVector<f32>>, IntegrationError> {
108        let mut stepper = ode_solvers::Dopri5::new(
109            self,
110            self.start_time,
111            self.end_time,
112            output_step_size,
113            self.initial_values.clone(),
114            self.rtol,
115            self.atol,
116        );
117        stepper.integrate()?;
118        Ok(stepper.into())
119    }
120}
121
122impl<Sys> ode_solvers::dop_shared::System<f32, DVector<f32>> for &ODEProblem<Sys>
123where
124    Sys: ODESystem,
125{
126    fn system(&self, x: f32, y: &DVector<f32>, dy: &mut DVector<f32>) {
127        self.system.vector_field(dy, y, x);
128    }
129}
130
131#[cfg(test)]
132pub(crate) fn textplot_ode_result<Sys>(
133    problem: &ODEProblem<Sys>,
134    result: &SolverResult<f32, DVector<f32>>,
135) -> String {
136    let mut chart = Chart::new(100, 80, 0.0, problem.end_time);
137    let (t_out, x_out) = result.get();
138
139    let dim = problem.initial_values.len();
140    let line_data: Vec<_> = (0..dim)
141        .into_iter()
142        .map(|i| t_out.iter().copied().zip(x_out.iter().map(|x| x[i])).collect::<Vec<_>>())
143        .collect();
144
145    let lines: Vec<_> = line_data.iter().map(|data| Shape::Lines(data)).into_iter().collect();
146
147    let chart = lines.iter().fold(&mut chart, |chart, line| chart.lineplot(line));
148    chart.axis();
149    chart.figures();
150    chart.to_string()
151}
152
153#[allow(non_snake_case)]
154pub mod lotka_volterra;
155pub mod polynomial;
156
157pub use lotka_volterra::*;
158pub use polynomial::*;