This is an automated email from the ASF dual-hosted git repository.
jongyoul pushed a commit to branch branch-0.12
in repository https://gitbox.apache.org/repos/asf/zeppelin.git
The following commit(s) were added to refs/heads/branch-0.12 by this push:
new f4847ea8eb [NO-ISSUE] Implement Origin check for terminal interpreter
WebSocket connections
f4847ea8eb is described below
commit f4847ea8ebc03a70768c91af0a33154768076fd0
Author: ChanHo Lee <[email protected]>
AuthorDate: Sat Nov 2 21:18:32 2024 +0900
[NO-ISSUE] Implement Origin check for terminal interpreter WebSocket
connections
### What is this PR for?
This PR adds an Origin check to ensure that WebSocket connections are
initiated from trusted sources only.
By validating the `Origin` header in the initial WebSocket handshake, we
can prevent unauthorized or malicious websites from establishing WebSocket
connections with our server.
Changes:
- Added server-side validation of the `Origin` header during WebSocket
connection requests.
Other security enhancements may be needed and can be handled in future
iterations.
### What type of PR is it?
Improvement
### Todos
* [ ] - Task
### How should this be tested?
* Strongly recommended: add automated unit tests for any new or changed
behavior
* Outline any manual steps to test the PR here.
### Screenshots (if appropriate)
### Questions:
* Does the license files need to update? No
* Is there breaking changes for older versions? No
* Does this needs documentation? No
Closes #4823 from tbonelee/websocket.
Signed-off-by: Jongyoul Lee <[email protected]>
(cherry picked from commit 3575a3cf8834cdc0100f0a554846829584df0c30)
Signed-off-by: Jongyoul Lee <[email protected]>
---
.../apache/zeppelin/shell/TerminalInterpreter.java | 26 +++-
.../zeppelin/shell/terminal/TerminalThread.java | 11 +-
.../websocket/TerminalSessionConfigurator.java | 39 ++++++
.../zeppelin/shell/TerminalInterpreterTest.java | 141 ++++++++++++++++++++-
.../shell/terminal/TerminalSocketTest.java | 55 ++++----
5 files changed, 228 insertions(+), 44 deletions(-)
diff --git
a/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java
b/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java
index d54495e2c2..5a4eb8bae7 100644
--- a/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java
+++ b/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java
@@ -63,6 +63,7 @@ public class TerminalInterpreter extends KerberosInterpreter {
private InterpreterContext intpContext;
private int terminalPort = 0;
+ private String terminalHostIp;
// Internal and external IP mapping of zeppelin server
private HashMap<String, String> mapIpMapping = new HashMap<>();
@@ -109,7 +110,11 @@ public class TerminalInterpreter extends
KerberosInterpreter {
if (null == terminalThread) {
try {
terminalPort =
RemoteInterpreterUtils.findRandomAvailablePortOnAllLocalInterfaces();
- terminalThread = new TerminalThread(terminalPort);
+ terminalHostIp = RemoteInterpreterUtils.findAvailableHostAddress();
+ LOGGER.info("Terminal host IP: " + terminalHostIp);
+ LOGGER.info("Terminal port: " + terminalPort);
+ String allowedOrigin = generateOrigin(terminalHostIp, terminalPort);
+ terminalThread = new TerminalThread(terminalPort, allowedOrigin);
terminalThread.start();
} catch (IOException e) {
LOGGER.error(e.getMessage(), e);
@@ -136,20 +141,20 @@ public class TerminalInterpreter extends
KerberosInterpreter {
mapIpMapping = gson.fromJson(strIpMapping, new TypeToken<Map<String,
String>>(){}.getType());
}
- createTerminalDashboard(context.getNoteId(), context.getParagraphId(),
terminalPort);
+ createTerminalDashboard(context.getNoteId(), context.getParagraphId(),
+ terminalHostIp, terminalPort);
return new InterpreterResult(Code.SUCCESS);
}
- public void createTerminalDashboard(String noteId, String paragraphId, int
port) {
- String hostName = "", hostIp = "";
+ public void createTerminalDashboard(String noteId, String paragraphId,
String hostIp, int port) {
+ String hostName = "";
URL urlTemplate =
Resources.getResource("ui_templates/terminal-dashboard.jinja");
String template = null;
try {
template = Resources.toString(urlTemplate, Charsets.UTF_8);
InetAddress addr = InetAddress.getLocalHost();
hostName = addr.getHostName().toString();
- hostIp = RemoteInterpreterUtils.findAvailableHostAddress();
// Internal and external IP mapping of zeppelin server
if (mapIpMapping.containsKey(hostIp)) {
@@ -164,7 +169,7 @@ public class TerminalInterpreter extends
KerberosInterpreter {
Jinjava jinjava = new Jinjava();
HashMap<String, Object> jinjaParams = new HashMap();
Date now = new Date();
- String terminalServerUrl = "http://" + hostIp + ":" + port +
+ String terminalServerUrl = generateOrigin(hostIp, port) +
"?noteId=" + noteId + "¶graphId=" + paragraphId + "&t=" +
now.getTime();
jinjaParams.put("HOST_NAME", hostName);
jinjaParams.put("HOST_IP", hostIp);
@@ -183,6 +188,10 @@ public class TerminalInterpreter extends
KerberosInterpreter {
}
}
+ private String generateOrigin(String hostIp, int port) {
+ return "http://" + hostIp + ":" + port;
+ }
+
@Override
public void cancel(InterpreterContext context) {
}
@@ -238,6 +247,11 @@ public class TerminalInterpreter extends
KerberosInterpreter {
return terminalPort;
}
+ @VisibleForTesting
+ public String getTerminalHostIp() {
+ return terminalHostIp;
+ }
+
@VisibleForTesting
public boolean terminalThreadIsRunning() {
return terminalThread.isRunning();
diff --git
a/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java
b/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java
index 620a6dee5e..31c7e22655 100644
--- a/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java
+++ b/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java
@@ -18,7 +18,9 @@
package org.apache.zeppelin.shell.terminal;
import javax.websocket.server.ServerContainer;
+import javax.websocket.server.ServerEndpointConfig;
+import
org.apache.zeppelin.shell.terminal.websocket.TerminalSessionConfigurator;
import org.apache.zeppelin.shell.terminal.websocket.TerminalSocket;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
@@ -38,9 +40,11 @@ public class TerminalThread extends Thread {
private Server jettyServer = new Server();
private int port = 0;
+ private String allwedOrigin;
- public TerminalThread(int port) {
+ public TerminalThread(int port, String allwedOrigin) {
this.port = port;
+ this.allwedOrigin = allwedOrigin;
}
public void run() {
@@ -72,7 +76,10 @@ public class TerminalThread extends Thread {
try {
ServerContainer container =
WebSocketServerContainerInitializer.configureContext(context);
- container.addEndpoint(TerminalSocket.class);
+ container.addEndpoint(
+ ServerEndpointConfig.Builder.create(TerminalSocket.class, "/")
+ .configurator(new TerminalSessionConfigurator(allwedOrigin))
+ .build());
jettyServer.start();
jettyServer.join();
} catch (Exception e) {
diff --git
a/shell/src/main/java/org/apache/zeppelin/shell/terminal/websocket/TerminalSessionConfigurator.java
b/shell/src/main/java/org/apache/zeppelin/shell/terminal/websocket/TerminalSessionConfigurator.java
new file mode 100644
index 0000000000..0e7d20f218
--- /dev/null
+++
b/shell/src/main/java/org/apache/zeppelin/shell/terminal/websocket/TerminalSessionConfigurator.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.shell.terminal.websocket;
+
+import javax.websocket.server.ServerEndpointConfig.Configurator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class TerminalSessionConfigurator extends Configurator {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(TerminalSessionConfigurator.class);
+ private String allowedOrigin;
+
+ public TerminalSessionConfigurator(String allowedOrigin) {
+ this.allowedOrigin = allowedOrigin;
+ }
+
+ @Override
+ public boolean checkOrigin(String originHeaderValue) {
+ boolean allowed = allowedOrigin.equals(originHeaderValue);
+ LOGGER.info("Checking origin for TerminalSessionConfigurator: " +
+ originHeaderValue + " allowed: " + allowed);
+ return allowed;
+ }
+}
diff --git
a/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java
b/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java
index 6365a0a4ae..4d71889952 100644
--- a/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java
+++ b/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java
@@ -17,6 +17,12 @@
package org.apache.zeppelin.shell;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import javax.websocket.ClientEndpointConfig;
+import javax.websocket.ClientEndpointConfig.Builder;
+import javax.websocket.ClientEndpointConfig.Configurator;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterResult;
@@ -34,6 +40,7 @@ import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.IOException;
@@ -81,11 +88,17 @@ class TerminalInterpreterTest extends BaseInterpreterTest {
boolean running = terminal.terminalThreadIsRunning();
assertTrue(running);
- URI uri = URI.create("ws://localhost:" + terminal.getTerminalPort() +
"/terminal/");
+ URI webSocketConnectionUri = URI.create("ws://" +
terminal.getTerminalHostIp() +
+ ":" + terminal.getTerminalPort() + "/terminal/");
+ LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
+ String origin = "http://" + terminal.getTerminalHostIp() + ":" +
terminal.getTerminalPort();
+ LOGGER.info("origin: " + origin);
+ ClientEndpointConfig clientEndpointConfig =
getOriginRequestHeaderConfig(origin);
webSocketContainer = ContainerProvider.getWebSocketContainer();
// Attempt Connect
- session = webSocketContainer.connectToServer(TerminalSocketTest.class,
uri);
+ session = webSocketContainer.connectToServer(
+ TerminalSocketTest.class, clientEndpointConfig,
webSocketConnectionUri);
// Send Start terminal service message
String terminalReadyCmd = String.format("{\"type\":\"TERMINAL_READY\"," +
@@ -161,11 +174,17 @@ class TerminalInterpreterTest extends BaseInterpreterTest
{
boolean running = terminal.terminalThreadIsRunning();
assertTrue(running);
- URI uri = URI.create("ws://localhost:" + terminal.getTerminalPort() +
"/terminal/");
+ URI webSocketConnectionUri = URI.create("ws://" +
terminal.getTerminalHostIp() +
+ ":" + terminal.getTerminalPort() + "/terminal/");
+ LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
+ String origin = "http://" + terminal.getTerminalHostIp() + ":" +
terminal.getTerminalPort();
+ LOGGER.info("origin: " + origin);
+ ClientEndpointConfig clientEndpointConfig =
getOriginRequestHeaderConfig(origin);
webSocketContainer = ContainerProvider.getWebSocketContainer();
// Attempt Connect
- session = webSocketContainer.connectToServer(TerminalSocketTest.class,
uri);
+ session = webSocketContainer.connectToServer(
+ TerminalSocketTest.class, clientEndpointConfig,
webSocketConnectionUri);
// Send Start terminal service message
String terminalReadyCmd = String.format("{\"type\":\"TERMINAL_READY\"," +
@@ -229,4 +248,118 @@ class TerminalInterpreterTest extends BaseInterpreterTest
{
}
}
}
+
+ @Test
+ void testValidOrigin() {
+ Session session = null;
+
+ // mock connect terminal
+ boolean running = terminal.terminalThreadIsRunning();
+ assertTrue(running);
+
+ URI webSocketConnectionUri = URI.create("ws://" +
terminal.getTerminalHostIp() +
+ ":" + terminal.getTerminalPort() + "/terminal/");
+ LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
+ String origin = "http://" + terminal.getTerminalHostIp() + ":" +
terminal.getTerminalPort();
+ LOGGER.info("origin: " + origin);
+ ClientEndpointConfig clientEndpointConfig =
getOriginRequestHeaderConfig(origin);
+ WebSocketContainer webSocketContainer =
ContainerProvider.getWebSocketContainer();
+
+ Throwable exception = null;
+ try {
+ // Attempt Connect
+ session = webSocketContainer.connectToServer(
+ TerminalSocketTest.class, clientEndpointConfig,
webSocketConnectionUri);
+ } catch (DeploymentException e) {
+ exception = e;
+ } catch (IOException e) {
+ exception = e;
+ } finally {
+ if (session != null) {
+ try {
+ session.close();
+ } catch (IOException e) {
+ LOGGER.error(e.getMessage(), e);
+ }
+ }
+
+ // Force lifecycle stop when done with container.
+ // This is to free up threads and resources that the
+ // JSR-356 container allocates. But unfortunately
+ // the JSR-356 spec does not handle lifecycles (yet)
+ if (webSocketContainer instanceof LifeCycle) {
+ try {
+ ((LifeCycle) webSocketContainer).stop();
+ } catch (Exception e) {
+ LOGGER.error(e.getMessage(), e);
+ }
+ }
+ }
+
+ assertNull(exception);
+ }
+
+ @Test
+ void testInvalidOrigin() {
+ Session session = null;
+
+ // mock connect terminal
+ boolean running = terminal.terminalThreadIsRunning();
+ assertTrue(running);
+
+ URI webSocketConnectionUri = URI.create("ws://" +
terminal.getTerminalHostIp() +
+ ":" + terminal.getTerminalPort() + "/terminal/");
+ LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri);
+ String origin = "http://invalid-origin";
+ LOGGER.info("origin: " + origin);
+ ClientEndpointConfig clientEndpointConfig =
getOriginRequestHeaderConfig(origin);
+ WebSocketContainer webSocketContainer =
ContainerProvider.getWebSocketContainer();
+
+ Throwable exception = null;
+ try {
+ // Attempt Connect
+ session = webSocketContainer.connectToServer(
+ TerminalSocketTest.class, clientEndpointConfig,
webSocketConnectionUri);
+ } catch (DeploymentException e) {
+ exception = e;
+ } catch (IOException e) {
+ exception = e;
+ } finally {
+ if (session != null) {
+ try {
+ session.close();
+ } catch (IOException e) {
+ LOGGER.error(e.getMessage(), e);
+ }
+ }
+
+ // Force lifecycle stop when done with container.
+ // This is to free up threads and resources that the
+ // JSR-356 container allocates. But unfortunately
+ // the JSR-356 spec does not handle lifecycles (yet)
+ if (webSocketContainer instanceof LifeCycle) {
+ try {
+ ((LifeCycle) webSocketContainer).stop();
+ } catch (Exception e) {
+ LOGGER.error(e.getMessage(), e);
+ }
+ }
+ }
+
+ assertTrue(exception instanceof IOException);
+ assertEquals("Connect failure", exception.getMessage());
+ }
+
+ private static ClientEndpointConfig getOriginRequestHeaderConfig(String
origin) {
+ Configurator configurator = new Configurator() {
+ @Override
+ public void beforeRequest(Map<String, List<String>> headers) {
+ headers.put("Origin", Arrays.asList(origin));
+ }
+ };
+ ClientEndpointConfig clientEndpointConfig = Builder.create()
+ .configurator(configurator)
+ .build();
+ return clientEndpointConfig;
+ }
}
diff --git
a/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java
b/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java
index 7861256462..051d73e569 100644
---
a/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java
+++
b/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java
@@ -17,49 +17,40 @@
package org.apache.zeppelin.shell.terminal;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import javax.websocket.ClientEndpoint;
-import javax.websocket.CloseReason;
-import javax.websocket.OnClose;
-import javax.websocket.OnError;
-import javax.websocket.OnMessage;
-import javax.websocket.OnOpen;
-import javax.websocket.Session;
-import javax.websocket.server.ServerEndpoint;
import java.util.ArrayList;
import java.util.List;
+import javax.websocket.CloseReason;
+import javax.websocket.Endpoint;
+import javax.websocket.EndpointConfig;
+import javax.websocket.Session;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-@ClientEndpoint
-@ServerEndpoint(value = "/")
-public class TerminalSocketTest {
+public class TerminalSocketTest extends Endpoint {
private static final Logger LOGGER =
LoggerFactory.getLogger(TerminalSocketTest.class);
public static final List<String> ReceivedMsg = new ArrayList();
- @OnOpen
- public void onWebSocketConnect(Session sess)
- {
- LOGGER.info("Socket Connected: " + sess);
- }
-
- @OnMessage
- public void onWebSocketText(String message)
- {
- LOGGER.info("Received TEXT message: " + message);
- ReceivedMsg.add(message);
+ @Override
+ public void onOpen(Session session, EndpointConfig endpointConfig) {
+ LOGGER.info("Socket Connected: " + session);
+
+ session.addMessageHandler(new
javax.websocket.MessageHandler.Whole<String>() {
+ @Override
+ public void onMessage(String message) {
+ LOGGER.info("Received TEXT message: " + message);
+ ReceivedMsg.add(message);
+ }
+ });
}
- @OnClose
- public void onWebSocketClose(CloseReason reason)
- {
- LOGGER.info("Socket Closed: " + reason);
+ @Override
+ public void onClose(Session session, CloseReason closeReason) {
+ LOGGER.info("Socket Closed: " + closeReason);
}
- @OnError
- public void onWebSocketError(Throwable cause)
- {
+ @Override
+ public void onError(Session session, Throwable cause) {
LOGGER.error(cause.getMessage(), cause);
}
}