]> git.meshlink.io Git - meshlink/blobdiff - src/utcp.c
Fix cornercases closing channels.
[meshlink] / src / utcp.c
index 47acf221318bca518eca4d88bdd7ba04b0982071..20dd0aba049757ca30833779ec772c63ba685310 100644 (file)
     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 */
 
-#define _GNU_SOURCE
-
-#include <assert.h>
-#include <errno.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <stdint.h>
-#include <stdbool.h>
-#include <string.h>
-#include <unistd.h>
+#include "system.h"
 #include <time.h>
 
 #include "utcp_priv.h"
@@ -268,7 +259,7 @@ static ssize_t buffer_put_at(struct buffer *buf, size_t offset, const void *data
 
        uint32_t realoffset = buf->offset + offset;
 
-       if(buf->size - buf->offset < offset) {
+       if(buf->size - buf->offset <= offset) {
                // The offset wrapped
                realoffset -= buf->size;
        }
@@ -305,7 +296,7 @@ static ssize_t buffer_copy(struct buffer *buf, void *data, size_t offset, size_t
 
        uint32_t realoffset = buf->offset + offset;
 
-       if(buf->size - buf->offset < offset) {
+       if(buf->size - buf->offset <= offset) {
                // The offset wrapped
                realoffset -= buf->size;
        }
@@ -322,7 +313,11 @@ static ssize_t buffer_copy(struct buffer *buf, void *data, size_t offset, size_t
 }
 
 // Copy data from the buffer without removing it.
-static ssize_t buffer_call(struct buffer *buf, utcp_recv_t cb, void *arg, size_t offset, size_t len) {
+static ssize_t buffer_call(struct utcp_connection *c, struct buffer *buf, size_t offset, size_t len) {
+       if(!c->recv) {
+               return len;
+       }
+
        // Ensure we don't copy more than is actually stored in the buffer
        if(offset >= buf->used) {
                return 0;
@@ -334,20 +329,25 @@ static ssize_t buffer_call(struct buffer *buf, utcp_recv_t cb, void *arg, size_t
 
        uint32_t realoffset = buf->offset + offset;
 
-       if(buf->size - buf->offset < offset) {
+       if(buf->size - buf->offset <= offset) {
                // The offset wrapped
                realoffset -= buf->size;
        }
 
        if(buf->size - realoffset < len) {
                // The data is wrapped
-               ssize_t rx1 = cb(arg, buf->data + realoffset, buf->size - realoffset);
+               ssize_t rx1 = c->recv(c, buf->data + realoffset, buf->size - realoffset);
 
                if(rx1 < buf->size - realoffset) {
                        return rx1;
                }
 
-               ssize_t rx2 = cb(arg, buf->data, len - (buf->size - realoffset));
+               // The channel might have been closed by the previous callback
+               if(!c->recv) {
+                       return len;
+               }
+
+               ssize_t rx2 = c->recv(c, buf->data, len - (buf->size - realoffset));
 
                if(rx2 < 0) {
                        return rx2;
@@ -355,7 +355,7 @@ static ssize_t buffer_call(struct buffer *buf, utcp_recv_t cb, void *arg, size_t
                        return rx1 + rx2;
                }
        } else {
-               return cb(arg, buf->data + realoffset, len);
+               return c->recv(c, buf->data + realoffset, len);
        }
 }
 
@@ -365,7 +365,7 @@ static ssize_t buffer_discard(struct buffer *buf, size_t len) {
                len = buf->used;
        }
 
-       if(buf->size - buf->offset < len) {
+       if(buf->size - buf->offset <= len) {
                buf->offset -= buf->size;
        }
 
@@ -414,7 +414,6 @@ static int compare(const void *va, const void *vb) {
        const struct utcp_connection *b = *(struct utcp_connection **)vb;
 
        assert(a && b);
-       assert(a->src && b->src);
 
        int c = (int)a->src - (int)b->src;
 
@@ -651,6 +650,7 @@ void utcp_accept(struct utcp_connection *c, utcp_recv_t recv, void *priv) {
        debug(c, "accepted %p %p\n", c, recv, priv);
        c->recv = recv;
        c->priv = priv;
+       c->do_poll = true;
        set_state(c, ESTABLISHED);
 }
 
@@ -1110,13 +1110,11 @@ static void handle_in_order(struct utcp_connection *c, const void *data, size_t
                        len = c->sacks[0].offset + c->sacks[0].len;
                        size_t remainder = len - offset;
 
-                       if(c->recv) {
-                               ssize_t rxd = buffer_call(&c->rcvbuf, c->recv, c, offset, remainder);
+                       ssize_t rxd = buffer_call(c, &c->rcvbuf, offset, remainder);
 
-                               if(rxd != (ssize_t)remainder) {
-                                       // TODO: handle the application not accepting all data.
-                                       abort();
-                               }
+                       if(rxd != (ssize_t)remainder) {
+                               // TODO: handle the application not accepting all data.
+                               abort();
                        }
                }
        }
@@ -1161,8 +1159,8 @@ static void handle_unreliable(struct utcp_connection *c, const struct hdr *hdr,
        }
 
        // Send the packet if it's the final fragment
-       if(!(hdr->ctl & MF) && c->recv) {
-               buffer_call(&c->rcvbuf, c->recv, c, 0, hdr->wnd + len);
+       if(!(hdr->ctl & MF)) {
+               buffer_call(c, &c->rcvbuf, 0, hdr->wnd + len);
        }
 
        c->rcv.nxt = hdr->seq + len;
@@ -1295,7 +1293,7 @@ ssize_t utcp_recv(struct utcp *utcp, const void *data, size_t len) {
 
                if(hdr.ctl & SYN && !(hdr.ctl & ACK) && utcp->accept) {
                        // If we don't want to accept it, send a RST back
-                       if((utcp->pre_accept && !utcp->pre_accept(utcp, hdr.dst))) {
+                       if((utcp->listen && !utcp->listen(utcp, hdr.dst))) {
                                len = 1;
                                goto reset;
                        }
@@ -1370,7 +1368,7 @@ synack:
 
        if(c->state == CLOSED) {
                debug(c, "got packet for closed connection\n");
-               return 0;
+               goto reset;
        }
 
        // It is for an existing connection.
@@ -1458,17 +1456,6 @@ synack:
                }
        }
 
-       if(hdr.ctl & ACK && (seqdiff(hdr.ack, c->snd.last) > 0 || seqdiff(hdr.ack, c->snd.una) < 0)) {
-               debug(c, "packet ack seqno out of range, %u <= %u < %u\n", c->snd.una, hdr.ack, c->snd.una + c->sndbuf.used);
-
-               // Ignore unacceptable RST packets.
-               if(hdr.ctl & RST) {
-                       return 0;
-               }
-
-               goto reset;
-       }
-
        // 2. Handle RST packets
 
        if(hdr.ctl & RST) {
@@ -1558,6 +1545,11 @@ synack:
 
        // 3. Advance snd.una
 
+       if(seqdiff(hdr.ack, c->snd.last) > 0 || seqdiff(hdr.ack, c->snd.una) < 0) {
+               debug(c, "packet ack seqno out of range, %u <= %u < %u\n", c->snd.una, hdr.ack, c->snd.una + c->sndbuf.used);
+               goto reset;
+       }
+
        advanced = seqdiff(hdr.ack, c->snd.una);
 
        if(advanced) {
@@ -1721,6 +1713,7 @@ skip_ack:
                                c->snd.last++;
                                set_state(c, FIN_WAIT_1);
                        } else {
+                               c->do_poll = true;
                                set_state(c, ESTABLISHED);
                        }
 
@@ -1779,8 +1772,15 @@ skip_ack:
                        return 0;
 
                case ESTABLISHED:
+                       break;
+
                case FIN_WAIT_1:
                case FIN_WAIT_2:
+                       if(c->reapable) {
+                               // We already closed the connection and are not interested in more data.
+                               goto reset;
+                       }
+
                        break;
 
                case CLOSE_WAIT:
@@ -2002,7 +2002,7 @@ static bool reset_connection(struct utcp_connection *c) {
        hdr.src = c->src;
        hdr.dst = c->dst;
        hdr.seq = c->snd.nxt;
-       hdr.ack = 0;
+       hdr.ack = c->rcv.nxt;
        hdr.wnd = 0;
        hdr.ctl = RST;
 
@@ -2155,7 +2155,7 @@ bool utcp_is_active(struct utcp *utcp) {
        return false;
 }
 
-struct utcp *utcp_init(utcp_accept_t accept, utcp_pre_accept_t pre_accept, utcp_send_t send, void *priv) {
+struct utcp *utcp_init(utcp_accept_t accept, utcp_listen_t listen, utcp_send_t send, void *priv) {
        if(!send) {
                errno = EFAULT;
                return NULL;
@@ -2181,7 +2181,7 @@ struct utcp *utcp_init(utcp_accept_t accept, utcp_pre_accept_t pre_accept, utcp_
        }
 
        utcp->accept = accept;
-       utcp->pre_accept = pre_accept;
+       utcp->listen = listen;
        utcp->send = send;
        utcp->priv = priv;
        utcp->timeout = DEFAULT_USER_TIMEOUT; // sec
@@ -2398,10 +2398,10 @@ void utcp_set_poll_cb(struct utcp_connection *c, utcp_poll_t poll) {
        }
 }
 
-void utcp_set_accept_cb(struct utcp *utcp, utcp_accept_t accept, utcp_pre_accept_t pre_accept) {
+void utcp_set_accept_cb(struct utcp *utcp, utcp_accept_t accept, utcp_listen_t listen) {
        if(utcp) {
                utcp->accept = accept;
-               utcp->pre_accept = pre_accept;
+               utcp->listen = listen;
        }
 }
 
@@ -2455,8 +2455,8 @@ void utcp_offline(struct utcp *utcp, bool offline) {
        }
 }
 
-void utcp_set_retransmit_cb(struct utcp *utcp, utcp_retransmit_t retransmit) {
-       utcp->retransmit = retransmit;
+void utcp_set_retransmit_cb(struct utcp *utcp, utcp_retransmit_t cb) {
+       utcp->retransmit = cb;
 }
 
 void utcp_set_clock_granularity(long granularity) {