]> git.meshlink.io Git - meshlink/blobdiff - src/protocol_key.c
Include our own key in REQ_PUBKEY requests.
[meshlink] / src / protocol_key.c
index 641e2b969434c8941a5ca752b9dca69822175dea..16e97eb2c6053652668412982bafd2bfc7b98b25 100644 (file)
 
 static const int req_key_timeout = 2;
 
-void send_key_changed(meshlink_handle_t *mesh) {
-       send_request(mesh, mesh->everyone, NULL, "%d %x %s", KEY_CHANGED, prng(mesh, UINT_MAX), mesh->self->name);
-
-       /* Force key exchange for connections using SPTPS */
-
-       for splay_each(node_t, n, mesh->nodes)
-               if(n->status.reachable && n->status.validkey) {
-                       sptps_force_kex(&n->sptps);
-               }
-}
-
 bool key_changed_h(meshlink_handle_t *mesh, connection_t *c, const char *request) {
        assert(request);
        assert(*request);
@@ -91,7 +80,9 @@ static bool send_initial_sptps_data(void *handle, uint8_t type, const void *data
 bool send_req_key(meshlink_handle_t *mesh, node_t *to) {
        if(!node_read_public_key(mesh, to)) {
                logger(mesh, MESHLINK_DEBUG, "No ECDSA key known for %s", to->name);
-               send_request(mesh, to->nexthop->connection, NULL, "%d %s %s %d", REQ_KEY, mesh->self->name, to->name, REQ_PUBKEY);
+               char *pubkey = ecdsa_get_base64_public_key(mesh->private_key);
+               send_request(mesh, to->nexthop->connection, NULL, "%d %s %s %d %s", REQ_KEY, mesh->self->name, to->name, REQ_PUBKEY, pubkey);
+               free(pubkey);
                return true;
        }
 
@@ -121,6 +112,19 @@ static bool req_key_ext_h(meshlink_handle_t *mesh, connection_t *c, const char *
                        return false;
                }
 
+               if(!node_read_public_key(mesh, from)) {
+                       char hiskey[MAX_STRING_SIZE];
+
+                       if(sscanf(request, "%*d %*s %*s %*d " MAX_STRING, hiskey) == 1) {
+                               from->ecdsa = ecdsa_set_base64_public_key(hiskey);
+
+                               if(!from->ecdsa) {
+                                       logger(mesh, MESHLINK_ERROR, "Got bad %s from %s: %s", "REQ_PUBKEY", from->name, "invalid pubkey");
+                                       return true;
+                               }
+                       }
+               }
+
                send_request(mesh, from->nexthop->connection, NULL, "%d %s %s %d %s", REQ_KEY, mesh->self->name, from->name, ANS_PUBKEY, pubkey);
                free(pubkey);
                return true;
@@ -146,7 +150,7 @@ static bool req_key_ext_h(meshlink_handle_t *mesh, connection_t *c, const char *
                for list_each(outgoing_t, outgoing, mesh->outgoings) {
                        if(outgoing->node == from && outgoing->ev.cb) {
                                outgoing->timeout = 0;
-                               timeout_set(&mesh->loop, &outgoing->ev, &(struct timeval) {
+                               timeout_set(&mesh->loop, &outgoing->ev, &(struct timespec) {
                                        0, 0
                                });
                        }
@@ -188,8 +192,17 @@ static bool req_key_ext_h(meshlink_handle_t *mesh, connection_t *c, const char *
                from->status.validkey = false;
                from->status.waitingforkey = true;
                from->last_req_key = mesh->loop.now.tv_sec;
-               sptps_start(&from->sptps, from, false, true, mesh->private_key, from->ecdsa, label, sizeof(label) - 1, send_sptps_data, receive_sptps_record);
-               sptps_receive_data(&from->sptps, buf, len);
+
+               if(!sptps_start(&from->sptps, from, false, true, mesh->private_key, from->ecdsa, label, sizeof(label) - 1, send_sptps_data, receive_sptps_record)) {
+                       logger(mesh, MESHLINK_ERROR, "Could not start SPTPS session with %s: %s", from->name, strerror(errno));
+                       return true;
+               }
+
+               if(!sptps_receive_data(&from->sptps, buf, len)) {
+                       logger(mesh, MESHLINK_ERROR, "Could not process SPTPS data from %s: %s", from->name, strerror(errno));
+                       return true;
+               }
+
                return true;
        }
 
@@ -207,7 +220,11 @@ static bool req_key_ext_h(meshlink_handle_t *mesh, connection_t *c, const char *
                        return true;
                }
 
-               sptps_receive_data(&from->sptps, buf, len);
+               if(!sptps_receive_data(&from->sptps, buf, len)) {
+                       logger(mesh, MESHLINK_ERROR, "Could not process SPTPS data from %s: %s", from->name, strerror(errno));
+                       return true;
+               }
+
                return true;
        }
 
@@ -332,12 +349,12 @@ bool ans_key_h(meshlink_handle_t *mesh, connection_t *c, const char *request) {
 
                /* Append the known UDP address of the from node, if we have a confirmed one */
                if(!*address && from->status.udp_confirmed && from->address.sa.sa_family != AF_UNSPEC) {
-                       char *address, *port;
+                       char *reflexive_address, *reflexive_port;
                        logger(mesh, MESHLINK_DEBUG, "Appending reflexive UDP address to ANS_KEY from %s to %s", from->name, to->name);
-                       sockaddr2str(&from->address, &address, &port);
-                       send_request(mesh, to->nexthop->connection, NULL, "%s %s %s", request, address, port);
-                       free(address);
-                       free(port);
+                       sockaddr2str(&from->address, &reflexive_address, &reflexive_port);
+                       send_request(mesh, to->nexthop->connection, NULL, "%s %s %s", request, reflexive_address, reflexive_port);
+                       free(reflexive_address);
+                       free(reflexive_port);
                        return true;
                }
 
@@ -352,7 +369,7 @@ bool ans_key_h(meshlink_handle_t *mesh, connection_t *c, const char *request) {
 
                        /* Inform all other nodes we want to communicate with and which are reachable via this connection */
                        for splay_each(node_t, n, mesh->nodes) {
-                               if(n->nexthop == c->node) {
+                               if(n->nexthop != c->node) {
                                        continue;
                                }