catlog/stdlib/analyses/ode/
mass_action.rs

1//! Mass-action ODE analysis of models.
2//!
3//! Such ODEs are based on the *law of mass action* familiar from chemistry and
4//! mathematical epidemiology. Here, however, we also consider a generalised version
5//! where we do not require that mass be preserved. This allows the construction
6//! of systems of arbitrary polynomial (first-order) ODEs.
7
8use std::{collections::HashMap, fmt};
9
10use indexmap::IndexMap;
11use nalgebra::DVector;
12use num_traits::Zero;
13
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16#[cfg(feature = "serde-wasm")]
17use tsify::Tsify;
18
19use super::{ODEAnalysis, Parameter};
20use crate::dbl::{
21    model::{DiscreteTabModel, FpDblModel, ModalDblModel, TabEdge},
22    theory::{ModalMorType, ModalObType, TabMorType, TabObType, Unital},
23};
24use crate::one::FgCategory;
25use crate::simulate::ode::{NumericalPolynomialSystem, ODEProblem, PolynomialSystem};
26use crate::stdlib::analyses::petri::transition_interface;
27use crate::zero::{QualifiedName, alg::Polynomial, name, rig::Monomial};
28
29/// There are three types of mass-action semantics, each more expressive than the previous:
30/// - balanced
31/// - unbalanced (rates per transition)
32/// - unbalanced (rates per place)
33#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
34#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
35#[cfg_attr(feature = "serde", serde(tag = "type", content = "granularity"))]
36#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
37#[cfg_attr(feature = "serde-wasm", tsify(into_wasm_abi, from_wasm_abi))]
38pub enum MassConservationType {
39    /// Mass is conserved.
40    Balanced,
41    /// Mass is not conserved.
42    Unbalanced(RateGranularity),
43}
44
45/// When mass is not necessarily conserved, consumption/production rate parameters
46/// can be set either *per transition* or *per place*.
47#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
48#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
49#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
50#[cfg_attr(feature = "serde-wasm", tsify(into_wasm_abi, from_wasm_abi))]
51pub enum RateGranularity {
52    /// Each transition gets assigned a single consumption and single production rate.
53    PerTransition,
54
55    /// Each transition gets assigned a consumption rate for each input place and
56    /// a production rate for each output place.
57    PerPlace,
58}
59
60/// Parameters in the generated polynomial equations are *undirected* in the
61/// balanced case and *directed* in the unbalanced case.
62#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
63pub enum FlowParameter {
64    /// If mass is conserved, we don't need to worry whether a flow is incoming or outgoing.
65    Balanced {
66        /// Since there is no direction, the rate parameter corresponds to a single transition.
67        transition: QualifiedName,
68    },
69    /// If mass is not conserved, then we need to know whether a flow is incoming or outgoing.
70    Unbalanced {
71        /// The direction of the flow.
72        direction: Direction,
73        /// The structure of the rate parameter can be either per transition or per place.
74        parameter: RateParameter,
75    },
76}
77
78/// Depending on the rate granularity, the parameters are specified by different structures.
79#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
80pub enum RateParameter {
81    /// For per transition rates, we simply need to know the associated transition.
82    PerTransition {
83        /// The transition to which we associate the rate parameter.
84        transition: QualifiedName,
85    },
86
87    /// For per place rates, we need to know both the transition and the corresponding
88    /// input/output place.
89    PerPlace {
90        /// The transition whose input/output objects we wish to associate rate parameters.
91        transition: QualifiedName,
92        /// The input/output object to which we associate the rate parameter.
93        place: QualifiedName,
94    },
95}
96
97/// The associated direction of a "flow" term. Note that this is *opposite* from
98/// the terminology of "input" and "output", i.e. a flow A=>B gives rise to an
99/// *incoming flow to B* and an *outgoing flow from A*.
100#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
101pub enum Direction {
102    /// The parameter corresponds to an incoming flow to a specific output.
103    IncomingFlow,
104
105    /// The parameter corresponds to an outgoing flow to a specific input.
106    OutgoingFlow,
107}
108
109impl fmt::Display for FlowParameter {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        match &self {
112            FlowParameter::Balanced { transition: trans } => {
113                write!(f, "{}", trans)
114            }
115            FlowParameter::Unbalanced {
116                direction: Direction::IncomingFlow,
117                parameter: RateParameter::PerTransition { transition: trans },
118            } => {
119                write!(f, "Incoming({})", trans)
120            }
121            FlowParameter::Unbalanced {
122                direction: Direction::IncomingFlow,
123                parameter: RateParameter::PerPlace { transition: trans, place: output },
124            } => {
125                write!(f, "([{}]->{})", trans, output)
126            }
127            FlowParameter::Unbalanced {
128                direction: Direction::OutgoingFlow,
129                parameter: RateParameter::PerTransition { transition: trans },
130            } => {
131                write!(f, "Outgoing({})", trans)
132            }
133            FlowParameter::Unbalanced {
134                direction: Direction::OutgoingFlow,
135                parameter: RateParameter::PerPlace { transition: trans, place: input },
136            } => {
137                write!(f, "({}->[{}])", input, trans)
138            }
139        }
140    }
141}
142
143/// Data defining an unbalanced mass-action ODE problem for a model.
144#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
145#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
146#[cfg_attr(
147    feature = "serde-wasm",
148    tsify(into_wasm_abi, from_wasm_abi, hashmap_as_object)
149)]
150pub struct MassActionProblemData {
151    /// Whether or not mass is conserved.
152    #[cfg_attr(feature = "serde", serde(rename = "massConservationType"))]
153    pub mass_conservation_type: MassConservationType,
154
155    /// Map from morphism IDs to consumption rate coefficients (nonnegative reals),
156    /// for the balanced per transition case.
157    /// N.B. This is renamed to "rates" in catlog-wasm for backwards compatibility.
158    #[cfg_attr(feature = "serde", serde(rename = "rates"))]
159    transition_rates: HashMap<QualifiedName, f32>,
160
161    /// Map from morphism IDs to consumption rate coefficients (nonnegative reals),
162    /// for the unbalanced per transition case.
163    #[cfg_attr(feature = "serde", serde(rename = "transitionConsumptionRates"))]
164    transition_consumption_rates: HashMap<QualifiedName, f32>,
165
166    /// Map from morphism IDs to production rate coefficients (nonnegative reals),
167    /// for the unbalanced per transition case.
168    #[cfg_attr(feature = "serde", serde(rename = "transitionProductionRates"))]
169    transition_production_rates: HashMap<QualifiedName, f32>,
170
171    /// Map from morphism IDs to (map from input objects to consumption rate coefficients),
172    /// for the unbalanced per place case (nonnegative reals).
173    #[cfg_attr(feature = "serde", serde(rename = "placeConsumptionRates"))]
174    place_consumption_rates: HashMap<QualifiedName, HashMap<QualifiedName, f32>>,
175
176    /// Map from morphism IDs to (map from output objects to production rate coefficients),
177    /// for the unbalanced per place case (nonnegative reals).
178    #[cfg_attr(feature = "serde", serde(rename = "placeProductionRates"))]
179    place_production_rates: HashMap<QualifiedName, HashMap<QualifiedName, f32>>,
180
181    /// Map from object IDs to initial values (nonnegative reals).
182    #[cfg_attr(feature = "serde", serde(rename = "initialValues"))]
183    pub initial_values: HashMap<QualifiedName, f32>,
184
185    /// Duration of simulation.
186    pub duration: f32,
187}
188
189/// Mass-action ODE analysis for Petri nets.
190///
191/// This struct implements the object part of the functorial semantics for reaction
192/// networks (aka, Petri nets) due to [Baez & Pollard](crate::refs::ReactionNets).
193pub struct PetriNetMassActionAnalysis {
194    /// Object type for places.
195    pub place_ob_type: ModalObType,
196    /// Morphism type for transitions.
197    pub transition_mor_type: ModalMorType,
198}
199
200impl Default for PetriNetMassActionAnalysis {
201    fn default() -> Self {
202        let ob_type = ModalObType::new(name("Object"));
203        Self {
204            place_ob_type: ob_type.clone(),
205            transition_mor_type: ModalMorType::Zero(ob_type),
206        }
207    }
208}
209
210impl PetriNetMassActionAnalysis {
211    /// Creates a mass-action system with symbolic rate coefficients.
212    pub fn build_system(
213        &self,
214        model: &ModalDblModel<Unital>,
215        mass_conservation_type: MassConservationType,
216    ) -> PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8> {
217        let mut sys = PolynomialSystem::new();
218        for ob in model.ob_generators_with_type(&self.place_ob_type) {
219            sys.add_term(ob, Polynomial::zero());
220        }
221        for mor in model.mor_generators_with_type(&self.transition_mor_type) {
222            let (inputs, outputs) = transition_interface(model, &mor);
223            let term: Monomial<_, _> =
224                inputs.iter().map(|ob| (ob.clone().unwrap_generator(), 1)).collect();
225
226            match mass_conservation_type {
227                MassConservationType::Balanced => {
228                    let term: Polynomial<_, _, _> = [(
229                        Parameter::generator(FlowParameter::Balanced { transition: mor }),
230                        term.clone(),
231                    )]
232                    .into_iter()
233                    .collect();
234
235                    for input in inputs {
236                        sys.add_term(input.unwrap_generator(), -term.clone());
237                    }
238
239                    for output in outputs {
240                        sys.add_term(output.unwrap_generator(), term.clone());
241                    }
242                }
243
244                MassConservationType::Unbalanced(granularity) => {
245                    for input in inputs {
246                        let input_term: Polynomial<_, _, _> = match granularity {
247                            RateGranularity::PerTransition => [(
248                                Parameter::generator(FlowParameter::Unbalanced {
249                                    direction: Direction::OutgoingFlow,
250                                    parameter: RateParameter::PerTransition {
251                                        transition: mor.clone(),
252                                    },
253                                }),
254                                term.clone(),
255                            )],
256                            RateGranularity::PerPlace => [(
257                                Parameter::generator(FlowParameter::Unbalanced {
258                                    direction: Direction::OutgoingFlow,
259                                    parameter: RateParameter::PerPlace {
260                                        transition: mor.clone(),
261                                        place: input.clone().unwrap_generator(),
262                                    },
263                                }),
264                                term.clone(),
265                            )],
266                        }
267                        .into_iter()
268                        .collect();
269
270                        sys.add_term(input.unwrap_generator(), -input_term.clone());
271                    }
272                    for output in outputs {
273                        let output_term: Polynomial<_, _, _> = match granularity {
274                            RateGranularity::PerTransition => [(
275                                Parameter::generator(FlowParameter::Unbalanced {
276                                    direction: Direction::IncomingFlow,
277                                    parameter: RateParameter::PerTransition {
278                                        transition: mor.clone(),
279                                    },
280                                }),
281                                term.clone(),
282                            )],
283                            RateGranularity::PerPlace => [(
284                                Parameter::generator(FlowParameter::Unbalanced {
285                                    direction: Direction::IncomingFlow,
286                                    parameter: RateParameter::PerPlace {
287                                        transition: mor.clone(),
288                                        place: output.clone().unwrap_generator(),
289                                    },
290                                }),
291                                term.clone(),
292                            )],
293                        }
294                        .into_iter()
295                        .collect();
296
297                        sys.add_term(output.unwrap_generator(), output_term.clone());
298                    }
299                }
300            }
301        }
302
303        sys.normalize()
304    }
305}
306
307/// Mass-action ODE analysis for stock-flow models.
308pub struct StockFlowMassActionAnalysis {
309    /// Object type for stocks.
310    pub stock_ob_type: TabObType,
311    /// Morphism type for flows between stocks.
312    pub flow_mor_type: TabMorType,
313    /// Morphism type for positive links from stocks to flows.
314    pub pos_link_mor_type: TabMorType,
315    /// Morphism type for negative links from stocks to flows.
316    pub neg_link_mor_type: TabMorType,
317}
318
319impl Default for StockFlowMassActionAnalysis {
320    fn default() -> Self {
321        let stock_ob_type = TabObType::Basic(name("Object"));
322        let flow_mor_type = TabMorType::Hom(Box::new(stock_ob_type.clone()));
323        Self {
324            stock_ob_type,
325            flow_mor_type,
326            pos_link_mor_type: TabMorType::Basic(name("Link")),
327            neg_link_mor_type: TabMorType::Basic(name("NegativeLink")),
328        }
329    }
330}
331
332impl StockFlowMassActionAnalysis {
333    /// Creates a mass-action system with symbolic rate coefficients.
334    pub fn build_system(
335        &self,
336        model: &DiscreteTabModel,
337        mass_conservation_type: MassConservationType,
338    ) -> PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8> {
339        let terms: Vec<_> = self.flow_monomials(model).into_iter().collect();
340
341        let mut sys = PolynomialSystem::new();
342        for ob in model.ob_generators_with_type(&self.stock_ob_type) {
343            sys.add_term(ob, Polynomial::zero());
344        }
345        for (flow, term) in terms {
346            let dom = model.mor_generator_dom(&flow).unwrap_basic();
347            let cod = model.mor_generator_cod(&flow).unwrap_basic();
348            match mass_conservation_type {
349                MassConservationType::Balanced => {
350                    let param = Parameter::generator(FlowParameter::Balanced { transition: flow });
351                    let term: Polynomial<_, _, _> = [(param, term.clone())].into_iter().collect();
352                    sys.add_term(dom, -term.clone());
353                    sys.add_term(cod, term);
354                }
355                MassConservationType::Unbalanced(_) => {
356                    let dom_param = Parameter::generator(FlowParameter::Unbalanced {
357                        direction: Direction::OutgoingFlow,
358                        parameter: RateParameter::PerTransition { transition: flow.clone() },
359                    });
360                    let cod_param = Parameter::generator(FlowParameter::Unbalanced {
361                        direction: Direction::IncomingFlow,
362                        parameter: RateParameter::PerTransition { transition: flow },
363                    });
364                    let dom_term: Polynomial<_, _, _> =
365                        [(dom_param, term.clone())].into_iter().collect();
366                    let cod_term: Polynomial<_, _, _> = [(cod_param, term)].into_iter().collect();
367                    sys.add_term(dom, -dom_term);
368                    sys.add_term(cod, cod_term);
369                }
370            }
371        }
372        sys
373    }
374
375    /// Constructs a monomial for each flow in the model.
376    pub(super) fn flow_monomials(
377        &self,
378        model: &DiscreteTabModel,
379    ) -> HashMap<QualifiedName, Monomial<QualifiedName, i8>> {
380        let mut terms: HashMap<_, _> = model
381            .mor_generators_with_type(&self.flow_mor_type)
382            .map(|flow| {
383                let dom = model.mor_generator_dom(&flow).unwrap_basic();
384                (flow, Monomial::generator(dom))
385            })
386            .collect();
387
388        let mut multiply_for_link = |link: QualifiedName, exponent: i8| {
389            let dom = model.mor_generator_dom(&link).unwrap_basic();
390            let path = model.mor_generator_cod(&link).unwrap_tabulated();
391            let Some(TabEdge::Basic(cod)) = path.only() else {
392                panic!("Codomain of link should be basic morphism");
393            };
394            if let Some(term) = terms.get_mut(&cod) {
395                let mon: Monomial<_, i8> = [(dom, exponent)].into_iter().collect();
396                *term = std::mem::take(term) * mon;
397            } else {
398                panic!("Codomain of link does not belong to model");
399            };
400        };
401
402        for link in model.mor_generators_with_type(&self.pos_link_mor_type) {
403            multiply_for_link(link, 1);
404        }
405        for link in model.mor_generators_with_type(&self.neg_link_mor_type) {
406            multiply_for_link(link, -1);
407        }
408
409        terms
410    }
411}
412
413/// Substitutes numerical rate coefficients into a symbolic mass-action system.
414pub fn extend_mass_action_scalars(
415    sys: PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8>,
416    data: &MassActionProblemData,
417) -> PolynomialSystem<QualifiedName, f32, i8> {
418    let sys = sys.extend_scalars(|poly| {
419        poly.eval(|flow| match flow {
420            FlowParameter::Balanced { transition } => {
421                data.transition_rates.get(transition).cloned().unwrap_or_default()
422            }
423            FlowParameter::Unbalanced { direction, parameter } => match (direction, parameter) {
424                (Direction::IncomingFlow, RateParameter::PerTransition { transition }) => {
425                    data.transition_production_rates.get(transition).cloned().unwrap_or_default()
426                }
427                (Direction::OutgoingFlow, RateParameter::PerTransition { transition }) => {
428                    data.transition_consumption_rates.get(transition).cloned().unwrap_or_default()
429                }
430                (Direction::IncomingFlow, RateParameter::PerPlace { transition, place }) => data
431                    .place_production_rates
432                    .get(transition)
433                    .and_then(|rate| rate.get(place))
434                    .copied()
435                    .unwrap_or_default(),
436                (Direction::OutgoingFlow, RateParameter::PerPlace { transition, place }) => data
437                    .place_consumption_rates
438                    .get(transition)
439                    .and_then(|rate| rate.get(place))
440                    .copied()
441                    .unwrap_or_default(),
442            },
443        })
444    });
445
446    sys.normalize()
447}
448
449/// Builds the numerical ODE analysis for a mass-action system whose scalars have been substituted.
450pub fn into_mass_action_analysis(
451    sys: PolynomialSystem<QualifiedName, f32, i8>,
452    data: MassActionProblemData,
453) -> ODEAnalysis<NumericalPolynomialSystem<i8>> {
454    let ob_index: IndexMap<_, _> =
455        sys.components.keys().cloned().enumerate().map(|(i, x)| (x, i)).collect();
456    let n = ob_index.len();
457
458    let initial_values = ob_index
459        .keys()
460        .map(|ob| data.initial_values.get(ob).copied().unwrap_or_default());
461    let x0 = DVector::from_iterator(n, initial_values);
462
463    let num_sys = sys.to_numerical();
464    let problem = ODEProblem::new(num_sys, x0).end_time(data.duration);
465
466    ODEAnalysis::new(problem, ob_index)
467}
468
469#[cfg(test)]
470mod tests {
471    use expect_test::expect;
472    use std::rc::Rc;
473
474    use super::*;
475    use crate::simulate::ode::LatexEquation;
476    use crate::stdlib::{analyses, models::*, theories::*};
477
478    // Tests for stock-flow diagrams. These all use the backward_link() model,
479    // which has a single flow x==f==>y and a single link y->f.
480
481    #[test]
482    fn balanced_stock_flow() {
483        let th = Rc::new(th_category_links());
484        let model = backward_link(th);
485        let sys = StockFlowMassActionAnalysis::default()
486            .build_system(&model, analyses::ode::MassConservationType::Balanced);
487        let expected = expect!([r#"
488            dx = (-f) x y
489            dy = f x y
490        "#]);
491        expected.assert_eq(&sys.to_string());
492    }
493
494    #[test]
495    fn unbalanced_stock_flow() {
496        let th = Rc::new(th_category_links());
497        let model = backward_link(th);
498        let sys = StockFlowMassActionAnalysis::default().build_system(
499            &model,
500            analyses::ode::MassConservationType::Unbalanced(
501                analyses::ode::RateGranularity::PerTransition,
502            ),
503        );
504        let expected = expect!([r#"
505            dx = (-Outgoing(f)) x y
506            dy = (Incoming(f)) x y
507        "#]);
508        expected.assert_eq(&sys.to_string());
509    }
510
511    // Tests for signed stock-flow diagrams. These all use the negative_backwards_link()
512    // model, which has a single flow x==f=>y and a single negative link y->f.
513
514    #[test]
515    fn balanced_signed_stock_flow() {
516        let th = Rc::new(th_category_signed_links());
517        let model = negative_backward_link(th);
518        let sys = StockFlowMassActionAnalysis::default()
519            .build_system(&model, analyses::ode::MassConservationType::Balanced);
520        let expected = expect!([r#"
521            dx = (-f) x y^{-1}
522            dy = f x y^{-1}
523        "#]);
524        expected.assert_eq(&sys.to_string());
525    }
526
527    #[test]
528    fn unbalanced_signed_stock_flow() {
529        let th = Rc::new(th_category_signed_links());
530        let model = negative_backward_link(th);
531        let sys = StockFlowMassActionAnalysis::default().build_system(
532            &model,
533            analyses::ode::MassConservationType::Unbalanced(
534                analyses::ode::RateGranularity::PerTransition,
535            ),
536        );
537        let expected = expect!([r#"
538            dx = (-Outgoing(f)) x y^{-1}
539            dy = (Incoming(f)) x y^{-1}
540        "#]);
541        expected.assert_eq(&sys.to_string());
542    }
543
544    // Tests for Petri nets. These all use the catalyzed_reaction() model, which
545    // has a single transition [x,c]-->f-->[y,c].
546
547    #[test]
548    fn balanced_petri() {
549        let th = Rc::new(th_sym_monoidal_category());
550        let model = catalyzed_reaction(th);
551        let sys = PetriNetMassActionAnalysis::default()
552            .build_system(&model, analyses::ode::MassConservationType::Balanced);
553        let expected = expect!([r#"
554            dx = (-f) c x
555            dy = f c x
556            dc = 0
557        "#]);
558        expected.assert_eq(&sys.to_string());
559    }
560
561    #[test]
562    fn unbalanced_petri_per_transition() {
563        let th = Rc::new(th_sym_monoidal_category());
564        let model = catalyzed_reaction(th);
565        let sys = PetriNetMassActionAnalysis::default().build_system(
566            &model,
567            analyses::ode::MassConservationType::Unbalanced(
568                analyses::ode::RateGranularity::PerTransition,
569            ),
570        );
571        let expected = expect!([r#"
572            dx = (-Outgoing(f)) c x
573            dy = (Incoming(f)) c x
574            dc = (Incoming(f) + -Outgoing(f)) c x
575        "#]);
576        expected.assert_eq(&sys.to_string());
577    }
578
579    #[test]
580    fn unbalanced_petri_per_place() {
581        let th = Rc::new(th_sym_monoidal_category());
582        let model = catalyzed_reaction(th);
583        let sys = PetriNetMassActionAnalysis::default().build_system(
584            &model,
585            analyses::ode::MassConservationType::Unbalanced(
586                analyses::ode::RateGranularity::PerPlace,
587            ),
588        );
589        let expected = expect!([r#"
590            dx = (-(x->[f])) c x
591            dy = (([f]->y)) c x
592            dc = (([f]->c) + -(c->[f])) c x
593        "#]);
594        expected.assert_eq(&sys.to_string());
595    }
596
597    // Test for LaTeX.
598
599    #[test]
600    fn to_latex() {
601        let th = Rc::new(th_category_links());
602        let model = backward_link(th);
603        let sys = StockFlowMassActionAnalysis::default().build_system(
604            &model,
605            analyses::ode::MassConservationType::Unbalanced(
606                analyses::ode::RateGranularity::PerTransition,
607            ),
608        );
609        let expected = vec![
610            LatexEquation {
611                lhs: "\\frac{\\mathrm{d}}{\\mathrm{d}t} x".to_string(),
612                rhs: "(-Outgoing(f)) x y".to_string(),
613            },
614            LatexEquation {
615                lhs: "\\frac{\\mathrm{d}}{\\mathrm{d}t} y".to_string(),
616                rhs: "(Incoming(f)) x y".to_string(),
617            },
618        ];
619        assert_eq!(expected, sys.to_latex_equations());
620    }
621}