1use derive_more::From;
33use ego_tree::{NodeRef, Tree};
34use itertools::{Itertools, zip_eq};
35use std::collections::VecDeque;
36
37use super::tree_algorithms::TreeIsomorphism;
38
39#[derive(Clone, Debug, From, PartialEq, Eq)]
50pub enum OpenTree<Ty, Op> {
51 Id(Ty),
53
54 #[from]
56 Comp(Tree<Option<Op>>),
57}
58
59impl<Ty, Op> OpenTree<Ty, Op> {
60 pub fn empty(ty: Ty) -> Self {
62 OpenTree::Id(ty)
63 }
64
65 pub fn single(op: Op, arity: usize) -> Self {
67 let mut tree = Tree::new(Some(op));
68 for _ in 0..arity {
69 tree.root_mut().append(None);
70 }
71 tree.into()
72 }
73
74 pub fn graft(subtrees: impl IntoIterator<Item = Self>, op: Op) -> Self {
80 let mut tree = Tree::new(Some(op));
81 for subtree in subtrees {
82 match subtree {
83 OpenTree::Id(_) => tree.root_mut().append(None),
84 OpenTree::Comp(subtree) => tree.root_mut().append_subtree(subtree),
85 };
86 }
87 tree.into()
88 }
89
90 pub fn linear(iter: impl IntoIterator<Item = Op>) -> Option<Self> {
96 let mut values: Vec<_> = iter.into_iter().collect();
97 let value = values.pop()?;
98 let mut tree = Tree::new(Some(value));
99 let mut node_id = tree.root().id();
100 for value in values.into_iter().rev() {
101 node_id = tree.get_mut(node_id).unwrap().append(Some(value)).id();
102 }
103 tree.get_mut(node_id).unwrap().append(None);
104 Some(tree.into())
105 }
106
107 pub fn arity(&self) -> usize {
112 match self {
113 OpenTree::Comp(tree) => tree.root().boundary().count(),
114 OpenTree::Id(_) => 1,
115 }
116 }
117
118 pub fn size(&self) -> usize {
124 match self {
125 OpenTree::Comp(tree) => {
126 tree.root().descendants().filter(|node| node.value().is_some()).count()
127 }
128 OpenTree::Id(_) => 0,
129 }
130 }
131
132 pub fn is_empty(&self) -> bool {
134 matches!(self, OpenTree::Id(_))
135 }
136
137 pub fn is_isomorphic_to(&self, other: &Self) -> bool
145 where
146 Ty: Eq,
147 Op: Eq,
148 {
149 match (self, other) {
150 (OpenTree::Comp(tree1), OpenTree::Comp(tree2)) => tree1.is_isomorphic_to(tree2),
151 (OpenTree::Id(type1), OpenTree::Id(type2)) => *type1 == *type2,
152 _ => false,
153 }
154 }
155
156 pub fn map<CodOp>(self, mut f: impl FnMut(Op) -> CodOp) -> OpenTree<Ty, CodOp> {
158 match self {
159 OpenTree::Comp(tree) => tree.map(|value| value.map(&mut f)).into(),
160 OpenTree::Id(ty) => OpenTree::Id(ty),
161 }
162 }
163}
164
165pub trait OpenNodeRef<T> {
167 fn is_boundary(&self) -> bool;
169
170 fn boundary(&self) -> impl Iterator<Item = Self>;
172
173 fn get_value(&self) -> Option<&T>;
175
176 fn parent_value(&self) -> Option<&T>;
178}
179
180impl<'a, T: 'a> OpenNodeRef<T> for NodeRef<'a, Option<T>> {
181 fn is_boundary(&self) -> bool {
182 let is_null = self.value().is_none();
183 assert!(!(is_null && self.has_children()), "Boundary nodes should be leaves");
184 is_null
185 }
186
187 fn boundary(&self) -> impl Iterator<Item = Self> {
188 self.descendants().filter(|node| node.is_boundary())
189 }
190
191 fn get_value(&self) -> Option<&T> {
192 self.value().as_ref()
193 }
194
195 fn parent_value(&self) -> Option<&T> {
196 self.parent()
197 .map(|p| p.value().as_ref().expect("Inner nodes should not be null"))
198 }
199}
200
201impl<Ty, Op> OpenTree<Ty, OpenTree<Ty, Op>> {
202 pub fn flatten(self) -> OpenTree<Ty, Op> {
204 let mut outer_tree = match self {
206 OpenTree::Id(x) => return OpenTree::Id(x),
207 OpenTree::Comp(tree) => tree,
208 };
209
210 let value = std::mem::take(outer_tree.root_mut().value())
212 .expect("Root node of outer tree should contain a tree");
213 let (mut tree, root_type) = match value {
214 OpenTree::Id(x) => (Tree::new(None), Some(x)),
215 OpenTree::Comp(tree) => (tree, None),
216 };
217
218 let mut queue = VecDeque::new();
219 for (child, leaf) in zip_eq(outer_tree.root().children(), tree.root().boundary()) {
220 queue.push_back((child.id(), leaf.id()));
221 }
222
223 while let Some((outer_id, leaf_id)) = queue.pop_front() {
224 let Some(value) = std::mem::take(outer_tree.get_mut(outer_id).unwrap().value()) else {
225 continue;
226 };
227 match value {
228 OpenTree::Id(_) => {
229 let Ok(outer_parent) =
230 outer_tree.get(outer_id).unwrap().children().exactly_one()
231 else {
232 panic!("Identity tree should have exactly one parent")
233 };
234 queue.push_back((outer_parent.id(), leaf_id));
235 }
236 OpenTree::Comp(inner_tree) => {
237 let subtree_id = tree.extend_tree(inner_tree).id();
238 let value = std::mem::take(tree.get_mut(subtree_id).unwrap().value());
239
240 let mut inner_node = tree.get_mut(leaf_id).unwrap();
241 *inner_node.value() = value;
242 inner_node.reparent_from_id_append(subtree_id);
243
244 let outer_node = outer_tree.get(outer_id).unwrap();
245 let inner_node: NodeRef<_> = inner_node.into();
246 for (child, leaf) in zip_eq(outer_node.children(), inner_node.boundary()) {
247 queue.push_back((child.id(), leaf.id()));
248 }
249 }
250 }
251 }
252
253 if tree.root().value().is_none() {
254 OpenTree::Id(root_type.unwrap())
255 } else {
256 tree.into()
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use ego_tree::tree;
265
266 type OT = OpenTree<char, char>;
267
268 #[test]
269 fn construct_tree() {
270 assert_eq!(OT::empty('X').arity(), 1);
271
272 let tree = OT::single('f', 2);
273 assert_eq!(tree.arity(), 2);
274 assert_eq!(tree, tree!(Some('f') => { None, None }).into());
275
276 let tree = tree!(Some('h') => { Some('g') => { Some('f') => { None } } });
277 assert_eq!(OT::linear(vec!['f', 'g', 'h']), Some(tree.into()));
278 }
279
280 #[test]
281 fn flatten_tree() {
282 let tree = OT::from(tree!(
284 Some('f') => {
285 Some('h') => {
286 Some('k') => { None, None},
287 None,
288 },
289 Some('g') => {
290 None,
291 Some('l') => { None, None }
292 },
293 }
294 ));
295 assert!(!tree.is_empty());
296 assert_eq!(tree.size(), 5);
297 assert_eq!(tree.arity(), 6);
298
299 let subtree1 = OT::from(tree!(
300 Some('f') => {
301 None,
302 Some('g') => { None, None },
303 }
304 ));
305 let subtree2 = OT::from(tree!(
306 Some('h') => {
307 Some('k') => { None, None },
308 None
309 }
310 ));
311 let subtree3 = OT::from(tree!(
312 Some('l') => { None, None }
313 ));
314
315 let outer_tree: OpenTree<_, _> = tree!(
316 Some(subtree1.clone()) => {
317 Some(subtree2.clone()) => { None, None, None },
318 None,
319 Some(subtree3.clone()) => { None, None },
320 }
321 )
322 .into();
323 assert!(outer_tree.flatten().is_isomorphic_to(&tree));
324
325 let outer_tree: OpenTree<_, _> = tree!(
326 Some(subtree1) => {
327 Some(OpenTree::Id('X')) => {
328 Some(subtree2) => { None, None, None },
329 },
330 Some(OpenTree::Id('X')) => { None },
331 Some(OpenTree::Id('X')) => {
332 Some(subtree3) => { None, None },
333 },
334 }
335 )
336 .into();
337 assert!(outer_tree.flatten().is_isomorphic_to(&tree));
338
339 let outer_tree: OpenTree<_, _> = OpenTree::Id('X');
341 assert_eq!(outer_tree.flatten(), OT::Id('X'));
342
343 let outer_tree: OpenTree<_, _> = tree!(
345 Some(OT::Id('X')) => { Some(OT::Id('x')) => { None } }
346 )
347 .into();
348 assert_eq!(outer_tree.flatten(), OT::Id('X'));
349 }
350}