Hi,

I have encountered a problem in using SSL sockets in blocking mode. My
application is multi threaded, with one thread waiting to read and
another that is waiting to write. Upon some external input, a third
thread tries to shutdown the connection and then close the socket.
However, calling PR_Shutdown or PR_Close never returns from
SSL_LOCK_READER(ss).

Following are a simple client and server that demonstrate the problem
(search the client for 999 and the server for "delayed_close").

Thanks,
Yahel Zamir,
SW Engineer at CWNT

===
M_DIST = /usr/home/yahel/programs/nss/mozilla/dist

M_DIST_LINUX = ${M_DIST}/Linux2.4_x86_glibc_PTH_DBG.OBJ

M_INC =  -I${M_DIST_LINUX}/include
M_INC += -I${M_DIST}/public/nss
M_INC += -I${M_DIST}/private/nss # for error strings

M_FLAGS = -g -Wall -pthread

M_LIBS = -L${M_DIST_LINUX}/lib
M_LIBS += -lssl3 -lnss3 -lnspr4 -lplds4 -lplc4 -lsoftokn3
M_LIBS += -lfreebl3 -lsectool -lpthread -ldl -lc

all: server client
        @echo done

clean:
        rm -f client server client.o server.o

server.o: server.c
        gcc server.c -c ${M_FLAGS} ${M_INC} -o server.o

client.o: client.c
        gcc client.c -c ${M_FLAGS} ${M_INC} -o client.o

server: server.o
        gcc server.o ${M_FLAGS} ${M_LIBS} -o server

client: client.o
        gcc client.o ${M_FLAGS} ${M_LIBS} -o client

===

// client.c
// --------

#include <unistd.h>
#include <plgetopt.h>
#include <nspr.h>
#include <nss.h>
#include <pk11pub.h>
#include <ssl.h>
#include <sslproto.h>
#include <key.h>
#include <secutil.h> // a private API needed for SECU_Strerror

static char *s_progName = NULL;
static char *s_password = NULL;
static char *s_hostName = NULL;

/*
 * errWarn()
 *
 * Print a warning message for NSS and NSPR errors.
 * More detailed explanations for the error can be found at:
 * http://www.mozilla.org/projects/security/pki/nss/ref/ssl/sslerr.html
 */
static void
errWarn(char * funcString)
{
    PRErrorCode perr = PR_GetError();
    const char * errString = SECU_Strerror(perr);
    fprintf(stderr, "%s: %s returned error %d (%s)\n", s_progName,
funcString, perr, errString);
    return;
}

static void
errExit(char * funcString)
{
    errWarn(funcString);
    exit(3);
}

