catlog/stdlib/analyses/stochastic/
mass_action.rs

1//! Stochastic mass action anaylsis of ODEs.
2//!
3//! These stochastic mass-action use statistical methods to apply transitions.
4
5use indexmap::IndexMap;
6use rebop::gillespie;
7use std::collections::HashMap;
8
9use crate::{
10    dbl::{modal::*, model::FgDblModel},
11    stdlib::analyses::{ode::ODESolution, petri::transition_interface},
12    zero::{QualifiedName, name},
13};
14
15#[cfg(feature = "serde")]
16use serde::{Deserialize, Serialize};
17#[cfg(feature = "serde-wasm")]
18use tsify::Tsify;
19
20/// Data defining the stochastic mass-action ODE problem.
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
23#[cfg_attr(
24    feature = "serde-wasm",
25    tsify(into_wasm_abi, from_wasm_abi, hashmap_as_object)
26)]
27pub struct StochasticMassActionProblemData {
28    /// Map from morphism IDs to rate coefficients (nonnegative reals).
29    rates: HashMap<QualifiedName, f32>,
30
31    /// Map from object IDs to initial values (nonnegative integers).
32    #[cfg_attr(feature = "serde", serde(rename = "initialValues"))]
33    pub initial_values: HashMap<QualifiedName, u32>,
34
35    /// Duration of simulation.
36    pub duration: f32,
37}
38
39/// Stochastic mass-action analysis of a model.
40pub struct StochasticMassActionAnalysis {
41    /// Reaction network for the analysis.
42    pub problem: rebop::gillespie::Gillespie,
43
44    /// Map from object IDs to variable indices.
45    pub variable_index: IndexMap<QualifiedName, usize>,
46
47    /// Map from object IDs to initial values (nonnegative integers).
48    pub initial_values: HashMap<QualifiedName, u32>,
49
50    /// Duration of simulation.
51    pub duration: f32,
52}
53
54impl StochasticMassActionAnalysis {
55    /// Simulates the stochastic mass-action system and collects the results.
56    pub fn simulate(&mut self) -> ODESolution {
57        let mut time = vec![0.0];
58        let mut states: HashMap<_, _> = self
59            .variable_index
60            .keys()
61            .map(|id| {
62                let initial = self.initial_values.get(id).copied().unwrap_or_default();
63                (id.clone(), vec![initial as f32])
64            })
65            .collect();
66        for t in 0..(self.duration as usize) {
67            self.problem.advance_until(t as f64);
68            time.push(self.problem.get_time() as f32);
69            for (id, idx) in self.variable_index.iter() {
70                states.get_mut(id).unwrap().push(self.problem.get_species(*idx) as f32)
71            }
72        }
73        ODESolution { time, states }
74    }
75}
76
77/// Stochastic mass-action analysis for Petri nets.
78pub struct PetriNetStochasticMassActionAnalysis {
79    /// Object type for places.
80    pub place_ob_type: ModalObType,
81    /// Morphism type for transitions.
82    pub transition_mor_type: ModalMorType,
83}
84
85impl Default for PetriNetStochasticMassActionAnalysis {
86    fn default() -> Self {
87        let ob_type = ModalObType::new(name("Object"));
88        Self {
89            place_ob_type: ob_type.clone(),
90            transition_mor_type: ModalMorType::Zero(ob_type),
91        }
92    }
93}
94
95impl PetriNetStochasticMassActionAnalysis {
96    /// Creates a stochastic mass-action system.
97    pub fn build_stochastic_system(
98        &self,
99        model: &ModalDblModel,
100        data: StochasticMassActionProblemData,
101    ) -> StochasticMassActionAnalysis {
102        let ob_generators: Vec<_> = model.ob_generators_with_type(&self.place_ob_type).collect();
103
104        let initial: Vec<_> = ob_generators
105            .iter()
106            .map(|id| data.initial_values.get(id).copied().unwrap_or_default() as isize)
107            .collect();
108        let mut problem = gillespie::Gillespie::new(initial, false);
109
110        for mor in model.mor_generators_with_type(&self.transition_mor_type) {
111            let (inputs, outputs) = transition_interface(model, &mor);
112
113            // 1. convert the inputs/outputs to sequences of counts
114            let input_vec = ob_generators.iter().map(|id| {
115                inputs
116                    .iter()
117                    .filter(|&ob| matches!(ob, ModalOb::Generator(id2) if id2 == id))
118                    .count() as u32
119            });
120            let output_vec = ob_generators.iter().map(|id| {
121                outputs
122                    .iter()
123                    .filter(|&ob| matches!(ob, ModalOb::Generator(id2) if id2 == id))
124                    .count() as isize
125            });
126
127            // 2. output := output - input
128            let input_vec: Vec<_> = input_vec.collect();
129            let output_vec: Vec<_> = output_vec
130                .zip(input_vec.iter().copied())
131                .map(|(o, i)| o - (i as isize))
132                .collect();
133            if let Some(rate) = data.rates.get(&mor) {
134                problem.add_reaction(gillespie::Rate::lma(*rate as f64, input_vec), output_vec)
135            }
136        }
137
138        let variable_index: IndexMap<_, _> =
139            ob_generators.into_iter().enumerate().map(|(i, x)| (x, i)).collect();
140
141        StochasticMassActionAnalysis {
142            problem,
143            variable_index,
144            initial_values: data.initial_values,
145            duration: data.duration,
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use std::rc::Rc;
153
154    use super::*;
155    use crate::stdlib::sir_petri;
156    use crate::stdlib::theories::*;
157    use crate::zero::name;
158
159    #[test]
160    fn sir_petri_stochastic_dynamics() {
161        let th = Rc::new(th_sym_monoidal_category());
162        let model = sir_petri(th);
163        let data = StochasticMassActionProblemData {
164            rates: HashMap::from_iter([(name("infect"), 1e-5f32), (name("recover"), 1e-2f32)]),
165            initial_values: HashMap::from_iter([
166                (name("S"), 1e5 as u32),
167                (name("I"), 1),
168                (name("R"), 0),
169            ]),
170            duration: 10f32,
171        };
172        let sys =
173            PetriNetStochasticMassActionAnalysis::default().build_stochastic_system(&model, data);
174        assert_eq!(2, sys.problem.nb_reactions());
175        assert_eq!(3, sys.problem.nb_species());
176    }
177}