This is an automated email from the ASF dual-hosted git repository.

markt pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tomcat.git

commit e03da0f2245af6381fd5081ef5f7436a740b8dec
Author: Mark Thomas <ma...@apache.org>
AuthorDate: Tue Dec 21 14:42:50 2021 +0000

    Add support for the new TLS configuration API for client connections
---
 java/jakarta/websocket/ClientEndpointConfig.java   | 20 +++++---
 .../websocket/DefaultClientEndpointConfig.java     | 13 ++++-
 .../tomcat/websocket/WsWebSocketContainer.java     | 14 ++++--
 res/checkstyle/jakarta-import-control.xml          |  1 +
 .../websocket/TestWsWebSocketContainerSSL.java     | 58 +++++++++++++++++++++-
 webapps/docs/changelog.xml                         |  4 ++
 6 files changed, 98 insertions(+), 12 deletions(-)

diff --git a/java/jakarta/websocket/ClientEndpointConfig.java 
b/java/jakarta/websocket/ClientEndpointConfig.java
index a56af4b..fbf752b 100644
--- a/java/jakarta/websocket/ClientEndpointConfig.java
+++ b/java/jakarta/websocket/ClientEndpointConfig.java
@@ -20,12 +20,16 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
+import javax.net.ssl.SSLContext;
+
 public interface ClientEndpointConfig extends EndpointConfig {
 
     List<String> getPreferredSubprotocols();
 
     List<Extension> getExtensions();
 
+    SSLContext getSSLContext();
+
     public Configurator getConfigurator();
 
     public final class Builder {
@@ -46,15 +50,13 @@ public interface ClientEndpointConfig extends 
EndpointConfig {
         private Configurator configurator = DEFAULT_CONFIGURATOR;
         private List<String> preferredSubprotocols = Collections.emptyList();
         private List<Extension> extensions = Collections.emptyList();
-        private List<Class<? extends Encoder>> encoders =
-                Collections.emptyList();
-        private List<Class<? extends Decoder>> decoders =
-                Collections.emptyList();
-
+        private List<Class<? extends Encoder>> encoders = 
Collections.emptyList();
+        private List<Class<? extends Decoder>> decoders = 
Collections.emptyList();
+        private SSLContext sslContext = null;
 
         public ClientEndpointConfig build() {
             return new DefaultClientEndpointConfig(preferredSubprotocols,
-                    extensions, encoders, decoders, configurator);
+                    extensions, encoders, decoders, sslContext, configurator);
         }
 
 
@@ -110,6 +112,12 @@ public interface ClientEndpointConfig extends 
EndpointConfig {
             }
             return this;
         }
+
+
+        public Builder sslContext(SSLContext sslContext) {
+            this.sslContext = sslContext;
+            return this;
+        }
     }
 
 