static void
Usage(const char *progName)
{
        fprintf(stderr, "Usage: %s hostname -n cert_name -p port -d cert_dir -
w password \n", progName);
        exit(1);
}

/* fakePasswd()
 *
 * This function is our custom password handler that is called by
 * SSL when retreiving private certs and keys from the database.
Returns a
 * pointer to a string that with a password for the database. Password
pointer
 * should point to dynamically allocated memory that will be freed
later.
 * We set "arg" to give the correct password, using
SSL_SetPKCS11PinArg().
 */
char *
fakePasswd(PK11SlotInfo *info, PRBool retry, void *arg)
{
        char * passwd = NULL;

        if ((!retry) && (arg != NULL)) {
                passwd = PL_strdup((char *)arg);
        }

        return passwd;
}


/* Function:  setupSSLSocket()
 *
 * Purpose:  Configure a socket for SSL.
 */
PRFileDesc *
setupSSLSocket()
{
        PRFileDesc         *sslSocket;
        PRFileDesc         *tcpSocket;
        SECStatus           secStatus;
        PRSocketOptionData      socketOption;
        PRStatus            prStatus;

    tcpSocket = PR_NewTCPSocket();
    if (tcpSocket == NULL) {
        errWarn("PR_NewTCPSocket");
        return NULL;
    }

    // Ensure the socket is blocking.
    socketOption.option             = PR_SockOpt_Nonblocking;
    socketOption.value.non_blocking = PR_FALSE;

    prStatus = PR_SetSocketOption(tcpSocket, &socketOption);
    if (prStatus != PR_SUCCESS) {
        errWarn("PR_SetSocketOption");
        PR_Close(tcpSocket);
        return NULL;
    }

        sslSocket = SSL_ImportFD(NULL, tcpSocket);
        if (sslSocket == NULL) {
                errWarn("SSL_ImportFD");
        PR_Close(tcpSocket);
        return NULL;
        }

    // ensure original socket is not used
    tcpSocket = NULL;

    do {

        // handshake as client
        secStatus = SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT,
PR_TRUE);
        if (secStatus != SECSuccess) {
                errWarn("SSL_OptionSet:SSL_HANDSHAKE_AS_CLIENT");
                break;
        }

        // enable full duplex
        secStatus = SSL_OptionSet(sslSocket, SSL_ENABLE_FDX, PR_TRUE);
        if (secStatus != SECSuccess) {
            errWarn("SSL_OptionSet(SSL_ENABLE_FDX)");
            break;
        }

        // allow communication without encryption (YZ - must be only
upon client request)
        if (SECSuccess != SSL_CipherPrefSet(sslSocket,
SSL_RSA_WITH_NULL_SHA, PR_TRUE)
            || SECSuccess != SSL_CipherPrefSet(sslSocket,
SSL_RSA_WITH_NULL_MD5, PR_TRUE)) {
            errWarn("SSL_CipherPrefSet - Null Cipher");
            break;
        }

        secStatus = SSL_SetPKCS11PinArg(sslSocket, s_password);
        if (secStatus != SECSuccess) {
                errWarn("SSL_SetPKCS11PinArg");
                break;
        }

        secStatus = SSL_SetURL(sslSocket, s_hostName);
        if (secStatus != SECSuccess) {
            errWarn("SSL_SetURL");
            break;
        }

        // Success
        return sslSocket;

    } while (0);

    // Failure
        PR_Close(sslSocket);
        return NULL;
}

/* Function:  clientMain()
 *
 * Purpose:  Setup an SSL socket and connect a server.
 *           Send a small message and expect a reply message.
 */
void
clientMain(unsigned short port)
{
        SECStatus       secStatus;
        PRStatus    prStatus;
        PRInt32     rv;
        PRNetAddr       addr;
        PRHostEnt   hostEntry;
    PRFileDesc *sslSocket;
        char        buffer[256];

    sslSocket = setupSSLSocket();
    if (sslSocket == NULL) {
                errExit("setupSSLSocket");
    }

        prStatus = PR_GetHostByName(s_hostName, buffer, sizeof(buffer),
&hostEntry);
        if (prStatus != PR_SUCCESS) {
                errExit("PR_GetHostByName");
        }

        rv = PR_EnumerateHostEnt(0, &hostEntry, port, &addr);
        if (rv < 0) {
                errExit("PR_EnumerateHostEnt");
        }

    prStatus = PR_Connect(sslSocket, &addr, PR_INTERVAL_NO_TIMEOUT);
    if (prStatus != PR_SUCCESS) {
        errExit("PR_Connect");
    }

    secStatus = SSL_ResetHandshake(sslSocket, /* asServer */
PR_FALSE);
    if (secStatus != SECSuccess) {
        errExit("SSL_ResetHandshake");
    }

    // single write, then single read

    sprintf(buffer, "hello from client");
    int msgSize = strlen(buffer);

        sleep(999);

    rv = PR_Write(sslSocket, buffer, msgSize);
    if (rv != msgSize) {
        errWarn("PR_Write");
        PR_Close(sslSocket);
        return;
    }

    printf("client sent: %s \n", buffer);

    rv = PR_Read(sslSocket, buffer, sizeof(buffer));
    // (rv == 0) is EOF
    if (rv <= 0) {
        errWarn("PR_Read");
        PR_Close(sslSocket);
        return;
    }

    buffer[rv] = 0;
    printf("client received: %s \n", buffer);

    prStatus = PR_Close(sslSocket);
    if (prStatus != PR_SUCCESS) {
        errExit("PR_Close");
    }

    return;
}

