]> git.meshlink.io Git - utcp/blobdiff - utcp.c
Fix potential segmentation fault when the receive callback is not set.
[utcp] / utcp.c
diff --git a/utcp.c b/utcp.c
index 75d4f83417a3be30bccc6346757abef8f7643585..f7c1859de2abe5ef05873bb1ef6565d449810e3d 100644 (file)
--- a/utcp.c
+++ b/utcp.c
@@ -401,7 +401,7 @@ static void buffer_exit(struct buffer *buf) {
 }
 
 static uint32_t buffer_free(const struct buffer *buf) {
-       return buf->maxsize - buf->used;
+       return buf->maxsize > buf->used ? buf->maxsize - buf->used : 0;
 }
 
 // Connections are stored in a sorted list.
@@ -816,7 +816,6 @@ ssize_t utcp_send(struct utcp_connection *c, const void *data, size_t len) {
        if(!is_reliable(c)) {
                c->snd.una = c->snd.nxt = c->snd.last;
                buffer_discard(&c->sndbuf, c->sndbuf.used);
-               c->do_poll = true;
        }
 
        if(is_reliable(c) && !timespec_isset(&c->rtrx_timeout)) {
@@ -1114,11 +1113,14 @@ static void handle_in_order(struct utcp_connection *c, const void *data, size_t
                        size_t offset = len;
                        len = c->sacks[0].offset + c->sacks[0].len;
                        size_t remainder = len - offset;
-                       ssize_t rxd = buffer_call(&c->rcvbuf, c->recv, c, offset, remainder);
 
-                       if(rxd != (ssize_t)remainder) {
-                               // TODO: handle the application not accepting all data.
-                               abort();
+                       if(c->recv) {
+                               ssize_t rxd = buffer_call(&c->rcvbuf, c->recv, c, offset, remainder);
+
+                               if(rxd != (ssize_t)remainder) {
+                                       // TODO: handle the application not accepting all data.
+                                       abort();
+                               }
                        }
                }
        }
@@ -1133,7 +1135,10 @@ static void handle_in_order(struct utcp_connection *c, const void *data, size_t
 static void handle_unreliable(struct utcp_connection *c, const struct hdr *hdr, const void *data, size_t len) {
        // Fast path for unfragmented packets
        if(!hdr->wnd && !(hdr->ctl & MF)) {
-               c->recv(c, data, len);
+               if(c->recv) {
+                       c->recv(c, data, len);
+               }
+
                c->rcv.nxt = hdr->seq + len;
                return;
        }
@@ -1160,7 +1165,7 @@ static void handle_unreliable(struct utcp_connection *c, const struct hdr *hdr,
        }
 
        // Send the packet if it's the final fragment
-       if(!(hdr->ctl & MF)) {
+       if(!(hdr->ctl & MF) && c->recv) {
                buffer_call(&c->rcvbuf, c->recv, c, 0, hdr->wnd + len);
        }
 
@@ -1594,7 +1599,10 @@ synack:
 
                if(data_acked) {
                        buffer_discard(&c->sndbuf, data_acked);
-                       c->do_poll = true;
+
+                       if(is_reliable(c)) {
+                               c->do_poll = true;
+                       }
                }
 
                // Also advance snd.nxt if possible
@@ -2318,7 +2326,7 @@ void utcp_set_sndbuf(struct utcp_connection *c, size_t size) {
                c->sndbuf.maxsize = -1;
        }
 
-       c->do_poll = buffer_free(&c->sndbuf);
+       c->do_poll = is_reliable(c) && buffer_free(&c->sndbuf);
 }
 
 size_t utcp_get_rcvbuf(struct utcp_connection *c) {
@@ -2386,7 +2394,7 @@ void utcp_set_recv_cb(struct utcp_connection *c, utcp_recv_t recv) {
 void utcp_set_poll_cb(struct utcp_connection *c, utcp_poll_t poll) {
        if(c) {
                c->poll = poll;
-               c->do_poll = buffer_free(&c->sndbuf);
+               c->do_poll = is_reliable(c) && buffer_free(&c->sndbuf);
        }
 }