catlog/stdlib/analyses/ode/
mod.rs
1use std::{collections::HashMap, hash::Hash};
4
5use derivative::Derivative;
6use ode_solvers::dop_shared::IntegrationError;
7
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10#[cfg(feature = "serde-wasm")]
11use tsify_next::Tsify;
12
13use crate::simulate::ode::{ODEProblem, ODESystem};
14
15#[derive(Clone, Derivative)]
17#[derivative(Default(bound = ""))]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
20#[cfg_attr(feature = "serde-wasm", tsify(into_wasm_abi, from_wasm_abi))]
21pub struct ODESolution<Id>
22where
23 Id: Eq + Hash,
24{
25 time: Vec<f32>,
27
28 states: HashMap<Id, Vec<f32>>,
30}
31
32pub struct ODEAnalysis<Id, Sys> {
34 pub problem: ODEProblem<Sys>,
36
37 pub variable_index: HashMap<Id, usize>,
39}
40
41impl<Id, Sys> ODEAnalysis<Id, Sys> {
42 pub fn new(problem: ODEProblem<Sys>, variable_index: HashMap<Id, usize>) -> Self {
44 Self {
45 problem,
46 variable_index,
47 }
48 }
49
50 pub fn solve_with_defaults(self) -> Result<ODESolution<Id>, IntegrationError>
52 where
53 Id: Eq + Hash,
54 Sys: ODESystem,
55 {
56 if self.variable_index.is_empty() {
58 return Ok(Default::default());
59 }
60
61 let duration = self.problem.end_time - self.problem.start_time;
62 let output_step_size = (duration / 100.0).min(0.01f32);
63 let result = self.problem.solve_dopri5(output_step_size)?;
64
65 let (t_out, x_out) = result.get();
66 Ok(ODESolution {
67 time: t_out.clone(),
68 states: self
69 .variable_index
70 .into_iter()
71 .map(|(ob, i)| (ob, x_out.iter().map(|x| x[i]).collect()))
72 .collect(),
73 })
74 }
75}
76
77#[allow(non_snake_case)]
78pub mod lotka_volterra;
79#[allow(clippy::type_complexity)]
80pub mod mass_action;
81
82pub use lotka_volterra::*;
83pub use mass_action::*;