int
main(int argc, char **argv)
{
        char *              certName      = NULL;
        char *              certDir       = ".";
        unsigned short      port          = 0;
        PLOptState *        optstate;
        PLOptStatus         status;
        SECStatus           secStatus;

        s_progName = argv[0];

        optstate = PL_CreateOptState(argc, argv, "d:p:n:w:");
        while ((status = PL_GetNextOpt(optstate)) == PL_OPT_OK) {
                switch(optstate->option) {
                case 0:   s_hostName = PL_strdup(optstate->value);      break;
                case 'd': certDir = PL_strdup(optstate->value);         break;
                case 'n': certName = PL_strdup(optstate->value);        break;
                case 'p': port = atoi(optstate->value);                 break;
                case 'w': s_password = PL_strdup(optstate->value);      break;
        case '?': Usage(s_progName);                            break;
        default:  Usage(s_progName);
break;
                }
        }

        if (certName == NULL || s_hostName == NULL || s_password == NULL ||
port == 0) {
        Usage(s_progName);
    }


    // Client set up
    // -------------

    PR_Init( PR_SYSTEM_THREAD, PR_PRIORITY_NORMAL, 1);

    PK11_SetPasswordFunc(fakePasswd);

    secStatus = NSS_Init(certDir);
    if (secStatus != SECSuccess) {
        errExit("NSS_Init");
    }

        secStatus = NSS_SetExportPolicy();
        if (secStatus != SECSuccess) {
                errExit("NSS_SetExportPolicy");
        }

    // optional - clear client cache
    SSL_ClearSessionCache();


    // Client main function
    // --------------------

    clientMain(port);


    // Client shutdown
    // ---------------

    SSL_ClearSessionCache(); // otherwise, NSS_Shutdown fails.
    if (NSS_Shutdown() != SECSuccess) {
        errExit("NSS_Shutdown");
    }

    PR_Cleanup();

    if (certName) {
        free(certName);
    }

    if (certDir) {
        free(certDir);
    }

    if (s_hostName) {
        free(s_hostName);
    }

    printf("%s: normal termination\n", s_progName);
    return 0;
}

====

// server.c
// --------

#include <unistd.h>
#include <pthread.h>
#include <plgetopt.h>
#include <nspr.h>
#include <nss.h>
#include <pk11pub.h>
#include <ssl.h>
#include <sslproto.h>
#include <key.h>
#include <secutil.h> // a private API needed for SECU_Strerror

static char *s_progName = NULL;

/*
 * errWarn()
 *
 * Print a warning message for NSS and NSPR errors.
 * More detailed explanations for the error can be found at:
 * http://www.mozilla.org/projects/security/pki/nss/ref/ssl/sslerr.html
 */
static void
errWarn(char * funcString)
{
    PRErrorCode perr = PR_GetError();
    const char * errString = SECU_Strerror(perr);
    fprintf(stderr, "%s: %s returned error %d (%s)\n", s_progName,
funcString, perr, errString);
    return;
}

static void
errExit(char * funcString)
{
    errWarn(funcString);
    exit(3);
}

static void
Usage(const char *progName)
{
        fprintf(stderr, "Usage: %s -n cert_name -p port -w password [-d
cert_dir] \n", progName);
        exit(1);
}

/* fakePasswd()
 *
 * This function is our custom password handler that is called by
 * SSL when retreiving private certs and keys from the database.
Returns a
 * pointer to a string that with a password for the database. Password
pointer
 * should point to dynamically allocated memory that will be freed
later.
 * We set "arg" to give the correct password, using
SSL_SetPKCS11PinArg().
 */
char *
fakePasswd(PK11SlotInfo *info, PRBool retry, void *arg)
{
        char * passwd = NULL;

        if ((!retry) && (arg != NULL)) {
                passwd = PL_strdup((char *)arg);
        }

        return passwd;
}


