1 /**
2 Copyright: Copyright (c) 2017, Oleg Butko. All rights reserved.
3 Copyright: Copyright (c) 2018-2019, Joakim Brännström. All rights reserved.
4 License: MIT
5 Author: Joakim Brännström (joakim.brannstrom@gmx.com)
6 Author: Oleg Butko (deviator)
7 */
8 module miniorm.api;
9 
10 import core.time : dur;
11 import logger = std.experimental.logger;
12 import std.array : Appender;
13 import std.datetime : SysTime, Duration;
14 import std.range;
15 
16 import miniorm.exception;
17 import miniorm.queries;
18 
19 import d2sqlite3;
20 
21 version (unittest) {
22     import std.algorithm : map;
23     import unit_threaded.assertions;
24 }
25 
26 ///
27 struct Miniorm {
28     private Statement[string] cachedStmt;
29     /// True means that all queries are logged.
30     private bool log_;
31 
32     ///
33     Database db;
34     alias getUnderlyingDb this;
35 
36     ref Database getUnderlyingDb() {
37         return db;
38     }
39 
40     ///
41     this(Database db) {
42         this.db = db;
43     }
44 
45     ///
46     this(string path, int flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE) {
47         this(Database(path, flags));
48     }
49 
50     ~this() {
51         cleanupCache;
52     }
53 
54     /// Start a RAII handled transaction.
55     Transaction transaction() {
56         return Transaction(this);
57     }
58 
59     /// Toggle logging.
60     void log(bool v) nothrow {
61         this.log_ = v;
62     }
63 
64     /// Returns: True if logging is activated
65     private bool isLog() {
66         return log_;
67     }
68 
69     private void cleanupCache() {
70         foreach (ref s; cachedStmt.byValue)
71             s.finalize;
72         cachedStmt = null;
73     }
74 
75     void opAssign(ref typeof(this) rhs) {
76         cleanupCache;
77         db = rhs.db;
78     }
79 
80     void run(string sql, bool delegate(ResultRange) dg = null) {
81         if (isLog)
82             logger.trace(sql);
83         db.run(sql, dg);
84     }
85 
86     void close() {
87         cleanupCache;
88         db.close();
89     }
90 
91     size_t run(T)(Count!T v) {
92         const sql = v.toSql.toString;
93         if (isLog)
94             logger.trace(sql);
95         return db.executeCheck(sql).front.front.as!size_t;
96     }
97 
98     auto run(T)(Select!T v) {
99         import std.algorithm : map;
100         import std.format : format;
101         import std.range : inputRangeObject;
102 
103         const sql = v.toSql.toString;
104         if (isLog)
105             logger.trace(sql);
106 
107         auto result = db.executeCheck(sql);
108 
109         static T qconv(typeof(result.front) e) {
110             import miniorm.schema : fieldToCol;
111 
112             T ret;
113             static string rr() {
114                 string[] res;
115                 res ~= "import std.traits : isStaticArray, OriginalType;";
116                 res ~= "import miniorm.api : fromSqLiteDateTime;";
117                 foreach (i, a; fieldToCol!("", T)()) {
118                     res ~= `{`;
119                     if (a.columnType == "DATETIME") {
120                         res ~= `{ ret.%1$s = fromSqLiteDateTime(e.peek!string(%2$d)); }`.format(a.identifier,
121                                 i);
122                     } else {
123                         res ~= q{alias ET = typeof(ret.%s);}.format(a.identifier);
124                         res ~= q{static if (isStaticArray!ET)};
125                         res ~= `
126                             {
127                                 import std.algorithm : min;
128                                 auto ubval = e[%2$d].as!(ubyte[]);
129                                 auto etval = cast(typeof(ET.init[]))ubval;
130                                 auto ln = min(ret.%1$s.length, etval.length);
131                                 ret.%1$s[0..ln] = etval[0..ln];
132                             }
133                             `.format(a.identifier, i);
134                         res ~= q{else static if (is(ET == enum))};
135                         res ~= format(q{ret.%1$s = cast(ET) e.peek!ET(%2$d);}, a.identifier, i);
136                         res ~= q{else};
137                         res ~= format(q{ret.%1$s = e.peek!ET(%2$d);}, a.identifier, i);
138                     }
139                     res ~= `}`;
140                 }
141                 return res.join("\n");
142             }
143 
144             mixin(rr());
145             return ret;
146         }
147 
148         return result.map!qconv;
149     }
150 
151     void run(T)(Delete!T v) {
152         const sql = v.toSql.toString;
153         if (isLog)
154             logger.trace(sql);
155         db.run(sql);
156     }
157 
158     void run(AggregateInsert all = AggregateInsert.no, T0, T1)(Insert!T0 v, T1[] arr...)
159             if (!isInputRange!T1) {
160         procInsert!all(v, arr);
161     }
162 
163     void run(AggregateInsert all = AggregateInsert.no, T, R)(Insert!T v, R rng)
164             if (isInputRange!R) {
165         procInsert!all(v, rng);
166     }
167 
168     private void procInsert(AggregateInsert all = AggregateInsert.no, T, R)(Insert!T q, R rng)
169             if ((all && hasLength!R) || !all) {
170         import std.algorithm : among;
171 
172         // generate code for binding values in a struct to a prepared
173         // statement.
174         // Expects an external variable "n" to exist that keeps track of the
175         // index. This is requied when the binding is for multiple values.
176         // Expects the statement to be named "stmt".
177         // Expects the variable to read values from to be named "v".
178         // Indexing start from 1 according to the sqlite manual.
179         static string genBinding(T)(bool replace) {
180             import miniorm.schema : fieldToCol;
181 
182             string s;
183             foreach (i, v; fieldToCol!("", T)) {
184                 if (!replace && v.isPrimaryKey)
185                     continue;
186                 if (v.columnType == "DATETIME")
187                     s ~= "stmt.bind(n+1, v." ~ v.identifier ~ ".toUTC.toSqliteDateTime);";
188                 else
189                     s ~= "stmt.bind(n+1, v." ~ v.identifier ~ ");";
190                 s ~= "++n;";
191             }
192             return s;
193         }
194 
195         alias T = ElementType!R;
196 
197         const replace = q.query.opt == InsertOpt.InsertOrReplace;
198 
199         static if (all == AggregateInsert.yes)
200             q = q.values(rng.length);
201         else
202             q = q.values(1);
203 
204         const sql = q.toSql.toString;
205 
206         if (sql !in cachedStmt)
207             cachedStmt[sql] = db.prepare(sql);
208         auto stmt = cachedStmt[sql];
209 
210         static if (all == AggregateInsert.yes) {
211             int n;
212             foreach (v; rng) {
213                 if (replace) {
214                     mixin(genBinding!T(true));
215                 } else {
216                     mixin(genBinding!T(false));
217                 }
218             }
219             if (isLog)
220                 logger.trace(sql, " -> ", rng);
221             stmt.execute();
222             stmt.reset();
223         } else {
224             foreach (v; rng) {
225                 int n;
226                 if (replace) {
227                     mixin(genBinding!T(true));
228                 } else {
229                     mixin(genBinding!T(false));
230                 }
231                 if (isLog)
232                     logger.trace(sql, " -> ", v);
233                 stmt.execute();
234                 stmt.reset();
235             }
236         }
237     }
238 }
239 
240 /** Wheter one aggregated insert or multiple should be generated.
241  *
242  * no:
243  * ---
244  * INSERT INTO foo ('v0') VALUES (?)
245  * INSERT INTO foo ('v0') VALUES (?)
246  * INSERT INTO foo ('v0') VALUES (?)
247  * ---
248  *
249  * yes:
250  * ---
251  * INSERT INTO foo ('v0') VALUES (?) (?) (?)
252  * ---
253  */
254 enum AggregateInsert {
255     no,
256     yes
257 }
258 
259 version (unittest) {
260     import miniorm.schema;
261 
262     import std.conv : text, to;
263     import std.range;
264     import std.algorithm;
265     import std.datetime;
266     import std.array;
267     import std.stdio;
268 
269     import unit_threaded.assertions;
270 }
271 
272 @("shall operate on a database allocted in std.experimental.allocators without any errors")
273 unittest {
274     struct One {
275         ulong id;
276         string text;
277     }
278 
279     import std.experimental.allocator;
280     import std.experimental.allocator.mallocator;
281     import std.experimental.allocator.building_blocks.scoped_allocator;
282 
283     // TODO: fix this
284     //Microrm* db;
285     //ScopedAllocator!Mallocator scalloc;
286     //db = scalloc.make!Microrm(":memory:");
287     //scope (exit) {
288     //    db.close;
289     //    scalloc.dispose(db);
290     //}
291 
292     // TODO: replace the one below with the above code.
293     auto db = Miniorm(":memory:");
294     db.run(buildSchema!One);
295     db.run(insert!One.insert, iota(0, 10).map!(i => One(i * 100, "hello" ~ text(i))));
296     db.run(count!One).shouldEqual(10);
297 
298     auto ones = db.run(select!One).array;
299     ones.length.shouldEqual(10);
300     assert(ones.all!(a => a.id < 100));
301     db.getUnderlyingDb.lastInsertRowid.shouldEqual(ones[$ - 1].id);
302 
303     db.run(delete_!One);
304     db.run(count!One).shouldEqual(0);
305     db.run(insertOrReplace!One, iota(0, 499).map!(i => One((i + 1) * 100, "hello" ~ text(i))));
306     ones = db.run(select!One).array;
307     ones.length.shouldEqual(499);
308     assert(ones.all!(a => a.id >= 100));
309     db.lastInsertRowid.shouldEqual(ones[$ - 1].id);
310 }
311 
312 @("shall insert and extract datetime from the table")
313 unittest {
314     import std.datetime : Clock;
315     import core.thread : Thread;
316     import core.time : dur;
317 
318     struct One {
319         ulong id;
320         SysTime time;
321     }
322 
323     auto db = Miniorm(":memory:");
324     db.run(buildSchema!One);
325 
326     const time = Clock.currTime;
327     Thread.sleep(1.dur!"msecs");
328 
329     db.run(insert!One.insert, One(0, Clock.currTime));
330 
331     auto ones = db.run(select!One).array;
332     ones.length.shouldEqual(1);
333     ones[0].time.shouldBeGreaterThan(time);
334 }
335 
336 unittest {
337     struct One {
338         ulong id;
339         string text;
340     }
341 
342     auto db = Miniorm(":memory:");
343     db.run(buildSchema!One);
344 
345     db.run(count!One).shouldEqual(0);
346     db.run!(AggregateInsert.yes)(insert!One.insert, iota(0, 10)
347             .map!(i => One(i * 100, "hello" ~ text(i))));
348     db.run(count!One).shouldEqual(10);
349 
350     auto ones = db.run(select!One).array;
351     assert(ones.length == 10);
352     assert(ones.all!(a => a.id < 100));
353     assert(db.lastInsertRowid == ones[$ - 1].id);
354 
355     db.run(delete_!One);
356     db.run(count!One).shouldEqual(0);
357 
358     import std.datetime;
359     import std.conv : to;
360 
361     db.run!(AggregateInsert.yes)(insertOrReplace!One, iota(0, 499)
362             .map!(i => One((i + 1) * 100, "hello" ~ text(i))));
363     ones = db.run(select!One).array;
364     assert(ones.length == 499);
365     assert(ones.all!(a => a.id >= 100));
366     assert(db.lastInsertRowid == ones[$ - 1].id);
367 }
368 
369 @("shall convert the database type to the enum when retrieving via select")
370 unittest {
371     static struct Foo {
372         enum MyEnum : string {
373             foo = "batman",
374             bar = "robin",
375         }
376 
377         ulong id;
378         MyEnum enum_;
379     }
380 
381     auto db = Miniorm(":memory:");
382     db.run(buildSchema!Foo);
383 
384     db.run(insert!Foo.insert, Foo(0, Foo.MyEnum.bar));
385     auto res = db.run(select!Foo).array;
386 
387     res.length.shouldEqual(1);
388     res[0].enum_.shouldEqual(Foo.MyEnum.bar);
389 }
390 
391 unittest {
392     struct Limit {
393         int min, max;
394     }
395 
396     struct Limits {
397         Limit volt, curr;
398     }
399 
400     struct Settings {
401         ulong id;
402         Limits limits;
403     }
404 
405     auto db = Miniorm(":memory:");
406     db.run(buildSchema!Settings);
407     assert(db.run(count!Settings) == 0);
408     db.run(insertOrReplace!Settings, Settings(10, Limits(Limit(0, 12), Limit(-10, 10))));
409     assert(db.run(count!Settings) == 1);
410 
411     db.run(insertOrReplace!Settings, Settings(10, Limits(Limit(0, 2), Limit(-3, 3))));
412     db.run(insertOrReplace!Settings, Settings(11, Limits(Limit(0, 11), Limit(-11, 11))));
413     db.run(insertOrReplace!Settings, Settings(12, Limits(Limit(0, 12), Limit(-12, 12))));
414 
415     assert(db.run(count!Settings) == 3);
416     assert(db.run(count!Settings.where(`"limits.volt.max" = 2`)) == 1);
417     assert(db.run(count!Settings.where(`"limits.volt.max" > 10`)) == 2);
418     db.run(delete_!Settings.where(`"limits.volt.max" < 10`));
419     assert(db.run(count!Settings) == 2);
420 }
421 
422 unittest {
423     struct Settings {
424         ulong id;
425         int[5] data;
426     }
427 
428     auto db = Miniorm(":memory:");
429     db.run(buildSchema!Settings);
430 
431     db.run(insert!Settings.insert, Settings(0, [1, 2, 3, 4, 5]));
432 
433     assert(db.run(count!Settings) == 1);
434     auto s = db.run(select!Settings).front;
435     assert(s.data == [1, 2, 3, 4, 5]);
436 }
437 
438 SysTime fromSqLiteDateTime(string raw_dt) {
439     import core.time : dur;
440     import std.datetime : DateTime, UTC;
441     import std.format : formattedRead;
442 
443     int year, month, day, hour, minute, second, msecs;
444     formattedRead(raw_dt, "%s-%s-%s %s:%s:%s.%s", year, month, day, hour, minute, second, msecs);
445     auto dt = DateTime(year, month, day, hour, minute, second);
446 
447     return SysTime(dt, msecs.dur!"msecs", UTC());
448 }
449 
450 string toSqliteDateTime(SysTime ts) {
451     import std.format;
452 
453     return format("%04s-%02s-%02s %02s:%02s:%02s.%s", ts.year,
454             cast(ushort) ts.month, ts.day, ts.hour, ts.minute, ts.second,
455             ts.fracSecs.total!"msecs");
456 }
457 
458 class SpinSqlTimeout : Exception {
459     this(string msg, string file = __FILE__, int line = __LINE__) @safe pure nothrow {
460         super(msg, file, line);
461     }
462 }
463 
464 /** Execute an SQL query until it succeeds.
465  *
466  * Note: If there are any errors in the query it will go into an infinite loop.
467  */
468 auto spinSql(alias query, alias logFn = logger.warning)(Duration timeout,
469         Duration minTime = 50.dur!"msecs", Duration maxTime = 150.dur!"msecs") {
470     import core.thread : Thread;
471     import core.time : dur;
472     import std.datetime.stopwatch : StopWatch, AutoStart;
473     import std.exception : collectException;
474     import std.random : uniform;
475 
476     const sw = StopWatch(AutoStart.yes);
477 
478     while (sw.peek < timeout) {
479         try {
480             return query();
481         } catch (Exception e) {
482             logFn(e.msg).collectException;
483             // even though the database have a builtin sleep it still result in too much spam.
484             () @trusted {
485                 Thread.sleep(uniform(minTime.total!"msecs", maxTime.total!"msecs").dur!"msecs");
486             }();
487         }
488     }
489 
490     throw new SpinSqlTimeout(null);
491 }
492 
493 auto spinSql(alias query, alias logFn = logger.warning)() nothrow {
494     while (true) {
495         try {
496             return spinSql!(query, logFn)(Duration.max);
497         } catch (Exception e) {
498         }
499     }
500 }
501 
502 /** Sleep for a random time that is min_ + rnd(0, span).
503  *
504  * Params:
505  *  span = unit is msecs.
506  */
507 void rndSleep(Duration min_, ulong span) nothrow @trusted {
508     import core.thread : Thread;
509     import core.time : dur;
510     import std.random : uniform;
511 
512     auto t_span = () {
513         try {
514             return uniform(0, span).dur!"msecs";
515         } catch (Exception e) {
516         }
517         return span.dur!"msecs";
518     }();
519 
520     Thread.sleep(min_ + t_span);
521 }
522 
523 /// RAII handling of a transaction.
524 struct Transaction {
525     Miniorm db;
526     bool isDone;
527 
528     this(Miniorm db) {
529         this.db = db;
530         spinSql!(() { db.begin; });
531     }
532 
533     ~this() {
534         scope (exit)
535             isDone = true;
536         if (!isDone)
537             db.rollback;
538     }
539 
540     void commit() {
541         scope (exit)
542             isDone = true;
543         db.commit;
544     }
545 }