catlog/stdlib/analyses/ode/
mass_action.rs

1//! Mass-action ODE analysis of models.
2//!
3//! Such ODEs are based on the *law of mass action* familiar from chemistry and
4//! mathematical epidemiology. Here, however, we also consider a generalised version
5//! where we do not require that mass be preserved. This allows the construction
6//! of systems of arbitrary polynomial (first-order) ODEs.
7
8use std::{collections::HashMap, fmt};
9
10use indexmap::IndexMap;
11use nalgebra::DVector;
12use num_traits::Zero;
13
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16#[cfg(feature = "serde-wasm")]
17use tsify::Tsify;
18
19use super::ODEAnalysis;
20use crate::dbl::{
21    model::{DiscreteTabModel, FgDblModel, ModalDblModel, TabEdge},
22    theory::{ModalMorType, ModalObType, TabMorType, TabObType},
23};
24use crate::one::FgCategory;
25use crate::simulate::ode::{NumericalPolynomialSystem, ODEProblem, PolynomialSystem};
26use crate::stdlib::analyses::petri::transition_interface;
27use crate::zero::{QualifiedName, alg::Polynomial, name, rig::Monomial};
28
29/// There are three types of mass-action semantics, each more expressive than the previous:
30/// - balanced
31/// - unbalanced (rates per transition)
32/// - unbalanced (rates per place)
33#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
34#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
35#[cfg_attr(feature = "serde", serde(tag = "type", content = "granularity"))]
36#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
37#[cfg_attr(feature = "serde-wasm", tsify(into_wasm_abi, from_wasm_abi))]
38pub enum MassConservationType {
39    /// Mass is conserved.
40    Balanced,
41    /// Mass is not conserved.
42    Unbalanced(RateGranularity),
43}
44
45/// When mass is not necessarily conserved, consumption/production rate parameters
46/// can be set either *per transition* or *per place*.
47#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
48#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
49#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
50#[cfg_attr(feature = "serde-wasm", tsify(into_wasm_abi, from_wasm_abi))]
51pub enum RateGranularity {
52    /// Each transition gets assigned a single consumption and single production rate.
53    PerTransition,
54
55    /// Each transition gets assigned a consumption rate for each input place and
56    /// a production rate for each output place.
57    PerPlace,
58}
59
60/// Parameters in the generated polynomial equations are *undirected* in the
61/// balanced case and *directed* in the unbalanced case.
62#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
63pub enum FlowParameter {
64    /// If mass is conserved, we don't need to worry whether a flow is incoming or outgoing.
65    Balanced {
66        /// Since there is no direction, the rate parameter corresponds to a single transition.
67        transition: QualifiedName,
68    },
69    /// If mass is not conserved, then we need to know whether a flow is incoming or outgoing.
70    Unbalanced {
71        /// The direction of the flow.
72        direction: Direction,
73        /// The structure of the rate parameter can be either per transition or per place.
74        parameter: RateParameter,
75    },
76}
77
78/// Depending on the rate granularity, the parameters are specified by different structures.
79#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
80pub enum RateParameter {
81    /// For per transition rates, we simply need to know the associated transition.
82    PerTransition {
83        /// The transition to which we associate the rate parameter.
84        transition: QualifiedName,
85    },
86
87    /// For per place rates, we need to know both the transition and the corresponding
88    /// input/output place.
89    PerPlace {
90        /// The transition whose input/output objects we wish to associate rate parameters.
91        transition: QualifiedName,
92        /// The input/output object to which we associate the rate parameter.
93        place: QualifiedName,
94    },
95}
96
97/// The associated direction of a "flow" term. Note that this is *opposite* from
98/// the terminology of "input" and "output", i.e. a flow A=>B gives rise to an
99/// *incoming flow to B* and an *outgoing flow from A*.
100#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
101pub enum Direction {
102    /// The parameter corresponds to an incoming flow to a specific output.
103    IncomingFlow,
104
105    /// The parameter corresponds to an outgoing flow to a specific input.
106    OutgoingFlow,
107}
108
109impl fmt::Display for FlowParameter {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        match &self {
112            FlowParameter::Balanced { transition: trans } => {
113                write!(f, "{}", trans)
114            }
115            FlowParameter::Unbalanced {
116                direction: Direction::IncomingFlow,
117                parameter: RateParameter::PerTransition { transition: trans },
118            } => {
119                write!(f, "Incoming({})", trans)
120            }
121            FlowParameter::Unbalanced {
122                direction: Direction::IncomingFlow,
123                parameter: RateParameter::PerPlace { transition: trans, place: output },
124            } => {
125                write!(f, "([{}]->{})", trans, output)
126            }
127            FlowParameter::Unbalanced {
128                direction: Direction::OutgoingFlow,
129                parameter: RateParameter::PerTransition { transition: trans },
130            } => {
131                write!(f, "Outgoing({})", trans)
132            }
133            FlowParameter::Unbalanced {
134                direction: Direction::OutgoingFlow,
135                parameter: RateParameter::PerPlace { transition: trans, place: input },
136            } => {
137                write!(f, "({}->[{}])", input, trans)
138            }
139        }
140    }
141}
142
143/// Data defining an unbalanced mass-action ODE problem for a model.
144#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
145#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
146#[cfg_attr(
147    feature = "serde-wasm",
148    tsify(into_wasm_abi, from_wasm_abi, hashmap_as_object)
149)]
150pub struct MassActionProblemData {
151    /// Whether or not mass is conserved.
152    #[cfg_attr(feature = "serde", serde(rename = "massConservationType"))]
153    pub mass_conservation_type: MassConservationType,
154
155    /// Map from morphism IDs to consumption rate coefficients (nonnegative reals),
156    /// for the balanced per transition case.
157    /// N.B. This is renamed to "rates" in catlog-wasm for backwards compatibility.
158    #[cfg_attr(feature = "serde", serde(rename = "rates"))]
159    transition_rates: HashMap<QualifiedName, f32>,
160
161    /// Map from morphism IDs to consumption rate coefficients (nonnegative reals),
162    /// for the unbalanced per transition case.
163    #[cfg_attr(feature = "serde", serde(rename = "transitionConsumptionRates"))]
164    transition_consumption_rates: HashMap<QualifiedName, f32>,
165
166    /// Map from morphism IDs to production rate coefficients (nonnegative reals),
167    /// for the unbalanced per transition case.
168    #[cfg_attr(feature = "serde", serde(rename = "transitionProductionRates"))]
169    transition_production_rates: HashMap<QualifiedName, f32>,
170
171    /// Map from morphism IDs to (map from input objects to consumption rate coefficients),
172    /// for the unbalanced per place case (nonnegative reals).
173    #[cfg_attr(feature = "serde", serde(rename = "placeConsumptionRates"))]
174    place_consumption_rates: HashMap<QualifiedName, HashMap<QualifiedName, f32>>,
175
176    /// Map from morphism IDs to (map from output objects to production rate coefficients),
177    /// for the unbalanced per place case (nonnegative reals).
178    #[cfg_attr(feature = "serde", serde(rename = "placeProductionRates"))]
179    place_production_rates: HashMap<QualifiedName, HashMap<QualifiedName, f32>>,
180
181    /// Map from object IDs to initial values (nonnegative reals).
182    #[cfg_attr(feature = "serde", serde(rename = "initialValues"))]
183    pub initial_values: HashMap<QualifiedName, f32>,
184
185    /// Duration of simulation.
186    pub duration: f32,
187}
188
189/// Symbolic parameter in mass-action polynomial system.
190type Parameter<Id> = Polynomial<Id, f32, i8>;
191
192/// Mass-action ODE analysis for Petri nets.
193///
194/// This struct implements the object part of the functorial semantics for reaction
195/// networks (aka, Petri nets) due to [Baez & Pollard](crate::refs::ReactionNets).
196pub struct PetriNetMassActionAnalysis {
197    /// Object type for places.
198    pub place_ob_type: ModalObType,
199    /// Morphism type for transitions.
200    pub transition_mor_type: ModalMorType,
201}
202
203impl Default for PetriNetMassActionAnalysis {
204    fn default() -> Self {
205        let ob_type = ModalObType::new(name("Object"));
206        Self {
207            place_ob_type: ob_type.clone(),
208            transition_mor_type: ModalMorType::Zero(ob_type),
209        }
210    }
211}
212
213impl PetriNetMassActionAnalysis {
214    /// Creates a mass-action system with symbolic rate coefficients.
215    pub fn build_system(
216        &self,
217        model: &ModalDblModel,
218        mass_conservation_type: MassConservationType,
219    ) -> PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8> {
220        let mut sys = PolynomialSystem::new();
221        for ob in model.ob_generators_with_type(&self.place_ob_type) {
222            sys.add_term(ob, Polynomial::zero());
223        }
224        for mor in model.mor_generators_with_type(&self.transition_mor_type) {
225            let (inputs, outputs) = transition_interface(model, &mor);
226            let term: Monomial<_, _> =
227                inputs.iter().map(|ob| (ob.clone().unwrap_generator(), 1)).collect();
228
229            match mass_conservation_type {
230                MassConservationType::Balanced => {
231                    let term: Polynomial<_, _, _> = [(
232                        Parameter::generator(FlowParameter::Balanced { transition: mor }),
233                        term.clone(),
234                    )]
235                    .into_iter()
236                    .collect();
237
238                    for input in inputs {
239                        sys.add_term(input.unwrap_generator(), -term.clone());
240                    }
241
242                    for output in outputs {
243                        sys.add_term(output.unwrap_generator(), term.clone());
244                    }
245                }
246
247                MassConservationType::Unbalanced(granularity) => {
248                    for input in inputs {
249                        let input_term: Polynomial<_, _, _> = match granularity {
250                            RateGranularity::PerTransition => [(
251                                Parameter::generator(FlowParameter::Unbalanced {
252                                    direction: Direction::OutgoingFlow,
253                                    parameter: RateParameter::PerTransition {
254                                        transition: mor.clone(),
255                                    },
256                                }),
257                                term.clone(),
258                            )],
259                            RateGranularity::PerPlace => [(
260                                Parameter::generator(FlowParameter::Unbalanced {
261                                    direction: Direction::OutgoingFlow,
262                                    parameter: RateParameter::PerPlace {
263                                        transition: mor.clone(),
264                                        place: input.clone().unwrap_generator(),
265                                    },
266                                }),
267                                term.clone(),
268                            )],
269                        }
270                        .into_iter()
271                        .collect();
272
273                        sys.add_term(input.unwrap_generator(), -input_term.clone());
274                    }
275                    for output in outputs {
276                        let output_term: Polynomial<_, _, _> = match granularity {
277                            RateGranularity::PerTransition => [(
278                                Parameter::generator(FlowParameter::Unbalanced {
279                                    direction: Direction::IncomingFlow,
280                                    parameter: RateParameter::PerTransition {
281                                        transition: mor.clone(),
282                                    },
283                                }),
284                                term.clone(),
285                            )],
286                            RateGranularity::PerPlace => [(
287                                Parameter::generator(FlowParameter::Unbalanced {
288                                    direction: Direction::IncomingFlow,
289                                    parameter: RateParameter::PerPlace {
290                                        transition: mor.clone(),
291                                        place: output.clone().unwrap_generator(),
292                                    },
293                                }),
294                                term.clone(),
295                            )],
296                        }
297                        .into_iter()
298                        .collect();
299
300                        sys.add_term(output.unwrap_generator(), output_term.clone());
301                    }
302                }
303            }
304        }
305
306        sys.normalize()
307    }
308}
309
310/// Mass-action ODE analysis for stock-flow models.
311pub struct StockFlowMassActionAnalysis {
312    /// Object type for stocks.
313    pub stock_ob_type: TabObType,
314    /// Morphism type for flows between stocks.
315    pub flow_mor_type: TabMorType,
316    /// Morphism type for positive links from stocks to flows.
317    pub pos_link_mor_type: TabMorType,
318    /// Morphism type for negative links from stocks to flows.
319    pub neg_link_mor_type: TabMorType,
320}
321
322impl Default for StockFlowMassActionAnalysis {
323    fn default() -> Self {
324        let stock_ob_type = TabObType::Basic(name("Object"));
325        let flow_mor_type = TabMorType::Hom(Box::new(stock_ob_type.clone()));
326        Self {
327            stock_ob_type,
328            flow_mor_type,
329            pos_link_mor_type: TabMorType::Basic(name("Link")),
330            neg_link_mor_type: TabMorType::Basic(name("NegativeLink")),
331        }
332    }
333}
334
335impl StockFlowMassActionAnalysis {
336    /// Creates a mass-action system with symbolic rate coefficients.
337    pub fn build_system(
338        &self,
339        model: &DiscreteTabModel,
340        mass_conservation_type: MassConservationType,
341    ) -> PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8> {
342        let terms: Vec<_> = self.flow_monomials(model).into_iter().collect();
343
344        let mut sys = PolynomialSystem::new();
345        for ob in model.ob_generators_with_type(&self.stock_ob_type) {
346            sys.add_term(ob, Polynomial::zero());
347        }
348        for (flow, term) in terms {
349            let dom = model.mor_generator_dom(&flow).unwrap_basic();
350            let cod = model.mor_generator_cod(&flow).unwrap_basic();
351            match mass_conservation_type {
352                MassConservationType::Balanced => {
353                    let param = Parameter::generator(FlowParameter::Balanced { transition: flow });
354                    let term: Polynomial<_, _, _> = [(param, term.clone())].into_iter().collect();
355                    sys.add_term(dom, -term.clone());
356                    sys.add_term(cod, term);
357                }
358                MassConservationType::Unbalanced(_) => {
359                    let dom_param = Parameter::generator(FlowParameter::Unbalanced {
360                        direction: Direction::OutgoingFlow,
361                        parameter: RateParameter::PerTransition { transition: flow.clone() },
362                    });
363                    let cod_param = Parameter::generator(FlowParameter::Unbalanced {
364                        direction: Direction::IncomingFlow,
365                        parameter: RateParameter::PerTransition { transition: flow },
366                    });
367                    let dom_term: Polynomial<_, _, _> =
368                        [(dom_param, term.clone())].into_iter().collect();
369                    let cod_term: Polynomial<_, _, _> = [(cod_param, term)].into_iter().collect();
370                    sys.add_term(dom, -dom_term);
371                    sys.add_term(cod, cod_term);
372                }
373            }
374        }
375        sys
376    }
377
378    /// Constructs a monomial for each flow in the model.
379    pub(super) fn flow_monomials(
380        &self,
381        model: &DiscreteTabModel,
382    ) -> HashMap<QualifiedName, Monomial<QualifiedName, i8>> {
383        let mut terms: HashMap<_, _> = model
384            .mor_generators_with_type(&self.flow_mor_type)
385            .map(|flow| {
386                let dom = model.mor_generator_dom(&flow).unwrap_basic();
387                (flow, Monomial::generator(dom))
388            })
389            .collect();
390
391        let mut multiply_for_link = |link: QualifiedName, exponent: i8| {
392            let dom = model.mor_generator_dom(&link).unwrap_basic();
393            let path = model.mor_generator_cod(&link).unwrap_tabulated();
394            let Some(TabEdge::Basic(cod)) = path.only() else {
395                panic!("Codomain of link should be basic morphism");
396            };
397            if let Some(term) = terms.get_mut(&cod) {
398                let mon: Monomial<_, i8> = [(dom, exponent)].into_iter().collect();
399                *term = std::mem::take(term) * mon;
400            } else {
401                panic!("Codomain of link does not belong to model");
402            };
403        };
404
405        for link in model.mor_generators_with_type(&self.pos_link_mor_type) {
406            multiply_for_link(link, 1);
407        }
408        for link in model.mor_generators_with_type(&self.neg_link_mor_type) {
409            multiply_for_link(link, -1);
410        }
411
412        terms
413    }
414}
415
416/// Substitutes numerical rate coefficients into a symbolic mass-action system.
417pub fn extend_mass_action_scalars(
418    sys: PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8>,
419    data: &MassActionProblemData,
420) -> PolynomialSystem<QualifiedName, f32, i8> {
421    let sys = sys.extend_scalars(|poly| {
422        poly.eval(|flow| match flow {
423            FlowParameter::Balanced { transition } => {
424                data.transition_rates.get(transition).cloned().unwrap_or_default()
425            }
426            FlowParameter::Unbalanced { direction, parameter } => match (direction, parameter) {
427                (Direction::IncomingFlow, RateParameter::PerTransition { transition }) => {
428                    data.transition_production_rates.get(transition).cloned().unwrap_or_default()
429                }
430                (Direction::OutgoingFlow, RateParameter::PerTransition { transition }) => {
431                    data.transition_consumption_rates.get(transition).cloned().unwrap_or_default()
432                }
433                (Direction::IncomingFlow, RateParameter::PerPlace { transition, place }) => data
434                    .place_production_rates
435                    .get(transition)
436                    .and_then(|rate| rate.get(place))
437                    .copied()
438                    .unwrap_or_default(),
439                (Direction::OutgoingFlow, RateParameter::PerPlace { transition, place }) => data
440                    .place_consumption_rates
441                    .get(transition)
442                    .and_then(|rate| rate.get(place))
443                    .copied()
444                    .unwrap_or_default(),
445            },
446        })
447    });
448
449    sys.normalize()
450}
451
452/// Builds the numerical ODE analysis for a mass-action system whose scalars have been substituted.
453pub fn into_mass_action_analysis(
454    sys: PolynomialSystem<QualifiedName, f32, i8>,
455    data: MassActionProblemData,
456) -> ODEAnalysis<NumericalPolynomialSystem<i8>> {
457    let ob_index: IndexMap<_, _> =
458        sys.components.keys().cloned().enumerate().map(|(i, x)| (x, i)).collect();
459    let n = ob_index.len();
460
461    let initial_values = ob_index
462        .keys()
463        .map(|ob| data.initial_values.get(ob).copied().unwrap_or_default());
464    let x0 = DVector::from_iterator(n, initial_values);
465
466    let num_sys = sys.to_numerical();
467    let problem = ODEProblem::new(num_sys, x0).end_time(data.duration);
468
469    ODEAnalysis::new(problem, ob_index)
470}
471
472#[cfg(test)]
473mod tests {
474    use expect_test::expect;
475    use std::rc::Rc;
476
477    use super::*;
478    use crate::simulate::ode::LatexEquation;
479    use crate::stdlib::{analyses, models::*, theories::*};
480
481    // Tests for stock-flow diagrams. These all use the backward_link() model,
482    // which has a single flow x==f==>y and a single link y->f.
483
484    #[test]
485    fn balanced_stock_flow() {
486        let th = Rc::new(th_category_links());
487        let model = backward_link(th);
488        let sys = StockFlowMassActionAnalysis::default()
489            .build_system(&model, analyses::ode::MassConservationType::Balanced);
490        let expected = expect!([r#"
491            dx = (-f) x y
492            dy = f x y
493        "#]);
494        expected.assert_eq(&sys.to_string());
495    }
496
497    #[test]
498    fn unbalanced_stock_flow() {
499        let th = Rc::new(th_category_links());
500        let model = backward_link(th);
501        let sys = StockFlowMassActionAnalysis::default().build_system(
502            &model,
503            analyses::ode::MassConservationType::Unbalanced(
504                analyses::ode::RateGranularity::PerTransition,
505            ),
506        );
507        let expected = expect!([r#"
508            dx = (-Outgoing(f)) x y
509            dy = (Incoming(f)) x y
510        "#]);
511        expected.assert_eq(&sys.to_string());
512    }
513
514    // Tests for signed stock-flow diagrams. These all use the negative_backwards_link()
515    // model, which has a single flow x==f=>y and a single negative link y->f.
516
517    #[test]
518    fn balanced_signed_stock_flow() {
519        let th = Rc::new(th_category_signed_links());
520        let model = negative_backward_link(th);
521        let sys = StockFlowMassActionAnalysis::default()
522            .build_system(&model, analyses::ode::MassConservationType::Balanced);
523        let expected = expect!([r#"
524            dx = (-f) x y^{-1}
525            dy = f x y^{-1}
526        "#]);
527        expected.assert_eq(&sys.to_string());
528    }
529
530    #[test]
531    fn unbalanced_signed_stock_flow() {
532        let th = Rc::new(th_category_signed_links());
533        let model = negative_backward_link(th);
534        let sys = StockFlowMassActionAnalysis::default().build_system(
535            &model,
536            analyses::ode::MassConservationType::Unbalanced(
537                analyses::ode::RateGranularity::PerTransition,
538            ),
539        );
540        let expected = expect!([r#"
541            dx = (-Outgoing(f)) x y^{-1}
542            dy = (Incoming(f)) x y^{-1}
543        "#]);
544        expected.assert_eq(&sys.to_string());
545    }
546
547    // Tests for Petri nets. These all use the catalyzed_reaction() model, which
548    // has a single transition [x,c]-->f-->[y,c].
549
550    #[test]
551    fn balanced_petri() {
552        let th = Rc::new(th_sym_monoidal_category());
553        let model = catalyzed_reaction(th);
554        let sys = PetriNetMassActionAnalysis::default()
555            .build_system(&model, analyses::ode::MassConservationType::Balanced);
556        let expected = expect!([r#"
557            dx = (-f) c x
558            dy = f c x
559            dc = 0
560        "#]);
561        expected.assert_eq(&sys.to_string());
562    }
563
564    #[test]
565    fn unbalanced_petri_per_transition() {
566        let th = Rc::new(th_sym_monoidal_category());
567        let model = catalyzed_reaction(th);
568        let sys = PetriNetMassActionAnalysis::default().build_system(
569            &model,
570            analyses::ode::MassConservationType::Unbalanced(
571                analyses::ode::RateGranularity::PerTransition,
572            ),
573        );
574        let expected = expect!([r#"
575            dx = (-Outgoing(f)) c x
576            dy = (Incoming(f)) c x
577            dc = (Incoming(f) + -Outgoing(f)) c x
578        "#]);
579        expected.assert_eq(&sys.to_string());
580    }
581
582    #[test]
583    fn unbalanced_petri_per_place() {
584        let th = Rc::new(th_sym_monoidal_category());
585        let model = catalyzed_reaction(th);
586        let sys = PetriNetMassActionAnalysis::default().build_system(
587            &model,
588            analyses::ode::MassConservationType::Unbalanced(
589                analyses::ode::RateGranularity::PerPlace,
590            ),
591        );
592        let expected = expect!([r#"
593            dx = (-(x->[f])) c x
594            dy = (([f]->y)) c x
595            dc = (([f]->c) + -(c->[f])) c x
596        "#]);
597        expected.assert_eq(&sys.to_string());
598    }
599
600    // Test for LaTeX.
601
602    #[test]
603    fn to_latex() {
604        let th = Rc::new(th_category_links());
605        let model = backward_link(th);
606        let sys = StockFlowMassActionAnalysis::default().build_system(
607            &model,
608            analyses::ode::MassConservationType::Unbalanced(
609                analyses::ode::RateGranularity::PerTransition,
610            ),
611        );
612        let expected = vec![
613            LatexEquation {
614                lhs: "\\frac{\\mathrm{d}}{\\mathrm{d}t} x".to_string(),
615                rhs: "(-Outgoing(f)) x y".to_string(),
616            },
617            LatexEquation {
618                lhs: "\\frac{\\mathrm{d}}{\\mathrm{d}t} y".to_string(),
619                rhs: "(Incoming(f)) x y".to_string(),
620            },
621        ];
622        assert_eq!(expected, sys.to_latex_equations());
623    }
624}