1 /**
2 Copyright: Copyright (c) 2018, Joakim Brännström. All rights reserved.
3 License: $(LINK2 http://www.boost.org/LICENSE_1_0.txt, Boost Software License 1.0)
4 Author: Joakim Brännström (joakim.brannstrom@gmx.com)
5 
6 This module defines the protocol for data transfer and functionality to use it.
7 */
8 module distssh.protocol;
9 
10 import std.array : appender;
11 import std.range : put;
12 import logger = std.experimental.logger;
13 import msgpack_ll;
14 
15 enum Kind : ubyte {
16     none,
17     heartBeat,
18     environment,
19     remoteHost,
20 }
21 
22 enum KindSize = DataSize!(MsgpackType.uint8);
23 
24 struct Serialize(WriterT) {
25 @safe:
26 
27     WriterT w;
28 
29     void pack(Kind k) {
30         ubyte[KindSize] pkgtype;
31         formatType!(MsgpackType.uint8)(k, pkgtype);
32         put(w, pkgtype[]);
33     }
34 
35     void pack(const string s) {
36         import msgpack_ll;
37 
38         ubyte[5] hdr;
39         // TODO a uint is potentially too big. standard says 2^32-1
40         formatType!(MsgpackType.str32)(cast(uint) s.length, hdr);
41         put(w, hdr[]);
42         put(w, cast(const(ubyte)[]) s);
43     }
44 
45     void pack(MsgpackType Type, T)(T v) {
46         import msgpack_ll;
47 
48         ubyte[DataSize!Type] buf;
49         formatType!Type(v, buf);
50         put(w, buf[]);
51     }
52 
53     void pack(T)() if (is(T == HeartBeat)) {
54         pack(Kind.heartBeat);
55     }
56 
57     void pack(const ProtocolEnv env) {
58         import std.algorithm : map, sum;
59 
60         // dfmt off
61         const tot_size =
62             KindSize +
63             DataSize!(MsgpackType.uint32) +
64             DataSize!(MsgpackType.uint32) +
65             env.value.map!(a => 2*DataSize!(MsgpackType.str32) + a.key.length + a.value.length).sum;
66         // dfmt on
67 
68         pack(Kind.environment);
69         pack!(MsgpackType.uint32)(cast(uint) tot_size);
70         pack!(MsgpackType.uint32)(cast(uint) env.length);
71 
72         foreach (const kv; env) {
73             pack(kv.key);
74             pack(kv.value);
75         }
76     }
77 
78     void pack(const RemoteHost host) {
79         pack(Kind.remoteHost);
80         pack(host.address);
81     }
82 }
83 
84 struct Deserialize {
85     import std.conv : to;
86     import std.typecons : Nullable;
87 
88     ubyte[] buf;
89 
90     void put(const ubyte[] v) {
91         buf ~= v;
92     }
93 
94     /** Consume from the buffer until a valid kind is found.
95      */
96     void cleanupUntilKind() nothrow {
97         while (buf.length != 0) {
98             if (buf.length < KindSize)
99                 break;
100 
101             try {
102                 auto raw = peek!(MsgpackType.uint8, ubyte)();
103                 if (raw <= Kind.max && raw != Kind.none)
104                     break;
105                 debug logger.trace("dropped ", raw);
106             }
107             catch (Exception e) {
108             }
109 
110             buf = buf[1 .. $];
111         }
112     }
113 
114     Kind nextKind() {
115         if (buf.length < KindSize)
116             return Kind.none;
117         auto raw = peek!(MsgpackType.uint8, ubyte)();
118         if (raw > Kind.max)
119             throw new Exception("Malformed packet kind: " ~ raw.to!string);
120         return cast(Kind) raw;
121     }
122 
123     Nullable!HeartBeat unpack(T)() if (is(T == HeartBeat)) {
124         if (buf.length < KindSize)
125             return typeof(return)();
126 
127         auto k = demux!(MsgpackType.uint8, ubyte)();
128         if (k == Kind.heartBeat)
129             return typeof(return)(HeartBeat());
130         return typeof(return)();
131     }
132 
133     Nullable!ProtocolEnv unpack(T)() if (is(T == ProtocolEnv)) {
134         if (nextKind != Kind.environment)
135             return typeof(return)();
136 
137         const kind_totsize = KindSize + DataSize!(MsgpackType.uint32);
138         if (buf.length < kind_totsize)
139             return typeof(return)();
140 
141         const tot_size = () {
142             auto s = buf[KindSize .. $];
143             return peek!(MsgpackType.uint32, uint)(s);
144         }();
145 
146         debug logger.trace("Bytes to unpack: ", tot_size);
147 
148         if (buf.length < tot_size)
149             return typeof(return)();
150 
151         // all data is received, start unpacking
152         ProtocolEnv env;
153         demux!(MsgpackType.uint8, ubyte);
154         demux!(MsgpackType.uint32, uint);
155 
156         const kv_pairs = demux!(MsgpackType.uint32, uint);
157         for (uint i; i < kv_pairs; ++i) {
158             string key;
159             string value;
160 
161             // may contain invalid utf8 chars but still have to consume everything
162             try {
163                 key = demux!string();
164             }
165             catch (Exception e) {
166             }
167 
168             try {
169                 value = demux!string();
170             }
171             catch (Exception e) {
172             }
173 
174             env ~= EnvVariable(key, value);
175         }
176 
177         return typeof(return)(env);
178     }
179 
180     Nullable!RemoteHost unpack(T)() if (is(T == RemoteHost)) {
181         if (nextKind != Kind.remoteHost)
182             return typeof(return)();
183 
184         // strip the kind
185         demux!(MsgpackType.uint8, ubyte);
186 
187         try {
188             auto host = RemoteHost(demux!string());
189             return typeof(return)(host);
190         }
191         catch (Exception e) {
192         }
193 
194         return typeof(return)();
195     }
196 
197 private:
198     void consume(MsgpackType type)() {
199         buf = buf[DataSize!type .. $];
200     }
201 
202     void consume(size_t len) {
203         buf = buf[len .. $];
204     }
205 
206     T peek(MsgpackType Type, T)() {
207         return peek!(Type, T)(buf);
208     }
209 
210     static T peek(MsgpackType Type, T)(ref ubyte[] buf) {
211         import std.exception : enforce;
212 
213         enforce(getType(buf[0]) == Type);
214         T v = parseType!Type(buf[0 .. DataSize!Type]);
215 
216         return v;
217     }
218 
219     T demux(MsgpackType Type, T)() {
220         import std.exception : enforce;
221         import msgpack_ll;
222 
223         enforce(getType(buf[0]) == Type);
224         T v = parseType!Type(buf[0 .. DataSize!Type]);
225         consume!Type();
226 
227         return v;
228     }
229 
230     string demux(T)() if (is(T == string)) {
231         import std.exception : enforce;
232         import std.utf : validate;
233         import msgpack_ll;
234 
235         enforce(getType(buf[0]) == MsgpackType.str32);
236         auto len = parseType!(MsgpackType.str32)(buf[0 .. DataSize!(MsgpackType.str32)]);
237         consume!(MsgpackType.str32);
238 
239         // 2^32-1 according to the standard
240         enforce(len < int.max);
241 
242         char[] raw = cast(char[]) buf[0 .. len];
243         consume(len);
244         validate(raw);
245 
246         return raw.idup;
247     }
248 }
249 
250 struct HeartBeat {
251 }
252 
253 struct EnvVariable {
254     string key;
255     string value;
256 }
257 
258 struct ProtocolEnv {
259     EnvVariable[] value;
260     alias value this;
261 }
262 
263 struct RemoteHost {
264     string address;
265 }
266 
267 @("shall pack and unpack a HeartBeat")
268 unittest {
269     auto app = appender!(ubyte[])();
270     auto ser = Serialize!(typeof(app))(app);
271 
272     ser.pack!HeartBeat;
273     assert(app.data.length > 0);
274 
275     auto deser = Deserialize(app.data);
276     assert(deser.nextKind == Kind.heartBeat);
277     auto hb = deser.unpack!HeartBeat;
278     assert(!hb.isNull);
279 }
280 
281 @("shall clean the buffer until a valid kind is found")
282 unittest {
283     auto app = appender!(ubyte[])();
284     app.put(cast(ubyte) 42);
285     auto ser = Serialize!(typeof(app))(app);
286     ser.pack!HeartBeat;
287 
288     auto deser = Deserialize(app.data);
289     assert(deser.buf.length == 3);
290     deser.cleanupUntilKind;
291     assert(deser.buf.length == 2);
292 }
293 
294 @("shall pack and unpack an environment")
295 unittest {
296     auto app = appender!(ubyte[])();
297     auto ser = Serialize!(typeof(app))(app);
298 
299     ser.pack(ProtocolEnv([EnvVariable("foo", "bar")]));
300     logger.trace(app.data);
301     logger.trace(app.data.length);
302     assert(app.data.length > 0);
303 
304     auto deser = Deserialize(app.data);
305     logger.trace(deser.unpack!ProtocolEnv);
306 }