X-Git-Url: http://git.meshlink.io/?a=blobdiff_plain;f=src%2Fsptps.c;h=fab2e9d9b8a87a05dba12cb00e0cacc910d5a419;hb=HEAD;hp=f44374ee4bf472b7fea3127ff3b777ffa26aa517;hpb=dc0e52cb3e42620c3139e713b373d130aa30b698;p=meshlink diff --git a/src/sptps.c b/src/sptps.c index f44374ee..d4a31e81 100644 --- a/src/sptps.c +++ b/src/sptps.c @@ -27,8 +27,6 @@ #include "prf.h" #include "sptps.h" -unsigned int sptps_replaywin = 32; - /* Nonce MUST be exchanged first (done) Signatures MUST be done over both nonces, to guarantee the signature is fresh @@ -54,19 +52,27 @@ void sptps_log_quiet(sptps_t *s, int s_errno, const char *format, va_list ap) { (void)s_errno; (void)format; (void)ap; + + assert(format); } void sptps_log_stderr(sptps_t *s, int s_errno, const char *format, va_list ap) { (void)s; (void)s_errno; + + assert(format); + vfprintf(stderr, format, ap); fputc('\n', stderr); } -void (*sptps_log)(sptps_t *s, int s_errno, const char *format, va_list ap) = sptps_log_stderr; +void (*sptps_log)(sptps_t *s, int s_errno, const char *format, va_list ap) = sptps_log_quiet; // Log an error message. static bool error(sptps_t *s, int s_errno, const char *format, ...) { + assert(s_errno); + assert(format); + if(format) { va_list ap; va_start(ap, format); @@ -79,6 +85,8 @@ static bool error(sptps_t *s, int s_errno, const char *format, ...) { } static void warning(sptps_t *s, const char *format, ...) { + assert(format); + va_list ap; va_start(ap, format); sptps_log(s, 0, format, ap); @@ -87,7 +95,7 @@ static void warning(sptps_t *s, const char *format, ...) { // Send a record (datagram version, accepts all record types, handles encryption and authentication). static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const void *data, uint16_t len) { - char buffer[len + 21UL]; + char buffer[len + SPTPS_DATAGRAM_OVERHEAD]; // Create header with sequence number, length and record type uint32_t seqno = s->outseqno++; @@ -100,7 +108,7 @@ static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const void *data if(s->outstate) { // If first handshake has finished, encrypt and HMAC chacha_poly1305_encrypt(s->outcipher, seqno, buffer + 4, len + 1, buffer + 4, NULL); - return s->send_data(s->handle, type, buffer, len + 21UL); + return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_OVERHEAD); } else { // Otherwise send as plaintext return s->send_data(s->handle, type, buffer, len + 5UL); @@ -112,7 +120,7 @@ static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_ return send_record_priv_datagram(s, type, data, len); } - char buffer[len + 19UL]; + char buffer[len + SPTPS_OVERHEAD]; // Create header with sequence number, length and record type uint32_t seqno = s->outseqno++; @@ -125,7 +133,7 @@ static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_ if(s->outstate) { // If first handshake has finished, encrypt and HMAC chacha_poly1305_encrypt(s->outcipher, seqno, buffer + 2, len + 1, buffer + 2, NULL); - return s->send_data(s->handle, type, buffer, len + 19UL); + return s->send_data(s->handle, type, buffer, len + SPTPS_OVERHEAD); } else { // Otherwise send as plaintext return s->send_data(s->handle, type, buffer, len + 3UL); @@ -134,6 +142,8 @@ static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_ // Send an application record. bool sptps_send_record(sptps_t *s, uint8_t type, const void *data, uint16_t len) { + assert(!len || data); + // Sanity checks: application cannot send data before handshake is finished, // and only record types 0..127 are allowed. if(!s->outstate) { @@ -201,6 +211,9 @@ static bool send_sig(sptps_t *s) { // Generate key material from the shared secret created from the ECDHE key exchange. static bool generate_key_material(sptps_t *s, const char *shared, size_t len) { + assert(shared); + assert(len); + // Initialise cipher and digest structures if necessary if(!s->outstate) { s->incipher = chacha_poly1305_init(); @@ -361,10 +374,15 @@ static bool receive_sig(sptps_t *s, const char *data, uint16_t len) { // Force another Key EXchange (for testing purposes). bool sptps_force_kex(sptps_t *s) { - if(!s->outstate || s->state != SPTPS_SECONDARY_KEX) { + if(!s->outstate) { return error(s, EINVAL, "Cannot force KEX in current state"); } + if(s->state != SPTPS_SECONDARY_KEX) { + // We are already in the middle of a key exchange + return true; + } + s->state = SPTPS_KEX; return send_kex(s); } @@ -436,7 +454,7 @@ bool sptps_verify_datagram(sptps_t *s, const void *data, size_t len) { return error(s, EIO, "SPTPS state not ready to verify this datagram"); } - if(len < 21) { + if(len < SPTPS_DATAGRAM_OVERHEAD) { return error(s, EIO, "Received short packet in sptps_verify_datagram"); } @@ -445,16 +463,14 @@ bool sptps_verify_datagram(sptps_t *s, const void *data, size_t len) { seqno = ntohl(seqno); // TODO: check whether seqno makes sense, to avoid CPU intensive decrypt - char buffer[len]; - size_t outlen; - return chacha_poly1305_decrypt(s->incipher, seqno, (const char *)data + 4, len - 4, buffer, &outlen); + return chacha_poly1305_verify(s->incipher, seqno, (const char *)data + 4, len - 4); } // Receive incoming data, datagram version. static bool sptps_receive_data_datagram(sptps_t *s, const void *vdata, size_t len) { const char *data = vdata; - if(len < (s->instate ? 21 : 5)) { + if(len < (s->instate ? SPTPS_DATAGRAM_OVERHEAD : 5)) { return error(s, EIO, "Received short packet in sptps_receive_data_datagram"); } @@ -480,11 +496,20 @@ static bool sptps_receive_data_datagram(sptps_t *s, const void *vdata, size_t le // Decrypt - char buffer[len]; + if(len > s->decrypted_buffer_len) { + s->decrypted_buffer_len *= 2; + char *new_buffer = realloc(s->decrypted_buffer, s->decrypted_buffer_len); + + if(!new_buffer) { + return error(s, errno, strerror(errno)); + } + + s->decrypted_buffer = new_buffer; + } size_t outlen; - if(!chacha_poly1305_decrypt(s->incipher, seqno, data + 4, len - 4, buffer, &outlen)) { + if(!chacha_poly1305_decrypt(s->incipher, seqno, data + 4, len - 4, s->decrypted_buffer, &outlen)) { return error(s, EIO, "Failed to decrypt and verify packet"); } @@ -528,20 +553,20 @@ static bool sptps_receive_data_datagram(sptps_t *s, const void *vdata, size_t le } // Append a NULL byte for safety. - buffer[len - 20] = 0; + s->decrypted_buffer[len - 20] = 0; - uint8_t type = buffer[0]; + uint8_t type = s->decrypted_buffer[0]; if(type < SPTPS_HANDSHAKE) { if(!s->instate) { return error(s, EIO, "Application record received before handshake finished"); } - if(!s->receive_record(s->handle, type, buffer + 1, len - 21)) { + if(!s->receive_record(s->handle, type, s->decrypted_buffer + 1, len - SPTPS_DATAGRAM_OVERHEAD)) { abort(); } } else if(type == SPTPS_HANDSHAKE) { - if(!receive_handshake(s, buffer + 1, len - 21)) { + if(!receive_handshake(s, s->decrypted_buffer + 1, len - SPTPS_DATAGRAM_OVERHEAD)) { abort(); } } else { @@ -589,7 +614,7 @@ bool sptps_receive_data(sptps_t *s, const void *data, size_t len) { s->reclen = ntohs(s->reclen); // If we have the length bytes, ensure our buffer can hold the whole request. - s->inbuf = realloc(s->inbuf, s->reclen + 19UL); + s->inbuf = realloc(s->inbuf, s->reclen + SPTPS_OVERHEAD); if(!s->inbuf) { return error(s, errno, strerror(errno)); @@ -602,7 +627,7 @@ bool sptps_receive_data(sptps_t *s, const void *data, size_t len) { } // Read up to the end of the record. - size_t toread = s->reclen + (s->instate ? 19UL : 3UL) - s->buflen; + size_t toread = s->reclen + (s->instate ? SPTPS_OVERHEAD : 3UL) - s->buflen; if(toread > len) { toread = len; @@ -614,7 +639,7 @@ bool sptps_receive_data(sptps_t *s, const void *data, size_t len) { ptr += toread; // If we don't have a whole record, exit. - if(s->buflen < s->reclen + (s->instate ? 19UL : 3UL)) { + if(s->buflen < s->reclen + (s->instate ? SPTPS_OVERHEAD : 3UL)) { return true; } @@ -670,7 +695,13 @@ bool sptps_start(sptps_t *s, void *handle, bool initiator, bool datagram, ecdsa_ s->datagram = datagram; s->mykey = mykey; s->hiskey = hiskey; - s->replaywin = sptps_replaywin; + s->replaywin = 32; + s->decrypted_buffer_len = 1024; + s->decrypted_buffer = malloc(s->decrypted_buffer_len); + + if(!s->decrypted_buffer) { + return error(s, errno, strerror(errno)); + } if(s->replaywin) { s->late = malloc(s->replaywin); @@ -721,6 +752,8 @@ bool sptps_stop(sptps_t *s) { free(s->key); free(s->label); free(s->late); + memset(s->decrypted_buffer, 0, s->decrypted_buffer_len); + free(s->decrypted_buffer); memset(s, 0, sizeof(*s)); return true; }