/* Function:  setupSSLSocket()
 *
 * Purpose:  Configure a socket for SSL.
 * NSS has 3 methods for configuring a TCP socket for SSL:
 * 1. Do the configuration step by step.
 * 2. Inherit a model socket using SSL_ImportFD.
 * 3. Do the configuration on the listen socket. This way,
 *    sockets created by PR_Accept inherit the configuration.
 *
 * - this application uses method 3, so we configure the listen
socket.
 */
PRFileDesc *
setupSSLSocket(PRFileDesc *tcpSocket,
               CERTCertificate *cert,
               SECKEYPrivateKey *privKey,
               SSLKEAType sslKEA,
               char* password)
{
        PRFileDesc *sslSocket;
        SECStatus   secStatus;

        sslSocket = SSL_ImportFD(NULL, tcpSocket);
        if (sslSocket == NULL) {
                errWarn("SSL_ImportFD");
        PR_Close(tcpSocket);
        return NULL;
        }

    // ensure original socket is not used
    tcpSocket = NULL;

    do {

        // handshake as server
        secStatus = SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_SERVER,
PR_TRUE);
        if (secStatus != SECSuccess) {
                errWarn("SSL_OptionSet:SSL_HANDSHAKE_AS_SERVER");
                break;
        }

        // enable full duplex
        secStatus = SSL_OptionSet(sslSocket, SSL_ENABLE_FDX, PR_TRUE);
        if (secStatus != SECSuccess) {
            errWarn("SSL_OptionSet(SSL_ENABLE_FDX)");
            break;
        }

        // allow communication without encryption (YZ - must be only
upon client request)
        if (SECSuccess != SSL_CipherPrefSet(sslSocket,
SSL_RSA_WITH_NULL_SHA, PR_TRUE)
            || SECSuccess != SSL_CipherPrefSet(sslSocket,
SSL_RSA_WITH_NULL_MD5, PR_TRUE)) {
            errWarn("SSL_CipherPrefSet - Null Cipher");
            break;
        }

        /* // optional, see SSL
reference.                                                        */
        /* secStatus = SSL_AuthCertificateHook(sslSocket,
myAuthCertificate, CERT_GetDefaultCertDB()); */
        /* if (secStatus != SECSuccess)
{                                                              */
        /*
errWarn("SSL_AuthCertificateHook");
*/
        /*
break;
*/
        /
* }
*/

        /* // optional, see SSL
reference.                                        */
        /* secStatus =
SSL_BadCertHook(sslSocket,                                      */
        /*
(SSLBadCertHandler)myBadCertHandler, &certErr); */
        /* if (secStatus != SECSuccess)
{                                              */
        /*
errWarn("SSL_BadCertHook");
*/
        /*
break;
*/
        /
* }
*/

        /* // optional, see SSL
reference.                                        */
        /* secStatus =
SSL_HandshakeCallback(sslSocket,                                */
        /*
(SSLHandshakeCallback)myHandshakeCallback, */
        /*
NULL);                                    */
        /* if (secStatus != SECSuccess)
{                                              */
        /*
errWarn("SSL_HandshakeCallback");
*/
        /*
break;
*/
        /
* }
*/

        secStatus = SSL_SetPKCS11PinArg(sslSocket, password);
        if (secStatus != SECSuccess) {
                errWarn("SSL_HandshakeCallback");
                break;
        }

        secStatus = SSL_ConfigSecureServer(sslSocket, cert, privKey,
sslKEA);
        if (secStatus != SECSuccess) {
                errWarn("SSL_ConfigSecureServer");
                break;
        }

        // Success
        return sslSocket;

    } while (0);

    // Failure
        PR_Close(sslSocket);
        return NULL;
}

void* delayed_close(void *pSocket)
{
        PRFileDesc *socket = (PRFileDesc *) pSocket;
        sleep(1);
        PR_Shutdown(socket, PR_SHUTDOWN_BOTH);
        PR_Close(socket);
        printf("\n" "socket closed .\n");
        return NULL;
}

/* Function:  handleConnection()
 *
 * Purpose:  handle a single SSL connection.
 *
 */
