1use std::cell::RefCell;
18use std::collections::HashMap;
19use std::fmt::Debug;
20use std::hash::{BuildHasher, BuildHasherDefault, Hash, RandomState};
21
22use derivative::Derivative;
23use egglog::{EGraph, ast::*, span};
24use nonempty::NonEmpty;
25use ref_cast::RefCast;
26use thiserror::Error;
27use ustr::{IdentityHasher, Ustr};
28
29use super::{category::*, graph::*, path::*};
30use crate::egglog_util::{CommandRewrite, CommandRule, Program};
31use crate::validate::{self, Validate};
32
33#[derive(Clone, Derivative)]
51#[derivative(Debug(bound = "V: Debug, E: Debug, S: Debug"))]
52#[derivative(Default(bound = "S: Default", new = "true"))]
53#[derivative(PartialEq(bound = "V: Eq + Hash, E: Eq + Hash, S: BuildHasher"))]
54#[derivative(Eq(bound = "V: Eq + Hash, E: Eq + Hash, S: BuildHasher"))]
55pub struct FpCategory<V, E, S = RandomState> {
56 generators: HashGraph<V, E, S>,
57 equations: Vec<PathEq<V, E>>,
58 #[derivative(Debug = "ignore", PartialEq = "ignore")]
59 builder: RefCell<CategoryProgramBuilder<V, E, S>>,
60 #[derivative(Debug = "ignore", PartialEq = "ignore")]
61 egraph: RefCell<EGraph>,
62}
63
64pub type UstrFpCategory = FpCategory<Ustr, Ustr, BuildHasherDefault<IdentityHasher>>;
66
67impl<V, E, S> FpCategory<V, E, S>
68where
69 V: Eq + Clone + Hash,
70 E: Eq + Clone + Hash,
71 S: BuildHasher,
72{
73 pub fn generators(&self) -> &(impl FinGraph<V = V, E = E> + use<V, E, S>) {
75 &self.generators
76 }
77
78 pub fn equations(&self) -> impl Iterator<Item = &PathEq<V, E>> {
80 self.equations.iter()
81 }
82
83 pub fn is_free(&self) -> bool {
85 self.equations.is_empty()
86 }
87
88 pub fn add_ob_generator(&mut self, v: V) {
90 assert!(self.generators.add_vertex(v.clone()), "Object generator already exists");
91 self.builder.get_mut().add_ob_generator(v);
92 }
93
94 pub fn add_ob_generators(&mut self, iter: impl IntoIterator<Item = V>) {
96 for v in iter {
97 self.add_ob_generator(v)
98 }
99 }
100
101 pub fn add_mor_generator(&mut self, e: E, dom: V, cod: V) {
103 assert!(
104 self.generators.add_edge(e.clone(), dom.clone(), cod.clone()),
105 "Morphism generator already exists"
106 );
107 let (dom, cod) = (self.ob_generator_expr(dom), self.ob_generator_expr(cod));
108 self.builder.get_mut().add_mor_generator(e, dom, cod);
109 }
110
111 pub fn make_mor_generator(&mut self, e: E) {
113 assert!(self.generators.make_edge(e.clone()), "Morphism generator already exists");
114 self.builder.get_mut().make_mor_generator(e);
115 }
116
117 pub fn get_dom(&self, e: &E) -> Option<&V> {
119 self.generators.get_src(e)
120 }
121
122 pub fn get_cod(&self, e: &E) -> Option<&V> {
124 self.generators.get_tgt(e)
125 }
126
127 pub fn set_dom(&mut self, e: E, v: V) {
129 assert!(
130 self.generators.set_src(e.clone(), v.clone()).is_none(),
131 "Domain of morphism generator should not already be set"
132 );
133 let (mor, ob) = (self.mor_generator_expr(e), self.ob_generator_expr(v));
134 self.builder.get_mut().set_dom(mor, ob);
135 }
136
137 pub fn set_cod(&mut self, e: E, v: V) {
139 assert!(
140 self.generators.set_tgt(e.clone(), v.clone()).is_none(),
141 "Codomain of morphism generator should not already be set"
142 );
143 let (mor, ob) = (self.mor_generator_expr(e), self.ob_generator_expr(v));
144 self.builder.get_mut().set_cod(mor, ob);
145 }
146
147 pub fn add_equation(&mut self, eq: PathEq<V, E>) {
149 self.equations.push(eq.clone());
150 let (lhs, rhs) = (self.path_expr(eq.lhs), self.path_expr(eq.rhs));
151 self.builder.get_mut().equate(lhs, rhs);
152 }
153
154 pub fn equate(&mut self, lhs: Path<V, E>, rhs: Path<V, E>) {
156 self.add_equation(PathEq::new(lhs, rhs));
157 }
158
159 fn ob_generator_expr(&self, v: V) -> Expr {
160 self.builder.borrow_mut().ob_generator(v)
161 }
162 fn mor_generator_expr(&self, e: E) -> Expr {
163 self.builder.borrow_mut().mor_generator(e)
164 }
165 fn path_expr(&self, path: Path<V, E>) -> Expr {
166 path.map_reduce(
167 |v| {
168 let ob = self.ob_generator_expr(v);
169 self.builder.borrow().id(ob)
170 },
171 |e| self.mor_generator_expr(e),
172 |f, g| self.builder.borrow().compose2(f, g),
173 )
174 }
175
176 pub fn iter_invalid(&self) -> impl Iterator<Item = InvalidFpCategory<E>> + '_ {
178 let generator_errors = self.generators.iter_invalid().map(|err| match err {
179 InvalidGraphData::Src(e) => InvalidFpCategory::Dom(e),
180 InvalidGraphData::Tgt(e) => InvalidFpCategory::Cod(e),
181 });
182 let equation_errors = self.equations.iter().enumerate().flat_map(|(i, eq)| {
183 eq.iter_invalid_in(&self.generators).map(move |err| match err {
184 InvalidPathEq::Lhs() => InvalidFpCategory::EqLhs(i),
185 InvalidPathEq::Rhs() => InvalidFpCategory::EqRhs(i),
186 InvalidPathEq::Src() => InvalidFpCategory::EqSrc(i),
187 InvalidPathEq::Tgt() => InvalidFpCategory::EqTgt(i),
188 })
189 });
190 generator_errors.chain(equation_errors)
191 }
192}
193
194impl<V, E, S> Category for FpCategory<V, E, S>
195where
196 V: Eq + Clone + Hash,
197 E: Eq + Clone + Hash,
198 S: BuildHasher,
199{
200 type Ob = V;
201 type Mor = Path<V, E>;
202
203 fn has_ob(&self, x: &Self::Ob) -> bool {
204 self.generators.has_vertex(x)
205 }
206 fn has_mor(&self, path: &Self::Mor) -> bool {
207 path.contained_in(&self.generators)
208 }
209 fn dom(&self, path: &Self::Mor) -> Self::Ob {
210 path.src(&self.generators)
211 }
212 fn cod(&self, path: &Self::Mor) -> Self::Ob {
213 path.tgt(&self.generators)
214 }
215
216 fn compose(&self, path: Path<Self::Ob, Self::Mor>) -> Self::Mor {
217 path.flatten_in(&self.generators).expect("Paths should be composable")
218 }
219 fn compose2(&self, path1: Self::Mor, path2: Self::Mor) -> Self::Mor {
220 path1
221 .concat_in(&self.generators, path2)
222 .expect("Target of first path should equal source of second path")
223 }
224
225 fn morphisms_are_equal(&self, path1: Self::Mor, path2: Self::Mor) -> bool {
226 let (lhs, rhs) = (self.path_expr(path1), self.path_expr(path2));
227 self.builder.borrow_mut().check_equal(lhs, rhs);
228 self.builder
229 .borrow_mut()
230 .program()
231 .check_in(&mut self.egraph.borrow_mut())
232 .expect("Unexpected egglog error")
233 }
234}
235
236impl<V, E, S> FgCategory for FpCategory<V, E, S>
237where
238 V: Eq + Clone + Hash,
239 E: Eq + Clone + Hash,
240 S: BuildHasher,
241{
242 type ObGen = V;
243 type MorGen = E;
244
245 fn ob_generators(&self) -> impl Iterator<Item = Self::ObGen> {
246 self.generators.vertices()
247 }
248 fn mor_generators(&self) -> impl Iterator<Item = Self::MorGen> {
249 self.generators.edges()
250 }
251 fn mor_generator_dom(&self, f: &Self::MorGen) -> Self::Ob {
252 self.generators.src(f)
253 }
254 fn mor_generator_cod(&self, f: &Self::MorGen) -> Self::Ob {
255 self.generators.tgt(f)
256 }
257}
258
259impl<V, E, S> Validate for FpCategory<V, E, S>
260where
261 V: Eq + Clone + Hash,
262 E: Eq + Clone + Hash,
263 S: BuildHasher,
264{
265 type ValidationError = InvalidFpCategory<E>;
266
267 fn validate(&self) -> Result<(), NonEmpty<Self::ValidationError>> {
268 validate::wrap_errors(self.iter_invalid())
269 }
270}
271
272#[derive(Debug, Error)]
274pub enum InvalidFpCategory<E> {
275 #[error("Domain of morphism generator `{0}` is not in the category")]
277 Dom(E),
278
279 #[error("Codomain of morphism generator `{0}` is not in the category")]
281 Cod(E),
282
283 #[error("LHS of path equation `{0}` is not in the category")]
285 EqLhs(usize),
286
287 #[error("RHS of path equation `{0}` is not in the category")]
289 EqRhs(usize),
290
291 #[error("Path equation `{0}` has sources that are not equal")]
293 EqSrc(usize),
294
295 #[error("Path equation `{0}` has targets that are not equal")]
297 EqTgt(usize),
298}
299
300#[derive(Clone)]
308struct CategoryProgramBuilder<V, E, S = RandomState> {
309 prog: Vec<Command>,
310 sym: CategorySymbols,
311 ob_generators: HashMap<V, usize, S>,
312 mor_generators: HashMap<E, usize, S>,
313}
314
315impl<V, E, S> CategoryProgramBuilder<V, E, S>
316where
317 V: Eq + Hash,
318 E: Eq + Hash,
319 S: BuildHasher,
320{
321 pub fn add_ob_generator(&mut self, v: V) -> usize {
323 let id = self.ob_generator_id(v);
324 let action = Action::Expr(span!(), self.ob_generator_with_id(id));
325 self.prog.push(Command::Action(action));
326 id
327 }
328
329 pub fn add_mor_generator(&mut self, e: E, dom: Expr, cod: Expr) -> usize {
331 let id = self.make_mor_generator(e);
332 self.set_dom(self.mor_generator_with_id(id), dom);
333 self.set_cod(self.mor_generator_with_id(id), cod);
334 id
335 }
336
337 pub fn make_mor_generator(&mut self, e: E) -> usize {
339 let id = self.mor_generator_id(e);
340 let action = Action::Expr(span!(), self.mor_generator_with_id(id));
341 self.prog.push(Command::Action(action));
342 id
343 }
344
345 pub fn set_dom(&mut self, mor: Expr, ob: Expr) {
347 let dom = self.dom(mor);
348 Program::ref_cast_mut(&mut self.prog).union(dom, ob);
349 }
350
351 pub fn set_cod(&mut self, mor: Expr, ob: Expr) {
353 let cod = self.cod(mor);
354 Program::ref_cast_mut(&mut self.prog).union(cod, ob);
355 }
356
357 pub fn ob_generator(&mut self, v: V) -> Expr {
359 let id = self.ob_generator_id(v);
360 self.ob_generator_with_id(id)
361 }
362
363 pub fn ob_generator_id(&mut self, v: V) -> usize {
365 let n = self.ob_generators.len();
366 *self.ob_generators.entry(v).or_insert(n)
367 }
368
369 pub fn mor_generator(&mut self, e: E) -> Expr {
371 let id = self.mor_generator_id(e);
372 self.mor_generator_with_id(id)
373 }
374
375 pub fn mor_generator_id(&mut self, e: E) -> usize {
377 let n = self.mor_generators.len();
378 *self.mor_generators.entry(e).or_insert(n)
379 }
380}
381
382impl<V, E, S> CategoryProgramBuilder<V, E, S> {
383 pub fn program(&mut self) -> Program {
385 Program(std::mem::take(&mut self.prog))
386 }
387
388 fn ob_generator_with_id(&self, id: usize) -> Expr {
390 let id: i64 = id.try_into().expect("Shouldn't have too many object generators");
391 call!(self.sym.ob_gen, vec![lit!(id)])
392 }
393
394 fn mor_generator_with_id(&self, id: usize) -> Expr {
396 let id: i64 = id.try_into().expect("Shouldn't have too many morphism generators");
397 call!(self.sym.mor_gen, vec![lit!(id)])
398 }
399
400 pub fn mor_is_valid(&self, mor: Expr) -> Expr {
402 call!(self.sym.mor_is_valid, vec![mor])
403 }
404
405 pub fn dom(&self, mor: Expr) -> Expr {
407 call!(self.sym.dom, vec![mor])
408 }
409
410 pub fn cod(&self, mor: Expr) -> Expr {
412 call!(self.sym.cod, vec![mor])
413 }
414
415 pub fn id(&self, ob: Expr) -> Expr {
417 call!(self.sym.id, vec![ob])
418 }
419
420 pub fn compose2(&self, f: Expr, g: Expr) -> Expr {
422 call!(self.sym.compose, vec![f, g])
423 }
424
425 pub fn equate(&mut self, lhs: Expr, rhs: Expr) {
427 Program::ref_cast_mut(&mut self.prog).union(lhs, rhs);
428 }
429
430 pub fn check_equal(&mut self, lhs: Expr, rhs: Expr) {
432 let schedule = self.schedule();
433 Program::ref_cast_mut(&mut self.prog).check_equal(lhs, rhs, Some(schedule));
434 }
435
436 fn schedule(&self) -> Schedule {
438 Schedule::Saturate(
439 span!(),
440 Box::new(Schedule::Run(
441 span!(),
442 GenericRunConfig {
443 ruleset: self.sym.axioms,
444 until: None,
445 },
446 )),
447 )
448 }
449
450 fn preamble(&mut self) {
452 let sym = &self.sym;
453 self.prog = vec![
454 Command::Datatype {
456 span: span!(),
457 name: sym.ob,
458 variants: vec![Variant {
459 span: span!(),
460 name: sym.ob_gen,
461 types: vec!["i64".into()],
462 cost: Some(0),
463 }],
464 },
465 Command::Datatype {
466 span: span!(),
467 name: sym.mor,
468 variants: vec![Variant {
469 span: span!(),
470 name: sym.mor_gen,
471 types: vec!["i64".into()],
472 cost: Some(0),
473 }],
474 },
475 Command::Constructor {
477 span: span!(),
478 name: sym.dom,
479 schema: Schema {
480 input: vec![sym.mor],
481 output: sym.ob,
482 },
483 cost: Some(1),
484 unextractable: false,
485 },
486 Command::Constructor {
487 span: span!(),
488 name: sym.cod,
489 schema: Schema {
490 input: vec![sym.mor],
491 output: sym.ob,
492 },
493 cost: Some(1),
494 unextractable: false,
495 },
496 Command::Constructor {
497 span: span!(),
498 name: sym.id,
499 schema: Schema {
500 input: vec![sym.ob],
501 output: sym.mor,
502 },
503 cost: Some(1),
504 unextractable: false,
505 },
506 Command::Constructor {
507 span: span!(),
508 name: sym.compose,
509 schema: Schema {
510 input: vec![sym.mor, sym.mor],
511 output: sym.mor,
512 },
513 cost: Some(1),
514 unextractable: false,
515 },
516 Command::AddRuleset(sym.axioms),
518 Command::Relation {
520 span: span!(),
521 name: sym.mor_is_valid,
522 inputs: vec![sym.mor],
523 },
524 Command::from(CommandRule {
526 ruleset: sym.axioms,
527 head: vec![Action::Expr(span!(), self.mor_is_valid(var!("f")))],
528 body: vec![Fact::Eq(span!(), var!("f"), call!(sym.mor_gen, vec![var!("name")]))],
529 }),
530 Command::from(CommandRule {
532 ruleset: sym.axioms,
533 head: vec![Action::Expr(span!(), self.mor_is_valid(var!("f")))],
534 body: vec![Fact::Eq(span!(), var!("f"), self.id(var!("x")))],
535 }),
536 Command::from(CommandRule {
539 ruleset: sym.axioms,
540 head: vec![Action::Expr(span!(), self.mor_is_valid(var!("fg")))],
541 body: vec![
542 Fact::Eq(span!(), var!("fg"), self.compose2(var!("f"), var!("g"))),
543 Fact::Fact(self.mor_is_valid(var!("f"))),
544 Fact::Fact(self.mor_is_valid(var!("g"))),
545 Fact::Eq(span!(), self.cod(var!("f")), self.dom(var!("g"))),
546 ],
547 }),
548 Command::from(CommandRule {
550 ruleset: sym.axioms,
551 head: vec![
552 Action::Union(span!(), self.dom(var!("fg")), self.dom(var!("f"))),
553 Action::Union(span!(), self.cod(var!("fg")), self.cod(var!("g"))),
554 ],
555 body: vec![
556 Fact::Eq(span!(), var!("fg"), self.compose2(var!("f"), var!("g"))),
557 Fact::Fact(self.mor_is_valid(var!("fg"))),
558 ],
559 }),
560 Command::from(CommandRewrite {
561 ruleset: sym.axioms,
562 lhs: self.dom(self.id(var!("x"))),
563 rhs: var!("x"),
564 }),
565 Command::from(CommandRewrite {
566 ruleset: sym.axioms,
567 lhs: self.cod(self.id(var!("x"))),
568 rhs: var!("x"),
569 }),
570 Command::from(CommandRule {
573 ruleset: sym.axioms,
574 head: vec![Action::Union(
575 span!(),
576 var!("fgh"),
577 self.compose2(var!("f"), self.compose2(var!("g"), var!("h"))),
578 )],
579 body: vec![
580 Fact::Eq(
581 span!(),
582 var!("fgh"),
583 self.compose2(self.compose2(var!("f"), var!("g")), var!("h")),
584 ),
585 Fact::Fact(self.mor_is_valid(var!("fgh"))),
586 ],
587 }),
588 Command::from(CommandRule {
589 ruleset: sym.axioms,
590 head: vec![Action::Union(
591 span!(),
592 var!("fgh"),
593 self.compose2(self.compose2(var!("f"), var!("g")), var!("h")),
594 )],
595 body: vec![
596 Fact::Eq(
597 span!(),
598 var!("fgh"),
599 self.compose2(var!("f"), self.compose2(var!("g"), var!("h"))),
600 ),
601 Fact::Fact(self.mor_is_valid(var!("fgh"))),
602 ],
603 }),
604 Command::from(CommandRewrite {
605 ruleset: sym.axioms,
606 lhs: self.compose2(var!("f"), self.id(self.cod(var!("f")))),
607 rhs: var!("f"),
608 }),
609 Command::from(CommandRewrite {
610 ruleset: sym.axioms,
611 lhs: self.compose2(self.id(self.dom(var!("f"))), var!("f")),
612 rhs: var!("f"),
613 }),
614 ]
615 }
616}
617
618impl<V, E, S: Default> Default for CategoryProgramBuilder<V, E, S> {
619 fn default() -> Self {
620 let mut result = Self {
621 prog: Default::default(),
622 sym: Default::default(),
623 ob_generators: Default::default(),
624 mor_generators: Default::default(),
625 };
626 result.preamble();
627 result
628 }
629}
630
631#[derive(Clone)]
632struct CategorySymbols {
633 ob: Symbol,
634 mor: Symbol,
635 mor_is_valid: Symbol,
636 ob_gen: Symbol,
637 mor_gen: Symbol,
638 dom: Symbol,
639 cod: Symbol,
640 id: Symbol,
641 compose: Symbol,
642 axioms: Symbol,
643}
644
645impl Default for CategorySymbols {
646 fn default() -> Self {
647 Self {
648 ob: "Ob".into(),
649 mor: "Mor".into(),
650 mor_is_valid: "is_mor_valid".into(),
651 ob_gen: "ObGen".into(),
652 mor_gen: "MorGen".into(),
653 dom: "dom".into(),
654 cod: "cod".into(),
655 id: "id".into(),
656 compose: "compose".into(),
657 axioms: "CatAxioms".into(),
658 }
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665 use expect_test::expect;
666 use nonempty::nonempty;
667 use ustr::ustr;
668
669 #[test]
670 fn sch_sgraph() {
671 let mut sch_sgraph = UstrFpCategory::new();
672 let (v, e) = (ustr("V"), ustr("E"));
673 let (s, t, i) = (ustr("src"), ustr("tgt"), ustr("inv"));
674 sch_sgraph.add_ob_generators([v, e]);
675 sch_sgraph.add_mor_generator(s, e, v);
676 sch_sgraph.add_mor_generator(t, e, v);
677 sch_sgraph.add_mor_generator(i, e, e);
678 assert!(sch_sgraph.is_free());
679 sch_sgraph.equate(Path::pair(i, i), Path::empty(e));
680 sch_sgraph.equate(Path::pair(i, s), Path::single(t));
681 sch_sgraph.equate(Path::pair(i, t), Path::single(s));
682 assert!(!sch_sgraph.is_free());
683 assert!(sch_sgraph.validate().is_ok());
684
685 assert!(!sch_sgraph.morphisms_are_equal(Path::single(s), Path::single(t)));
686 assert!(sch_sgraph.morphisms_are_equal(Path::pair(i, i), Path::empty(e)));
687 assert!(sch_sgraph.morphisms_are_equal(Path::Seq(nonempty![i, i, i, s]), Path::single(t)));
688 }
689
690 #[test]
691 fn egraph_preamble() {
692 let mut builder: CategoryProgramBuilder<char, char, RandomState> = Default::default();
693 let prog = builder.program();
694
695 let expected = expect![[r#"
696 (datatype Ob (ObGen i64 :cost 0))
697 (datatype Mor (MorGen i64 :cost 0))
698 (constructor dom (Mor) Ob :cost 1)
699 (constructor cod (Mor) Ob :cost 1)
700 (constructor id (Ob) Mor :cost 1)
701 (constructor compose (Mor Mor) Mor :cost 1)
702 (ruleset CatAxioms)
703 (relation is_mor_valid (Mor))
704 (rule ((= f (MorGen name)))
705 ((is_mor_valid f))
706 :ruleset CatAxioms )
707 (rule ((= f (id x)))
708 ((is_mor_valid f))
709 :ruleset CatAxioms )
710 (rule ((= fg (compose f g))
711 (is_mor_valid f)
712 (is_mor_valid g)
713 (= (cod f) (dom g)))
714 ((is_mor_valid fg))
715 :ruleset CatAxioms )
716 (rule ((= fg (compose f g))
717 (is_mor_valid fg))
718 ((union (dom fg) (dom f))
719 (union (cod fg) (cod g)))
720 :ruleset CatAxioms )
721 (rewrite (dom (id x)) x :ruleset CatAxioms)
722 (rewrite (cod (id x)) x :ruleset CatAxioms)
723 (rule ((= fgh (compose (compose f g) h))
724 (is_mor_valid fgh))
725 ((union fgh (compose f (compose g h))))
726 :ruleset CatAxioms )
727 (rule ((= fgh (compose f (compose g h)))
728 (is_mor_valid fgh))
729 ((union fgh (compose (compose f g) h)))
730 :ruleset CatAxioms )
731 (rewrite (compose f (id (cod f))) f :ruleset CatAxioms)
732 (rewrite (compose (id (dom f)) f) f :ruleset CatAxioms)
733 "#]];
734 expected.assert_eq(&prog.to_string());
735
736 let mut egraph: EGraph = Default::default();
737 assert!(prog.run_in(&mut egraph).is_ok());
738 }
739}