This is an automated email from the ASF dual-hosted git repository.

jongyoul pushed a commit to branch branch-0.12
in repository https://gitbox.apache.org/repos/asf/zeppelin.git


The following commit(s) were added to refs/heads/branch-0.12 by this push:
     new f4847ea8eb [NO-ISSUE] Implement Origin check for terminal interpreter 
WebSocket connections
f4847ea8eb is described below

commit f4847ea8ebc03a70768c91af0a33154768076fd0
Author: ChanHo Lee <chanho0...@gmail.com>
AuthorDate: Sat Nov 2 21:18:32 2024 +0900

    [NO-ISSUE] Implement Origin check for terminal interpreter WebSocket 
connections
    
    ### What is this PR for?
    
    This PR adds an Origin check to ensure that WebSocket connections are 
initiated from trusted sources only.
    By validating the `Origin` header in the initial WebSocket handshake, we 
can prevent unauthorized or malicious websites from establishing WebSocket 
connections with our server.
    
    Changes:
    - Added server-side validation of the `Origin` header during WebSocket 
connection requests.
    
    Other security enhancements may be needed and can be handled in future 
iterations.
    
    ### What type of PR is it?
    
    Improvement
    
    ### Todos
    * [ ] - Task
    
    ### How should this be tested?
    * Strongly recommended: add automated unit tests for any new or changed 
behavior
    * Outline any manual steps to test the PR here.
    
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Does the license files need to update? No
    * Is there breaking changes for older versions? No
    * Does this needs documentation? No
    
    Closes #4823 from tbonelee/websocket.
    
    Signed-off-by: Jongyoul Lee <jongy...@gmail.com>
    (cherry picked from commit 3575a3cf8834cdc0100f0a554846829584df0c30)
    Signed-off-by: Jongyoul Lee <jongy...@gmail.com>
---
 .../apache/zeppelin/shell/TerminalInterpreter.java |  26 +++-
 .../zeppelin/shell/terminal/TerminalThread.java    |  11 +-
 .../websocket/TerminalSessionConfigurator.java     |  39 ++++++
 .../zeppelin/shell/TerminalInterpreterTest.java    | 141 ++++++++++++++++++++-
 .../shell/terminal/TerminalSocketTest.java         |  55 ++++----
 5 files changed, 228 insertions(+), 44 deletions(-)