SECStatus
handleConnection(PRFileDesc *pSocket)
{
        PRFileDesc         *sslSocket = (PRFileDesc *) pSocket;
        PRStatus            prStatus;
        PRSocketOptionData  socketOption;
    char                buffer[256];
    PRInt32             rv;

    // ensure the socket is blocking. this should be the default.
        socketOption.option = PR_SockOpt_Nonblocking;
        socketOption.value.non_blocking = PR_FALSE;
        PR_SetSocketOption(sslSocket, &socketOption);

    /* // handshake as server - required in addition to the SSL
option,   */
    /* // in case the listen socket was not an SSL
socket.                */
    /* secStatus = SSL_ResetHandshake(sslSocket,
PR_TRUE );               */
    /* if (secStatus != SECSuccess)
{                                     */
    /*
errWarn("SSL_ResetHandshake");                                 */
    /*     return
secStatus;                                              */
    /
* }
*/


    pthread_t pth;
        pthread_create(&pth, NULL, delayed_close, sslSocket);

    // use socket: single read, then single write

    rv = PR_Read(sslSocket, buffer, sizeof(buffer));
    // (rv == 0) is EOF
    if (rv <= 0) {
        errWarn("PR_Read");
        PR_Close(sslSocket);
        return SECFailure;
    }

    buffer[rv] = 0;
    printf("server received: %s \n", buffer);

    sprintf(buffer, "hello from server");
    int msgSize = strlen(buffer);

    rv = PR_Write(sslSocket, buffer, msgSize);
    if (rv != msgSize) {
        errWarn("PR_Write");
        PR_Close(sslSocket);
        return SECFailure;
    }

    printf("server sent: %s \n", buffer);

    // Finally
    // -------

    printf("\n" "Closing client connection.\n");

    prStatus = PR_Close(sslSocket);
        if (prStatus != PR_SUCCESS) {
                errWarn("PR_Close");
                return SECFailure;
        }

        return SECSuccess;

}

/* Function: startListening()
 *
 * Purpose: Create a new socket and starts listening.
 */

PRFileDesc *
startListening(unsigned short port)
{
    PRFileDesc *       listen_sock;
    PRStatus           prStatus;
    PRNetAddr          addr;
    PRSocketOptionData opt;

    addr.inet.family = PR_AF_INET;
    addr.inet.ip     = PR_INADDR_ANY;
    addr.inet.port   = PR_htons(port);

    listen_sock = PR_NewTCPSocket();
    if (listen_sock == NULL) {
        errExit("PR_NewTCPSocket");
    }

    // YZ set blocking mode. this should be the default anyway.
    opt.option = PR_SockOpt_Nonblocking;
    opt.value.non_blocking = PR_FALSE;
    prStatus = PR_SetSocketOption(listen_sock, &opt);
    if (prStatus < 0) {
        errExit("PR_SetSocketOption(PR_SockOpt_Nonblocking =
PR_FALSE)");
    }

    opt.option=PR_SockOpt_Reuseaddr;
    opt.value.reuse_addr = PR_TRUE;
    prStatus = PR_SetSocketOption(listen_sock, &opt);
    if (prStatus < 0) {
        errExit("PR_SetSocketOption(PR_SockOpt_Reuseaddr = PR_TRUE)");
    }

    prStatus = PR_Bind(listen_sock, &addr);
    if (prStatus < 0) {
        errExit("PR_Bind");
    }

    prStatus = PR_Listen(listen_sock, 5);
    if (prStatus < 0) {
        errExit("PR_Listen");
    }

    return listen_sock;
}

/* Function: serverMain()
 *
 * Purpose:
 * Loop to accept connections. For every connection,
 * receive a message and then send a reply message.
 * Since listenSocket is an SSL socket, tcpSocket is an SSL socket
too.
 */
void
serverMain(PRFileDesc *listenSocket)
{
        PRNetAddr   addr;
        PRStatus    prStatus;

        while (1) {
                PRFileDesc *tcpSocket;

                printf("\n" "Waiting for new connection.\n");

                /* Accept a connection */
                tcpSocket = PR_Accept(listenSocket, &addr, 
PR_INTERVAL_NO_TIMEOUT);
                if (tcpSocket == NULL) {
                        errWarn("PR_Accept");
                        break;
                }

                // Handle one connection at a time.
        // Can be replaced with a new thread
        pthread_t pth;
                pthread_create(&pth, NULL, handleConnection, tcpSocket);
        }

        printf("\n" "Closing listen socket.\n");

        prStatus = PR_Close(listenSocket);
        if (prStatus != PR_SUCCESS) {
                errWarn("PR_Close");
        }

        return;
}

