1use derivative::Derivative;
4use std::{fmt, hash::Hash};
5
6use crate::tt::util::{Row, pretty::*};
7use crate::validate::{self, Validate};
8use crate::zero::{Column, HashColumn, LabelSegment, Mapping, MutMapping, NameSegment};
9
10pub type Ports<T> = Row<T>;
15
16#[derive(Clone, Derivative)]
18#[derivative(Default(bound = ""))]
19struct PortMap<T, J> {
20 ports: Ports<T>,
21 mapping: HashColumn<NameSegment, J>,
22}
23
24impl<T, J> PortMap<T, J> {
25 fn new(ports: Ports<T>) -> Self {
26 Self { ports, mapping: Default::default() }
27 }
28}
29
30#[derive(Clone, Derivative)]
32#[derivative(Default(bound = ""))]
33pub struct UWD<T, J> {
34 outer: PortMap<T, J>,
35 inner: Row<PortMap<T, J>>,
36 junctions: HashColumn<J, T>,
37}
38
39impl<T, J> UWD<T, J> {
40 pub fn empty() -> Self {
42 Self::default()
43 }
44
45 pub fn with_ports(outer_ports: Ports<T>) -> Self {
47 Self {
48 outer: PortMap::new(outer_ports),
49 inner: Default::default(),
50 junctions: Default::default(),
51 }
52 }
53
54 pub fn boxes(&self) -> impl Iterator<Item = (&NameSegment, &LabelSegment, &Ports<T>)> {
56 self.inner.iter().map(|(name, (label, pm))| (name, label, &pm.ports))
57 }
58
59 pub fn outer_ports(&self) -> &Ports<T> {
61 &self.outer.ports
62 }
63
64 pub fn has_box(&self, box_name: NameSegment) -> bool {
66 self.outer.ports.has(box_name)
67 }
68
69 pub fn has_port(&self, box_name: NameSegment, port_name: NameSegment) -> bool {
71 self.inner.get(box_name).is_some_and(|inner| inner.ports.has(port_name))
72 }
73
74 pub fn has_outer_port(&self, port_name: NameSegment) -> bool {
76 self.outer.ports.has(port_name)
77 }
78
79 pub fn add_box(&mut self, name: NameSegment, label: LabelSegment) {
81 self.inner.insert(name, label, PortMap::default())
82 }
83
84 pub fn add_box_with_ports(&mut self, name: NameSegment, label: LabelSegment, ports: Ports<T>) {
86 self.inner.insert(name, label, PortMap::new(ports));
87 }
88
89 pub fn add_port(
91 &mut self,
92 box_name: NameSegment,
93 port_name: NameSegment,
94 label: LabelSegment,
95 ty: T,
96 ) -> Option<()> {
97 let inner = self.inner.get_mut(box_name)?;
98 inner.ports.insert(port_name, label, ty);
99 Some(())
100 }
101
102 pub fn add_outer_port(&mut self, name: NameSegment, label: LabelSegment, ty: T) {
104 self.outer.ports.insert(name, label, ty);
105 }
106}
107
108impl<T: Clone + Eq, J: Clone + Eq + Hash> UWD<T, J> {
109 pub fn junctions(&self) -> impl Iterator<Item = J> {
111 self.junctions.iter().map(|(j, _)| j)
112 }
113
114 pub fn has_junction(&self, junction: &J) -> bool {
116 self.junctions.is_set(junction)
117 }
118
119 pub fn get(&self, box_name: NameSegment, port_name: NameSegment) -> Option<&J> {
121 self.inner.get(box_name).and_then(|inner| inner.mapping.get(&port_name))
122 }
123
124 pub fn get_outer(&self, port_name: NameSegment) -> Option<&J> {
126 self.outer.mapping.get(&port_name)
127 }
128
129 pub fn set(
131 &mut self,
132 box_name: NameSegment,
133 port_name: NameSegment,
134 junction: J,
135 ) -> Option<()> {
136 let inner = self.inner.get_mut(box_name)?;
137 let ty = inner.ports.get(port_name)?;
138 if !self.junctions.is_set(&junction) {
139 self.junctions.set(junction.clone(), ty.clone());
140 }
141 inner.mapping.set(port_name, junction);
142 Some(())
143 }
144
145 pub fn set_outer(&mut self, port_name: NameSegment, junction: J) -> Option<()> {
147 let ty = self.outer.ports.get(port_name)?;
148 if !self.junctions.is_set(&junction) {
149 self.junctions.set(junction.clone(), ty.clone());
150 }
151 self.outer.mapping.set(port_name, junction);
152 Some(())
153 }
154}
155
156pub enum InvalidUWD {
158 OuterPortType {
160 port_name: NameSegment,
162 },
163 InnerPortType {
165 box_name: NameSegment,
167 port_name: NameSegment,
169 },
170}
171
172impl<T: Clone + Eq, J: Clone + Eq + Hash> UWD<T, J> {
173 fn iter_invalid(&self) -> impl Iterator<Item = InvalidUWD> + use<'_, T, J> {
174 let junctions = &self.junctions;
175 let outer_errors = self.outer.ports.iter().filter_map(|(&port_name, (_, ty))| {
176 let valid = self
177 .outer
178 .mapping
179 .get(&port_name)
180 .is_some_and(|j| junctions.get(j).is_some_and(|jty| jty == ty));
181 (!valid).then_some(InvalidUWD::OuterPortType { port_name })
182 });
183 let inner_errors = self.inner.iter().flat_map(move |(&box_name, (_, port_map))| {
184 port_map.ports.iter().filter_map(move |(&port_name, (_, ty))| {
185 let valid = port_map
186 .mapping
187 .get(&port_name)
188 .is_some_and(|j| junctions.get(j).is_some_and(|jty| jty == ty));
189 (!valid).then_some(InvalidUWD::InnerPortType { box_name, port_name })
190 })
191 });
192 outer_errors.chain(inner_errors)
193 }
194}
195
196impl<T: Clone + Eq, J: Clone + Eq + Hash> Validate for UWD<T, J> {
197 type ValidationError = InvalidUWD;
198
199 fn validate(&self) -> Result<(), nonempty::NonEmpty<Self::ValidationError>> {
200 validate::wrap_errors(self.iter_invalid())
201 }
202}
203
204impl<T: fmt::Display, J: fmt::Display + Clone + Eq> ToDoc for PortMap<T, J> {
206 fn to_doc<'a>(&self) -> D<'a> {
207 let args = self.ports.iter().map(|(port_name, (label, ty))| {
208 let arg = binop(t(":"), t(label.to_string()), t(ty.to_string()));
209 let var = match self.mapping.get(port_name) {
210 Some(junction) => t(junction.to_string()),
211 None => t("_"),
212 };
213 binop(t(":="), arg, var)
214 });
215 tuple(args)
216 }
217}
218
219#[derive(Derivative)]
221#[derivative(Default(new = "true"))]
222pub struct UWDPrinter {
223 #[derivative(Default(value = "true"))]
224 include_summary: bool,
225}
226
227impl UWDPrinter {
228 pub fn include_summary(mut self, value: bool) -> Self {
230 self.include_summary = value;
231 self
232 }
233
234 pub fn summary<T: Clone + Eq, J: Clone + Eq + Hash>(&self, uwd: &UWD<T, J>) -> String {
236 let n_boxes = uwd.inner.len();
237 let n_junctions = uwd.junctions.len();
238 format!(
239 "UWD with {n_boxes} box{} and {n_junctions} junction{}",
240 if n_boxes != 1 { "es" } else { "" },
241 if n_junctions != 1 { "s" } else { "" },
242 )
243 }
244
245 pub fn doc<'a, T: fmt::Display + Clone + Eq, J: fmt::Display + Clone + Eq + Hash>(
249 &self,
250 uwd: &UWD<T, J>,
251 ) -> D<'a> {
252 let head = uwd.outer.to_doc();
253 let clauses = uwd
254 .inner
255 .iter()
256 .map(|(_, (label, port_map))| unop(t(label.to_string()), port_map.to_doc()));
257 let body = intersperse(clauses, t(",") + s());
258 let result = head + t(" :-") + (s() + body).indented();
259 if self.include_summary {
260 t(self.summary(uwd)) + hardline() + result
261 } else {
262 result
263 }
264 }
265}
266
267impl<T: fmt::Display + Clone + Eq, J: fmt::Display + Clone + Eq + Hash> ToDoc for UWD<T, J> {
268 fn to_doc<'a>(&self) -> D<'a> {
269 UWDPrinter::new().doc(self)
270 }
271}
272
273impl<T: fmt::Display + Clone + Eq, J: fmt::Display + Clone + Eq + Hash> fmt::Display for UWD<T, J> {
274 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275 write!(f, "{}", self.to_doc().pretty())
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use expect_test::expect;
283
284 fn binary_composite_uwd() -> UWD<&'static str, &'static str> {
285 let mut uwd: UWD<_, _> = UWD::with_ports(Ports::from_iter([("x", "X"), ("z", "Z")]));
286 uwd.add_box_with_ports("R".into(), "R".into(), Ports::from_iter([("a", "X"), ("b", "Y")]));
287 uwd.add_box_with_ports("S".into(), "S".into(), Ports::from_iter([("c", "Y"), ("d", "Z")]));
288 uwd.set("R".into(), "a".into(), "u");
289 uwd.set("R".into(), "b".into(), "v");
290 uwd.set("S".into(), "c".into(), "v");
291 uwd.set("S".into(), "d".into(), "w");
292 uwd.set_outer("x".into(), "u");
293 uwd.set_outer("z".into(), "w");
294 uwd
295 }
296
297 #[test]
298 fn pretty_print() {
299 let uwd = binary_composite_uwd();
300 let expected = expect![[r#"
301 UWD with 2 boxes and 3 junctions
302 [x : X := u, z : Z := w] :-
303 R [a : X := u, b : Y := v],
304 S [c : Y := v, d : Z := w]"#]];
305 expected.assert_eq(&uwd.to_string());
306 }
307
308 #[test]
309 fn validate() {
310 assert!(binary_composite_uwd().validate().is_ok());
311 }
312}