]> git.meshlink.io Git - meshlink/commitdiff
Move join state out of meshlink_handle_t, and ensure proper cleanup on errors.
authorGuus Sliepen <guus@meshlink.io>
Tue, 11 Feb 2020 20:37:33 +0000 (21:37 +0100)
committerGuus Sliepen <guus@meshlink.io>
Tue, 11 Feb 2020 20:37:33 +0000 (21:37 +0100)
Move the state we keep when calling meshlink_join() out of meshlink_handle_t
and just put it on the stack of meshlink_join(). Also make sure we properly
release allocated resources in all error conditions during a join.

src/meshlink.c
src/meshlink_internal.h

index 5b01ad2c1fda7b12abb823baa9c80da391177fb7..eb9142cc743d842c7f8967e0b9a4408aa170113e 100644 (file)
@@ -609,7 +609,22 @@ static bool write_main_config_files(meshlink_handle_t *mesh) {
        return true;
 }
 
-static bool finalize_join(meshlink_handle_t *mesh, const void *buf, uint16_t len) {
+typedef struct {
+       meshlink_handle_t *mesh;
+       int sock;
+       char cookie[18 + 32];
+       char hash[18];
+       bool success;
+       sptps_t sptps;
+       char *data;
+       size_t thedatalen;
+       size_t blen;
+       char line[4096];
+       char buffer[4096];
+} join_state_t;
+
+static bool finalize_join(join_state_t *state, const void *buf, uint16_t len) {
+       meshlink_handle_t *mesh = state->mesh;
        packmsg_input_t in = {buf, len};
        uint32_t version = packmsg_get_uint32(&in);
 
@@ -706,7 +721,7 @@ static bool finalize_join(meshlink_handle_t *mesh, const void *buf, uint16_t len
                        sockaddr_t sa;
                        socklen_t salen = sizeof(sa);
 
-                       if(getpeername(mesh->sock, &sa.sa, &salen) == 0) {
+                       if(getpeername(state->sock, &sa.sa, &salen) == 0) {
                                node_add_recent_address(mesh, n, &sa);
                        }
                }
@@ -728,7 +743,7 @@ static bool finalize_join(meshlink_handle_t *mesh, const void *buf, uint16_t len
                return false;
        }
 
-       sptps_send_record(&mesh->sptps, 1, ecdsa_get_public_key(mesh->private_key), 32);
+       sptps_send_record(&state->sptps, 1, ecdsa_get_public_key(mesh->private_key), 32);
 
        logger(mesh, MESHLINK_DEBUG, "Configuration stored in: %s\n", mesh->confbase);
 
@@ -737,11 +752,11 @@ static bool finalize_join(meshlink_handle_t *mesh, const void *buf, uint16_t len
 
 static bool invitation_send(void *handle, uint8_t type, const void *data, size_t len) {
        (void)type;
-       meshlink_handle_t *mesh = handle;
+       join_state_t *state = handle;
        const char *ptr = data;
 
        while(len) {
-               int result = send(mesh->sock, ptr, len, 0);
+               int result = send(state->sock, ptr, len, 0);
 
                if(result == -1 && errno == EINTR) {
                        continue;
@@ -757,19 +772,20 @@ static bool invitation_send(void *handle, uint8_t type, const void *data, size_t
 }
 
 static bool invitation_receive(void *handle, uint8_t type, const void *msg, uint16_t len) {
-       meshlink_handle_t *mesh = handle;
+       join_state_t *state = handle;
+       meshlink_handle_t *mesh = state->mesh;
 
        switch(type) {
        case SPTPS_HANDSHAKE:
-               return sptps_send_record(&mesh->sptps, 0, mesh->cookie, sizeof(mesh)->cookie);
+               return sptps_send_record(&state->sptps, 0, state->cookie, 18);
 
        case 0:
-               return finalize_join(mesh, msg, len);
+               return finalize_join(state, msg, len);
 
        case 1:
                logger(mesh, MESHLINK_DEBUG, "Invitation successfully accepted.\n");
-               shutdown(mesh->sock, SHUT_RDWR);
-               mesh->success = true;
+               shutdown(state->sock, SHUT_RDWR);
+               state->success = true;
                break;
 
        default:
@@ -779,15 +795,11 @@ static bool invitation_receive(void *handle, uint8_t type, const void *msg, uint
        return true;
 }
 
-static bool recvline(meshlink_handle_t *mesh, size_t len) {
+static bool recvline(join_state_t *state) {
        char *newline = NULL;
 
-       if(!mesh->sock) {
-               abort();
-       }
-
-       while(!(newline = memchr(mesh->buffer, '\n', mesh->blen))) {
-               int result = recv(mesh->sock, mesh->buffer + mesh->blen, sizeof(mesh)->buffer - mesh->blen, 0);
+       while(!(newline = memchr(state->buffer, '\n', state->blen))) {
+               int result = recv(state->sock, state->buffer + state->blen, sizeof(state)->buffer - state->blen, 0);
 
                if(result == -1 && errno == EINTR) {
                        continue;
@@ -795,19 +807,19 @@ static bool recvline(meshlink_handle_t *mesh, size_t len) {
                        return false;
                }
 
-               mesh->blen += result;
+               state->blen += result;
        }
 
-       if((size_t)(newline - mesh->buffer) >= len) {
+       if((size_t)(newline - state->buffer) >= sizeof(state->line)) {
                return false;
        }
 
-       len = newline - mesh->buffer;
+       size_t len = newline - state->buffer;
 
-       memcpy(mesh->line, mesh->buffer, len);
-       mesh->line[len] = 0;
-       memmove(mesh->buffer, newline + 1, mesh->blen - len - 1);
-       mesh->blen -= len + 1;
+       memcpy(state->line, state->buffer, len);
+       state->line[len] = 0;
+       memmove(state->buffer, newline + 1, state->blen - len - 1);
+       state->blen -= len + 1;
 
        return true;
 }
@@ -1546,8 +1558,6 @@ bool meshlink_start(meshlink_handle_t *mesh) {
                return false;
        }
 
-       mesh->thedatalen = 0;
-
        // TODO: open listening sockets first
 
        //Check that a valid name is set
@@ -2587,26 +2597,33 @@ bool meshlink_join(meshlink_handle_t *mesh, const char *invitation) {
                return false;
        }
 
+       join_state_t state = {
+               .mesh = mesh,
+               .sock = -1,
+       };
+
+       ecdsa_t *key = NULL;
+       ecdsa_t *hiskey = NULL;
+
+       //TODO: think of a better name for this variable, or of a different way to tokenize the invitation URL.
+       char copy[strlen(invitation) + 1];
+
        pthread_mutex_lock(&mesh->mutex);
 
        //Before doing meshlink_join make sure we are not connected to another mesh
        if(mesh->threadstarted) {
                logger(mesh, MESHLINK_ERROR, "Cannot join while started\n");
                meshlink_errno = MESHLINK_EINVAL;
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
+               goto exit;
        }
 
        // Refuse to join a mesh if we are already part of one. We are part of one if we know at least one other node.
        if(mesh->nodes->count > 1) {
                logger(mesh, MESHLINK_ERROR, "Already part of an existing mesh\n");
                meshlink_errno = MESHLINK_EINVAL;
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
+               goto exit;
        }
 
-       //TODO: think of a better name for this variable, or of a different way to tokenize the invitation URL.
-       char copy[strlen(invitation) + 1];
        strcpy(copy, invitation);
 
        // Split the invitation URL into a list of hostname/port tuples, a key hash and a cookie.
@@ -2626,22 +2643,20 @@ bool meshlink_join(meshlink_handle_t *mesh, const char *invitation) {
        char *address = copy;
        char *port = NULL;
 
-       if(!b64decode(slash, mesh->hash, 18) || !b64decode(slash + 24, mesh->cookie, 18)) {
+       if(!b64decode(slash, state.hash, 18) || !b64decode(slash + 24, state.cookie, 18)) {
                goto invalid;
        }
 
        // Generate a throw-away key for the invitation.
-       ecdsa_t *key = ecdsa_generate();
+       key = ecdsa_generate();
 
        if(!key) {
                meshlink_errno = MESHLINK_EINTERNAL;
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
+               goto exit;
        }
 
        char *b64key = ecdsa_get_base64_public_key(key);
        char *comma;
-       mesh->sock = -1;
 
        while(address && *address) {
                // We allow commas in the address part to support multiple addresses in one invitation URL.
@@ -2681,21 +2696,21 @@ bool meshlink_join(meshlink_handle_t *mesh, const char *invitation) {
 
                if(ai) {
                        for(struct addrinfo *aip = ai; aip; aip = aip->ai_next) {
-                               mesh->sock = socket_in_netns(aip->ai_family, aip->ai_socktype, aip->ai_protocol, mesh->netns);
+                               state.sock = socket_in_netns(aip->ai_family, aip->ai_socktype, aip->ai_protocol, mesh->netns);
 
-                               if(mesh->sock == -1) {
+                               if(state.sock == -1) {
                                        logger(mesh, MESHLINK_DEBUG, "Could not open socket: %s\n", strerror(errno));
                                        meshlink_errno = MESHLINK_ENETWORK;
                                        continue;
                                }
 
-                               set_timeout(mesh->sock, 5000);
+                               set_timeout(state.sock, 5000);
 
-                               if(connect(mesh->sock, aip->ai_addr, aip->ai_addrlen)) {
+                               if(connect(state.sock, aip->ai_addr, aip->ai_addrlen)) {
                                        logger(mesh, MESHLINK_DEBUG, "Could not connect to %s port %s: %s\n", address, port, strerror(errno));
                                        meshlink_errno = MESHLINK_ENETWORK;
-                                       closesocket(mesh->sock);
-                                       mesh->sock = -1;
+                                       closesocket(state.sock);
+                                       state.sock = -1;
                                        continue;
                                }
                        }
@@ -2705,30 +2720,27 @@ bool meshlink_join(meshlink_handle_t *mesh, const char *invitation) {
                        meshlink_errno = MESHLINK_ERESOLV;
                }
 
-               if(mesh->sock != -1 || !comma) {
+               if(state.sock != -1 || !comma) {
                        break;
                }
 
                address = comma;
        }
 
-       if(mesh->sock == -1) {
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
+       if(state.sock == -1) {
+               goto exit;
        }
 
        logger(mesh, MESHLINK_DEBUG, "Connected to %s port %s...\n", address, port);
 
        // Tell him we have an invitation, and give him our throw-away key.
 
-       mesh->blen = 0;
+       state.blen = 0;
 
-       if(!sendline(mesh->sock, "0 ?%s %d.%d %s", b64key, PROT_MAJOR, PROT_MINOR, mesh->appname)) {
+       if(!sendline(state.sock, "0 ?%s %d.%d %s", b64key, PROT_MAJOR, PROT_MINOR, mesh->appname)) {
                logger(mesh, MESHLINK_DEBUG, "Error sending request to %s port %s: %s\n", address, port, strerror(errno));
-               closesocket(mesh->sock);
                meshlink_errno = MESHLINK_ENETWORK;
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
+               goto exit;
        }
 
        free(b64key);
@@ -2736,58 +2748,51 @@ bool meshlink_join(meshlink_handle_t *mesh, const char *invitation) {
        char hisname[4096] = "";
        int code, hismajor, hisminor = 0;
 
-       if(!recvline(mesh, sizeof(mesh)->line) || sscanf(mesh->line, "%d %s %d.%d", &code, hisname, &hismajor, &hisminor) < 3 || code != 0 || hismajor != PROT_MAJOR || !check_id(hisname) || !recvline(mesh, sizeof(mesh)->line) || !rstrip(mesh->line) || sscanf(mesh->line, "%d ", &code) != 1 || code != ACK || strlen(mesh->line) < 3) {
+       if(!recvline(&state) || sscanf(state.line, "%d %s %d.%d", &code, hisname, &hismajor, &hisminor) < 3 || code != 0 || hismajor != PROT_MAJOR || !check_id(hisname) || !recvline(&state) || !rstrip(state.line) || sscanf(state.line, "%d ", &code) != 1 || code != ACK || strlen(state.line) < 3) {
                logger(mesh, MESHLINK_DEBUG, "Cannot read greeting from peer\n");
-               closesocket(mesh->sock);
                meshlink_errno = MESHLINK_ENETWORK;
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
+               goto exit;
        }
 
        // Check if the hash of the key he gave us matches the hash in the URL.
-       char *fingerprint = mesh->line + 2;
+       char *fingerprint = state.line + 2;
        char hishash[64];
 
        if(sha512(fingerprint, strlen(fingerprint), hishash)) {
-               logger(mesh, MESHLINK_DEBUG, "Could not create hash\n%s\n", mesh->line + 2);
+               logger(mesh, MESHLINK_DEBUG, "Could not create hash\n%s\n", state.line + 2);
                meshlink_errno = MESHLINK_EINTERNAL;
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
+               goto exit;
        }
 
-       if(memcmp(hishash, mesh->hash, 18)) {
-               logger(mesh, MESHLINK_DEBUG, "Peer has an invalid key!\n%s\n", mesh->line + 2);
+       if(memcmp(hishash, state.hash, 18)) {
+               logger(mesh, MESHLINK_DEBUG, "Peer has an invalid key!\n%s\n", state.line + 2);
                meshlink_errno = MESHLINK_EPEER;
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
-
+               goto exit;
        }
 
-       ecdsa_t *hiskey = ecdsa_set_base64_public_key(fingerprint);
+       hiskey = ecdsa_set_base64_public_key(fingerprint);
 
        if(!hiskey) {
                meshlink_errno = MESHLINK_EINTERNAL;
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
+               goto exit;
        }
 
        // Start an SPTPS session
-       if(!sptps_start(&mesh->sptps, mesh, true, false, key, hiskey, meshlink_invitation_label, sizeof(meshlink_invitation_label), invitation_send, invitation_receive)) {
+       if(!sptps_start(&state.sptps, &state, true, false, key, hiskey, meshlink_invitation_label, sizeof(meshlink_invitation_label), invitation_send, invitation_receive)) {
                meshlink_errno = MESHLINK_EINTERNAL;
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
+               goto exit;
        }
 
        // Feed rest of input buffer to SPTPS
-       if(!sptps_receive_data(&mesh->sptps, mesh->buffer, mesh->blen)) {
+       if(!sptps_receive_data(&state.sptps, state.buffer, state.blen)) {
                meshlink_errno = MESHLINK_EPEER;
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
+               goto exit;
        }
 
-       int len;
+       ssize_t len;
+       logger(mesh, MESHLINK_DEBUG, "Starting invitation recv loop: %d %zu\n", state.sock, sizeof(state.line));
 
-       while((len = recv(mesh->sock, mesh->line, sizeof(mesh)->line, 0))) {
+       while((len = recv(state.sock, state.line, sizeof(state.line), 0))) {
                if(len < 0) {
                        if(errno == EINTR) {
                                continue;
@@ -2795,35 +2800,41 @@ bool meshlink_join(meshlink_handle_t *mesh, const char *invitation) {
 
                        logger(mesh, MESHLINK_DEBUG, "Error reading data from %s port %s: %s\n", address, port, strerror(errno));
                        meshlink_errno = MESHLINK_ENETWORK;
-                       pthread_mutex_unlock(&mesh->mutex);
-                       return false;
+                       goto exit;
                }
 
-               if(!sptps_receive_data(&mesh->sptps, mesh->line, len)) {
+               if(!sptps_receive_data(&state.sptps, state.line, len)) {
                        meshlink_errno = MESHLINK_EPEER;
-                       pthread_mutex_unlock(&mesh->mutex);
-                       return false;
+                       goto exit;
                }
        }
 
-       sptps_stop(&mesh->sptps);
-       ecdsa_free(hiskey);
-       ecdsa_free(key);
-       closesocket(mesh->sock);
-
-       if(!mesh->success) {
+       if(!state.success) {
                logger(mesh, MESHLINK_DEBUG, "Connection closed by peer, invitation cancelled.\n");
                meshlink_errno = MESHLINK_EPEER;
-               pthread_mutex_unlock(&mesh->mutex);
-               return false;
+               goto exit;
        }
 
+       sptps_stop(&state.sptps);
+       ecdsa_free(hiskey);
+       ecdsa_free(key);
+       closesocket(state.sock);
+
        pthread_mutex_unlock(&mesh->mutex);
        return true;
 
 invalid:
        logger(mesh, MESHLINK_DEBUG, "Invalid invitation URL\n");
        meshlink_errno = MESHLINK_EINVAL;
+exit:
+       sptps_stop(&state.sptps);
+       ecdsa_free(hiskey);
+       ecdsa_free(key);
+
+       if(state.sock != -1) {
+               closesocket(state.sock);
+       }
+
        pthread_mutex_unlock(&mesh->mutex);
        return false;
 }
@@ -3906,6 +3917,17 @@ void meshlink_set_dev_class_fast_retry_period(meshlink_handle_t *mesh, dev_class
        pthread_mutex_unlock(&mesh->mutex);
 }
 
+extern void meshlink_set_inviter_commits_first(struct meshlink_handle *mesh, bool inviter_commits_first) {
+       if(!mesh) {
+               meshlink_errno = EINVAL;
+               return;
+       }
+
+       pthread_mutex_lock(&mesh->mutex);
+       mesh->inviter_commits_first = inviter_commits_first;
+       pthread_mutex_unlock(&mesh->mutex);
+}
+
 void handle_network_change(meshlink_handle_t *mesh, bool online) {
        (void)online;
 
@@ -3931,7 +3953,6 @@ void call_error_cb(meshlink_handle_t *mesh, meshlink_errno_t meshlink_errno) {
        }
 }
 
-
 static void __attribute__((constructor)) meshlink_init(void) {
        crypto_init();
 }
index 6aa871717563e8ea07cf813fd9a5be11fddfae89..8104ba60e4c953489ca1102651d25675d90cd451 100644 (file)
@@ -192,17 +192,6 @@ struct meshlink_handle {
        char *catta_servicetype;
        unsigned int catta_interfaces;
 
-       // State used for meshlink_join()
-       int sock;
-       char cookie[18], hash[18];
-       bool success;
-       sptps_t sptps;
-       char *data;
-       size_t thedatalen;
-       size_t blen;
-       char line[4096];
-       char buffer[4096];
-
        // Proxy configuration, currently not exposed.
        char *proxyhost;
        char *proxyport;