int
main(int argc, char **argv)
{
        char *              password      = NULL;
        char *              certName      = NULL;
        char *              certDir       = ".";
        char *              cipherString  = NULL;
        unsigned short      port          = 0;
        PLOptState *        optstate;
        PLOptStatus         status;
        SECStatus           secStatus;
    CERTCertificate    *cert = NULL;
        SSLKEAType          sslKEA;
    SECKEYPrivateKey   *privKey = NULL;
    PRFileDesc *        listenSocket;
    PRFileDesc *        sslSocket;

        s_progName = argv[0];

        optstate = PL_CreateOptState(argc, argv, "c:d:p:n:w:");
        while ((status = PL_GetNextOpt(optstate)) == PL_OPT_OK) {
                switch(optstate->option) {
                case 'c': cipherString = PL_strdup(optstate->value);    break;
                case 'd': certDir = PL_strdup(optstate->value);         break;
                case 'n': certName = PL_strdup(optstate->value);        break;
                case 'p': port = atoi(optstate->value);                 break;
                case 'w': password = PL_strdup(optstate->value);        break;
                case 0:                                                 break;
                case '?': Usage(s_progName);                            break;
        default:  Usage(s_progName);
break;
                }
        }

        if (certName == NULL || password == NULL || port == 0) {
        Usage(s_progName);
    }


    // Server set up
    // -------------

    PR_Init( PR_SYSTEM_THREAD, PR_PRIORITY_NORMAL, 1);

    PK11_SetPasswordFunc(fakePasswd);

    secStatus = NSS_Init(certDir);
    if (secStatus != SECSuccess) {
        errExit("NSS_Init");
    }

    // Using export policy.
        secStatus = NSS_SetExportPolicy();
        if (secStatus != SECSuccess) {
                errExit("NSS_SetExportPolicy");
        }

        /* Get own certificate and private key. */
        cert = PK11_FindCertFromNickname(certName, password);
        if (cert == NULL) {
                errExit("PK11_FindCertFromNickname");
        }

    sslKEA = NSS_FindCertKEAType(cert);
    if (sslKEA == kt_null) {
        errExit("NSS_FindCertKEAType");
    }

        privKey = PK11_FindKeyByAnyCert(cert, password);
        if (privKey == NULL) {
                errExit("PK11_FindKeyByAnyCert");
        }
    // Set a session cache for single-process server.
    secStatus = SSL_ConfigServerSessionIDCache(100, 0, 0, 0);
    if (secStatus != SECSuccess) {
        errExit("SSL_ConfigServerSessionIDCache");
    }

    listenSocket = startListening(port);
    if (listenSocket == NULL) {
        errExit("startListening");
    }

    sslSocket = setupSSLSocket(listenSocket, cert, privKey, sslKEA,
password);
    if (sslSocket == NULL) {
        errExit("setupSSLSocket");
    }

    // Main server loop
    // ----------------

    serverMain(sslSocket);


    // Server shutdown
    // ---------------

    if (cert) {
        CERT_DestroyCertificate(cert);
    }

    if (privKey) {
        SECKEY_DestroyPrivateKey(privKey);
    }

    SSL_ShutdownServerSessionIDCache();

    if (NSS_Shutdown() != SECSuccess) {
        errExit("NSS_Shutdown");
    }

    PR_Cleanup();

    if (certName) {
        free(certName);
    }

    if (certDir) {
        free(certDir);
    }

    if (password) {
        free(password);
    }

    printf("%s: normal termination\n", s_progName);
    return 0;
}

_______________________________________________
dev-tech-crypto mailing list
dev-tech-crypto@lists.mozilla.org
https://lists.mozilla.org/listinfo/dev-tech-crypto

Reply via email to