]> git.meshlink.io Git - meshlink/blob - src/sptps.c
Move the routing header out of the SPTPS payload.
[meshlink] / src / sptps.c
1 /*
2     sptps.c -- Simple Peer-to-Peer Security
3     Copyright (C) 2014-2017 Guus Sliepen <guus@meshlink.io>
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "system.h"
21
22 #include "chacha-poly1305/chacha-poly1305.h"
23 #include "crypto.h"
24 #include "ecdh.h"
25 #include "ecdsa.h"
26 #include "logger.h"
27 #include "prf.h"
28 #include "sptps.h"
29
30 unsigned int sptps_replaywin = 32;
31
32 /*
33    Nonce MUST be exchanged first (done)
34    Signatures MUST be done over both nonces, to guarantee the signature is fresh
35    Otherwise: if ECDHE key of one side is compromised, it can be reused!
36
37    Add explicit tag to beginning of structure to distinguish the client and server when signing. (done)
38
39    Sign all handshake messages up to ECDHE kex with long-term public keys. (done)
40
41    HMACed KEX finished message to prevent downgrade attacks and prove you have the right key material (done by virtue of ECDSA over the whole ECDHE exchange?)
42
43    Explicit close message needs to be added.
44
45    Maybe do add some alert messages to give helpful error messages? Not more than TLS sends.
46
47    Use counter mode instead of OFB. (done)
48
49    Make sure ECC operations are fixed time (aka prevent side-channel attacks).
50 */
51
52 void sptps_log_quiet(sptps_t *s, int s_errno, const char *format, va_list ap) {
53         (void)s;
54         (void)s_errno;
55         (void)format;
56         (void)ap;
57 }
58
59 void sptps_log_stderr(sptps_t *s, int s_errno, const char *format, va_list ap) {
60         (void)s;
61         (void)s_errno;
62         vfprintf(stderr, format, ap);
63         fputc('\n', stderr);
64 }
65
66 void (*sptps_log)(sptps_t *s, int s_errno, const char *format, va_list ap) = sptps_log_stderr;
67
68 // Log an error message.
69 static bool error(sptps_t *s, int s_errno, const char *format, ...) {
70         if(format) {
71                 va_list ap;
72                 va_start(ap, format);
73                 sptps_log(s, s_errno, format, ap);
74                 va_end(ap);
75         }
76
77         errno = s_errno;
78         return false;
79 }
80
81 static void warning(sptps_t *s, const char *format, ...) {
82         va_list ap;
83         va_start(ap, format);
84         sptps_log(s, 0, format, ap);
85         va_end(ap);
86 }
87
88 // Send a record (datagram version, accepts all record types, handles encryption and authentication).
89 static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
90         char buffer[len + 21UL];
91
92         // Create header with sequence number, length and record type
93         uint32_t seqno = s->outseqno++;
94         uint32_t netseqno = ntohl(seqno);
95
96         memcpy(buffer, &netseqno, 4);
97         buffer[4] = type;
98         memcpy(buffer + 5, data, len);
99
100         if(s->outstate) {
101                 // If first handshake has finished, encrypt and HMAC
102                 chacha_poly1305_encrypt(s->outcipher, seqno, buffer + 4, len + 1, buffer + 4, NULL);
103                 return s->send_data(s->handle, type, buffer, len + 21UL);
104         } else {
105                 // Otherwise send as plaintext
106                 return s->send_data(s->handle, type, buffer, len + 5UL);
107         }
108 }
109
110 // Send a record (private version, accepts all record types, handles encryption and authentication).
111 static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
112         if(s->datagram) {
113                 return send_record_priv_datagram(s, type, data, len);
114         }
115
116         char buffer[len + 19UL];
117
118         // Create header with sequence number, length and record type
119         uint32_t seqno = s->outseqno++;
120         uint16_t netlen = htons(len);
121
122         memcpy(buffer, &netlen, 2);
123         buffer[2] = type;
124         memcpy(buffer + 3, data, len);
125
126         if(s->outstate) {
127                 // If first handshake has finished, encrypt and HMAC
128                 chacha_poly1305_encrypt(s->outcipher, seqno, buffer + 2, len + 1, buffer + 2, NULL);
129                 return s->send_data(s->handle, type, buffer, len + 19UL);
130         } else {
131                 // Otherwise send as plaintext
132                 return s->send_data(s->handle, type, buffer, len + 3UL);
133         }
134 }
135
136 // Send an application record.
137 bool sptps_send_record(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
138         // Sanity checks: application cannot send data before handshake is finished,
139         // and only record types 0..127 are allowed.
140         if(!s->outstate) {
141                 return error(s, EINVAL, "Handshake phase not finished yet");
142         }
143
144         if(type >= SPTPS_HANDSHAKE) {
145                 return error(s, EINVAL, "Invalid application record type");
146         }
147
148         return send_record_priv(s, type, data, len);
149 }
150
151 // Pass through unencrypted data.
152 bool sptps_send_unencrypted(sptps_t *s, const void *data, uint16_t len) {
153         // Sanity checks: application cannot send data before handshake is finished,
154         // and only non-datagram allowed.
155         if(!s->outstate) {
156                 return error(s, EINVAL, "Handshake phase not finished yet");
157         }
158
159         if(s->datagram) {
160                 return error(s, EINVAL, "Not allowed for datagrams");
161         }
162
163         return s->send_data(s->handle, SPTPS_UNENCRYPTED, data, len);
164 }
165
166 // Expect a given number of unencrypted bytes.
167 bool sptps_expect_unencrypted(sptps_t *s, uint16_t len) {
168         // Sanity checks: application cannot send data before handshake is finished,
169         // and only non-datagram allowed.
170         if(!s->instate) {
171                 return error(s, EINVAL, "Handshake phase not finished yet");
172         }
173
174         if(s->datagram) {
175                 return error(s, EINVAL, "Not allowed for datagrams");
176         }
177
178         s->reclen = len;
179         s->passthrough = true;
180         return true;
181 }
182
183 // Send a Key EXchange record, containing a random nonce and an ECDHE public key.
184 static bool send_kex(sptps_t *s) {
185         size_t keylen = ECDH_SIZE;
186
187         // Make room for our KEX message, which we will keep around since send_sig() needs it.
188         if(s->mykex) {
189                 return false;
190         }
191
192         s->mykex = realloc(s->mykex, 1 + 32 + keylen);
193
194         if(!s->mykex) {
195                 return error(s, errno, strerror(errno));
196         }
197
198         // Set version byte to zero.
199         s->mykex[0] = SPTPS_VERSION;
200
201         // Create a random nonce.
202         randomize(s->mykex + 1, 32);
203
204         // Create a new ECDH public key.
205         if(!(s->ecdh = ecdh_generate_public(s->mykex + 1 + 32))) {
206                 return error(s, EINVAL, "Failed to generate ECDH public key");
207         }
208
209         return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, 1 + 32 + keylen);
210 }
211
212 // Send a SIGnature record, containing an ECDSA signature over both KEX records.
213 static bool send_sig(sptps_t *s) {
214         size_t keylen = ECDH_SIZE;
215         size_t siglen = ecdsa_size(s->mykey);
216
217         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator, plus label
218         char msg[(1 + 32 + keylen) * 2 + 1 + s->labellen];
219         char sig[siglen];
220
221         msg[0] = s->initiator;
222         memcpy(msg + 1, s->mykex, 1 + 32 + keylen);
223         memcpy(msg + 1 + 33 + keylen, s->hiskex, 1 + 32 + keylen);
224         memcpy(msg + 1 + 2 * (33 + keylen), s->label, s->labellen);
225
226         // Sign the result.
227         if(!ecdsa_sign(s->mykey, msg, sizeof(msg), sig)) {
228                 return error(s, EINVAL, "Failed to sign SIG record");
229         }
230
231         // Send the SIG exchange record.
232         return send_record_priv(s, SPTPS_HANDSHAKE, sig, sizeof(sig));
233 }
234
235 // Generate key material from the shared secret created from the ECDHE key exchange.
236 static bool generate_key_material(sptps_t *s, const char *shared, size_t len) {
237         // Initialise cipher and digest structures if necessary
238         if(!s->outstate) {
239                 s->incipher = chacha_poly1305_init();
240                 s->outcipher = chacha_poly1305_init();
241
242                 if(!s->incipher || !s->outcipher) {
243                         return error(s, EINVAL, "Failed to open cipher");
244                 }
245         }
246
247         // Allocate memory for key material
248         size_t keylen = 2 * CHACHA_POLY1305_KEYLEN;
249
250         s->key = realloc(s->key, keylen);
251
252         if(!s->key) {
253                 return error(s, errno, strerror(errno));
254         }
255
256         // Create the HMAC seed, which is "key expansion" + session label + server nonce + client nonce
257         char seed[s->labellen + 64 + 13];
258         strcpy(seed, "key expansion");
259
260         if(s->initiator) {
261                 memcpy(seed + 13, s->mykex + 1, 32);
262                 memcpy(seed + 45, s->hiskex + 1, 32);
263         } else {
264                 memcpy(seed + 13, s->hiskex + 1, 32);
265                 memcpy(seed + 45, s->mykex + 1, 32);
266         }
267
268         memcpy(seed + 77, s->label, s->labellen);
269
270         // Use PRF to generate the key material
271         if(!prf(shared, len, seed, s->labellen + 64 + 13, s->key, keylen)) {
272                 return error(s, EINVAL, "Failed to generate key material");
273         }
274
275         return true;
276 }
277
278 // Send an ACKnowledgement record.
279 static bool send_ack(sptps_t *s) {
280         return send_record_priv(s, SPTPS_HANDSHAKE, "", 0);
281 }
282
283 // Receive an ACKnowledgement record.
284 static bool receive_ack(sptps_t *s, const char *data, uint16_t len) {
285         (void)data;
286
287         if(len) {
288                 return error(s, EIO, "Invalid ACK record length");
289         }
290
291         if(s->initiator) {
292                 if(!chacha_poly1305_set_key(s->incipher, s->key)) {
293                         return error(s, EINVAL, "Failed to set counter");
294                 }
295         } else {
296                 if(!chacha_poly1305_set_key(s->incipher, s->key + CHACHA_POLY1305_KEYLEN)) {
297                         return error(s, EINVAL, "Failed to set counter");
298                 }
299         }
300
301         free(s->key);
302         s->key = NULL;
303         s->instate = true;
304
305         return true;
306 }
307
308 // Receive a Key EXchange record, respond by sending a SIG record.
309 static bool receive_kex(sptps_t *s, const char *data, uint16_t len) {
310         // Verify length of the HELLO record
311         if(len != 1 + 32 + ECDH_SIZE) {
312                 return error(s, EIO, "Invalid KEX record length");
313         }
314
315         // Ignore version number for now.
316
317         // Make a copy of the KEX message, send_sig() and receive_sig() need it
318         if(s->hiskex) {
319                 return error(s, EINVAL, "Received a second KEX message before first has been processed");
320         }
321
322         s->hiskex = realloc(s->hiskex, len);
323
324         if(!s->hiskex) {
325                 return error(s, errno, strerror(errno));
326         }
327
328         memcpy(s->hiskex, data, len);
329
330         return send_sig(s);
331 }
332
333 // Receive a SIGnature record, verify it, if it passed, compute the shared secret and calculate the session keys.
334 static bool receive_sig(sptps_t *s, const char *data, uint16_t len) {
335         size_t keylen = ECDH_SIZE;
336         size_t siglen = ecdsa_size(s->hiskey);
337
338         // Verify length of KEX record.
339         if(len != siglen) {
340                 return error(s, EIO, "Invalid KEX record length");
341         }
342
343         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
344         char msg[(1 + 32 + keylen) * 2 + 1 + s->labellen];
345
346         msg[0] = !s->initiator;
347         memcpy(msg + 1, s->hiskex, 1 + 32 + keylen);
348         memcpy(msg + 1 + 33 + keylen, s->mykex, 1 + 32 + keylen);
349         memcpy(msg + 1 + 2 * (33 + keylen), s->label, s->labellen);
350
351         // Verify signature.
352         if(!ecdsa_verify(s->hiskey, msg, sizeof(msg), data)) {
353                 return error(s, EIO, "Failed to verify SIG record");
354         }
355
356         // Compute shared secret.
357         char shared[ECDH_SHARED_SIZE];
358
359         if(!ecdh_compute_shared(s->ecdh, s->hiskex + 1 + 32, shared)) {
360                 return error(s, EINVAL, "Failed to compute ECDH shared secret");
361         }
362
363         s->ecdh = NULL;
364
365         // Generate key material from shared secret.
366         if(!generate_key_material(s, shared, sizeof(shared))) {
367                 return false;
368         }
369
370         free(s->mykex);
371         free(s->hiskex);
372
373         s->mykex = NULL;
374         s->hiskex = NULL;
375
376         // Send cipher change record
377         if(s->outstate && !send_ack(s)) {
378                 return false;
379         }
380
381         // TODO: only set new keys after ACK has been set/received
382         if(s->initiator) {
383                 if(!chacha_poly1305_set_key(s->outcipher, s->key + CHACHA_POLY1305_KEYLEN)) {
384                         return error(s, EINVAL, "Failed to set key");
385                 }
386         } else {
387                 if(!chacha_poly1305_set_key(s->outcipher, s->key)) {
388                         return error(s, EINVAL, "Failed to set key");
389                 }
390         }
391
392         return true;
393 }
394
395 // Force another Key EXchange (for testing purposes).
396 bool sptps_force_kex(sptps_t *s) {
397         if(!s->outstate || s->state != SPTPS_SECONDARY_KEX) {
398                 return error(s, EINVAL, "Cannot force KEX in current state");
399         }
400
401         s->state = SPTPS_KEX;
402         return send_kex(s);
403 }
404
405 // Receive a handshake record.
406 static bool receive_handshake(sptps_t *s, const char *data, uint16_t len) {
407         // Only a few states to deal with handshaking.
408         switch(s->state) {
409         case SPTPS_SECONDARY_KEX:
410
411                 // We receive a secondary KEX request, first respond by sending our own.
412                 if(!send_kex(s)) {
413                         return false;
414                 }
415
416         // fallthrough
417         case SPTPS_KEX:
418
419                 // We have sent our KEX request, we expect our peer to sent one as well.
420                 if(!receive_kex(s, data, len)) {
421                         return false;
422                 }
423
424                 s->state = SPTPS_SIG;
425                 return true;
426
427         case SPTPS_SIG:
428
429                 // If we already sent our secondary public ECDH key, we expect the peer to send his.
430                 if(!receive_sig(s, data, len)) {
431                         return false;
432                 }
433
434                 if(s->outstate) {
435                         s->state = SPTPS_ACK;
436                 } else {
437                         s->outstate = true;
438
439                         if(!receive_ack(s, NULL, 0)) {
440                                 return false;
441                         }
442
443                         s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
444                         s->state = SPTPS_SECONDARY_KEX;
445                 }
446
447                 return true;
448
449         case SPTPS_ACK:
450
451                 // We expect a handshake message to indicate transition to the new keys.
452                 if(!receive_ack(s, data, len)) {
453                         return false;
454                 }
455
456                 s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
457                 s->state = SPTPS_SECONDARY_KEX;
458                 return true;
459
460         // TODO: split ACK into a VERify and ACK?
461         default:
462                 return error(s, EIO, "Invalid session state %d", s->state);
463         }
464 }
465
466 // Check datagram for valid HMAC
467 bool sptps_verify_datagram(sptps_t *s, const void *data, size_t len) {
468         if(!s->instate) {
469                 return error(s, EIO, "SPTPS state not ready to verify this datagram");
470         }
471
472         if(len < 21) {
473                 return error(s, EIO, "Received short packet in sptps_verify_datagram");
474         }
475
476         uint32_t seqno;
477         memcpy(&seqno, data, 4);
478         seqno = ntohl(seqno);
479         // TODO: check whether seqno makes sense, to avoid CPU intensive decrypt
480
481         char buffer[len];
482         size_t outlen;
483         return chacha_poly1305_decrypt(s->incipher, seqno, (const char *)data + 4, len - 4, buffer, &outlen);
484 }
485
486 // Receive incoming data, datagram version.
487 static bool sptps_receive_data_datagram(sptps_t *s, const void *vdata, size_t len) {
488         const char *data = vdata;
489
490         if(len < (s->instate ? 21 : 5)) {
491                 return error(s, EIO, "Received short packet in sptps_receive_data_datagram");
492         }
493
494         uint32_t seqno;
495         memcpy(&seqno, data, 4);
496         seqno = ntohl(seqno);
497
498         if(!s->instate) {
499                 if(seqno != s->inseqno) {
500                         return error(s, EIO, "Invalid packet seqno: %d != %d", seqno, s->inseqno);
501                 }
502
503                 s->inseqno = seqno + 1;
504
505                 uint8_t type = data[4];
506
507                 if(type != SPTPS_HANDSHAKE) {
508                         return error(s, EIO, "Application record received before handshake finished");
509                 }
510
511                 return receive_handshake(s, data + 5, len - 5);
512         }
513
514         // Decrypt
515
516         char buffer[len];
517
518         size_t outlen;
519
520         if(!chacha_poly1305_decrypt(s->incipher, seqno, data + 4, len - 4, buffer, &outlen)) {
521                 return error(s, EIO, "Failed to decrypt and verify packet");
522         }
523
524         // Replay protection using a sliding window of configurable size.
525         // s->inseqno is expected sequence number
526         // seqno is received sequence number
527         // s->late[] is a circular buffer, a 1 bit means a packet has not been received yet
528         // The circular buffer contains bits for sequence numbers from s->inseqno - s->replaywin * 8 to (but excluding) s->inseqno.
529         if(s->replaywin) {
530                 if(seqno != s->inseqno) {
531                         if(seqno >= s->inseqno + s->replaywin * 8) {
532                                 // TODO: Prevent packets that jump far ahead of the queue from causing many others to be dropped.
533                                 warning(s, "Lost %d packets\n", seqno - s->inseqno);
534                                 // Mark all packets in the replay window as being late.
535                                 memset(s->late, 255, s->replaywin);
536                         } else if(seqno < s->inseqno) {
537                                 // If the sequence number is farther in the past than the bitmap goes, or if the packet was already received, drop it.
538                                 if((s->inseqno >= s->replaywin * 8 && seqno < s->inseqno - s->replaywin * 8) || !(s->late[(seqno / 8) % s->replaywin] & (1 << seqno % 8))) {
539                                         return error(s, EIO, "Received late or replayed packet, seqno %d, last received %d\n", seqno, s->inseqno);
540                                 }
541                         } else {
542                                 // We missed some packets. Mark them in the bitmap as being late.
543                                 for(uint32_t i = s->inseqno; i < seqno; i++) {
544                                         s->late[(i / 8) % s->replaywin] |= 1 << i % 8;
545                                 }
546                         }
547                 }
548
549                 // Mark the current packet as not being late.
550                 s->late[(seqno / 8) % s->replaywin] &= ~(1 << seqno % 8);
551         }
552
553         if(seqno >= s->inseqno) {
554                 s->inseqno = seqno + 1;
555         }
556
557         if(!s->inseqno) {
558                 s->received = 0;
559         } else {
560                 s->received++;
561         }
562
563         // Append a NULL byte for safety.
564         buffer[len - 20] = 0;
565
566         uint8_t type = buffer[0];
567
568         if(type < SPTPS_HANDSHAKE) {
569                 if(!s->instate) {
570                         return error(s, EIO, "Application record received before handshake finished");
571                 }
572
573                 if(!s->receive_record(s->handle, type, buffer + 1, len - 21)) {
574                         abort();
575                 }
576         } else if(type == SPTPS_HANDSHAKE) {
577                 if(!receive_handshake(s, buffer + 1, len - 21)) {
578                         abort();
579                 }
580         } else {
581                 return error(s, EIO, "Invalid record type %d", type);
582         }
583
584         return true;
585 }
586
587 // Receive incoming data. Check if it contains a complete record, if so, handle it.
588 bool sptps_receive_data(sptps_t *s, const void *data, size_t len) {
589         if(!s->state) {
590                 return error(s, EIO, "Invalid session state zero");
591         }
592
593         if(s->datagram) {
594                 return sptps_receive_data_datagram(s, data, len);
595         }
596
597         const char *ptr = data;
598
599         while(len) {
600                 if(s->passthrough) {
601                         if(!s->buflen && s->reclen <= len) {
602                                 len -= s->reclen;
603                                 ptr += s->reclen;
604
605                                 s->reclen = 0;
606                                 s->passthrough = false;
607
608                                 if(!s->receive_record(s->handle, SPTPS_UNENCRYPTED, data, s->reclen)) {
609                                         return false;
610                                 }
611
612                                 continue;
613                         }
614
615                         size_t toread = s->reclen - s->buflen;
616                         if (toread >= len) {
617                                 toread = len;
618                         }
619
620                         memcpy(s->inbuf + s->buflen, ptr, toread);
621                         s->buflen += toread;
622                         len -= toread;
623                         ptr += toread;
624
625                         if(s->buflen < s->reclen) {
626                                 return;
627                         }
628
629                         s->reclen = 0;
630                         s->passthrough = false;
631
632                         if(!s->receive_record(s->handle, SPTPS_UNENCRYPTED, data, s->reclen)) {
633                                 return false;
634                         }
635
636                         s->buflen = 0;
637                         continue;
638                 }
639
640                 // First read the 2 length bytes.
641                 if(s->buflen < 2) {
642                         size_t toread = 2 - s->buflen;
643
644                         if(toread > len) {
645                                 toread = len;
646                         }
647
648                         memcpy(s->inbuf + s->buflen, ptr, toread);
649
650                         s->buflen += toread;
651                         len -= toread;
652                         ptr += toread;
653
654                         // Exit early if we don't have the full length.
655                         if(s->buflen < 2) {
656                                 return true;
657                         }
658
659                         // Get the length bytes
660
661                         memcpy(&s->reclen, s->inbuf, 2);
662                         s->reclen = ntohs(s->reclen);
663
664                         // If we have the length bytes, ensure our buffer can hold the whole request.
665                         s->inbuf = realloc(s->inbuf, s->reclen + 19UL);
666
667                         if(!s->inbuf) {
668                                 return error(s, errno, strerror(errno));
669                         }
670
671                         // Exit early if we have no more data to process.
672                         if(!len) {
673                                 return true;
674                         }
675                 }
676
677                 // Read up to the end of the record.
678                 size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->buflen;
679
680                 if(toread > len) {
681                         toread = len;
682                 }
683
684                 memcpy(s->inbuf + s->buflen, ptr, toread);
685                 s->buflen += toread;
686                 len -= toread;
687                 ptr += toread;
688
689                 // If we don't have a whole record, exit.
690                 if(s->buflen < s->reclen + (s->instate ? 19UL : 3UL)) {
691                         return true;
692                 }
693
694                 // Update sequence number.
695
696                 uint32_t seqno = s->inseqno++;
697
698                 // Check HMAC and decrypt.
699                 if(s->instate) {
700                         if(!chacha_poly1305_decrypt(s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL)) {
701                                 return error(s, EINVAL, "Failed to decrypt and verify record");
702                         }
703                 }
704
705                 // Append a NULL byte for safety.
706                 s->inbuf[s->reclen + 3UL] = 0;
707
708                 uint8_t type = s->inbuf[2];
709
710                 if(type < SPTPS_HANDSHAKE) {
711                         if(!s->instate) {
712                                 return error(s, EIO, "Application record received before handshake finished");
713                         }
714
715                         if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen)) {
716                                 return false;
717                         }
718                 } else if(type == SPTPS_HANDSHAKE) {
719                         if(!receive_handshake(s, s->inbuf + 3, s->reclen)) {
720                                 return false;
721                         }
722                 } else {
723                         return error(s, EIO, "Invalid record type %d", type);
724                 }
725
726                 s->buflen = 0;
727         }
728
729         return true;
730 }
731
732 // Start a SPTPS session.
733 bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_t *mykey, ecdsa_t *hiskey, const char *label, size_t labellen, send_data_t send_data, receive_record_t receive_record) {
734         if(!s || !mykey || !hiskey || !label || !labellen || !send_data || !receive_record) {
735                 return error(s, EINVAL, "Invalid argument to sptps_start()");
736         }
737
738         // Initialise struct sptps
739         memset(s, 0, sizeof(*s));
740
741         s->handle = handle;
742         s->initiator = initiator;
743         s->datagram = datagram;
744         s->mykey = mykey;
745         s->hiskey = hiskey;
746         s->replaywin = sptps_replaywin;
747
748         if(s->replaywin) {
749                 s->late = malloc(s->replaywin);
750
751                 if(!s->late) {
752                         return error(s, errno, strerror(errno));
753                 }
754
755                 memset(s->late, 0, s->replaywin);
756         }
757
758         s->label = malloc(labellen);
759
760         if(!s->label) {
761                 return error(s, errno, strerror(errno));
762         }
763
764         if(!datagram) {
765                 s->inbuf = malloc(7);
766
767                 if(!s->inbuf) {
768                         return error(s, errno, strerror(errno));
769                 }
770
771                 s->buflen = 0;
772         }
773
774         memcpy(s->label, label, labellen);
775         s->labellen = labellen;
776
777         s->send_data = send_data;
778         s->receive_record = receive_record;
779
780         // Do first KEX immediately
781         s->state = SPTPS_KEX;
782         return send_kex(s);
783 }
784
785 // Stop a SPTPS session.
786 bool sptps_stop(sptps_t *s) {
787         // Clean up any resources.
788         chacha_poly1305_exit(s->incipher);
789         chacha_poly1305_exit(s->outcipher);
790         ecdh_free(s->ecdh);
791         free(s->inbuf);
792         free(s->mykex);
793         free(s->hiskex);
794         free(s->key);
795         free(s->label);
796         free(s->late);
797         memset(s, 0, sizeof(*s));
798         return true;
799 }