]> git.meshlink.io Git - utcp/blobdiff - utcp.c
Implement slow start threshold according to RFC 5681.
[utcp] / utcp.c
diff --git a/utcp.c b/utcp.c
index f35f96560af29eaafcf67a95030884414d8e3cfd..7aeb2d358cf7b3157672f8bb2f25b488b9a40606 100644 (file)
--- a/utcp.c
+++ b/utcp.c
        } while (0)
 #endif
 
+static inline size_t min(size_t a, size_t b) {
+       return a < b ? a : b;
+}
+
 static inline size_t max(size_t a, size_t b) {
        return a > b ? a : b;
 }
@@ -113,9 +117,14 @@ static void print_packet(struct utcp *utcp, const char *dir, const void *pkt, si
 
        debug("\n");
 }
+
+static void debug_cwnd(struct utcp_connection *c) {
+       debug("snd.cwnd = %u\n", c->snd.cwnd);
+}
 #else
 #define debug(...) do {} while(0)
 #define print_packet(...) do {} while(0)
+#define debug_cwnd(...) do {} while(0)
 #endif
 
 static void set_state(struct utcp_connection *c, enum state state) {
@@ -382,9 +391,10 @@ static struct utcp_connection *allocate_connection(struct utcp *utcp, uint16_t s
 #endif
        c->snd.una = c->snd.iss;
        c->snd.nxt = c->snd.iss + 1;
-       c->rcv.wnd = utcp->mtu;
        c->snd.last = c->snd.nxt;
-       c->snd.cwnd = utcp->mtu;
+       c->snd.cwnd = (utcp->mtu > 2190 ? 2 : utcp->mtu > 1095 ? 3 : 4) * utcp->mtu;
+       c->snd.ssthresh = ~0;
+       debug_cwnd(c);
        c->utcp = utcp;
 
        // Add it to the sorted list of connections
@@ -415,13 +425,13 @@ static void update_rtt(struct utcp_connection *c, uint32_t rtt) {
        if(!utcp->srtt) {
                utcp->srtt = rtt;
                utcp->rttvar = rtt / 2;
-               utcp->rto = rtt + max(2 * rtt, CLOCK_GRANULARITY);
        } else {
                utcp->rttvar = (utcp->rttvar * 3 + absdiff(utcp->srtt, rtt)) / 4;
                utcp->srtt = (utcp->srtt * 7 + rtt) / 8;
-               utcp->rto = utcp->srtt + max(utcp->rttvar, CLOCK_GRANULARITY);
        }
 
+       utcp->rto = utcp->srtt + max(4 * utcp->rttvar, CLOCK_GRANULARITY);
+
        if(utcp->rto > MAX_RTO) {
                utcp->rto = MAX_RTO;
        }
@@ -468,7 +478,7 @@ struct utcp_connection *utcp_connect_ex(struct utcp *utcp, uint16_t dst, utcp_re
        pkt.hdr.dst = c->dst;
        pkt.hdr.seq = c->snd.iss;
        pkt.hdr.ack = 0;
-       pkt.hdr.wnd = c->rcv.wnd;
+       pkt.hdr.wnd = c->rcvbuf.maxsize;
        pkt.hdr.ctl = SYN;
        pkt.hdr.aux = 0x0101;
        pkt.init[0] = 1;
@@ -507,19 +517,18 @@ void utcp_accept(struct utcp_connection *c, utcp_recv_t recv, void *priv) {
 
 static void ack(struct utcp_connection *c, bool sendatleastone) {
        int32_t left = seqdiff(c->snd.last, c->snd.nxt);
-       int32_t cwndleft = c->snd.cwnd - seqdiff(c->snd.nxt, c->snd.una);
-       debug("cwndleft = %d\n", cwndleft);
+       int32_t cwndleft = min(c->snd.cwnd, c->snd.wnd) - seqdiff(c->snd.nxt, c->snd.una);
 
        assert(left >= 0);
 
-       if(cwndleft <= 0) {
-               cwndleft = 0;
-       }
-
-       if(cwndleft < left) {
+       if(cwndleft < 0) {
+               left = 0;
+       } else if(cwndleft < left) {
                left = cwndleft;
        }
 
+       debug("cwndleft = %d, left = %d\n", cwndleft, left);
+
        if(!left && !sendatleastone) {
                return;
        }
@@ -538,7 +547,7 @@ static void ack(struct utcp_connection *c, bool sendatleastone) {
        pkt->hdr.src = c->src;
        pkt->hdr.dst = c->dst;
        pkt->hdr.ack = c->rcv.nxt;
-       pkt->hdr.wnd = c->snd.wnd;
+       pkt->hdr.wnd = c->rcvbuf.maxsize;
        pkt->hdr.ctl = ACK;
        pkt->hdr.aux = 0;
 
@@ -629,6 +638,8 @@ ssize_t utcp_send(struct utcp_connection *c, const void *data, size_t len) {
 
        if(is_reliable(c) || (c->state != SYN_SENT && c->state != SYN_RECEIVED)) {
                len = buffer_put(&c->sndbuf, data, len);
+       } else {
+               return 0;
        }
 
        if(len <= 0) {
@@ -695,7 +706,7 @@ static void retransmit(struct utcp_connection *c) {
 
        pkt->hdr.src = c->src;
        pkt->hdr.dst = c->dst;
-       pkt->hdr.wnd = c->rcv.wnd;
+       pkt->hdr.wnd = c->rcvbuf.maxsize;
        pkt->hdr.aux = 0;
 
        switch(c->state) {
@@ -743,7 +754,12 @@ static void retransmit(struct utcp_connection *c) {
                }
 
                c->snd.nxt = c->snd.una + len;
-               c->snd.cwnd = utcp->mtu; // reduce cwnd on retransmit
+
+               // RFC 5681 slow start after timeout
+               c->snd.ssthresh = max(c->snd.cwnd / 2, utcp->mtu * 2); // eq. 4
+               c->snd.cwnd = utcp->mtu;
+               debug_cwnd(c);
+
                buffer_copy(&c->sndbuf, pkt->data, 0, len);
                print_packet(c->utcp, "rtrx", pkt, sizeof(pkt->hdr) + len);
                utcp->send(utcp, pkt, sizeof(pkt->hdr) + len);
@@ -1071,7 +1087,7 @@ synack:
                        pkt.hdr.dst = c->dst;
                        pkt.hdr.ack = c->rcv.irs + 1;
                        pkt.hdr.seq = c->snd.iss;
-                       pkt.hdr.wnd = c->rcv.wnd;
+                       pkt.hdr.wnd = c->rcvbuf.maxsize;
                        pkt.hdr.ctl = SYN | ACK;
 
                        if(init) {
@@ -1314,8 +1330,10 @@ synack:
 
                assert(data_acked >= 0);
 
+#ifndef NDEBUG
                int32_t bufused = seqdiff(c->snd.last, c->snd.una);
                assert(data_acked <= bufused);
+#endif
 
                if(data_acked) {
                        buffer_get(&c->sndbuf, NULL, data_acked);
@@ -1329,12 +1347,20 @@ synack:
                c->snd.una = hdr.ack;
 
                c->dupack = 0;
-               c->snd.cwnd += utcp->mtu;
+
+               // Increase the congestion window according to RFC 5681
+               if(c->snd.cwnd < c->snd.ssthresh) {
+                       c->snd.cwnd += min(advanced, utcp->mtu); // eq. 2
+               } else {
+                       c->snd.cwnd += max(1, (utcp->mtu * utcp->mtu) / c->snd.cwnd); // eq. 3
+               }
 
                if(c->snd.cwnd > c->sndbuf.maxsize) {
                        c->snd.cwnd = c->sndbuf.maxsize;
                }
 
+               debug_cwnd(c);
+
                // Check if we have sent a FIN that is now ACKed.
                switch(c->state) {
                case FIN_WAIT_1:
@@ -1367,6 +1393,7 @@ synack:
                                //Reset the congestion window so we wait for ACKs.
                                c->snd.nxt = c->snd.una;
                                c->snd.cwnd = utcp->mtu;
+                               debug_cwnd(c);
                                start_retransmit_timer(c);
                        }
                }