diff --git 
a/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java 
b/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java
index d54495e2c2..5a4eb8bae7 100644
--- a/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java
+++ b/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java
@@ -63,6 +63,7 @@ public class TerminalInterpreter extends KerberosInterpreter {
   private InterpreterContext intpContext;
 
   private int terminalPort = 0;
+  private String terminalHostIp;
 
   // Internal and external IP mapping of zeppelin server
   private HashMap<String, String> mapIpMapping = new HashMap<>();
@@ -109,7 +110,11 @@ public class TerminalInterpreter extends 
KerberosInterpreter {
     if (null == terminalThread) {
       try {
         terminalPort = 
RemoteInterpreterUtils.findRandomAvailablePortOnAllLocalInterfaces();
-        terminalThread = new TerminalThread(terminalPort);
+        terminalHostIp =  RemoteInterpreterUtils.findAvailableHostAddress();
+        LOGGER.info("Terminal host IP: " + terminalHostIp);
+        LOGGER.info("Terminal port: " + terminalPort);
+        String allowedOrigin = generateOrigin(terminalHostIp, terminalPort);
+        terminalThread = new TerminalThread(terminalPort, allowedOrigin);
         terminalThread.start();
       } catch (IOException e) {
         LOGGER.error(e.getMessage(), e);
@@ -136,20 +141,20 @@ public class TerminalInterpreter extends 
KerberosInterpreter {
       mapIpMapping = gson.fromJson(strIpMapping, new TypeToken<Map<String, 
String>>(){}.getType());
     }
 
-    createTerminalDashboard(context.getNoteId(), context.getParagraphId(), 
terminalPort);
+    createTerminalDashboard(context.getNoteId(), context.getParagraphId(),
+        terminalHostIp, terminalPort);
 
     return new InterpreterResult(Code.SUCCESS);
   }
 
-  public void createTerminalDashboard(String noteId, String paragraphId, int 
port) {
-    String hostName = "", hostIp = "";
+  public void createTerminalDashboard(String noteId, String paragraphId, 
String hostIp, int port) {
+    String hostName = "";
     URL urlTemplate = 
Resources.getResource("ui_templates/terminal-dashboard.jinja");
     String template = null;
     try {
       template = Resources.toString(urlTemplate, Charsets.UTF_8);
       InetAddress addr = InetAddress.getLocalHost();
       hostName = addr.getHostName().toString();
-      hostIp = RemoteInterpreterUtils.findAvailableHostAddress();
 
       // Internal and external IP mapping of zeppelin server
       if (mapIpMapping.containsKey(hostIp)) {
@@ -164,7 +169,7 @@ public class TerminalInterpreter extends 
KerberosInterpreter {
     Jinjava jinjava = new Jinjava();
     HashMap<String, Object> jinjaParams = new HashMap();
     Date now = new Date();
-    String terminalServerUrl = "http://"; + hostIp + ":" + port +
+    String terminalServerUrl = generateOrigin(hostIp, port) +
         "?noteId=" + noteId + "&paragraphId=" + paragraphId + "&t=" + 
now.getTime();
     jinjaParams.put("HOST_NAME", hostName);
     jinjaParams.put("HOST_IP", hostIp);
@@ -183,6 +188,10 @@ public class TerminalInterpreter extends 
KerberosInterpreter {
     }
   }
 
+  private String generateOrigin(String hostIp, int port) {
+    return "http://"; + hostIp + ":" + port;
+  }
+
   @Override
   public void cancel(InterpreterContext context) {
   }
@@ -238,6 +247,11 @@ public class TerminalInterpreter extends 
KerberosInterpreter {
     return terminalPort;
   }
 
+  @VisibleForTesting
+  public String getTerminalHostIp() {
+    return terminalHostIp;
+  }
+
   @VisibleForTesting
   public boolean terminalThreadIsRunning() {
     return terminalThread.isRunning();
diff --git 
a/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java 
b/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java
index 620a6dee5e..31c7e22655 100644
--- a/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java
+++ b/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java
@@ -18,7 +18,9 @@
 package org.apache.zeppelin.shell.terminal;
 
 import javax.websocket.server.ServerContainer;
+import javax.websocket.server.ServerEndpointConfig;
 
+import 
org.apache.zeppelin.shell.terminal.websocket.TerminalSessionConfigurator;
 import org.apache.zeppelin.shell.terminal.websocket.TerminalSocket;
 import org.eclipse.jetty.server.Server;
 import org.eclipse.jetty.server.ServerConnector;
@@ -38,9 +40,11 @@ public class TerminalThread extends Thread {
   private Server jettyServer = new Server();
 
   private int port = 0;
+  private String allwedOrigin;
 
-  public TerminalThread(int port) {
+  public TerminalThread(int port, String allwedOrigin) {
     this.port = port;
+    this.allwedOrigin = allwedOrigin;
   }
 
   public void run() {
@@ -72,7 +76,10 @@ public class TerminalThread extends Thread {
 
     try {
       ServerContainer container = 
WebSocketServerContainerInitializer.configureContext(context);
-      container.addEndpoint(TerminalSocket.class);
+      container.addEndpoint(
+          ServerEndpointConfig.Builder.create(TerminalSocket.class, "/")
+              .configurator(new TerminalSessionConfigurator(allwedOrigin))
+              .build());
       jettyServer.start();
       jettyServer.join();
     } catch (Exception e) {
diff --git 
a/shell/src/main/java/org/apache/zeppelin/shell/terminal/websocket/TerminalSessionConfigurator.java
 
b/shell/src/main/java/org/apache/zeppelin/shell/terminal/websocket/TerminalSessionConfigurator.java
new file mode 100644
index 0000000000..0e7d20f218
--- /dev/null
+++ 
b/shell/src/main/java/org/apache/zeppelin/shell/terminal/websocket/TerminalSessionConfigurator.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.shell.terminal.websocket;
+
+import javax.websocket.server.ServerEndpointConfig.Configurator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class TerminalSessionConfigurator  extends Configurator {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(TerminalSessionConfigurator.class);
+  private String allowedOrigin;
+
+  public TerminalSessionConfigurator(String allowedOrigin) {
+    this.allowedOrigin = allowedOrigin;
+  }
+
+  @Override
+  public boolean checkOrigin(String originHeaderValue) {
+    boolean allowed = allowedOrigin.equals(originHeaderValue);
+    LOGGER.info("Checking origin for TerminalSessionConfigurator: " +
+        originHeaderValue + " allowed: " + allowed);
+    return allowed;
+  }
+}
diff --git 
a/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java 
b/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java
index 6365a0a4ae..4d71889952 100644
--- a/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java
+++ b/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java
@@ -17,6 +17,12 @@
 
 package org.apache.zeppelin.shell;
 
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import javax.websocket.ClientEndpointConfig;
+import javax.websocket.ClientEndpointConfig.Builder;
+import javax.websocket.ClientEndpointConfig.Configurator;
 import org.apache.zeppelin.interpreter.InterpreterContext;
 import org.apache.zeppelin.interpreter.InterpreterException;
 import org.apache.zeppelin.interpreter.InterpreterResult;
@@ -34,6 +40,7 @@ import javax.websocket.Session;
 import javax.websocket.WebSocketContainer;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 import java.io.IOException;
@@ -81,11 +88,17 @@ class TerminalInterpreterTest extends BaseInterpreterTest {
       boolean running = terminal.terminalThreadIsRunning();
       assertTrue(running);
 
-      URI uri = URI.create("ws://localhost:" + terminal.getTerminalPort() + 
"/terminal/");
+      URI webSocketConnectionUri = URI.create("ws://" + 
terminal.getTerminalHostIp() +
+          ":" + terminal.getTerminalPort() + "/terminal/");
+      LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
+      String origin = "http://"; + terminal.getTerminalHostIp() + ":" + 
terminal.getTerminalPort();
+      LOGGER.info("origin: " + origin);
+      ClientEndpointConfig clientEndpointConfig = 
getOriginRequestHeaderConfig(origin);
       webSocketContainer = ContainerProvider.getWebSocketContainer();
 
       // Attempt Connect
-      session = webSocketContainer.connectToServer(TerminalSocketTest.class, 
uri);
+      session = webSocketContainer.connectToServer(
+          TerminalSocketTest.class, clientEndpointConfig, 
webSocketConnectionUri);
 
       // Send Start terminal service message
       String terminalReadyCmd = String.format("{\"type\":\"TERMINAL_READY\"," +
@@ -161,11 +174,17 @@ class TerminalInterpreterTest extends BaseInterpreterTest 
{
       boolean running = terminal.terminalThreadIsRunning();
       assertTrue(running);
 
-      URI uri = URI.create("ws://localhost:" + terminal.getTerminalPort() + 
"/terminal/");
+      URI webSocketConnectionUri = URI.create("ws://" + 
terminal.getTerminalHostIp() +
+          ":" + terminal.getTerminalPort() + "/terminal/");
+      LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
+      String origin = "http://"; + terminal.getTerminalHostIp() + ":" + 
terminal.getTerminalPort();
+      LOGGER.info("origin: " + origin);
+      ClientEndpointConfig clientEndpointConfig = 
getOriginRequestHeaderConfig(origin);
       webSocketContainer = ContainerProvider.getWebSocketContainer();
 
       // Attempt Connect
-      session = webSocketContainer.connectToServer(TerminalSocketTest.class, 
uri);
+      session = webSocketContainer.connectToServer(
+          TerminalSocketTest.class, clientEndpointConfig, 
webSocketConnectionUri);
 
       // Send Start terminal service message
       String terminalReadyCmd = String.format("{\"type\":\"TERMINAL_READY\"," +
@@ -229,4 +248,118 @@ class TerminalInterpreterTest extends BaseInterpreterTest 
{
       }
     }
   }
+
+  @Test
+  void testValidOrigin() {
+    Session session = null;
+
+    // mock connect terminal
+    boolean running = terminal.terminalThreadIsRunning();
+    assertTrue(running);
+
+    URI webSocketConnectionUri = URI.create("ws://" + 
terminal.getTerminalHostIp() +
+        ":" + terminal.getTerminalPort() + "/terminal/");
+    LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
+    String origin = "http://"; + terminal.getTerminalHostIp() + ":" + 
terminal.getTerminalPort();
+    LOGGER.info("origin: " + origin);
+    ClientEndpointConfig clientEndpointConfig = 
getOriginRequestHeaderConfig(origin);
+    WebSocketContainer webSocketContainer = 
ContainerProvider.getWebSocketContainer();
+
+    Throwable exception = null;
+    try {
+      // Attempt Connect
+      session = webSocketContainer.connectToServer(
+          TerminalSocketTest.class, clientEndpointConfig, 
webSocketConnectionUri);
+    } catch (DeploymentException e) {
+      exception = e;
+    } catch (IOException e) {
+      exception = e;
+    } finally {
+      if (session != null) {
+        try {
+          session.close();
+        } catch (IOException e) {
+          LOGGER.error(e.getMessage(), e);
+        }
+      }
+
+      // Force lifecycle stop when done with container.
+      // This is to free up threads and resources that the
+      // JSR-356 container allocates. But unfortunately
+      // the JSR-356 spec does not handle lifecycles (yet)
+      if (webSocketContainer instanceof LifeCycle) {
+        try {
+          ((LifeCycle) webSocketContainer).stop();
+        } catch (Exception e) {
+          LOGGER.error(e.getMessage(), e);
+        }
+      }
+    }
+
+    assertNull(exception);
+  }
+
+  @Test
+  void testInvalidOrigin() {
+    Session session = null;
+
+    // mock connect terminal
+    boolean running = terminal.terminalThreadIsRunning();
+    assertTrue(running);
+
+    URI webSocketConnectionUri = URI.create("ws://" + 
terminal.getTerminalHostIp() +
+        ":" + terminal.getTerminalPort() + "/terminal/");
+    LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
+    String origin = "http://invalid-origin";;
+    LOGGER.info("origin: " + origin);
+    ClientEndpointConfig clientEndpointConfig = 
getOriginRequestHeaderConfig(origin);
+    WebSocketContainer webSocketContainer = 
ContainerProvider.getWebSocketContainer();
+
+    Throwable exception = null;
+    try {
+      // Attempt Connect
+      session = webSocketContainer.connectToServer(
+          TerminalSocketTest.class, clientEndpointConfig, 
webSocketConnectionUri);
+    } catch (DeploymentException e) {
+      exception = e;
+    } catch (IOException e) {
+      exception = e;
+    } finally {
+      if (session != null) {
+        try {
+          session.close();
+        } catch (IOException e) {
+          LOGGER.error(e.getMessage(), e);
+        }
+      }
+
+      // Force lifecycle stop when done with container.
+      // This is to free up threads and resources that the
+      // JSR-356 container allocates. But unfortunately
+      // the JSR-356 spec does not handle lifecycles (yet)
+      if (webSocketContainer instanceof LifeCycle) {
+        try {
+          ((LifeCycle) webSocketContainer).stop();
+        } catch (Exception e) {
+          LOGGER.error(e.getMessage(), e);
+        }
+      }
+    }
+
+    assertTrue(exception instanceof IOException);
+    assertEquals("Connect failure", exception.getMessage());
+  }
+
+  private static ClientEndpointConfig getOriginRequestHeaderConfig(String 
origin) {
+    Configurator configurator = new Configurator() {
+      @Override
+      public void beforeRequest(Map<String, List<String>> headers) {
+        headers.put("Origin", Arrays.asList(origin));
+      }
+    };
+    ClientEndpointConfig clientEndpointConfig = Builder.create()
+        .configurator(configurator)
+        .build();
+    return clientEndpointConfig;
+  }
 }
diff --git 
a/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java
 
b/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java
index 7861256462..051d73e569 100644
--- 
a/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java
+++ 
b/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java
@@ -17,49 +17,40 @@
 
 package org.apache.zeppelin.shell.terminal;
 
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import javax.websocket.ClientEndpoint;
-import javax.websocket.CloseReason;
-import javax.websocket.OnClose;
-import javax.websocket.OnError;
-import javax.websocket.OnMessage;
-import javax.websocket.OnOpen;
-import javax.websocket.Session;
-import javax.websocket.server.ServerEndpoint;
 import java.util.ArrayList;
 import java.util.List;
+import javax.websocket.CloseReason;
+import javax.websocket.Endpoint;
+import javax.websocket.EndpointConfig;
+import javax.websocket.Session;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
-@ClientEndpoint
-@ServerEndpoint(value = "/")
-public class TerminalSocketTest {
+public class TerminalSocketTest extends Endpoint {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(TerminalSocketTest.class);
 
   public static final List<String> ReceivedMsg = new ArrayList();
 
-  @OnOpen
-  public void onWebSocketConnect(Session sess)
-  {
-    LOGGER.info("Socket Connected: " + sess);
-  }
-
-  @OnMessage
-  public void onWebSocketText(String message)
-  {
-    LOGGER.info("Received TEXT message: " + message);
-    ReceivedMsg.add(message);
+  @Override
+  public void onOpen(Session session, EndpointConfig endpointConfig) {
+    LOGGER.info("Socket Connected: " + session);
+
+    session.addMessageHandler(new 
javax.websocket.MessageHandler.Whole<String>() {
+      @Override
+      public void onMessage(String message) {
+        LOGGER.info("Received TEXT message: " + message);
+        ReceivedMsg.add(message);
+      }
+    });
   }
 
-  @OnClose
-  public void onWebSocketClose(CloseReason reason)
-  {
-    LOGGER.info("Socket Closed: " + reason);
+  @Override
+  public void onClose(Session session, CloseReason closeReason) {
+    LOGGER.info("Socket Closed: " + closeReason);
   }
 
-  @OnError
-  public void onWebSocketError(Throwable cause)
-  {
+  @Override
+  public void onError(Session session, Throwable cause) {
     LOGGER.error(cause.getMessage(), cause);
   }
 }

Reply via email to