]> git.meshlink.io Git - utcp/blobdiff - utcp.c
Return an error if we can't allocate the packet buffer during utcp_init().
[utcp] / utcp.c
diff --git a/utcp.c b/utcp.c
index fd507f91abecc9d674ab99c0cb47ef1d94b1f96e..5eca4d3897a82eec6bbc1bca6de33c324fb73daa 100644 (file)
--- a/utcp.c
+++ b/utcp.c
@@ -401,7 +401,7 @@ static void buffer_exit(struct buffer *buf) {
 }
 
 static uint32_t buffer_free(const struct buffer *buf) {
-       return buf->maxsize - buf->used;
+       return buf->maxsize > buf->used ? buf->maxsize - buf->used : 0;
 }
 
 // Connections are stored in a sorted list.
@@ -816,7 +816,6 @@ ssize_t utcp_send(struct utcp_connection *c, const void *data, size_t len) {
        if(!is_reliable(c)) {
                c->snd.una = c->snd.nxt = c->snd.last;
                buffer_discard(&c->sndbuf, c->sndbuf.used);
-               c->do_poll = true;
        }
 
        if(is_reliable(c) && !timespec_isset(&c->rtrx_timeout)) {
@@ -848,13 +847,7 @@ static void fast_retransmit(struct utcp_connection *c) {
        struct {
                struct hdr hdr;
                uint8_t data[];
-       } *pkt;
-
-       pkt = malloc(c->utcp->mtu);
-
-       if(!pkt) {
-               return;
-       }
+       } *pkt = c->utcp->pkt;
 
        pkt->hdr.src = c->src;
        pkt->hdr.dst = c->dst;
@@ -886,8 +879,6 @@ static void fast_retransmit(struct utcp_connection *c) {
        default:
                break;
        }
-
-       free(pkt);
 }
 
 static void retransmit(struct utcp_connection *c) {
@@ -899,6 +890,10 @@ static void retransmit(struct utcp_connection *c) {
 
        struct utcp *utcp = c->utcp;
 
+       if (utcp->retransmit) {
+               utcp->retransmit(c);
+       }
+
        struct {
                struct hdr hdr;
                uint8_t data[];
@@ -1046,8 +1041,14 @@ static void handle_out_of_order(struct utcp_connection *c, uint32_t offset, cons
        // Packet loss or reordering occured. Store the data in the buffer.
        ssize_t rxd = buffer_put_at(&c->rcvbuf, offset, data, len);
 
-       if(rxd < 0 || (size_t)rxd < len) {
-               abort();
+       if(rxd <= 0) {
+               debug(c, "packet outside receive buffer, dropping\n");
+               return;
+       }
+
+       if((size_t)rxd < len) {
+               debug(c, "packet partially outside receive buffer\n");
+               len = rxd;
        }
 
        // Make note of where we put it.
@@ -1108,11 +1109,14 @@ static void handle_in_order(struct utcp_connection *c, const void *data, size_t
                        size_t offset = len;
                        len = c->sacks[0].offset + c->sacks[0].len;
                        size_t remainder = len - offset;
-                       ssize_t rxd = buffer_call(&c->rcvbuf, c->recv, c, offset, remainder);
 
-                       if(rxd != (ssize_t)remainder) {
-                               // TODO: handle the application not accepting all data.
-                               abort();
+                       if(c->recv) {
+                               ssize_t rxd = buffer_call(&c->rcvbuf, c->recv, c, offset, remainder);
+
+                               if(rxd != (ssize_t)remainder) {
+                                       // TODO: handle the application not accepting all data.
+                                       abort();
+                               }
                        }
                }
        }
@@ -1127,7 +1131,10 @@ static void handle_in_order(struct utcp_connection *c, const void *data, size_t
 static void handle_unreliable(struct utcp_connection *c, const struct hdr *hdr, const void *data, size_t len) {
        // Fast path for unfragmented packets
        if(!hdr->wnd && !(hdr->ctl & MF)) {
-               c->recv(c, data, len);
+               if(c->recv) {
+                       c->recv(c, data, len);
+               }
+
                c->rcv.nxt = hdr->seq + len;
                return;
        }
@@ -1154,7 +1161,7 @@ static void handle_unreliable(struct utcp_connection *c, const struct hdr *hdr,
        }
 
        // Send the packet if it's the final fragment
-       if(!(hdr->ctl & MF)) {
+       if(!(hdr->ctl & MF) && c->recv) {
                buffer_call(&c->rcvbuf, c->recv, c, 0, hdr->wnd + len);
        }
 
@@ -1169,10 +1176,6 @@ static void handle_incoming_data(struct utcp_connection *c, const struct hdr *hd
 
        uint32_t offset = seqdiff(hdr->seq, c->rcv.nxt);
 
-       if(offset + len > c->rcvbuf.maxsize) {
-               abort();
-       }
-
        if(offset) {
                handle_out_of_order(c, offset, data, len);
        } else {
@@ -1592,7 +1595,10 @@ synack:
 
                if(data_acked) {
                        buffer_discard(&c->sndbuf, data_acked);
-                       c->do_poll = true;
+
+                       if(is_reliable(c)) {
+                               c->do_poll = true;
+                       }
                }
 
                // Also advance snd.nxt if possible
@@ -2163,6 +2169,13 @@ struct utcp *utcp_init(utcp_accept_t accept, utcp_pre_accept_t pre_accept, utcp_
                return NULL;
        }
 
+       utcp_set_mtu(utcp, DEFAULT_MTU);
+
+       if(!utcp->pkt) {
+               free(utcp);
+               return NULL;
+       }
+
        if(!CLOCK_GRANULARITY) {
                struct timespec res;
                clock_getres(UTCP_CLOCK, &res);
@@ -2173,7 +2186,6 @@ struct utcp *utcp_init(utcp_accept_t accept, utcp_pre_accept_t pre_accept, utcp_
        utcp->pre_accept = pre_accept;
        utcp->send = send;
        utcp->priv = priv;
-       utcp_set_mtu(utcp, DEFAULT_MTU);
        utcp->timeout = DEFAULT_USER_TIMEOUT; // sec
 
        return utcp;
@@ -2316,7 +2328,7 @@ void utcp_set_sndbuf(struct utcp_connection *c, size_t size) {
                c->sndbuf.maxsize = -1;
        }
 
-       c->do_poll = buffer_free(&c->sndbuf);
+       c->do_poll = is_reliable(c) && buffer_free(&c->sndbuf);
 }
 
 size_t utcp_get_rcvbuf(struct utcp_connection *c) {
@@ -2384,7 +2396,7 @@ void utcp_set_recv_cb(struct utcp_connection *c, utcp_recv_t recv) {
 void utcp_set_poll_cb(struct utcp_connection *c, utcp_poll_t poll) {
        if(c) {
                c->poll = poll;
-               c->do_poll = buffer_free(&c->sndbuf);
+               c->do_poll = is_reliable(c) && buffer_free(&c->sndbuf);
        }
 }
 
@@ -2445,6 +2457,10 @@ void utcp_offline(struct utcp *utcp, bool offline) {
        }
 }
 
+void utcp_set_retransmit_cb(struct utcp *utcp, utcp_retransmit_t retransmit) {
+       utcp->retransmit = retransmit;
+}
+
 void utcp_set_clock_granularity(long granularity) {
        CLOCK_GRANULARITY = granularity;
 }