Signed-off-by: Yanis Kurganov <YKurganov@ptsecurity.com>

---
 include/libssh/dh.h |  2 ++
 src/dh.c            | 30 ++++++++++++++-------
 src/packet.c        |  6 ++++-
 src/server.c        | 77 ++++++++++++++++++++++++++++++++++++++++++++++++-----
 4 files changed, 98 insertions(+), 17 deletions(-)

diff --git a/include/libssh/dh.h b/include/libssh/dh.h
index a579a3d..e65ce42 100644
--- a/include/libssh/dh.h
+++ b/include/libssh/dh.h
@@ -26,6 +26,8 @@
 #include "libssh/crypto.h"
 
 void ssh_print_bignum(const char *which,bignum num);
+int dh_generate_p(ssh_session session);
+int dh_generate_g(ssh_session session);
 int dh_generate_e(ssh_session session);
 int dh_generate_f(ssh_session session);
 int dh_generate_x(ssh_session session);
diff --git a/src/dh.c b/src/dh.c
index 0beec39..cb606d5 100644
--- a/src/dh.c
+++ b/src/dh.c
@@ -389,7 +389,7 @@ int dh_import_e(ssh_session session, ssh_string e_string) {
 }
 
 /* p number */
-static int dh_import_p(ssh_session session, ssh_string p_string) {
+static int dh_import_p_string(ssh_session session, ssh_string p_string) {
   session->next_crypto->p = make_string_bn(p_string);
   if (session->next_crypto->p == NULL) {
     return SSH_ERROR;
@@ -402,7 +402,7 @@ static int dh_import_p(ssh_session session, ssh_string p_string) {
   return SSH_OK;
 }
 
-static int dh_generate_p(ssh_session session, const unsigned char* p_value, size_t p_size) {
+static int dh_import_p_value(ssh_session session, const unsigned char* p_value, size_t p_size) {
   ssh_string p_string = ssh_string_new(p_size);
   int rc;
 
@@ -411,7 +411,7 @@ static int dh_generate_p(ssh_session session, const unsigned char* p_value, size
   }
 
   ssh_string_fill(p_string, p_value, p_size);
-  rc = dh_import_p(session, p_string);
+  rc = dh_import_p_string(session, p_string);
 
   ssh_string_burn(p_string);
   ssh_string_free(p_string);
@@ -419,8 +419,14 @@ static int dh_generate_p(ssh_session session, const unsigned char* p_value, size
   return rc;
 }
 
+/* used by server */
+int dh_generate_p(ssh_session session)
+{
+  return dh_import_p_value(session, p_group14_value, P_GROUP14_LEN);
+}
+
 /* g number */
-static int dh_import_g(ssh_session session, ssh_string g_string) {
+static int dh_import_g_string(ssh_session session, ssh_string g_string) {
   session->next_crypto->g = make_string_bn(g_string);
   if (session->next_crypto->g == NULL) {
     return SSH_ERROR;
@@ -433,7 +439,7 @@ static int dh_import_g(ssh_session session, ssh_string g_string) {
   return SSH_OK;
 }
 
-static int dh_generate_g(ssh_session session, unsigned int g_value) {
+static int dh_import_g_value(ssh_session session, unsigned int g_value) {
   session->next_crypto->g = bignum_new();
   if (session->next_crypto->g == NULL) {
     return SSH_ERROR;
@@ -447,6 +453,12 @@ static int dh_generate_g(ssh_session session, unsigned int g_value) {
   return SSH_OK;
 }
 
+/* used by server */
+int dh_generate_g(ssh_session session)
+{
+  return dh_import_g_value(session, G_VALUE);
+}
+
 int dh_build_k(ssh_session session) {
 #ifdef HAVE_LIBCRYPTO
   bignum_CTX ctx = bignum_ctx_new();
@@ -583,13 +595,13 @@ int ssh_client_dh_group_init(ssh_session session){
     ssh_set_error(session, SSH_FATAL, "Cannot init pbits");
     return SSH_ERROR;
   }
-  if(dh_generate_p(session,
+  if(dh_import_p_value(session,
      session->next_crypto->kex_type == SSH_KEX_DH_GROUP1_SHA1 ? p_group1_value : p_group14_value,
      session->next_crypto->kex_type == SSH_KEX_DH_GROUP1_SHA1 ? P_GROUP1_LEN : P_GROUP14_LEN) < 0) {
     ssh_set_error(session, SSH_FATAL, "Cannot import p number");
     return SSH_ERROR;
   }
-  if(dh_generate_g(session, G_VALUE) < 0) {
+  if(dh_import_g_value(session, G_VALUE) < 0) {
     ssh_set_error(session, SSH_FATAL, "Cannot import g number");
     return SSH_ERROR;
   }
@@ -622,7 +634,7 @@ int ssh_client_dh_gex_reply(ssh_session session, ssh_buffer packet){
     ssh_set_error(session,SSH_FATAL, "No p number in packet");
     return SSH_ERROR;
   }
-  rc = dh_import_p(session, s);
+  rc = dh_import_p_string(session, s);
   ssh_string_burn(s);
   ssh_string_free(s);
   if (rc < 0) {
@@ -635,7 +647,7 @@ int ssh_client_dh_gex_reply(ssh_session session, ssh_buffer packet){
     ssh_set_error(session,SSH_FATAL, "No g number in packet");
     return SSH_ERROR;
   }
-  rc = dh_import_g(session, s);
+  rc = dh_import_g_string(session, s);
   ssh_string_burn(s);
   ssh_string_free(s);
   if (rc < 0) {
diff --git a/src/packet.c b/src/packet.c
index 4296a74..26e8965 100644
--- a/src/packet.c
+++ b/src/packet.c
@@ -75,7 +75,11 @@ static ssh_packet_callback default_packet_handlers[]= {
 #endif
   ssh_packet_dh_reply,                     // SSH2_MSG_KEXDH_REPLY                31
                                            // SSH2_MSG_KEX_DH_GEX_GROUP           31
-  NULL,                                    // SSH2_MSG_KEX_DH_GEX_INIT            32
+#if WITH_SERVER
+  ssh_packet_kexdh_init,                   // SSH2_MSG_KEX_DH_GEX_INIT            32
+#else
+  NULL,
+#endif
   ssh_packet_dh_reply,                     // SSH2_MSG_KEX_DH_GEX_REPLY           33
   NULL,                                    // SSH2_MSG_KEX_DH_GEX_REQUEST         34
   NULL, NULL, NULL, NULL, NULL, NULL,	NULL,
diff --git a/src/server.c b/src/server.c
index f910ebb..6c01ceb 100644
--- a/src/server.c
+++ b/src/server.c
@@ -65,7 +65,8 @@
             session->common.callbacks->connect_status_function(session->common.callbacks->userdata, status); \
     } while (0)
 
-static int dh_handshake_server(ssh_session session);
+static int dh_send_group_server(ssh_session session);
+static int dh_handshake_server(ssh_session session, uint8_t reply_type);
 
 
 /**
@@ -146,7 +147,7 @@ static int server_set_kex(ssh_session session) {
  * @brief parse an incoming SSH_MSG_KEXDH_INIT packet and complete
  *        key exchange
  **/
-static int ssh_server_kexdh_init(ssh_session session, ssh_buffer packet){
+static int ssh_server_kexdh_init(ssh_session session, ssh_buffer packet, uint8_t reply_type){
     ssh_string e;
     e = buffer_get_ssh_string(packet);
     if (e == NULL) {
@@ -158,26 +159,59 @@ static int ssh_server_kexdh_init(ssh_session session, ssh_buffer packet){
       session->session_state=SSH_SESSION_STATE_ERROR;
     } else {
       session->dh_handshake_state=DH_STATE_INIT_SENT;
-      dh_handshake_server(session);
+      dh_handshake_server(session, reply_type);
     }
     ssh_string_free(e);
     return SSH_OK;
 }
 
+/** @internal
+ * @brief parse an incoming SSH2_MSG_KEX_DH_GEX_REQUEST_OLD packet
+ **/
+static int ssh_server_kexdh_gex_init(ssh_session session, ssh_buffer packet){
+    uint32_t pbits;
+    if (buffer_get_u32(session->in_buffer, &pbits) != sizeof(uint32_t)) {
+      ssh_set_error(session, SSH_FATAL, "No pbits in client request");
+      return SSH_ERROR;
+    }
+    session->next_crypto->pbits = ntohl(pbits);
+    SSH_LOG(SSH_LOG_PACKET,"pbits = %u", session->next_crypto->pbits);
+    if (dh_generate_p(session) < 0) {
+      ssh_set_error(session, SSH_FATAL, "Could not create p number");
+      return SSH_ERROR;
+    }
+    if (dh_generate_g(session) < 0) {
+      ssh_set_error(session, SSH_FATAL, "Could not create g number");
+      return SSH_ERROR;
+    }
+    session->dh_handshake_state=DH_STATE_GEX_REQUEST_SENT;
+    return dh_send_group_server(session);
+}
+
 SSH_PACKET_CALLBACK(ssh_packet_kexdh_init){
   int rc;
   (void)type;
   (void)user;
 
   SSH_LOG(SSH_LOG_PACKET,"Received SSH_MSG_KEXDH_INIT");
-  if(session->dh_handshake_state != DH_STATE_INIT){
+  if(session->dh_handshake_state != DH_STATE_INIT &&
+     session->dh_handshake_state != DH_STATE_GEX_REQUEST_SENT){
     SSH_LOG(SSH_LOG_RARE,"Invalid state for SSH_MSG_KEXDH_INIT");
     goto error;
   }
   switch(session->next_crypto->kex_type){
       case SSH_KEX_DH_GROUP1_SHA1:
       case SSH_KEX_DH_GROUP14_SHA1:
-        rc=ssh_server_kexdh_init(session, packet);
+        rc=ssh_server_kexdh_init(session, packet, SSH2_MSG_KEXDH_REPLY);
+        break;
+      case SSH_KEX_DH_GROUP_SHA1:
+      case SSH_KEX_DH_GROUP_SHA256:
+        if (session->dh_handshake_state == DH_STATE_INIT) {
+          SSH_LOG(SSH_LOG_PACKET,"Actually it's SSH2_MSG_KEX_DH_GEX_REQUEST_OLD");
+          rc=ssh_server_kexdh_gex_init(session, packet);
+        } else {
+          rc=ssh_server_kexdh_init(session, packet, SSH2_MSG_KEX_DH_GEX_REPLY);
+        }
         break;
   #ifdef HAVE_ECDH
       case SSH_KEX_ECDH_SHA2_NISTP256:
@@ -239,7 +273,36 @@ int ssh_get_key_params(ssh_session session, ssh_key *privkey){
     return SSH_OK;
 }
 
-static int dh_handshake_server(ssh_session session) {
+static int dh_send_group_server(ssh_session session)
+{
+  ssh_string num;
+  int rc;
+
+  if (buffer_add_u8(session->out_buffer, SSH2_MSG_KEX_DH_GEX_GROUP) < 0)
+    return SSH_ERROR;
+
+  num = make_bignum_string(session->next_crypto->p);
+  if (num == NULL)
+    return SSH_ERROR;
+
+  rc = buffer_add_ssh_string(session->out_buffer, num);
+  ssh_string_free(num);
+  if (rc < 0)
+    return SSH_ERROR;
+
+  num = make_bignum_string(session->next_crypto->g);
+  if (num == NULL)
+    return SSH_ERROR;
+
+  rc = buffer_add_ssh_string(session->out_buffer, num);
+  ssh_string_free(num);
+  if (rc < 0)
+    return SSH_ERROR;
+
+  return packet_send(session);
+}
+
+static int dh_handshake_server(ssh_session session, uint8_t reply_type) {
   ssh_key privkey;
   //ssh_string pubkey_blob = NULL;
   ssh_string sig_blob;
@@ -284,7 +347,7 @@ static int dh_handshake_server(ssh_session session) {
     return -1;
   }
 
-  if (buffer_add_u8(session->out_buffer, SSH2_MSG_KEXDH_REPLY) < 0 ||
+  if (buffer_add_u8(session->out_buffer, reply_type) < 0 ||
       buffer_add_ssh_string(session->out_buffer,
               session->next_crypto->server_pubkey) < 0 ||
       buffer_add_ssh_string(session->out_buffer, f) < 0 ||
-- 
1.9.5.msysgit.0

