As we don't support ARP table in U-Boot, these
connections must come from the same host.

Signed-off-by: Mikhail Kshevetskiy <[email protected]>
---
 net/Kconfig        | 10 ++++++++
 net/fastboot_tcp.c | 12 +++++++++
 net/httpd.c        | 20 ++++++++++++++-
 net/netcat.c       |  7 +++++
 net/tcp.c          | 64 ++++++++++++++++++++++++++++++----------------
 net/wget.c         |  7 +++++
 6 files changed, 97 insertions(+), 23 deletions(-)

diff --git a/net/Kconfig b/net/Kconfig
index 424c5f0dae8..e46d519f507 100644
--- a/net/Kconfig
+++ b/net/Kconfig
@@ -257,6 +257,16 @@ config HTTPD_COMMON
            * object will be stored to a memory area specified in
              image_load_addr variable
 
+config PROT_TCP_MAX_CONNS
+       int "Maximum number of TCP connections"
+       depends on PROT_TCP
+       default 1
+       help
+         In some cases (like httpd support) it may be desirable
+         to support more than one TCP connection at the time.
+         As we don't support ARP table, these connections MUST
+         comes from a single host.
+
 config IPV6
        bool "IPv6 support"
        help
diff --git a/net/fastboot_tcp.c b/net/fastboot_tcp.c
index 30eacca8d1e..dfa02709393 100644
--- a/net/fastboot_tcp.c
+++ b/net/fastboot_tcp.c
@@ -19,6 +19,12 @@ static char txbuf[sizeof(u64) + FASTBOOT_RESPONSE_LEN + 1];
 
 static u32 data_read;
 static u32 tx_last_offs, tx_last_len;
