/* This is a sample implementation of a libssh based SSH echo server */
/*
Copyright 2014 Audrius Butkevicius

This file is part of the SSH Library

You are free to copy this file, modify it in any way, consider it being public
domain. This does not apply to the rest of the library though, but it is
allowed to cut-and-paste working code from this file to any license of
program.
The goal is to show the API in action.
*/

#include "config.h"

#include <libssh/callbacks.h>
#include <libssh/server.h>

#include <ctype.h>
#include <stdlib.h>
#include <stdio.h>

#define SESSION_END (SSH_CLOSED | SSH_CLOSED_ERROR)


/* A userdata struct for session. */
struct session_data_struct {
    /* Pointer to the channel the session will allocate. */
    ssh_channel channel;
    int auth_attempts;
    int authenticated;
};

static int data_function(ssh_session session, ssh_channel channel, void *data,
                         uint32_t len, int is_stderr, void *userdata) {
    char *buf = (char *) data;

    (void) session;
    (void) is_stderr;
    (void) userdata;

    for (uint32_t i = 0; i < len; ++i) buf[i] = toupper(buf[i]); // My function

    return ssh_channel_write(channel, (char *) data, len);
}


static int auth_publickey(ssh_session session,
                          const char *user,
                          struct ssh_key_struct *pubkey,
                          char signature_state,
                          void *userdata)
{
    struct session_data_struct *sdata = (struct session_data_struct *) userdata;

    (void) user;
    (void) session;

    fprintf(stderr,"auth_publickey signature_state %d\n",signature_state);
    if (signature_state == SSH_PUBLICKEY_STATE_NONE) {
        return SSH_AUTH_SUCCESS;
    }

    if (signature_state != SSH_PUBLICKEY_STATE_VALID) {
        return SSH_AUTH_DENIED;
    }

    // valid so far. Use my hardcoded RSA key file for now
    {
        char const* PUBKEY_FILE = "id_rsa.pub";
        ssh_key key = NULL;
        int result = ssh_pki_import_pubkey_file(PUBKEY_FILE, &key);
        if ((result != SSH_OK) || (key==NULL)) {
          fprintf(stderr, "Unable to import public key file %s\n",
                  PUBKEY_FILE);
        } else {
          result = ssh_key_cmp(key, pubkey, SSH_KEY_CMP_PUBLIC);
          ssh_key_free(key);
          fprintf(stderr,"auth_publickey ssh_key_cmp %d\n", result);
          if (result == 0) {
            sdata->authenticated = 1;
            return SSH_AUTH_SUCCESS;
          }
        }
    }

    // no matches
    sdata->authenticated = 0;
    return SSH_AUTH_DENIED;
}

static ssh_channel channel_open(ssh_session session, void *userdata) {
    struct session_data_struct *sdata = (struct session_data_struct *) userdata;

    fprintf(stderr,"channel_open\n");
    sdata->channel = ssh_channel_new(session);
    return sdata->channel;
}


static void handle_session(ssh_session session) {
    ssh_event event = ssh_event_new();
    int n;

    /* Our struct holding information about the session. */
    struct session_data_struct sdata = {
        .channel = NULL,
        .auth_attempts = 0,
        .authenticated = 0
    };

    struct ssh_channel_callbacks_struct channel_cb = {
        .userdata = NULL,
        .channel_data_function = data_function
    };

    struct ssh_server_callbacks_struct server_cb = {
        .userdata = &sdata,
        .auth_pubkey_function = auth_publickey,
        .channel_open_request_session_function = channel_open,
    };

    ssh_set_auth_methods(session, SSH_AUTH_METHOD_PUBLICKEY);

    ssh_callbacks_init(&server_cb);
    ssh_callbacks_init(&channel_cb);

    ssh_set_server_callbacks(session, &server_cb);

    if (ssh_handle_key_exchange(session) != SSH_OK) {
        fprintf(stderr, "%s\n", ssh_get_error(session));
        ssh_event_free(event);
        return;
    }

    ssh_event_add_session(event, session);
    fprintf(stderr,"handle_session: authenticating\n");

    n = 0;
    while (sdata.authenticated == 0 || sdata.channel == NULL) {
        /* If the user has used up all attempts, or if he hasn't been able to
         * authenticate in 10 seconds (n * 100ms), disconnect. */
        if (sdata.auth_attempts >= 3 || n >= 1000) {
            fprintf(stderr,"handle_session: authentication failure attempts %d\n",sdata.auth_attempts);
            ssh_event_free(event);
            return;
        }

        if (ssh_event_dopoll(event, 100) == SSH_ERROR) {
            fprintf(stderr, "%s\n", ssh_get_error(session));
            ssh_event_free(event);
            return;
        }
        n++;
    }

    ssh_set_channel_callbacks(sdata.channel, &channel_cb);

    fprintf(stderr,"handle_session: main event loop\n");

    do {
        /* Poll the main event which takes care of the session & the channel */
        if (ssh_event_dopoll(event, -1) == SSH_ERROR) {
          ssh_channel_close(sdata.channel);
        }
    } while (ssh_channel_is_open(sdata.channel));

    fprintf(stderr,"handle_session: cleaning up\n");

    ssh_channel_send_eof(sdata.channel);
    ssh_channel_close(sdata.channel);

    /* Wait up to 5 seconds for the client to terminate the session. */
    for (n = 0; n < 50 && (ssh_get_status(session) & SESSION_END) == 0; n++) {
        ssh_event_dopoll(event, 100);
    }
    ssh_event_free(event);
}


int main(int argc, char **argv) {
    ssh_bind sshbind = ssh_bind_new();
    int port = 47888;

    if (sshbind == NULL) {
        fprintf(stderr, "ssh_bind_new failed\n");
        return 1;
    }

    ssh_set_log_level(SSH_LOG_PACKET); // SSH_LOG_NOLOG,SSH_LOG_WARNING,SSH_LOG_PROTOCOL,SSH_LOG_PACKET,SSH_LOG_FUNCTIONS
    ssh_bind_options_set(sshbind, SSH_BIND_OPTIONS_BINDPORT, &port);
    ssh_bind_options_set(sshbind, SSH_BIND_OPTIONS_RSAKEY, "ssh_host_rsa_key");

    if (ssh_bind_listen(sshbind) < 0) {
        fprintf(stderr, "%s\n", ssh_get_error(sshbind));
        return 1;
    }

    while (1) {
        ssh_session session = ssh_new();
        if (session == NULL) {
            fprintf(stderr, "Failed to allocate session\n");
            continue;
        }

        /* Blocks until there is a new incoming connection. Make me non-blocking ... */
        fprintf(stderr,"Listening on port %d\n",port);
        if (ssh_bind_accept(sshbind, session) != SSH_ERROR) {
            /* Blocks until the SSH session ends by client disconnecting. */
            handle_session(session);  /* Make me non-blocking... */
        } else {
            fprintf(stderr, "%s\n", ssh_get_error(sshbind));
        }
  
        ssh_disconnect(session);
        ssh_free(session);
    }

    ssh_bind_free(sshbind);
    ssh_finalize();
    return 0;
}
