1use std::{collections::HashMap, fmt};
9
10use indexmap::IndexMap;
11use nalgebra::DVector;
12use num_traits::Zero;
13
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16#[cfg(feature = "serde-wasm")]
17use tsify::Tsify;
18
19use super::ODEAnalysis;
20use crate::dbl::{
21 model::{DiscreteTabModel, FgDblModel, ModalDblModel, TabEdge},
22 theory::{ModalMorType, ModalObType, TabMorType, TabObType},
23};
24use crate::one::FgCategory;
25use crate::simulate::ode::{NumericalPolynomialSystem, ODEProblem, PolynomialSystem};
26use crate::stdlib::analyses::petri::transition_interface;
27use crate::zero::{QualifiedName, alg::Polynomial, name, rig::Monomial};
28
29#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
34#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
35#[cfg_attr(feature = "serde", serde(tag = "type", content = "granularity"))]
36#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
37#[cfg_attr(feature = "serde-wasm", tsify(into_wasm_abi, from_wasm_abi))]
38pub enum MassConservationType {
39 Balanced,
41 Unbalanced(RateGranularity),
43}
44
45#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
48#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
49#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
50#[cfg_attr(feature = "serde-wasm", tsify(into_wasm_abi, from_wasm_abi))]
51pub enum RateGranularity {
52 PerTransition,
54
55 PerPlace,
58}
59
60#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
63pub enum FlowParameter {
64 Balanced {
66 transition: QualifiedName,
68 },
69 Unbalanced {
71 direction: Direction,
73 parameter: RateParameter,
75 },
76}
77
78#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
80pub enum RateParameter {
81 PerTransition {
83 transition: QualifiedName,
85 },
86
87 PerPlace {
90 transition: QualifiedName,
92 place: QualifiedName,
94 },
95}
96
97#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
101pub enum Direction {
102 IncomingFlow,
104
105 OutgoingFlow,
107}
108
109impl fmt::Display for FlowParameter {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 match &self {
112 FlowParameter::Balanced { transition: trans } => {
113 write!(f, "{}", trans)
114 }
115 FlowParameter::Unbalanced {
116 direction: Direction::IncomingFlow,
117 parameter: RateParameter::PerTransition { transition: trans },
118 } => {
119 write!(f, "Incoming({})", trans)
120 }
121 FlowParameter::Unbalanced {
122 direction: Direction::IncomingFlow,
123 parameter: RateParameter::PerPlace { transition: trans, place: output },
124 } => {
125 write!(f, "([{}]->{})", trans, output)
126 }
127 FlowParameter::Unbalanced {
128 direction: Direction::OutgoingFlow,
129 parameter: RateParameter::PerTransition { transition: trans },
130 } => {
131 write!(f, "Outgoing({})", trans)
132 }
133 FlowParameter::Unbalanced {
134 direction: Direction::OutgoingFlow,
135 parameter: RateParameter::PerPlace { transition: trans, place: input },
136 } => {
137 write!(f, "({}->[{}])", input, trans)
138 }
139 }
140 }
141}
142
143#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
145#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
146#[cfg_attr(
147 feature = "serde-wasm",
148 tsify(into_wasm_abi, from_wasm_abi, hashmap_as_object)
149)]
150pub struct MassActionProblemData {
151 #[cfg_attr(feature = "serde", serde(rename = "massConservationType"))]
153 pub mass_conservation_type: MassConservationType,
154
155 #[cfg_attr(feature = "serde", serde(rename = "rates"))]
159 transition_rates: HashMap<QualifiedName, f32>,
160
161 #[cfg_attr(feature = "serde", serde(rename = "transitionConsumptionRates"))]
164 transition_consumption_rates: HashMap<QualifiedName, f32>,
165
166 #[cfg_attr(feature = "serde", serde(rename = "transitionProductionRates"))]
169 transition_production_rates: HashMap<QualifiedName, f32>,
170
171 #[cfg_attr(feature = "serde", serde(rename = "placeConsumptionRates"))]
174 place_consumption_rates: HashMap<QualifiedName, HashMap<QualifiedName, f32>>,
175
176 #[cfg_attr(feature = "serde", serde(rename = "placeProductionRates"))]
179 place_production_rates: HashMap<QualifiedName, HashMap<QualifiedName, f32>>,
180
181 #[cfg_attr(feature = "serde", serde(rename = "initialValues"))]
183 pub initial_values: HashMap<QualifiedName, f32>,
184
185 pub duration: f32,
187}
188
189type Parameter<Id> = Polynomial<Id, f32, i8>;
191
192pub struct PetriNetMassActionAnalysis {
197 pub place_ob_type: ModalObType,
199 pub transition_mor_type: ModalMorType,
201}
202
203impl Default for PetriNetMassActionAnalysis {
204 fn default() -> Self {
205 let ob_type = ModalObType::new(name("Object"));
206 Self {
207 place_ob_type: ob_type.clone(),
208 transition_mor_type: ModalMorType::Zero(ob_type),
209 }
210 }
211}
212
213impl PetriNetMassActionAnalysis {
214 pub fn build_system(
216 &self,
217 model: &ModalDblModel,
218 mass_conservation_type: MassConservationType,
219 ) -> PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8> {
220 let mut sys = PolynomialSystem::new();
221 for ob in model.ob_generators_with_type(&self.place_ob_type) {
222 sys.add_term(ob, Polynomial::zero());
223 }
224 for mor in model.mor_generators_with_type(&self.transition_mor_type) {
225 let (inputs, outputs) = transition_interface(model, &mor);
226 let term: Monomial<_, _> =
227 inputs.iter().map(|ob| (ob.clone().unwrap_generator(), 1)).collect();
228
229 match mass_conservation_type {
230 MassConservationType::Balanced => {
231 let term: Polynomial<_, _, _> = [(
232 Parameter::generator(FlowParameter::Balanced { transition: mor }),
233 term.clone(),
234 )]
235 .into_iter()
236 .collect();
237
238 for input in inputs {
239 sys.add_term(input.unwrap_generator(), -term.clone());
240 }
241
242 for output in outputs {
243 sys.add_term(output.unwrap_generator(), term.clone());
244 }
245 }
246
247 MassConservationType::Unbalanced(granularity) => {
248 for input in inputs {
249 let input_term: Polynomial<_, _, _> = match granularity {
250 RateGranularity::PerTransition => [(
251 Parameter::generator(FlowParameter::Unbalanced {
252 direction: Direction::OutgoingFlow,
253 parameter: RateParameter::PerTransition {
254 transition: mor.clone(),
255 },
256 }),
257 term.clone(),
258 )],
259 RateGranularity::PerPlace => [(
260 Parameter::generator(FlowParameter::Unbalanced {
261 direction: Direction::OutgoingFlow,
262 parameter: RateParameter::PerPlace {
263 transition: mor.clone(),
264 place: input.clone().unwrap_generator(),
265 },
266 }),
267 term.clone(),
268 )],
269 }
270 .into_iter()
271 .collect();
272
273 sys.add_term(input.unwrap_generator(), -input_term.clone());
274 }
275 for output in outputs {
276 let output_term: Polynomial<_, _, _> = match granularity {
277 RateGranularity::PerTransition => [(
278 Parameter::generator(FlowParameter::Unbalanced {
279 direction: Direction::IncomingFlow,
280 parameter: RateParameter::PerTransition {
281 transition: mor.clone(),
282 },
283 }),
284 term.clone(),
285 )],
286 RateGranularity::PerPlace => [(
287 Parameter::generator(FlowParameter::Unbalanced {
288 direction: Direction::IncomingFlow,
289 parameter: RateParameter::PerPlace {
290 transition: mor.clone(),
291 place: output.clone().unwrap_generator(),
292 },
293 }),
294 term.clone(),
295 )],
296 }
297 .into_iter()
298 .collect();
299
300 sys.add_term(output.unwrap_generator(), output_term.clone());
301 }
302 }
303 }
304 }
305
306 sys.normalize()
307 }
308}
309
310pub struct StockFlowMassActionAnalysis {
312 pub stock_ob_type: TabObType,
314 pub flow_mor_type: TabMorType,
316 pub pos_link_mor_type: TabMorType,
318 pub neg_link_mor_type: TabMorType,
320}
321
322impl Default for StockFlowMassActionAnalysis {
323 fn default() -> Self {
324 let stock_ob_type = TabObType::Basic(name("Object"));
325 let flow_mor_type = TabMorType::Hom(Box::new(stock_ob_type.clone()));
326 Self {
327 stock_ob_type,
328 flow_mor_type,
329 pos_link_mor_type: TabMorType::Basic(name("Link")),
330 neg_link_mor_type: TabMorType::Basic(name("NegativeLink")),
331 }
332 }
333}
334
335impl StockFlowMassActionAnalysis {
336 pub fn build_system(
338 &self,
339 model: &DiscreteTabModel,
340 mass_conservation_type: MassConservationType,
341 ) -> PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8> {
342 let terms: Vec<_> = self.flow_monomials(model).into_iter().collect();
343
344 let mut sys = PolynomialSystem::new();
345 for ob in model.ob_generators_with_type(&self.stock_ob_type) {
346 sys.add_term(ob, Polynomial::zero());
347 }
348 for (flow, term) in terms {
349 let dom = model.mor_generator_dom(&flow).unwrap_basic();
350 let cod = model.mor_generator_cod(&flow).unwrap_basic();
351 match mass_conservation_type {
352 MassConservationType::Balanced => {
353 let param = Parameter::generator(FlowParameter::Balanced { transition: flow });
354 let term: Polynomial<_, _, _> = [(param, term.clone())].into_iter().collect();
355 sys.add_term(dom, -term.clone());
356 sys.add_term(cod, term);
357 }
358 MassConservationType::Unbalanced(_) => {
359 let dom_param = Parameter::generator(FlowParameter::Unbalanced {
360 direction: Direction::OutgoingFlow,
361 parameter: RateParameter::PerTransition { transition: flow.clone() },
362 });
363 let cod_param = Parameter::generator(FlowParameter::Unbalanced {
364 direction: Direction::IncomingFlow,
365 parameter: RateParameter::PerTransition { transition: flow },
366 });
367 let dom_term: Polynomial<_, _, _> =
368 [(dom_param, term.clone())].into_iter().collect();
369 let cod_term: Polynomial<_, _, _> = [(cod_param, term)].into_iter().collect();
370 sys.add_term(dom, -dom_term);
371 sys.add_term(cod, cod_term);
372 }
373 }
374 }
375 sys
376 }
377
378 pub(super) fn flow_monomials(
380 &self,
381 model: &DiscreteTabModel,
382 ) -> HashMap<QualifiedName, Monomial<QualifiedName, i8>> {
383 let mut terms: HashMap<_, _> = model
384 .mor_generators_with_type(&self.flow_mor_type)
385 .map(|flow| {
386 let dom = model.mor_generator_dom(&flow).unwrap_basic();
387 (flow, Monomial::generator(dom))
388 })
389 .collect();
390
391 let mut multiply_for_link = |link: QualifiedName, exponent: i8| {
392 let dom = model.mor_generator_dom(&link).unwrap_basic();
393 let path = model.mor_generator_cod(&link).unwrap_tabulated();
394 let Some(TabEdge::Basic(cod)) = path.only() else {
395 panic!("Codomain of link should be basic morphism");
396 };
397 if let Some(term) = terms.get_mut(&cod) {
398 let mon: Monomial<_, i8> = [(dom, exponent)].into_iter().collect();
399 *term = std::mem::take(term) * mon;
400 } else {
401 panic!("Codomain of link does not belong to model");
402 };
403 };
404
405 for link in model.mor_generators_with_type(&self.pos_link_mor_type) {
406 multiply_for_link(link, 1);
407 }
408 for link in model.mor_generators_with_type(&self.neg_link_mor_type) {
409 multiply_for_link(link, -1);
410 }
411
412 terms
413 }
414}
415
416pub fn extend_mass_action_scalars(
418 sys: PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8>,
419 data: &MassActionProblemData,
420) -> PolynomialSystem<QualifiedName, f32, i8> {
421 let sys = sys.extend_scalars(|poly| {
422 poly.eval(|flow| match flow {
423 FlowParameter::Balanced { transition } => {
424 data.transition_rates.get(transition).cloned().unwrap_or_default()
425 }
426 FlowParameter::Unbalanced { direction, parameter } => match (direction, parameter) {
427 (Direction::IncomingFlow, RateParameter::PerTransition { transition }) => {
428 data.transition_production_rates.get(transition).cloned().unwrap_or_default()
429 }
430 (Direction::OutgoingFlow, RateParameter::PerTransition { transition }) => {
431 data.transition_consumption_rates.get(transition).cloned().unwrap_or_default()
432 }
433 (Direction::IncomingFlow, RateParameter::PerPlace { transition, place }) => data
434 .place_production_rates
435 .get(transition)
436 .and_then(|rate| rate.get(place))
437 .copied()
438 .unwrap_or_default(),
439 (Direction::OutgoingFlow, RateParameter::PerPlace { transition, place }) => data
440 .place_consumption_rates
441 .get(transition)
442 .and_then(|rate| rate.get(place))
443 .copied()
444 .unwrap_or_default(),
445 },
446 })
447 });
448
449 sys.normalize()
450}
451
452pub fn into_mass_action_analysis(
454 sys: PolynomialSystem<QualifiedName, f32, i8>,
455 data: MassActionProblemData,
456) -> ODEAnalysis<NumericalPolynomialSystem<i8>> {
457 let ob_index: IndexMap<_, _> =
458 sys.components.keys().cloned().enumerate().map(|(i, x)| (x, i)).collect();
459 let n = ob_index.len();
460
461 let initial_values = ob_index
462 .keys()
463 .map(|ob| data.initial_values.get(ob).copied().unwrap_or_default());
464 let x0 = DVector::from_iterator(n, initial_values);
465
466 let num_sys = sys.to_numerical();
467 let problem = ODEProblem::new(num_sys, x0).end_time(data.duration);
468
469 ODEAnalysis::new(problem, ob_index)
470}
471
472#[cfg(test)]
473mod tests {
474 use expect_test::expect;
475 use std::rc::Rc;
476
477 use super::*;
478 use crate::simulate::ode::LatexEquation;
479 use crate::stdlib::{analyses, models::*, theories::*};
480
481 #[test]
485 fn balanced_stock_flow() {
486 let th = Rc::new(th_category_links());
487 let model = backward_link(th);
488 let sys = StockFlowMassActionAnalysis::default()
489 .build_system(&model, analyses::ode::MassConservationType::Balanced);
490 let expected = expect!([r#"
491 dx = (-f) x y
492 dy = f x y
493 "#]);
494 expected.assert_eq(&sys.to_string());
495 }
496
497 #[test]
498 fn unbalanced_stock_flow() {
499 let th = Rc::new(th_category_links());
500 let model = backward_link(th);
501 let sys = StockFlowMassActionAnalysis::default().build_system(
502 &model,
503 analyses::ode::MassConservationType::Unbalanced(
504 analyses::ode::RateGranularity::PerTransition,
505 ),
506 );
507 let expected = expect!([r#"
508 dx = (-Outgoing(f)) x y
509 dy = (Incoming(f)) x y
510 "#]);
511 expected.assert_eq(&sys.to_string());
512 }
513
514 #[test]
518 fn balanced_signed_stock_flow() {
519 let th = Rc::new(th_category_signed_links());
520 let model = negative_backward_link(th);
521 let sys = StockFlowMassActionAnalysis::default()
522 .build_system(&model, analyses::ode::MassConservationType::Balanced);
523 let expected = expect!([r#"
524 dx = (-f) x y^{-1}
525 dy = f x y^{-1}
526 "#]);
527 expected.assert_eq(&sys.to_string());
528 }
529
530 #[test]
531 fn unbalanced_signed_stock_flow() {
532 let th = Rc::new(th_category_signed_links());
533 let model = negative_backward_link(th);
534 let sys = StockFlowMassActionAnalysis::default().build_system(
535 &model,
536 analyses::ode::MassConservationType::Unbalanced(
537 analyses::ode::RateGranularity::PerTransition,
538 ),
539 );
540 let expected = expect!([r#"
541 dx = (-Outgoing(f)) x y^{-1}
542 dy = (Incoming(f)) x y^{-1}
543 "#]);
544 expected.assert_eq(&sys.to_string());
545 }
546
547 #[test]
551 fn balanced_petri() {
552 let th = Rc::new(th_sym_monoidal_category());
553 let model = catalyzed_reaction(th);
554 let sys = PetriNetMassActionAnalysis::default()
555 .build_system(&model, analyses::ode::MassConservationType::Balanced);
556 let expected = expect!([r#"
557 dx = (-f) c x
558 dy = f c x
559 dc = 0
560 "#]);
561 expected.assert_eq(&sys.to_string());
562 }
563
564 #[test]
565 fn unbalanced_petri_per_transition() {
566 let th = Rc::new(th_sym_monoidal_category());
567 let model = catalyzed_reaction(th);
568 let sys = PetriNetMassActionAnalysis::default().build_system(
569 &model,
570 analyses::ode::MassConservationType::Unbalanced(
571 analyses::ode::RateGranularity::PerTransition,
572 ),
573 );
574 let expected = expect!([r#"
575 dx = (-Outgoing(f)) c x
576 dy = (Incoming(f)) c x
577 dc = (Incoming(f) + -Outgoing(f)) c x
578 "#]);
579 expected.assert_eq(&sys.to_string());
580 }
581
582 #[test]
583 fn unbalanced_petri_per_place() {
584 let th = Rc::new(th_sym_monoidal_category());
585 let model = catalyzed_reaction(th);
586 let sys = PetriNetMassActionAnalysis::default().build_system(
587 &model,
588 analyses::ode::MassConservationType::Unbalanced(
589 analyses::ode::RateGranularity::PerPlace,
590 ),
591 );
592 let expected = expect!([r#"
593 dx = (-(x->[f])) c x
594 dy = (([f]->y)) c x
595 dc = (([f]->c) + -(c->[f])) c x
596 "#]);
597 expected.assert_eq(&sys.to_string());
598 }
599
600 #[test]
603 fn to_latex() {
604 let th = Rc::new(th_category_links());
605 let model = backward_link(th);
606 let sys = StockFlowMassActionAnalysis::default().build_system(
607 &model,
608 analyses::ode::MassConservationType::Unbalanced(
609 analyses::ode::RateGranularity::PerTransition,
610 ),
611 );
612 let expected = vec![
613 LatexEquation {
614 lhs: "\\frac{\\mathrm{d}}{\\mathrm{d}t} x".to_string(),
615 rhs: "(-Outgoing(f)) x y".to_string(),
616 },
617 LatexEquation {
618 lhs: "\\frac{\\mathrm{d}}{\\mathrm{d}t} y".to_string(),
619 rhs: "(Incoming(f)) x y".to_string(),
620 },
621 ];
622 assert_eq!(expected, sys.to_latex_equations());
623 }
624}