+static int connections;
+
+static void tcp_stream_on_closed(struct tcp_stream *tcp)
+{
+       connections = 0;
+}
 
 static void tcp_stream_on_rcv_nxt_update(struct tcp_stream *tcp, u32 rx_bytes)
 {
@@ -96,10 +102,16 @@ static int tcp_stream_on_create(struct tcp_stream *tcp)
        if (tcp->lport != FASTBOOT_TCP_PORT)
                return 0;
 
+       if (connections > 0)
+               return 0;
+
+       connections++;
+
        data_read = 0;
        tx_last_offs = 0;
        tx_last_len = 0;
 
+       tcp->on_closed = tcp_stream_on_closed;
        tcp->on_rcv_nxt_update = tcp_stream_on_rcv_nxt_update;
        tcp->rx = tcp_stream_rx;
        tcp->tx = tcp_stream_tx;
diff --git a/net/httpd.c b/net/httpd.c
index 31c10843a44..ad13b155912 100644
--- a/net/httpd.c
+++ b/net/httpd.c
@@ -85,6 +85,8 @@ static struct http_reply options_reply = {
 
 static int stop_server;
 static int tsize_num_hash;
+static int connections;
+static struct httpd_priv *post_id;
 
 static struct httpd_config *cfg;
 
@@ -109,6 +111,12 @@ static void tcp_stream_on_closed(struct tcp_stream *tcp)
 {
        struct httpd_priv *priv = tcp->priv;
 
+       connections--;
+       if (tcp->priv == post_id) {
+               /* forget completed POST */
+               post_id = NULL;
+       }
+
        if ((priv->req_state != ST_REQ_DONE) &&
            (priv->req_state >= ST_REQ_MPFILE)) {
                printf("\nHTTPD: transfer was terminated\n");
@@ -119,7 +127,7 @@ static void tcp_stream_on_closed(struct tcp_stream *tcp)
 
        free(tcp->priv);
 
-       if (stop_server)
+       if (stop_server && (connections == 0))
                net_set_state(cfg->on_stop != NULL ?
                              cfg->on_stop() :
                              NETLOOP_SUCCESS);
@@ -334,6 +342,11 @@ static enum httpd_req_check http_parse_line(struct 
httpd_priv *priv, char *line)
                        /* expect: "\r\n--${boundary}--\r\n", so strlen() + 8 */
                        priv->post_flen -= strlen(priv->post_boundary) + 8;
 
+                       if (post_id != NULL) {
+                               /* do not allow multiple POST requests */
+                               return HTTPD_BAD_REQ;
+                       }
+
                        if (cfg->pre_post != NULL) {
                                post.addr     = NULL;
                                post.name     = priv->post_name;
@@ -345,6 +358,8 @@ static enum httpd_req_check http_parse_line(struct 
httpd_priv *priv, char *line)
                                        return ret;
                        }
 
+                       post_id = priv;
+
                        tsize_num_hash = 0;
                        printf("File: %s, %u bytes\n", priv->post_fname, 
priv->post_flen);
                        printf("Loading: ");
@@ -660,6 +675,7 @@ static int tcp_stream_on_create(struct tcp_stream *tcp)
        if (priv == NULL)
                return 0;
 
+       connections++;
        memset(priv, 0, sizeof(struct httpd_priv));
        priv->tcp = tcp;
 
@@ -688,6 +704,8 @@ void httpd_start(void)
                net_set_state(NETLOOP_FAIL);
                return;
        }
+       post_id = NULL;
+       connections = 0;
        stop_server = 0;
        memset(net_server_ethaddr, 0, 6);
        tcp_stream_set_on_create_handler(tcp_stream_on_create);
diff --git a/net/netcat.c b/net/netcat.c
index ea225c68c87..cfd39fafb61 100644
--- a/net/netcat.c
+++ b/net/netcat.c
@@ -23,6 +23,7 @@ static int                    listen;
 static int                     reading;
 static unsigned int            packets;
 static enum net_loop_state     netcat_loop_state;
+static int                     connections;
 
 static void show_block_marker(void)
 {
@@ -34,6 +35,8 @@ static void show_block_marker(void)
 
 static void tcp_stream_on_closed(struct tcp_stream *tcp)
 {
+       connections = 0;
+
        if (tcp->status != TCP_ERR_OK)
                netcat_loop_state = NETLOOP_FAIL;
 
@@ -101,6 +104,10 @@ static int tcp_stream_on_create(struct tcp_stream *tcp)
                        return 0;
        }
 
+       if (connections > 0)
+               return 0;
+
+       connections++;
        netcat_loop_state = NETLOOP_FAIL;
        net_boot_file_size = 0;
        packets = 0;
diff --git a/net/tcp.c b/net/tcp.c
index 9fb80f9c2a8..c0045e32198 100644
--- a/net/tcp.c
+++ b/net/tcp.c
@@ -25,6 +25,7 @@
 #include <net.h>
 #include <net/tcp.h>
 
+#define TCP_STREAM_MAX         (CONFIG_PROT_TCP_MAX_CONNS)
 #define TCP_SEND_RETRY         3
 #define TCP_SEND_TIMEOUT       2000UL
 #define TCP_RX_INACTIVE_TIMEOUT        30000UL
@@ -34,7 +35,7 @@
 #define TCP_PACKET_OK          0
 #define TCP_PACKET_DROP                1
 
-static struct tcp_stream tcp_stream;
+static struct tcp_stream tcp_streams[TCP_STREAM_MAX];
 
 static int (*tcp_stream_on_create)(struct tcp_stream *tcp);
 
@@ -155,17 +156,24 @@ static void tcp_stream_destroy(struct tcp_stream *tcp)
 void tcp_init(void)
 {
        static int initialized;
-       struct tcp_stream *tcp = &tcp_stream;
+       struct tcp_stream *tcp;
+       int i;
 
        tcp_stream_on_create = NULL;
        if (!initialized) {
                initialized = 1;
-               memset(tcp, 0, sizeof(struct tcp_stream));
+               for (i = 0; i < TCP_STREAM_MAX; i++) {
+                       tcp = &tcp_streams[i];
+                       memset(tcp, 0, sizeof(struct tcp_stream));
+               }
        }
 
-       tcp_stream_set_state(tcp, TCP_CLOSED);
-       tcp_stream_set_status(tcp, TCP_ERR_RST);
-       tcp_stream_destroy(tcp);
+       for (i = 0; i < TCP_STREAM_MAX; i++) {
+               tcp = &tcp_streams[i];
+               tcp_stream_set_state(tcp, TCP_CLOSED);
+               tcp_stream_set_status(tcp, TCP_ERR_RST);
+               tcp_stream_destroy(tcp);
+       }
 }
 
 void tcp_stream_set_on_create_handler(int (*on_create)(struct tcp_stream *))
@@ -176,28 +184,40 @@ void tcp_stream_set_on_create_handler(int 
(*on_create)(struct tcp_stream *))
 static struct tcp_stream *tcp_stream_add(struct in_addr rhost,
                                         u16 rport, u16 lport)
 {
-       struct tcp_stream *tcp = &tcp_stream;
+       int i;
+       struct tcp_stream *tcp;
 
-       if ((tcp_stream_on_create == NULL) ||
-           (tcp->state != TCP_CLOSED))
+       if (tcp_stream_on_create == NULL)
                return NULL;
 
-       tcp_stream_init(tcp, rhost, rport, lport);
-       if (!tcp_stream_on_create(tcp))
-               return NULL;
+       for (i = 0; i < TCP_STREAM_MAX; i++) {
+               tcp = &tcp_streams[i];
+               if (tcp->state != TCP_CLOSED)
+                       continue;
 
-       return tcp;
+               tcp_stream_init(tcp, rhost, rport, lport);
+               if (!tcp_stream_on_create(tcp))
+                       return NULL;
+
+               return tcp;
+       }
+       return NULL;
 }
 
 struct tcp_stream *tcp_stream_get(int is_new, struct in_addr rhost,
                                  u16 rport, u16 lport)
 {
-       struct tcp_stream *tcp = &tcp_stream;
+       int i;
+       struct tcp_stream *tcp;
 
-       if ((tcp->rhost.s_addr == rhost.s_addr) &&
-           (tcp->rport == rport) &&
-           (tcp->lport == lport))
-               return tcp;
+       for (i = 0; i < TCP_STREAM_MAX; i++) {
+               tcp = &tcp_streams[i];
+
+               if ((tcp->rhost.s_addr == rhost.s_addr) &&
+                   (tcp->rport == rport) &&
+                   (tcp->lport == lport))
+                       return tcp;
+       }
 
        return is_new ? tcp_stream_add(rhost, rport, lport) : NULL;
 }
@@ -390,12 +410,12 @@ static void tcp_stream_poll(struct tcp_stream *tcp, ulong 
time)
 
 void tcp_streams_poll(void)
 {
-       ulong                   time;
-       struct tcp_stream       *tcp;
+       int     i;
+       ulong   time;
 
        time = get_timer(0);
-       tcp = &tcp_stream;
-       tcp_stream_poll(tcp, time);
+       for (i = 0; i < TCP_STREAM_MAX; i++)
+               tcp_stream_poll(&tcp_streams[i], time);
 }
 
 /**
diff --git a/net/wget.c b/net/wget.c
index c91ca0bbc90..500cb58d777 100644
--- a/net/wget.c
+++ b/net/wget.c
@@ -41,6 +41,7 @@ static int wget_tsize_num_hash;
 
 static char *image_url;
 static enum net_loop_state wget_loop_state;
+static int connections;
 
 static ulong wget_load_size;
 
@@ -119,6 +120,8 @@ static void show_block_marker(void)
 
 static void tcp_stream_on_closed(struct tcp_stream *tcp)
 {
+       connections = 0;
+
        if (tcp->status != TCP_ERR_OK)
                wget_loop_state = NETLOOP_FAIL;
 
@@ -223,6 +226,10 @@ static int tcp_stream_on_create(struct tcp_stream *tcp)
            (tcp->rport != server_port))
                return 0;
 
+       if (connections > 0)
+               return 0;
+
+       connections++;
        tcp->max_retry_count = WGET_RETRY_COUNT;
        tcp->initial_timeout = WGET_TIMEOUT;
        tcp->on_closed = tcp_stream_on_closed;
-- 
2.43.0

Reply via email to