diff --git a/java/jakarta/websocket/DefaultClientEndpointConfig.java 
b/java/jakarta/websocket/DefaultClientEndpointConfig.java
index e166925..cf29809 100644
--- a/java/jakarta/websocket/DefaultClientEndpointConfig.java
+++ b/java/jakarta/websocket/DefaultClientEndpointConfig.java
@@ -20,12 +20,15 @@ import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 
+import javax.net.ssl.SSLContext;
+
 final class DefaultClientEndpointConfig implements ClientEndpointConfig {
 
     private final List<String> preferredSubprotocols;
     private final List<Extension> extensions;
     private final List<Class<? extends Encoder>> encoders;
     private final List<Class<? extends Decoder>> decoders;
+    private final SSLContext sslContext;
     private final Map<String,Object> userProperties = new 
ConcurrentHashMap<>();
     private final Configurator configurator;
 
@@ -34,11 +37,13 @@ final class DefaultClientEndpointConfig implements 
ClientEndpointConfig {
             List<Extension> extensions,
             List<Class<? extends Encoder>> encoders,
             List<Class<? extends Decoder>> decoders,
+            SSLContext sslContext,
             Configurator configurator) {
         this.preferredSubprotocols = preferredSubprotocols;
         this.extensions = extensions;
-        this.decoders = decoders;
         this.encoders = encoders;
+        this.decoders = decoders;
+        this.sslContext = sslContext;
         this.configurator = configurator;
     }
 
@@ -68,6 +73,12 @@ final class DefaultClientEndpointConfig implements 
ClientEndpointConfig {
 
 
     @Override
+    public SSLContext getSSLContext() {
+        return sslContext;
+    }
+
+
+    @Override
     public final Map<String, Object> getUserProperties() {
         return userProperties;
     }
diff --git a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java 
b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
index 7792122..e6c5f92 100644
--- a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
+++ b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
@@ -316,7 +316,7 @@ public class WsWebSocketContainer implements 
WebSocketContainer, BackgroundProce
                 // Regardless of whether a non-secure wrapper was created for a
                 // proxy CONNECT, need to use TLS from this point on so wrap 
the
                 // original AsynchronousSocketChannel
-                SSLEngine sslEngine = createSSLEngine(userProperties, host, 
port);
+                SSLEngine sslEngine = 
createSSLEngine(clientEndpointConfiguration, host, port);
                 channel = new AsyncChannelWrapperSecure(socketChannel, 
sslEngine);
             } else if (channel == null) {
                 // Only need to wrap as this point if it wasn't wrapped to 
process a
@@ -900,13 +900,19 @@ public class WsWebSocketContainer implements 
WebSocketContainer, BackgroundProce
     }
 
 
-    private SSLEngine createSSLEngine(Map<String,Object> userProperties, 
String host, int port)
+    private SSLEngine createSSLEngine(ClientEndpointConfig 
clientEndpointConfig, String host, int port)
             throws DeploymentException {
 
+        Map<String,Object> userProperties = 
clientEndpointConfig.getUserProperties();
         try {
             // See if a custom SSLContext has been provided
-            SSLContext sslContext =
-                    (SSLContext) 
userProperties.get(Constants.SSL_CONTEXT_PROPERTY);
+            SSLContext sslContext = clientEndpointConfig.getSSLContext();
+
+            // If no SSLContext is found, try the pre WebSocket 2.1 Tomcat
+            // specific method
+            if (sslContext == null) {
+                sslContext = (SSLContext) 
userProperties.get(Constants.SSL_CONTEXT_PROPERTY);
+            }
 
             if (sslContext == null) {
                 // Create the SSL Context
diff --git a/res/checkstyle/jakarta-import-control.xml 
b/res/checkstyle/jakarta-import-control.xml
index 4ac0792..9034aa9 100644
--- a/res/checkstyle/jakarta-import-control.xml
+++ b/res/checkstyle/jakarta-import-control.xml
@@ -72,5 +72,6 @@
   </subpackage>
   <subpackage name="websocket">
     <allow pkg="jakarta.websocket"/>
+    <allow pkg="javax.net.ssl"/>
   </subpackage>
 </import-control>
\ No newline at end of file
diff --git a/test/org/apache/tomcat/websocket/TestWsWebSocketContainerSSL.java 
b/test/org/apache/tomcat/websocket/TestWsWebSocketContainerSSL.java
index b30df96..bcc0a24 100644
--- a/test/org/apache/tomcat/websocket/TestWsWebSocketContainerSSL.java
+++ b/test/org/apache/tomcat/websocket/TestWsWebSocketContainerSSL.java
@@ -16,7 +16,11 @@
  */
 package org.apache.tomcat.websocket;
 
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.InputStream;
 import java.net.URI;
+import java.security.KeyStore;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
@@ -24,6 +28,9 @@ import java.util.Queue;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.TrustManagerFactory;
+
 import jakarta.websocket.ClientEndpointConfig;
 import jakarta.websocket.ContainerProvider;
 import jakarta.websocket.Session;
@@ -42,6 +49,7 @@ import org.apache.catalina.core.StandardServer;
 import org.apache.catalina.servlets.DefaultServlet;
 import org.apache.catalina.startup.Tomcat;
 import org.apache.tomcat.util.net.TesterSupport;
+import org.apache.tomcat.util.security.KeyStoreUtil;
 import org.apache.tomcat.websocket.TesterMessageCountClient.BasicText;
 import 
org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
 
@@ -74,7 +82,7 @@ public class TestWsWebSocketContainerSSL extends 
WebSocketBaseTest {
     private static final String MESSAGE_STRING_1 = "qwerty";
 
     @Test
-    public void testConnectToServerEndpointSSL() throws Exception {
+    public void testConnectToServerEndpointSslLegacy() throws Exception {
 
         Tomcat tomcat = getTomcatInstance();
         // No file system docBase required
@@ -112,6 +120,54 @@ public class TestWsWebSocketContainerSSL extends 
WebSocketBaseTest {
     }
 
 
+    @Test
+    public void testConnectToServerEndpointSSL() throws Exception {
+
+        Tomcat tomcat = getTomcatInstance();
+        // No file system docBase required
+        Context ctx = tomcat.addContext("", null);
+        ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
+        Tomcat.addServlet(ctx, "default", new DefaultServlet());
+        ctx.addServletMappingDecoded("/", "default");
+
+        tomcat.start();
+
+        WebSocketContainer wsContainer = 
ContainerProvider.getWebSocketContainer();
+
+        // Build the SSLContext
+        SSLContext sslContext = SSLContext.getInstance("TLS");
+        File trustStoreFile = new File(TesterSupport.CA_JKS);
+        KeyStore ks = KeyStore.getInstance("JKS");
+        try (InputStream is = new FileInputStream(trustStoreFile)) {
+            KeyStoreUtil.load(ks, is, 
Constants.SSL_TRUSTSTORE_PWD_DEFAULT.toCharArray());
+        }
+        TrustManagerFactory tmf = 
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+        tmf.init(ks);
+        sslContext.init(null,  tmf.getTrustManagers(), null);
+
+        ClientEndpointConfig clientEndpointConfig =
+                
ClientEndpointConfig.Builder.create().sslContext(sslContext).build();
+
+        Session wsSession = wsContainer.connectToServer(
+                TesterProgrammaticEndpoint.class,
+                clientEndpointConfig,
+                new URI("wss://localhost:" + getPort() +
+                        TesterEchoServer.Config.PATH_ASYNC));
+        CountDownLatch latch = new CountDownLatch(1);
+        BasicText handler = new BasicText(latch);
+        wsSession.addMessageHandler(handler);
+        wsSession.getBasicRemote().sendText(MESSAGE_STRING_1);
+
+        boolean latchResult = handler.getLatch().await(10, TimeUnit.SECONDS);
+
+        Assert.assertTrue(latchResult);
+
+        Queue<String> messages = handler.getMessages();
+        Assert.assertEquals(1, messages.size());
+        Assert.assertEquals(MESSAGE_STRING_1, messages.peek());
+    }
+
+
     @Override
     public void setUp() throws Exception {
         super.setUp();
diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml
index 4153ffd..bcc9d24 100644
--- a/webapps/docs/changelog.xml
+++ b/webapps/docs/changelog.xml
@@ -163,6 +163,10 @@
         that allows applications to opt to upgrade an HTTP connection to
         WebSocket. (markt)
       </add>
+      <add>
+        Add support for the WebSocket 2.1 client-side API for configuring TLS
+        connection for wss client connections. (markt)
+      </add>
     </changelog>
   </subsection>
 </section>

---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org
For additional commands, e-mail: dev-h...@tomcat.apache.org

Reply via email to