1use derive_more::From;
10use ego_tree::Tree;
11use std::collections::VecDeque;
12
13use super::graph::VDblGraph;
14use super::tree_algorithms::*;
15use crate::one::path::Path;
16
17#[derive(Clone, Debug, PartialEq, Eq)]
23pub enum DblNode<E, ProE, Sq> {
24 Cell(Sq),
26
27 Id(ProE),
34
35 Spine(E),
43}
44
45impl<E, ProE, Sq> DblNode<E, ProE, Sq> {
46 pub fn is_cell(&self) -> bool {
48 matches!(*self, DblNode::Cell(_))
49 }
50
51 pub fn is_id(&self) -> bool {
53 matches!(*self, DblNode::Id(_))
54 }
55
56 pub fn is_spine(&self) -> bool {
58 matches!(*self, DblNode::Spine(_))
59 }
60
61 pub fn dom<V>(
63 &self,
64 graph: &impl VDblGraph<V = V, E = E, ProE = ProE, Sq = Sq>,
65 ) -> Path<V, ProE>
66 where
67 ProE: Clone,
68 {
69 match self {
70 DblNode::Cell(sq) => graph.square_dom(sq),
71 DblNode::Id(p) => p.clone().into(),
72 DblNode::Spine(e) => Path::empty(graph.dom(e)),
73 }
74 }
75
76 pub fn cod<V>(
78 &self,
79 graph: &impl VDblGraph<V = V, E = E, ProE = ProE, Sq = Sq>,
80 ) -> Path<V, ProE>
81 where
82 ProE: Clone,
83 {
84 match self {
85 DblNode::Cell(sq) => graph.square_cod(sq).into(),
86 DblNode::Id(p) => p.clone().into(),
87 DblNode::Spine(e) => Path::empty(graph.cod(e)),
88 }
89 }
90
91 pub fn src<V>(&self, graph: &impl VDblGraph<V = V, E = E, ProE = ProE, Sq = Sq>) -> Path<V, E>
93 where
94 E: Clone,
95 {
96 match self {
97 DblNode::Cell(sq) => graph.square_src(sq).into(),
98 DblNode::Id(p) => Path::empty(graph.src(p)),
99 DblNode::Spine(e) => e.clone().into(),
100 }
101 }
102
103 pub fn tgt<V>(&self, graph: &impl VDblGraph<V = V, E = E, ProE = ProE, Sq = Sq>) -> Path<V, E>
105 where
106 E: Clone,
107 {
108 match self {
109 DblNode::Cell(sq) => graph.square_tgt(sq).into(),
110 DblNode::Id(p) => Path::empty(graph.tgt(p)),
111 DblNode::Spine(e) => e.clone().into(),
112 }
113 }
114
115 pub fn arity(&self, graph: &impl VDblGraph<E = E, ProE = ProE, Sq = Sq>) -> usize {
117 match self {
118 DblNode::Cell(sq) => graph.arity(sq),
119 DblNode::Id(_) => 1,
120 DblNode::Spine(_) => 0,
121 }
122 }
123
124 pub fn contained_in(&self, graph: &impl VDblGraph<E = E, ProE = ProE, Sq = Sq>) -> bool {
126 match self {
127 DblNode::Cell(sq) => graph.has_square(sq),
128 DblNode::Id(p) => graph.has_proedge(p),
129 DblNode::Spine(e) => graph.has_edge(e),
130 }
131 }
132}
133
134#[derive(Clone, Debug, From, PartialEq, Eq)]
147pub struct DblTree<E, ProE, Sq>(pub Tree<DblNode<E, ProE, Sq>>);
148
149impl<E, ProE, Sq> DblTree<E, ProE, Sq> {
150 pub fn empty(p: ProE) -> Self {
152 Tree::new(DblNode::Id(p)).into()
153 }
154
155 pub fn single(sq: Sq) -> Self {
157 Tree::new(DblNode::Cell(sq)).into()
158 }
159
160 pub fn linear(iter: impl IntoIterator<Item = Sq>) -> Option<Self> {
162 DblTree::from_nodes(iter.into_iter().map(DblNode::Cell))
163 }
164
165 pub fn spine(e: E) -> Self {
167 Tree::new(DblNode::Spine(e)).into()
168 }
169
170 pub fn spines<V>(path: Path<V, E>) -> Option<Self> {
172 DblTree::from_nodes(path.into_iter().map(DblNode::Spine))
173 }
174
175 pub fn from_nodes(iter: impl IntoIterator<Item = DblNode<E, ProE, Sq>>) -> Option<Self> {
177 let mut values: Vec<_> = iter.into_iter().collect();
178 let value = values.pop()?;
179 let mut tree = Tree::new(value);
180 let mut node_id = tree.root().id();
181 for value in values.into_iter().rev() {
182 node_id = tree.get_mut(node_id).unwrap().append(value).id();
183 }
184 Some(tree.into())
185 }
186
187 pub fn two_level(leaves: impl IntoIterator<Item = Sq>, base: Sq) -> Self {
189 Self::graft(leaves.into_iter().map(DblTree::single), base)
190 }
191
192 pub fn graft(subtrees: impl IntoIterator<Item = Self>, base: Sq) -> Self {
194 let mut tree = Tree::new(DblNode::Cell(base));
195 for subtree in subtrees {
196 tree.root_mut().append_subtree(subtree.0);
197 }
198 tree.into()
199 }
200
201 pub fn size(&self) -> usize {
206 self.0.values().filter(|dn| dn.is_cell()).count()
207 }
208
209 pub fn is_empty(&self) -> bool {
214 let root = self.0.root();
215 let root_is_id = root.value().is_id();
216 assert!(!(root_is_id && root.has_children()), "Identity node should not have children");
217 root_is_id
218 }
219
220 pub fn root(&self) -> &DblNode<E, ProE, Sq> {
222 self.0.root().value()
223 }
224
225 pub fn leaves(&self) -> impl Iterator<Item = &DblNode<E, ProE, Sq>> {
227 self.0.root().descendants().filter_map(|node| {
228 if node.has_children() {
229 None
230 } else {
231 Some(node.value())
232 }
233 })
234 }
235
236 pub fn src_nodes(&self) -> impl Iterator<Item = &DblNode<E, ProE, Sq>> {
242 let mut maybe_node = Some(self.0.root());
243 std::iter::from_fn(move || {
244 let prev = maybe_node;
245 maybe_node = maybe_node.and_then(|node| node.first_child());
246 prev.map(|node| node.value())
247 })
248 }
249
250 pub fn tgt_nodes(&self) -> impl Iterator<Item = &DblNode<E, ProE, Sq>> {
256 let mut maybe_node = Some(self.0.root());
257 std::iter::from_fn(move || {
258 let prev = maybe_node;
259 maybe_node = maybe_node.and_then(|node| node.last_child());
260 prev.map(|node| node.value())
261 })
262 }
263
264 pub fn dom<V>(
266 &self,
267 graph: &impl VDblGraph<V = V, E = E, ProE = ProE, Sq = Sq>,
268 ) -> Path<V, ProE>
269 where
270 ProE: Clone,
271 {
272 Path::collect(self.leaves().map(|dn| dn.dom(graph))).unwrap().flatten()
273 }
274
275 pub fn cod(&self, graph: &impl VDblGraph<E = E, ProE = ProE, Sq = Sq>) -> ProE
277 where
278 ProE: Clone,
279 {
280 self.root()
281 .cod(graph)
282 .only()
283 .expect("The root of a double tree should not be a spine")
284 }
285
286 pub fn src<V>(&self, graph: &impl VDblGraph<V = V, E = E, ProE = ProE, Sq = Sq>) -> Path<V, E>
288 where
289 E: Clone,
290 {
291 let mut edges: Vec<_> = self.src_nodes().map(|dn| dn.src(graph)).collect();
292 edges.reverse();
293 Path::from_vec(edges).unwrap().flatten()
294 }
295
296 pub fn tgt<V>(&self, graph: &impl VDblGraph<V = V, E = E, ProE = ProE, Sq = Sq>) -> Path<V, E>
298 where
299 E: Clone,
300 {
301 let mut edges: Vec<_> = self.tgt_nodes().map(|dn| dn.tgt(graph)).collect();
302 edges.reverse();
303 Path::from_vec(edges).unwrap().flatten()
304 }
305
306 pub fn arity(&self, graph: &impl VDblGraph<E = E, ProE = ProE, Sq = Sq>) -> usize {
308 self.leaves().map(|dn| dn.arity(graph)).sum()
309 }
310
311 pub fn contained_in(&self, graph: &impl VDblGraph<E = E, ProE = ProE, Sq = Sq>) -> bool
317 where
318 E: Eq + Clone,
319 ProE: Eq + Clone,
320 {
321 let mut traverse = self.0.bfs();
322 while let Some(node) = traverse.next() {
323 let dn = node.value();
324 if !dn.contained_in(graph) {
326 return false;
327 }
328 if !traverse
330 .peek_at_same_level()
331 .is_none_or(|next| dn.tgt(graph) == next.value().src(graph))
332 {
333 return false;
334 }
335 if node.has_children() {
337 let codomains = node.children().map(|child| child.value().cod(graph));
338 if Path::collect(codomains).unwrap().flatten() != dn.dom(graph) {
339 return false;
340 }
341 }
342 }
343 true
344 }
345
346 pub fn is_isomorphic_to(&self, other: &Self) -> bool
348 where
349 E: Eq,
350 ProE: Eq,
351 Sq: Eq,
352 {
353 self.0.is_isomorphic_to(&other.0)
354 }
355
356 pub fn map<CodE, CodSq>(
358 self,
359 mut fn_e: impl FnMut(E) -> CodE,
360 mut fn_sq: impl FnMut(Sq) -> CodSq,
361 ) -> DblTree<CodE, ProE, CodSq> {
362 self.0
363 .map(|dn| match dn {
364 DblNode::Cell(sq) => DblNode::Cell(fn_sq(sq)),
365 DblNode::Spine(e) => DblNode::Spine(fn_e(e)),
366 DblNode::Id(m) => DblNode::Id(m),
367 })
368 .into()
369 }
370}
371
372impl<V, E, ProE, Sq> DblNode<Path<V, E>, ProE, DblTree<E, ProE, Sq>> {
373 fn flatten(self) -> DblTree<E, ProE, Sq> {
375 match self {
376 DblNode::Cell(tree) => tree,
377 DblNode::Id(m) => DblTree::empty(m),
378 DblNode::Spine(path) => {
379 DblTree::spines(path).expect("Spine should be a non-empty path")
380 }
381 }
382 }
383}
384
385impl<V, E, ProE, Sq> DblTree<Path<V, E>, ProE, DblTree<E, ProE, Sq>>
386where
387 V: Clone,
388 E: Clone,
389 ProE: Clone + Eq + std::fmt::Debug,
390 Sq: Clone,
391{
392 pub fn flatten_in(
394 &self,
395 graph: &impl VDblGraph<V = V, E = E, ProE = ProE, Sq = Sq>,
396 ) -> DblTree<E, ProE, Sq> {
397 let outer_root = self.0.root();
399 let mut tree = outer_root.value().clone().flatten().0;
400
401 let mut outer_nodes = self.0.bfs();
403 outer_nodes.next();
404
405 let mut queue = VecDeque::new();
408 if outer_root.has_children() {
409 queue.push_back(tree.root().id());
410 }
411
412 while let Some(node_id) = queue.pop_front() {
413 let leaf_ids: Vec<_> = tree
414 .get(node_id)
415 .unwrap()
416 .descendants()
417 .filter_map(|node| {
418 if node.has_children() {
419 None
420 } else {
421 Some(node.id())
422 }
423 })
424 .collect();
425 for leaf_id in leaf_ids {
426 let mut leaf = tree.get_mut(leaf_id).unwrap();
427 for m in leaf.value().dom(graph) {
428 let outer_node =
429 outer_nodes.next().expect("Outer tree should have enough nodes");
430
431 let inner_tree = outer_node.value().clone().flatten();
432 assert_eq!(m, inner_tree.cod(graph), "(Co)domains should be compatible");
433
434 let subtree_id = leaf.append_subtree(inner_tree.0).id();
435 if outer_node.has_children() {
436 queue.push_back(subtree_id);
437 }
438 }
439 }
440 }
441
442 assert!(outer_nodes.next().is_none(), "Outer tree should not have extra nodes");
443 tree.into()
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use ego_tree::tree;
450 use nonempty::nonempty;
451
452 use super::super::category::{WalkingBimodule as Bimod, WalkingFunctor as Funct, *};
453 use super::*;
454
455 #[test]
456 fn tree_dom_cod() {
457 let bimod = Bimod::Main();
458 let graph = UnderlyingDblGraph(Bimod::Main());
459 let path = Path::Seq(nonempty![Bimod::Pro::Left, Bimod::Pro::Middle, Bimod::Pro::Right]);
460 let mid = bimod.composite_ext(path).unwrap();
461 let tree = DblTree::two_level(
462 vec![bimod.id_cell(Bimod::Pro::Left), mid.clone(), bimod.id_cell(Bimod::Pro::Right)],
463 mid.clone(),
464 );
465 let tree_alt = tree!(
466 mid.clone() => {
467 bimod.id_cell(Bimod::Pro::Left), mid.clone(), bimod.id_cell(Bimod::Pro::Right)
468 }
469 );
470 let tree_alt = DblTree(tree_alt.map(DblNode::Cell));
471 assert_eq!(tree, tree_alt);
472 assert!(tree.contained_in(&graph));
473
474 assert_eq!(tree.leaves().count(), 3);
475 assert_eq!(tree.arity(&graph), 5);
476 assert_eq!(
477 tree.dom(&graph),
478 Path::Seq(nonempty![
479 Bimod::Pro::Left,
480 Bimod::Pro::Left,
481 Bimod::Pro::Middle,
482 Bimod::Pro::Right,
483 Bimod::Pro::Right
484 ])
485 );
486 assert_eq!(tree.cod(&graph), Bimod::Pro::Middle);
487
488 let tree = tree!(
490 mid.clone() => {
491 bimod.id_cell(Bimod::Pro::Left), mid.clone()
492 }
493 );
494 assert!(!DblTree(tree.map(DblNode::Cell)).contained_in(&graph));
495 let tree = tree!(
496 mid.clone() => {
497 bimod.id_cell(Bimod::Pro::Right), mid.clone(), bimod.id_cell(Bimod::Pro::Left)
498 }
499 );
500 assert!(!DblTree(tree.map(DblNode::Cell)).contained_in(&graph));
501 }
502
503 #[test]
504 fn tree_src_tgt() {
505 let funct = Funct::Main();
506 let graph = UnderlyingDblGraph(Funct::Main());
507 let f = Funct::Arr::Arrow;
508 let unit1 = funct.unit_ext(Funct::Ob::One).unwrap();
509 let tree = DblTree::from_nodes(vec![DblNode::Spine(f), DblNode::Cell(unit1)]).unwrap();
510 let tree_alt = DblTree(tree!(
511 DblNode::Cell(unit1) => { DblNode::Spine(f) }
512 ));
513 assert_eq!(tree, tree_alt);
514 assert!(tree.contained_in(&graph));
515
516 assert_eq!(tree.src_nodes().count(), 2);
517 assert_eq!(tree.tgt_nodes().count(), 2);
518 assert_eq!(tree.src(&graph), Path::pair(f, Funct::Arr::One));
519 assert_eq!(tree.tgt(&graph), Path::pair(f, Funct::Arr::One));
520 assert!(tree.dom(&graph).is_empty());
521
522 let tree = DblTree(tree!(
524 DblNode::Cell(funct.composite2_ext(Funct::Ob::One, Funct::Ob::One).unwrap()) => {
525 DblNode::Cell(unit1) => { DblNode::Spine(Funct::Arr::One) },
526 DblNode::Cell(unit1) => { DblNode::Spine(f) },
527 }
528 ));
529 assert!(!tree.contained_in(&graph));
530 }
531
532 #[test]
533 fn flatten_tree() {
534 let bimod = Bimod::Main();
535 let graph = UnderlyingDblGraph(Bimod::Main());
536 let path = Path::Seq(nonempty![Bimod::Pro::Left, Bimod::Pro::Middle, Bimod::Pro::Right]);
537 let unitl = bimod.unit_ext(Bimod::Ob::Left).unwrap();
538 let unitr = bimod.unit_ext(Bimod::Ob::Right).unwrap();
539 let mid = bimod.composite_ext(path).unwrap();
540 let tree = tree!(
541 mid.clone() => {
542 bimod.id_cell(Bimod::Pro::Left) => {
543 unitl.clone(),
544 },
545 mid => {
546 unitl, bimod.id_cell(Bimod::Pro::Middle), unitr.clone(),
547 },
548 bimod.id_cell(Bimod::Pro::Right) => {
549 unitr,
550 }
551 }
552 );
553 let tree = DblTree(tree.map(DblNode::Cell));
554 assert_eq!(tree.dom(&graph), Path::single(Bimod::Pro::Middle));
555 assert_eq!(tree.cod(&graph), Bimod::Pro::Middle);
556
557 let outer = DblTree::single(tree.clone());
559 assert_eq!(outer.flatten_in(&graph), tree);
560
561 let outer = tree.clone().map(Path::single, DblTree::single);
563 let result = outer.flatten_in(&graph);
564 assert!(result.is_isomorphic_to(&tree));
565 }
566}