1use crate::{
3 dbl::model::*,
4 one::{Path, graph::FinGraph, graph_algorithms::toposort},
5 zero::{QualifiedLabel, QualifiedName, label, name},
6};
7use indexmap::IndexMap;
8use itertools::Itertools;
9use sea_query::SchemaBuilder;
10use sea_query::{
11 ColumnDef, ForeignKey, ForeignKeyCreateStatement, Iden, MysqlQueryBuilder,
12 PostgresQueryBuilder, SqliteQueryBuilder, Table, TableCreateStatement, prepare::Write,
13};
14use sqlformat::{Dialect, format};
15use std::fmt;
16
17impl Iden for QualifiedName {
18 fn unquoted(&self, s: &mut dyn Write) {
19 Iden::unquoted(&format!("{self}").as_str(), s)
20 }
21}
22
23impl Iden for QualifiedLabel {
24 fn unquoted(&self, s: &mut dyn Write) {
25 Iden::unquoted(&format!("{self}").as_str(), s)
26 }
27}
28
29impl Iden for &QualifiedLabel {
30 fn unquoted(&self, s: &mut dyn Write) {
31 Iden::unquoted(&format!("{self}").as_str(), s)
32 }
33}
34
35pub struct SQLAnalysis {
37 backend: SQLBackend,
38}
39
40impl SQLAnalysis {
41 pub fn new(backend: SQLBackend) -> Self {
43 Self { backend }
44 }
45
46 pub fn render(
48 &self,
49 model: &DiscreteDblModel,
50 ob_label: impl Fn(&QualifiedName) -> QualifiedLabel,
51 mor_label: impl Fn(&QualifiedName) -> QualifiedLabel,
52 ) -> Result<String, String> {
53 let g = model.generating_graph();
54 let t = toposort(g).map_err(|e| format!("Topological sort failed: {}", e))?;
55 let morphisms: IndexMap<&QualifiedName, Vec<QualifiedName>> =
56 IndexMap::from_iter(t.iter().rev().filter_map(|v| {
57 (name("Entity") == model.ob_generator_type(v))
58 .then_some((v, g.out_edges(v).collect::<Vec<QualifiedName>>()))
59 }));
60
61 let tables = self.make_tables(model, morphisms, ob_label, mor_label);
62
63 let output: String = tables
64 .iter()
65 .map(|table| match self.backend {
66 SQLBackend::MySQL => table.to_string(MysqlQueryBuilder),
67 SQLBackend::SQLite => table.to_string(SqliteQueryBuilder),
68 SQLBackend::PostgresSQL => table.to_string(PostgresQueryBuilder),
69 })
70 .join(";\n")
71 + ";";
72
73 let formatted_output = format(
75 &output,
76 &sqlformat::QueryParams::None,
77 &sqlformat::FormatOptions {
78 lines_between_queries: 2,
79 dialect: self.backend.clone().into(),
80 ..Default::default()
81 },
82 );
83
84 let result = match self.backend {
85 SQLBackend::SQLite => ["PRAGMA foreign_keys = ON", &formatted_output].join(";\n\n"),
86 _ => formatted_output,
87 };
88 Ok(result)
89 }
90
91 fn fk(
92 &self,
93 src_name: QualifiedLabel,
94 tgt_name: QualifiedLabel,
95 mor_name: QualifiedLabel,
96 ) -> ForeignKeyCreateStatement {
97 ForeignKey::create()
98 .name(format!("FK_{}_{}_{}", mor_name, src_name, tgt_name))
99 .from(src_name.clone(), mor_name)
100 .to(tgt_name.clone(), "id")
101 .to_owned()
102 }
103
104 fn make_tables(
105 &self,
106 model: &DiscreteDblModel,
107 morphisms: IndexMap<&QualifiedName, Vec<QualifiedName>>,
108 ob_label: impl Fn(&QualifiedName) -> QualifiedLabel,
109 mor_label: impl Fn(&QualifiedName) -> QualifiedLabel,
110 ) -> Vec<TableCreateStatement> {
111 morphisms
112 .into_iter()
113 .map(|(ob, mors)| {
114 let mut tbl = Table::create();
115
116 let table_column_defs = mors.iter().fold(
118 tbl.table(ob_label(ob)).if_not_exists().col(
119 ColumnDef::new("id").integer().not_null().auto_increment().primary_key(),
120 ),
121 |acc, mor| {
122 let mor_name = mor_label(mor);
123 if model.mor_generator_type(mor) == Path::Id(name("Entity")) {
126 acc.col(ColumnDef::new(mor_name.clone()).integer().not_null())
127 } else {
128 let tgt =
129 model.get_cod(mor).map(&ob_label).unwrap_or_else(|| label(""));
130 let mut col = ColumnDef::new(mor_name);
131 col.not_null();
132 add_column_type(&mut col, &tgt);
133 acc.col(col)
134 }
135 },
136 );
137
138 mors.iter()
139 .filter(|mor| model.mor_generator_type(mor) == Path::Id(name("Entity")))
140 .fold(
141 table_column_defs,
143 |acc, mor| {
144 let tgt =
145 model.get_cod(mor).map(&ob_label).unwrap_or_else(|| label(""));
146 acc.foreign_key(&mut self.fk(ob_label(ob), tgt, mor_label(mor)))
147 },
148 )
149 .to_owned()
150 })
151 .collect()
152 }
153}
154
155#[derive(Debug, Clone)]
159pub enum SQLBackend {
160 MySQL,
162
163 SQLite,
165
166 PostgresSQL,
168}
169
170impl SQLBackend {
171 pub fn as_type(&self) -> Box<dyn SchemaBuilder> {
173 match self {
174 SQLBackend::MySQL => Box::new(MysqlQueryBuilder),
175 SQLBackend::SQLite => Box::new(SqliteQueryBuilder),
176 SQLBackend::PostgresSQL => Box::new(PostgresQueryBuilder),
177 }
178 }
179}
180
181impl From<SQLBackend> for Dialect {
182 fn from(backend: SQLBackend) -> sqlformat::Dialect {
183 match backend {
184 SQLBackend::PostgresSQL => Dialect::PostgreSql,
185 _ => Dialect::Generic,
186 }
187 }
188}
189
190impl TryFrom<&str> for SQLBackend {
191 type Error = String;
192 fn try_from(backend: &str) -> Result<Self, Self::Error> {
193 match backend {
194 "MySQL" => Ok(SQLBackend::MySQL),
195 "SQLite" => Ok(SQLBackend::SQLite),
196 "PostgresSQL" => Ok(SQLBackend::PostgresSQL),
197 _ => Err(String::from("Invalid backend")),
198 }
199 }
200}
201
202impl fmt::Display for SQLBackend {
203 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
204 let string = match self {
205 SQLBackend::MySQL => "MySQL",
206 SQLBackend::SQLite => "SQLite",
207 SQLBackend::PostgresSQL => "PostgresSQL",
208 };
209 write!(f, "{}", string)
210 }
211}
212
213fn add_column_type(col: &mut ColumnDef, name: &QualifiedLabel) {
214 match format!("{}", name).as_str() {
215 "Int" => col.integer(),
216 "TinyInt" => col.tiny_integer(),
217 "Bool" => col.boolean(),
218 "Float" => col.float(),
219 "Time" => col.timestamp(),
220 "Date" => col.date(),
221 "DateTime" => col.date_time(),
222 _ => col.custom(name.clone()),
223 };
224}
225
226#[cfg(test)]
227mod tests {
228 use expect_test::expect;
229 use std::rc::Rc;
230
231 use super::*;
232 use crate::{stdlib::th_schema, tt};
233
234 #[test]
235 fn sql_schema() {
236 let th = Rc::new(th_schema());
237 let model = tt::modelgen::Model::from_text(
238 &th.into(),
239 "[
240 Person : Entity,
241 Dog : Entity,
242 walks : (Hom Entity)[Person, Dog],
243 Hair : AttrType,
244 has : Attr[Person, Hair],
245 ]",
246 );
247 let model = model.and_then(|m| m.as_discrete()).unwrap();
248
249 let expected = expect![[
250 r#"CREATE TABLE IF NOT EXISTS `Dog` (`id` int NOT NULL AUTO_INCREMENT PRIMARY KEY);
251
252CREATE TABLE IF NOT EXISTS `Person` (
253 `id` int NOT NULL AUTO_INCREMENT PRIMARY KEY,
254 `walks` int NOT NULL,
255 `has` Hair NOT NULL,
256 CONSTRAINT `FK_walks_Person_Dog` FOREIGN KEY (`walks`) REFERENCES `Dog` (`id`)
257);"#
258 ]];
259 let ddl = SQLAnalysis::new(SQLBackend::MySQL)
260 .render(
261 &model,
262 |id| format!("{id}").as_str().into(),
263 |id| format!("{id}").as_str().into(),
264 )
265 .expect("SQL should render");
266 expected.assert_eq(&ddl);
267 }
268}