1use std::{collections::HashMap, fmt};
9
10use indexmap::IndexMap;
11use nalgebra::DVector;
12use num_traits::Zero;
13
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16#[cfg(feature = "serde-wasm")]
17use tsify::Tsify;
18
19use super::{ODEAnalysis, Parameter};
20use crate::dbl::{
21 model::{DiscreteTabModel, FpDblModel, ModalDblModel, TabEdge},
22 theory::{ModalMorType, ModalObType, TabMorType, TabObType, Unital},
23};
24use crate::one::FgCategory;
25use crate::simulate::ode::{NumericalPolynomialSystem, ODEProblem, PolynomialSystem};
26use crate::stdlib::analyses::petri::transition_interface;
27use crate::zero::{QualifiedName, alg::Polynomial, name, rig::Monomial};
28
29#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
34#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
35#[cfg_attr(feature = "serde", serde(tag = "type", content = "granularity"))]
36#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
37#[cfg_attr(feature = "serde-wasm", tsify(into_wasm_abi, from_wasm_abi))]
38pub enum MassConservationType {
39 Balanced,
41 Unbalanced(RateGranularity),
43}
44
45#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
48#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
49#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
50#[cfg_attr(feature = "serde-wasm", tsify(into_wasm_abi, from_wasm_abi))]
51pub enum RateGranularity {
52 PerTransition,
54
55 PerPlace,
58}
59
60#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
63pub enum FlowParameter {
64 Balanced {
66 transition: QualifiedName,
68 },
69 Unbalanced {
71 direction: Direction,
73 parameter: RateParameter,
75 },
76}
77
78#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
80pub enum RateParameter {
81 PerTransition {
83 transition: QualifiedName,
85 },
86
87 PerPlace {
90 transition: QualifiedName,
92 place: QualifiedName,
94 },
95}
96
97#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
101pub enum Direction {
102 IncomingFlow,
104
105 OutgoingFlow,
107}
108
109impl fmt::Display for FlowParameter {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 match &self {
112 FlowParameter::Balanced { transition: trans } => {
113 write!(f, "{}", trans)
114 }
115 FlowParameter::Unbalanced {
116 direction: Direction::IncomingFlow,
117 parameter: RateParameter::PerTransition { transition: trans },
118 } => {
119 write!(f, "Incoming({})", trans)
120 }
121 FlowParameter::Unbalanced {
122 direction: Direction::IncomingFlow,
123 parameter: RateParameter::PerPlace { transition: trans, place: output },
124 } => {
125 write!(f, "([{}]->{})", trans, output)
126 }
127 FlowParameter::Unbalanced {
128 direction: Direction::OutgoingFlow,
129 parameter: RateParameter::PerTransition { transition: trans },
130 } => {
131 write!(f, "Outgoing({})", trans)
132 }
133 FlowParameter::Unbalanced {
134 direction: Direction::OutgoingFlow,
135 parameter: RateParameter::PerPlace { transition: trans, place: input },
136 } => {
137 write!(f, "({}->[{}])", input, trans)
138 }
139 }
140 }
141}
142
143#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
145#[cfg_attr(feature = "serde-wasm", derive(Tsify))]
146#[cfg_attr(
147 feature = "serde-wasm",
148 tsify(into_wasm_abi, from_wasm_abi, hashmap_as_object)
149)]
150pub struct MassActionProblemData {
151 #[cfg_attr(feature = "serde", serde(rename = "massConservationType"))]
153 pub mass_conservation_type: MassConservationType,
154
155 #[cfg_attr(feature = "serde", serde(rename = "rates"))]
159 transition_rates: HashMap<QualifiedName, f32>,
160
161 #[cfg_attr(feature = "serde", serde(rename = "transitionConsumptionRates"))]
164 transition_consumption_rates: HashMap<QualifiedName, f32>,
165
166 #[cfg_attr(feature = "serde", serde(rename = "transitionProductionRates"))]
169 transition_production_rates: HashMap<QualifiedName, f32>,
170
171 #[cfg_attr(feature = "serde", serde(rename = "placeConsumptionRates"))]
174 place_consumption_rates: HashMap<QualifiedName, HashMap<QualifiedName, f32>>,
175
176 #[cfg_attr(feature = "serde", serde(rename = "placeProductionRates"))]
179 place_production_rates: HashMap<QualifiedName, HashMap<QualifiedName, f32>>,
180
181 #[cfg_attr(feature = "serde", serde(rename = "initialValues"))]
183 pub initial_values: HashMap<QualifiedName, f32>,
184
185 pub duration: f32,
187}
188
189pub struct PetriNetMassActionAnalysis {
194 pub place_ob_type: ModalObType,
196 pub transition_mor_type: ModalMorType,
198}
199
200impl Default for PetriNetMassActionAnalysis {
201 fn default() -> Self {
202 let ob_type = ModalObType::new(name("Object"));
203 Self {
204 place_ob_type: ob_type.clone(),
205 transition_mor_type: ModalMorType::Zero(ob_type),
206 }
207 }
208}
209
210impl PetriNetMassActionAnalysis {
211 pub fn build_system(
213 &self,
214 model: &ModalDblModel<Unital>,
215 mass_conservation_type: MassConservationType,
216 ) -> PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8> {
217 let mut sys = PolynomialSystem::new();
218 for ob in model.ob_generators_with_type(&self.place_ob_type) {
219 sys.add_term(ob, Polynomial::zero());
220 }
221 for mor in model.mor_generators_with_type(&self.transition_mor_type) {
222 let (inputs, outputs) = transition_interface(model, &mor);
223 let term: Monomial<_, _> =
224 inputs.iter().map(|ob| (ob.clone().unwrap_generator(), 1)).collect();
225
226 match mass_conservation_type {
227 MassConservationType::Balanced => {
228 let term: Polynomial<_, _, _> = [(
229 Parameter::generator(FlowParameter::Balanced { transition: mor }),
230 term.clone(),
231 )]
232 .into_iter()
233 .collect();
234
235 for input in inputs {
236 sys.add_term(input.unwrap_generator(), -term.clone());
237 }
238
239 for output in outputs {
240 sys.add_term(output.unwrap_generator(), term.clone());
241 }
242 }
243
244 MassConservationType::Unbalanced(granularity) => {
245 for input in inputs {
246 let input_term: Polynomial<_, _, _> = match granularity {
247 RateGranularity::PerTransition => [(
248 Parameter::generator(FlowParameter::Unbalanced {
249 direction: Direction::OutgoingFlow,
250 parameter: RateParameter::PerTransition {
251 transition: mor.clone(),
252 },
253 }),
254 term.clone(),
255 )],
256 RateGranularity::PerPlace => [(
257 Parameter::generator(FlowParameter::Unbalanced {
258 direction: Direction::OutgoingFlow,
259 parameter: RateParameter::PerPlace {
260 transition: mor.clone(),
261 place: input.clone().unwrap_generator(),
262 },
263 }),
264 term.clone(),
265 )],
266 }
267 .into_iter()
268 .collect();
269
270 sys.add_term(input.unwrap_generator(), -input_term.clone());
271 }
272 for output in outputs {
273 let output_term: Polynomial<_, _, _> = match granularity {
274 RateGranularity::PerTransition => [(
275 Parameter::generator(FlowParameter::Unbalanced {
276 direction: Direction::IncomingFlow,
277 parameter: RateParameter::PerTransition {
278 transition: mor.clone(),
279 },
280 }),
281 term.clone(),
282 )],
283 RateGranularity::PerPlace => [(
284 Parameter::generator(FlowParameter::Unbalanced {
285 direction: Direction::IncomingFlow,
286 parameter: RateParameter::PerPlace {
287 transition: mor.clone(),
288 place: output.clone().unwrap_generator(),
289 },
290 }),
291 term.clone(),
292 )],
293 }
294 .into_iter()
295 .collect();
296
297 sys.add_term(output.unwrap_generator(), output_term.clone());
298 }
299 }
300 }
301 }
302
303 sys.normalize()
304 }
305}
306
307pub struct StockFlowMassActionAnalysis {
309 pub stock_ob_type: TabObType,
311 pub flow_mor_type: TabMorType,
313 pub pos_link_mor_type: TabMorType,
315 pub neg_link_mor_type: TabMorType,
317}
318
319impl Default for StockFlowMassActionAnalysis {
320 fn default() -> Self {
321 let stock_ob_type = TabObType::Basic(name("Object"));
322 let flow_mor_type = TabMorType::Hom(Box::new(stock_ob_type.clone()));
323 Self {
324 stock_ob_type,
325 flow_mor_type,
326 pos_link_mor_type: TabMorType::Basic(name("Link")),
327 neg_link_mor_type: TabMorType::Basic(name("NegativeLink")),
328 }
329 }
330}
331
332impl StockFlowMassActionAnalysis {
333 pub fn build_system(
335 &self,
336 model: &DiscreteTabModel,
337 mass_conservation_type: MassConservationType,
338 ) -> PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8> {
339 let terms: Vec<_> = self.flow_monomials(model).into_iter().collect();
340
341 let mut sys = PolynomialSystem::new();
342 for ob in model.ob_generators_with_type(&self.stock_ob_type) {
343 sys.add_term(ob, Polynomial::zero());
344 }
345 for (flow, term) in terms {
346 let dom = model.mor_generator_dom(&flow).unwrap_basic();
347 let cod = model.mor_generator_cod(&flow).unwrap_basic();
348 match mass_conservation_type {
349 MassConservationType::Balanced => {
350 let param = Parameter::generator(FlowParameter::Balanced { transition: flow });
351 let term: Polynomial<_, _, _> = [(param, term.clone())].into_iter().collect();
352 sys.add_term(dom, -term.clone());
353 sys.add_term(cod, term);
354 }
355 MassConservationType::Unbalanced(_) => {
356 let dom_param = Parameter::generator(FlowParameter::Unbalanced {
357 direction: Direction::OutgoingFlow,
358 parameter: RateParameter::PerTransition { transition: flow.clone() },
359 });
360 let cod_param = Parameter::generator(FlowParameter::Unbalanced {
361 direction: Direction::IncomingFlow,
362 parameter: RateParameter::PerTransition { transition: flow },
363 });
364 let dom_term: Polynomial<_, _, _> =
365 [(dom_param, term.clone())].into_iter().collect();
366 let cod_term: Polynomial<_, _, _> = [(cod_param, term)].into_iter().collect();
367 sys.add_term(dom, -dom_term);
368 sys.add_term(cod, cod_term);
369 }
370 }
371 }
372 sys
373 }
374
375 pub(super) fn flow_monomials(
377 &self,
378 model: &DiscreteTabModel,
379 ) -> HashMap<QualifiedName, Monomial<QualifiedName, i8>> {
380 let mut terms: HashMap<_, _> = model
381 .mor_generators_with_type(&self.flow_mor_type)
382 .map(|flow| {
383 let dom = model.mor_generator_dom(&flow).unwrap_basic();
384 (flow, Monomial::generator(dom))
385 })
386 .collect();
387
388 let mut multiply_for_link = |link: QualifiedName, exponent: i8| {
389 let dom = model.mor_generator_dom(&link).unwrap_basic();
390 let path = model.mor_generator_cod(&link).unwrap_tabulated();
391 let Some(TabEdge::Basic(cod)) = path.only() else {
392 panic!("Codomain of link should be basic morphism");
393 };
394 if let Some(term) = terms.get_mut(&cod) {
395 let mon: Monomial<_, i8> = [(dom, exponent)].into_iter().collect();
396 *term = std::mem::take(term) * mon;
397 } else {
398 panic!("Codomain of link does not belong to model");
399 };
400 };
401
402 for link in model.mor_generators_with_type(&self.pos_link_mor_type) {
403 multiply_for_link(link, 1);
404 }
405 for link in model.mor_generators_with_type(&self.neg_link_mor_type) {
406 multiply_for_link(link, -1);
407 }
408
409 terms
410 }
411}
412
413pub fn extend_mass_action_scalars(
415 sys: PolynomialSystem<QualifiedName, Parameter<FlowParameter>, i8>,
416 data: &MassActionProblemData,
417) -> PolynomialSystem<QualifiedName, f32, i8> {
418 let sys = sys.extend_scalars(|poly| {
419 poly.eval(|flow| match flow {
420 FlowParameter::Balanced { transition } => {
421 data.transition_rates.get(transition).cloned().unwrap_or_default()
422 }
423 FlowParameter::Unbalanced { direction, parameter } => match (direction, parameter) {
424 (Direction::IncomingFlow, RateParameter::PerTransition { transition }) => {
425 data.transition_production_rates.get(transition).cloned().unwrap_or_default()
426 }
427 (Direction::OutgoingFlow, RateParameter::PerTransition { transition }) => {
428 data.transition_consumption_rates.get(transition).cloned().unwrap_or_default()
429 }
430 (Direction::IncomingFlow, RateParameter::PerPlace { transition, place }) => data
431 .place_production_rates
432 .get(transition)
433 .and_then(|rate| rate.get(place))
434 .copied()
435 .unwrap_or_default(),
436 (Direction::OutgoingFlow, RateParameter::PerPlace { transition, place }) => data
437 .place_consumption_rates
438 .get(transition)
439 .and_then(|rate| rate.get(place))
440 .copied()
441 .unwrap_or_default(),
442 },
443 })
444 });
445
446 sys.normalize()
447}
448
449pub fn into_mass_action_analysis(
451 sys: PolynomialSystem<QualifiedName, f32, i8>,
452 data: MassActionProblemData,
453) -> ODEAnalysis<NumericalPolynomialSystem<i8>> {
454 let ob_index: IndexMap<_, _> =
455 sys.components.keys().cloned().enumerate().map(|(i, x)| (x, i)).collect();
456 let n = ob_index.len();
457
458 let initial_values = ob_index
459 .keys()
460 .map(|ob| data.initial_values.get(ob).copied().unwrap_or_default());
461 let x0 = DVector::from_iterator(n, initial_values);
462
463 let num_sys = sys.to_numerical();
464 let problem = ODEProblem::new(num_sys, x0).end_time(data.duration);
465
466 ODEAnalysis::new(problem, ob_index)
467}
468
469#[cfg(test)]
470mod tests {
471 use expect_test::expect;
472 use std::rc::Rc;
473
474 use super::*;
475 use crate::simulate::ode::LatexEquation;
476 use crate::stdlib::{analyses, models::*, theories::*};
477
478 #[test]
482 fn balanced_stock_flow() {
483 let th = Rc::new(th_category_links());
484 let model = backward_link(th);
485 let sys = StockFlowMassActionAnalysis::default()
486 .build_system(&model, analyses::ode::MassConservationType::Balanced);
487 let expected = expect!([r#"
488 dx = (-f) x y
489 dy = f x y
490 "#]);
491 expected.assert_eq(&sys.to_string());
492 }
493
494 #[test]
495 fn unbalanced_stock_flow() {
496 let th = Rc::new(th_category_links());
497 let model = backward_link(th);
498 let sys = StockFlowMassActionAnalysis::default().build_system(
499 &model,
500 analyses::ode::MassConservationType::Unbalanced(
501 analyses::ode::RateGranularity::PerTransition,
502 ),
503 );
504 let expected = expect!([r#"
505 dx = (-Outgoing(f)) x y
506 dy = (Incoming(f)) x y
507 "#]);
508 expected.assert_eq(&sys.to_string());
509 }
510
511 #[test]
515 fn balanced_signed_stock_flow() {
516 let th = Rc::new(th_category_signed_links());
517 let model = negative_backward_link(th);
518 let sys = StockFlowMassActionAnalysis::default()
519 .build_system(&model, analyses::ode::MassConservationType::Balanced);
520 let expected = expect!([r#"
521 dx = (-f) x y^{-1}
522 dy = f x y^{-1}
523 "#]);
524 expected.assert_eq(&sys.to_string());
525 }
526
527 #[test]
528 fn unbalanced_signed_stock_flow() {
529 let th = Rc::new(th_category_signed_links());
530 let model = negative_backward_link(th);
531 let sys = StockFlowMassActionAnalysis::default().build_system(
532 &model,
533 analyses::ode::MassConservationType::Unbalanced(
534 analyses::ode::RateGranularity::PerTransition,
535 ),
536 );
537 let expected = expect!([r#"
538 dx = (-Outgoing(f)) x y^{-1}
539 dy = (Incoming(f)) x y^{-1}
540 "#]);
541 expected.assert_eq(&sys.to_string());
542 }
543
544 #[test]
548 fn balanced_petri() {
549 let th = Rc::new(th_sym_monoidal_category());
550 let model = catalyzed_reaction(th);
551 let sys = PetriNetMassActionAnalysis::default()
552 .build_system(&model, analyses::ode::MassConservationType::Balanced);
553 let expected = expect!([r#"
554 dx = (-f) c x
555 dy = f c x
556 dc = 0
557 "#]);
558 expected.assert_eq(&sys.to_string());
559 }
560
561 #[test]
562 fn unbalanced_petri_per_transition() {
563 let th = Rc::new(th_sym_monoidal_category());
564 let model = catalyzed_reaction(th);
565 let sys = PetriNetMassActionAnalysis::default().build_system(
566 &model,
567 analyses::ode::MassConservationType::Unbalanced(
568 analyses::ode::RateGranularity::PerTransition,
569 ),
570 );
571 let expected = expect!([r#"
572 dx = (-Outgoing(f)) c x
573 dy = (Incoming(f)) c x
574 dc = (Incoming(f) + -Outgoing(f)) c x
575 "#]);
576 expected.assert_eq(&sys.to_string());
577 }
578
579 #[test]
580 fn unbalanced_petri_per_place() {
581 let th = Rc::new(th_sym_monoidal_category());
582 let model = catalyzed_reaction(th);
583 let sys = PetriNetMassActionAnalysis::default().build_system(
584 &model,
585 analyses::ode::MassConservationType::Unbalanced(
586 analyses::ode::RateGranularity::PerPlace,
587 ),
588 );
589 let expected = expect!([r#"
590 dx = (-(x->[f])) c x
591 dy = (([f]->y)) c x
592 dc = (([f]->c) + -(c->[f])) c x
593 "#]);
594 expected.assert_eq(&sys.to_string());
595 }
596
597 #[test]
600 fn to_latex() {
601 let th = Rc::new(th_category_links());
602 let model = backward_link(th);
603 let sys = StockFlowMassActionAnalysis::default().build_system(
604 &model,
605 analyses::ode::MassConservationType::Unbalanced(
606 analyses::ode::RateGranularity::PerTransition,
607 ),
608 );
609 let expected = vec![
610 LatexEquation {
611 lhs: "\\frac{\\mathrm{d}}{\\mathrm{d}t} x".to_string(),
612 rhs: "(-Outgoing(f)) x y".to_string(),
613 },
614 LatexEquation {
615 lhs: "\\frac{\\mathrm{d}}{\\mathrm{d}t} y".to_string(),
616 rhs: "(Incoming(f)) x y".to_string(),
617 },
618 ];
619 assert_eq!(expected, sys.to_latex_equations());
620 }
621}