1 /**
2 Copyright: Copyright (c) 2018, Joakim Brännström. All rights reserved.
3 License: $(LINK2 http://www.boost.org/LICENSE_1_0.txt, Boost Software License 1.0)
4 Author: Joakim Brännström (joakim.brannstrom@gmx.com)
5 
6 This module defines the protocol for data transfer and functionality to use it.
7 */
8 module distssh.protocol;
9 
10 import std.array : appender;
11 import std.range : put;
12 import logger = std.experimental.logger;
13 
14 import msgpack_ll;
15 import sumtype;
16 
17 enum Kind : ubyte {
18     none,
19     heartBeat,
20     /// The shell environment.
21     environment,
22     /// The working directory to execute the command in.
23     workdir,
24     /// Command to execute
25     command,
26     /// All configuration data has been sent.
27     confDone,
28     /// One or more key strokes to be written to stdin
29     key,
30     /// terminal capabilities
31     terminalCapability,
32 }
33 
34 enum KindSize = DataSize!(MsgpackType.uint8);
35 
36 struct Serialize(WriterT) {
37 @safe:
38 
39     WriterT w;
40 
41     void pack(Kind k) {
42         ubyte[KindSize] pkgtype;
43         formatType!(MsgpackType.uint8)(k, pkgtype);
44         put(w, pkgtype[]);
45     }
46 
47     void pack(const string s) {
48         import msgpack_ll;
49 
50         ubyte[5] hdr;
51         // TODO a uint is potentially too big. standard says 2^32-1
52         formatType!(MsgpackType.str32)(cast(uint) s.length, hdr);
53         put(w, hdr[]);
54         put(w, cast(const(ubyte)[]) s);
55     }
56 
57     void packArray(T)(T[] value)
58     in (value.length < ushort.max) {
59         import msgpack_ll;
60 
61         ubyte[DataSize!(MsgpackType.array16)] hdr;
62         formatType!(MsgpackType.array16)(cast(ushort) value.length, hdr);
63         put(w, hdr[]);
64         foreach (v; value) {
65             put(w, v);
66         }
67     }
68 
69     void pack(MsgpackType Type, T)(T v) {
70         import msgpack_ll;
71 
72         ubyte[DataSize!Type] buf;
73         formatType!Type(v, buf);
74         put(w, buf[]);
75     }
76 
77     void pack(T)() if (is(T == HeartBeat)) {
78         pack(Kind.heartBeat);
79     }
80 
81     void pack(T)() if (is(T == ConfDone)) {
82         pack(Kind.confDone);
83     }
84 
85     void pack(const Workdir wd) {
86         // dfmt off
87         const sz =
88             KindSize +
89             DataSize!(MsgpackType.str32) +
90             (cast(const(ubyte)[]) wd.value).length;
91         // dfmt on
92 
93         pack(Kind.workdir);
94         pack!(MsgpackType.uint32)(cast(uint) sz);
95         pack(wd.value);
96     }
97 
98     void pack(const Key key) {
99         // dfmt off
100         const sz =
101             KindSize +
102             DataSize!(MsgpackType.uint32) +
103             DataSize!(MsgpackType.array16) +
104             (cast(const(ubyte)[]) key.value).length;
105         // dfmt on
106 
107         pack(Kind.key);
108         pack!(MsgpackType.uint32)(cast(uint) sz);
109         packArray(key.value);
110     }
111 
112     void pack(const Command cmd) {
113         import std.algorithm : map, sum;
114 
115         // dfmt off
116         const sz =
117             KindSize +
118             DataSize!(MsgpackType.uint32) +
119             DataSize!(MsgpackType.uint32) +
120             cmd.value.map!(a => (cast(const(ubyte)[]) a).length).sum;
121         // dfmt on
122 
123         pack(Kind.command);
124         pack!(MsgpackType.uint32)(cast(uint) sz);
125         pack!(MsgpackType.uint32)(cast(uint) cmd.value.length);
126 
127         foreach (a; cmd.value) {
128             pack(a);
129         }
130     }
131 
132     void pack(const ProtocolEnv env) {
133         import std.algorithm : map, sum;
134 
135         // dfmt off
136         const tot_size =
137             KindSize +
138             DataSize!(MsgpackType.uint32) +
139             DataSize!(MsgpackType.uint32) +
140             env.value.map!(a => 2*DataSize!(MsgpackType.str32) +
141                            (cast(const(ubyte)[]) a.key).length +
142                            (cast(const(ubyte)[]) a.value).length).sum;
143         // dfmt on
144 
145         pack(Kind.environment);
146         pack!(MsgpackType.uint32)(cast(uint) tot_size);
147         pack!(MsgpackType.uint32)(cast(uint) env.length);
148 
149         foreach (const kv; env) {
150             pack(kv.key);
151             pack(kv.value);
152         }
153     }
154 
155     void pack(const TerminalCapability t) {
156         // modelled after:
157         // struct termios {
158         //     tcflag_t   c_iflag;
159         //     tcflag_t   c_oflag;
160         //     tcflag_t   c_cflag;
161         //     tcflag_t   c_lflag;
162         //     cc_t       c_line;
163         //     cc_t[NCCS] c_cc;
164         //     speed_t    c_ispeed;
165         //     speed_t    c_ospeed;
166         // }
167 
168         // dfmt off
169         const uint sz =
170             KindSize +
171             DataSize!(MsgpackType.uint32) +
172 
173             DataSize!(MsgpackType.uint32) +
174             DataSize!(MsgpackType.uint32) +
175             DataSize!(MsgpackType.uint32) +
176             DataSize!(MsgpackType.uint32) +
177             DataSize!(MsgpackType.uint8) +
178             DataSize!(MsgpackType.array16) + t.value.c_cc.length +
179             DataSize!(MsgpackType.uint32) +
180             DataSize!(MsgpackType.uint32);
181         // dfmt on
182 
183         pack(Kind.terminalCapability);
184         pack!(MsgpackType.uint32)(sz);
185         pack!(MsgpackType.uint32)(t.value.c_iflag);
186         pack!(MsgpackType.uint32)(t.value.c_oflag);
187         pack!(MsgpackType.uint32)(t.value.c_cflag);
188         pack!(MsgpackType.uint32)(t.value.c_lflag);
189         pack!(MsgpackType.uint8)(t.value.c_line);
190         packArray(t.value.c_cc[]);
191         pack!(MsgpackType.uint32)(t.value.c_ispeed);
192         pack!(MsgpackType.uint32)(t.value.c_ospeed);
193     }
194 }
195 
196 struct Deserialize {
197     import std.conv : to;
198 
199     alias Result = SumType!(None, HeartBeat, ProtocolEnv, ConfDone, Command,
200             Workdir, Key, TerminalCapability);
201 
202     ubyte[] buf;
203 
204     void put(const ubyte[] v) {
205         buf ~= v;
206     }
207 
208     Result unpack() {
209         cleanupUntilKind();
210 
211         Result rval;
212         if (buf.length < KindSize)
213             return rval;
214 
215         const k = () {
216             auto raw = peek!(MsgpackType.uint8, ubyte)();
217             if (raw > Kind.max)
218                 return Kind.none;
219             return cast(Kind) raw;
220         }();
221 
222         debug logger.tracef("%-(%X, %)", buf);
223 
224         final switch (k) {
225         case Kind.none:
226             consume!(MsgpackType.uint8);
227             return rval;
228         case Kind.heartBeat:
229             consume!(MsgpackType.uint8);
230             rval = HeartBeat.init;
231             break;
232         case Kind.environment:
233             rval = unpackProtocolEnv;
234             break;
235         case Kind.confDone:
236             consume!(MsgpackType.uint8);
237             rval = ConfDone.init;
238             break;
239         case Kind.command:
240             rval = unpackCommand;
241             break;
242         case Kind.workdir:
243             rval = unpackWorkdir;
244             break;
245         case Kind.key:
246             rval = unpackKey;
247             break;
248         case Kind.terminalCapability:
249             rval = unpackTerminalCapability;
250             break;
251         }
252 
253         return rval;
254     }
255 
256     private bool enoughData() {
257         const hdrTotalSz = KindSize + DataSize!(MsgpackType.uint32);
258         if (buf.length < hdrTotalSz)
259             return false;
260 
261         const totalSz = () {
262             auto s = buf[KindSize .. $];
263             return peek!(MsgpackType.uint32, uint)(s);
264         }();
265 
266         debug logger.trace("Bytes to unpack: ", totalSz);
267 
268         if (buf.length < totalSz)
269             return false;
270         return true;
271     }
272 
273     /** Consume from the buffer until a valid kind is found.
274      */
275     private void cleanupUntilKind() nothrow {
276         while (buf.length != 0) {
277             if (buf.length < KindSize)
278                 break;
279 
280             try {
281                 auto raw = peek!(MsgpackType.uint8, ubyte)();
282                 if (raw <= Kind.max)
283                     break;
284                 debug logger.trace("dropped ", raw);
285             } catch (Exception e) {
286             }
287 
288             buf = buf[1 .. $];
289         }
290     }
291 
292     private ProtocolEnv unpackProtocolEnv() {
293         const kind_totsize = KindSize + DataSize!(MsgpackType.uint32);
294         if (buf.length < kind_totsize)
295             return typeof(return)();
296 
297         const tot_size = () {
298             auto s = buf[KindSize .. $];
299             return peek!(MsgpackType.uint32, uint)(s);
300         }();
301 
302         debug logger.trace("Bytes to unpack: ", tot_size);
303 
304         if (buf.length < tot_size)
305             return typeof(return)();
306 
307         // all data is received, start unpacking
308         ProtocolEnv env;
309         demux!(MsgpackType.uint8, ubyte);
310         demux!(MsgpackType.uint32, uint);
311 
312         const kv_pairs = demux!(MsgpackType.uint32, uint);
313         for (uint i; i < kv_pairs; ++i) {
314             string key;
315             string value;
316 
317             // may contain invalid utf8 chars but still have to consume everything
318             try {
319                 key = demux!string();
320             } catch (Exception e) {
321             }
322 
323             try {
324                 value = demux!string();
325             } catch (Exception e) {
326             }
327 
328             env ~= EnvVariable(key, value);
329         }
330 
331         return typeof(return)(env);
332     }
333 
334     private Command unpackCommand() {
335         const hdrTotalSz = KindSize + DataSize!(MsgpackType.uint32);
336         if (buf.length < hdrTotalSz)
337             return Command.init;
338 
339         const totalSz = () {
340             auto s = buf[KindSize .. $];
341             return peek!(MsgpackType.uint32, uint)(s);
342         }();
343 
344         debug logger.trace("Bytes to unpack: ", totalSz);
345 
346         if (buf.length < totalSz)
347             return typeof(return)();
348 
349         // all data is received, start unpacking
350         demux!(MsgpackType.uint8, ubyte);
351         demux!(MsgpackType.uint32, uint);
352 
353         Command cmd;
354         const elems = demux!(MsgpackType.uint32, uint);
355         foreach (_; 0 .. elems) {
356             cmd.value ~= demux!string();
357         }
358 
359         return cmd;
360     }
361 
362     private Workdir unpackWorkdir() {
363         const hdrTotalSz = KindSize + DataSize!(MsgpackType.uint32);
364         if (buf.length < hdrTotalSz)
365             return Workdir.init;
366 
367         const totalSz = () {
368             auto s = buf[KindSize .. $];
369             return peek!(MsgpackType.uint32, uint)(s);
370         }();
371 
372         debug logger.trace("Bytes to unpack: ", totalSz);
373 
374         if (buf.length < totalSz)
375             return typeof(return)();
376 
377         // all data is received, start unpacking
378         demux!(MsgpackType.uint8, ubyte);
379         demux!(MsgpackType.uint32, uint);
380 
381         return Workdir(demux!string);
382     }
383 
384     private Key unpackKey() {
385         const hdrTotalSz = KindSize + DataSize!(MsgpackType.uint32);
386         if (buf.length < hdrTotalSz)
387             return Key.init;
388 
389         const totalSz = () {
390             auto s = buf[KindSize .. $];
391             return peek!(MsgpackType.uint32, uint)(s);
392         }();
393 
394         debug logger.trace("Bytes to unpack: ", totalSz);
395 
396         if (buf.length < totalSz)
397             return typeof(return)();
398 
399         // all data is received, start unpacking
400         demux!(MsgpackType.uint8, ubyte);
401         demux!(MsgpackType.uint32, uint);
402 
403         Key key;
404         const elems = demux!(MsgpackType.array16, ushort);
405         key.value = buf[0 .. elems];
406         buf = buf[elems .. $];
407 
408         return key;
409     }
410 
411     private TerminalCapability unpackTerminalCapability() {
412         if (!enoughData)
413             return TerminalCapability.init;
414 
415         // all data is received, start unpacking
416         demux!(MsgpackType.uint8, ubyte);
417         demux!(MsgpackType.uint32, uint);
418 
419         TerminalCapability t;
420 
421         t.value.c_iflag = demux!(MsgpackType.uint32, uint);
422         t.value.c_oflag = demux!(MsgpackType.uint32, uint);
423         t.value.c_cflag = demux!(MsgpackType.uint32, uint);
424         t.value.c_lflag = demux!(MsgpackType.uint32, uint);
425         t.value.c_line = demux!(MsgpackType.uint8, ubyte);
426 
427         const elems = demux!(MsgpackType.array16, ushort);
428         t.value.c_cc = buf[0 .. elems];
429         buf = buf[elems .. $];
430 
431         t.value.c_ispeed = demux!(MsgpackType.uint32, uint);
432         t.value.c_ospeed = demux!(MsgpackType.uint32, uint);
433 
434         return t;
435     }
436 
437 private:
438     void consume(MsgpackType type)() {
439         buf = buf[DataSize!type .. $];
440     }
441 
442     void consume(size_t len) {
443         buf = buf[len .. $];
444     }
445 
446     T peek(MsgpackType Type, T)() {
447         return peek!(Type, T)(buf);
448     }
449 
450     static T peek(MsgpackType Type, T)(ref ubyte[] buf) {
451         import std.exception : enforce;
452 
453         enforce(getType(buf[0]) == Type);
454         T v = parseType!Type(buf[0 .. DataSize!Type]);
455 
456         return v;
457     }
458 
459     T demux(MsgpackType Type, T)() {
460         import std.exception : enforce;
461         import msgpack_ll;
462 
463         enforce(getType(buf[0]) == Type);
464         T v = parseType!Type(buf[0 .. DataSize!Type]);
465         consume!Type();
466 
467         return v;
468     }
469 
470     string demux(T)() if (is(T == string)) {
471         import std.exception : enforce;
472         import std.utf : validate;
473         import msgpack_ll;
474 
475         enforce(getType(buf[0]) == MsgpackType.str32);
476         auto len = parseType!(MsgpackType.str32)(buf[0 .. DataSize!(MsgpackType.str32)]);
477         consume!(MsgpackType.str32);
478 
479         // 2^32-1 according to the standard
480         enforce(len < int.max);
481 
482         char[] raw = cast(char[]) buf[0 .. len];
483         consume(len);
484         validate(raw);
485 
486         return raw.idup;
487     }
488 }
489 
490 struct None {
491 }
492 
493 struct HeartBeat {
494 }
495 
496 struct ConfDone {
497 }
498 
499 struct EnvVariable {
500     string key;
501     string value;
502 }
503 
504 struct ProtocolEnv {
505     EnvVariable[] value;
506     alias value this;
507 }
508 
509 struct Command {
510     string[] value;
511 }
512 
513 struct Workdir {
514     string value;
515 }
516 
517 struct Key {
518     const(ubyte)[] value;
519 }
520 
521 struct TerminalCapability {
522     import core.sys.posix.termios;
523 
524     termios value;
525 }
526 
527 @("shall pack and unpack a HeartBeat")
528 unittest {
529     auto app = appender!(ubyte[])();
530     auto ser = Serialize!(typeof(app))(app);
531 
532     ser.pack!HeartBeat;
533     assert(app.data.length > 0);
534 
535     auto deser = Deserialize(app.data);
536     deser.unpack.match!((None x) { assert(false); }, (ConfDone x) {
537         assert(false);
538     }, (ProtocolEnv x) { assert(false); }, (HeartBeat x) { assert(true); }, (Key x) {
539         assert(false);
540     }, (Command x) { assert(false); }, (Workdir x) { assert(false); }, (TerminalCapability x) {
541         assert(false);
542     });
543 }
544 
545 @("shall clean the buffer until a valid kind is found")
546 unittest {
547     auto app = appender!(ubyte[])();
548     app.put(cast(ubyte) 42);
549     auto ser = Serialize!(typeof(app))(app);
550     ser.pack!HeartBeat;
551 
552     auto deser = Deserialize(app.data);
553     assert(deser.buf.length == 3);
554     deser.cleanupUntilKind;
555     assert(deser.buf.length == 2);
556 }
557 
558 @("shall pack and unpack an environment")
559 unittest {
560     auto app = appender!(ubyte[])();
561     auto ser = Serialize!(typeof(app))(app);
562 
563     ser.pack(ProtocolEnv([EnvVariable("foo", "bar")]));
564     assert(app.data.length > 0);
565 
566     auto deser = Deserialize(app.data);
567     deser.unpack.match!((None x) { assert(false); }, (ConfDone x) {
568         assert(false);
569     }, (ProtocolEnv x) { assert(true); logger.trace(x); }, (HeartBeat x) {
570         assert(false);
571     }, (Key x) { assert(false); }, (Command x) { assert(false); }, (Workdir x) {
572         assert(false);
573     }, (TerminalCapability x) { assert(false); });
574 }
575 
576 @("shall pack and unpack a key")
577 unittest {
578     auto app = appender!(ubyte[])();
579     auto ser = Serialize!(typeof(app))(app);
580 
581     ser.pack(Key([1, 2, 3]));
582     assert(app.data.length > 0);
583 
584     auto deser = Deserialize(app.data);
585     deser.unpack.match!((None x) { assert(false); }, (ConfDone x) {
586         assert(false);
587     }, (ProtocolEnv x) { assert(false); }, (HeartBeat x) { assert(false); }, (Workdir x) {
588         assert(false);
589     }, (Command x) { assert(false); }, (Key x) { assert(true); logger.trace(x); },
590             (TerminalCapability x) { assert(false); });
591 }
592 
593 @("shall pack and unpack a termio")
594 unittest {
595     import core.sys.posix.termios;
596 
597     auto app = appender!(ubyte[])();
598     auto ser = Serialize!(typeof(app))(app);
599 
600     termios mode;
601     if (tcgetattr(0, &mode) != 0) {
602         assert(false);
603     }
604 
605     ser.pack(TerminalCapability(mode));
606     assert(app.data.length > 0);
607 
608     auto deser = Deserialize(app.data);
609     deser.unpack.match!((None x) { assert(false); }, (ConfDone x) {
610         assert(false);
611     }, (ProtocolEnv x) { assert(false); }, (HeartBeat x) { assert(false); }, (Workdir x) {
612         assert(false);
613     }, (Command x) { assert(false); }, (Key x) { assert(false); }, (TerminalCapability x) {
614         assert(true);
615         assert(x.value == mode);
616     });
617 }