This is an automated email from the ASF dual-hosted git repository. jongyoul pushed a commit to branch branch-0.12 in repository https://gitbox.apache.org/repos/asf/zeppelin.git
The following commit(s) were added to refs/heads/branch-0.12 by this push: new f4847ea8eb [NO-ISSUE] Implement Origin check for terminal interpreter WebSocket connections f4847ea8eb is described below commit f4847ea8ebc03a70768c91af0a33154768076fd0 Author: ChanHo Lee <chanho0...@gmail.com> AuthorDate: Sat Nov 2 21:18:32 2024 +0900 [NO-ISSUE] Implement Origin check for terminal interpreter WebSocket connections ### What is this PR for? This PR adds an Origin check to ensure that WebSocket connections are initiated from trusted sources only. By validating the `Origin` header in the initial WebSocket handshake, we can prevent unauthorized or malicious websites from establishing WebSocket connections with our server. Changes: - Added server-side validation of the `Origin` header during WebSocket connection requests. Other security enhancements may be needed and can be handled in future iterations. ### What type of PR is it? Improvement ### Todos * [ ] - Task ### How should this be tested? * Strongly recommended: add automated unit tests for any new or changed behavior * Outline any manual steps to test the PR here. ### Screenshots (if appropriate) ### Questions: * Does the license files need to update? No * Is there breaking changes for older versions? No * Does this needs documentation? No Closes #4823 from tbonelee/websocket. Signed-off-by: Jongyoul Lee <jongy...@gmail.com> (cherry picked from commit 3575a3cf8834cdc0100f0a554846829584df0c30) Signed-off-by: Jongyoul Lee <jongy...@gmail.com> --- .../apache/zeppelin/shell/TerminalInterpreter.java | 26 +++- .../zeppelin/shell/terminal/TerminalThread.java | 11 +- .../websocket/TerminalSessionConfigurator.java | 39 ++++++ .../zeppelin/shell/TerminalInterpreterTest.java | 141 ++++++++++++++++++++- .../shell/terminal/TerminalSocketTest.java | 55 ++++---- 5 files changed, 228 insertions(+), 44 deletions(-) diff --git a/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java b/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java index d54495e2c2..5a4eb8bae7 100644 --- a/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java +++ b/shell/src/main/java/org/apache/zeppelin/shell/TerminalInterpreter.java @@ -63,6 +63,7 @@ public class TerminalInterpreter extends KerberosInterpreter { private InterpreterContext intpContext; private int terminalPort = 0; + private String terminalHostIp; // Internal and external IP mapping of zeppelin server private HashMap<String, String> mapIpMapping = new HashMap<>(); @@ -109,7 +110,11 @@ public class TerminalInterpreter extends KerberosInterpreter { if (null == terminalThread) { try { terminalPort = RemoteInterpreterUtils.findRandomAvailablePortOnAllLocalInterfaces(); - terminalThread = new TerminalThread(terminalPort); + terminalHostIp = RemoteInterpreterUtils.findAvailableHostAddress(); + LOGGER.info("Terminal host IP: " + terminalHostIp); + LOGGER.info("Terminal port: " + terminalPort); + String allowedOrigin = generateOrigin(terminalHostIp, terminalPort); + terminalThread = new TerminalThread(terminalPort, allowedOrigin); terminalThread.start(); } catch (IOException e) { LOGGER.error(e.getMessage(), e); @@ -136,20 +141,20 @@ public class TerminalInterpreter extends KerberosInterpreter { mapIpMapping = gson.fromJson(strIpMapping, new TypeToken<Map<String, String>>(){}.getType()); } - createTerminalDashboard(context.getNoteId(), context.getParagraphId(), terminalPort); + createTerminalDashboard(context.getNoteId(), context.getParagraphId(), + terminalHostIp, terminalPort); return new InterpreterResult(Code.SUCCESS); } - public void createTerminalDashboard(String noteId, String paragraphId, int port) { - String hostName = "", hostIp = ""; + public void createTerminalDashboard(String noteId, String paragraphId, String hostIp, int port) { + String hostName = ""; URL urlTemplate = Resources.getResource("ui_templates/terminal-dashboard.jinja"); String template = null; try { template = Resources.toString(urlTemplate, Charsets.UTF_8); InetAddress addr = InetAddress.getLocalHost(); hostName = addr.getHostName().toString(); - hostIp = RemoteInterpreterUtils.findAvailableHostAddress(); // Internal and external IP mapping of zeppelin server if (mapIpMapping.containsKey(hostIp)) { @@ -164,7 +169,7 @@ public class TerminalInterpreter extends KerberosInterpreter { Jinjava jinjava = new Jinjava(); HashMap<String, Object> jinjaParams = new HashMap(); Date now = new Date(); - String terminalServerUrl = "http://" + hostIp + ":" + port + + String terminalServerUrl = generateOrigin(hostIp, port) + "?noteId=" + noteId + "¶graphId=" + paragraphId + "&t=" + now.getTime(); jinjaParams.put("HOST_NAME", hostName); jinjaParams.put("HOST_IP", hostIp); @@ -183,6 +188,10 @@ public class TerminalInterpreter extends KerberosInterpreter { } } + private String generateOrigin(String hostIp, int port) { + return "http://" + hostIp + ":" + port; + } + @Override public void cancel(InterpreterContext context) { } @@ -238,6 +247,11 @@ public class TerminalInterpreter extends KerberosInterpreter { return terminalPort; } + @VisibleForTesting + public String getTerminalHostIp() { + return terminalHostIp; + } + @VisibleForTesting public boolean terminalThreadIsRunning() { return terminalThread.isRunning(); diff --git a/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java b/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java index 620a6dee5e..31c7e22655 100644 --- a/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java +++ b/shell/src/main/java/org/apache/zeppelin/shell/terminal/TerminalThread.java @@ -18,7 +18,9 @@ package org.apache.zeppelin.shell.terminal; import javax.websocket.server.ServerContainer; +import javax.websocket.server.ServerEndpointConfig; +import org.apache.zeppelin.shell.terminal.websocket.TerminalSessionConfigurator; import org.apache.zeppelin.shell.terminal.websocket.TerminalSocket; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; @@ -38,9 +40,11 @@ public class TerminalThread extends Thread { private Server jettyServer = new Server(); private int port = 0; + private String allwedOrigin; - public TerminalThread(int port) { + public TerminalThread(int port, String allwedOrigin) { this.port = port; + this.allwedOrigin = allwedOrigin; } public void run() { @@ -72,7 +76,10 @@ public class TerminalThread extends Thread { try { ServerContainer container = WebSocketServerContainerInitializer.configureContext(context); - container.addEndpoint(TerminalSocket.class); + container.addEndpoint( + ServerEndpointConfig.Builder.create(TerminalSocket.class, "/") + .configurator(new TerminalSessionConfigurator(allwedOrigin)) + .build()); jettyServer.start(); jettyServer.join(); } catch (Exception e) { diff --git a/shell/src/main/java/org/apache/zeppelin/shell/terminal/websocket/TerminalSessionConfigurator.java b/shell/src/main/java/org/apache/zeppelin/shell/terminal/websocket/TerminalSessionConfigurator.java new file mode 100644 index 0000000000..0e7d20f218 --- /dev/null +++ b/shell/src/main/java/org/apache/zeppelin/shell/terminal/websocket/TerminalSessionConfigurator.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.shell.terminal.websocket; + +import javax.websocket.server.ServerEndpointConfig.Configurator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class TerminalSessionConfigurator extends Configurator { + private static final Logger LOGGER = LoggerFactory.getLogger(TerminalSessionConfigurator.class); + private String allowedOrigin; + + public TerminalSessionConfigurator(String allowedOrigin) { + this.allowedOrigin = allowedOrigin; + } + + @Override + public boolean checkOrigin(String originHeaderValue) { + boolean allowed = allowedOrigin.equals(originHeaderValue); + LOGGER.info("Checking origin for TerminalSessionConfigurator: " + + originHeaderValue + " allowed: " + allowed); + return allowed; + } +} diff --git a/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java b/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java index 6365a0a4ae..4d71889952 100644 --- a/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java +++ b/shell/src/test/java/org/apache/zeppelin/shell/TerminalInterpreterTest.java @@ -17,6 +17,12 @@ package org.apache.zeppelin.shell; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import javax.websocket.ClientEndpointConfig; +import javax.websocket.ClientEndpointConfig.Builder; +import javax.websocket.ClientEndpointConfig.Configurator; import org.apache.zeppelin.interpreter.InterpreterContext; import org.apache.zeppelin.interpreter.InterpreterException; import org.apache.zeppelin.interpreter.InterpreterResult; @@ -34,6 +40,7 @@ import javax.websocket.Session; import javax.websocket.WebSocketContainer; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.IOException; @@ -81,11 +88,17 @@ class TerminalInterpreterTest extends BaseInterpreterTest { boolean running = terminal.terminalThreadIsRunning(); assertTrue(running); - URI uri = URI.create("ws://localhost:" + terminal.getTerminalPort() + "/terminal/"); + URI webSocketConnectionUri = URI.create("ws://" + terminal.getTerminalHostIp() + + ":" + terminal.getTerminalPort() + "/terminal/"); + LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri); + String origin = "http://" + terminal.getTerminalHostIp() + ":" + terminal.getTerminalPort(); + LOGGER.info("origin: " + origin); + ClientEndpointConfig clientEndpointConfig = getOriginRequestHeaderConfig(origin); webSocketContainer = ContainerProvider.getWebSocketContainer(); // Attempt Connect - session = webSocketContainer.connectToServer(TerminalSocketTest.class, uri); + session = webSocketContainer.connectToServer( + TerminalSocketTest.class, clientEndpointConfig, webSocketConnectionUri); // Send Start terminal service message String terminalReadyCmd = String.format("{\"type\":\"TERMINAL_READY\"," + @@ -161,11 +174,17 @@ class TerminalInterpreterTest extends BaseInterpreterTest { boolean running = terminal.terminalThreadIsRunning(); assertTrue(running); - URI uri = URI.create("ws://localhost:" + terminal.getTerminalPort() + "/terminal/"); + URI webSocketConnectionUri = URI.create("ws://" + terminal.getTerminalHostIp() + + ":" + terminal.getTerminalPort() + "/terminal/"); + LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri); + String origin = "http://" + terminal.getTerminalHostIp() + ":" + terminal.getTerminalPort(); + LOGGER.info("origin: " + origin); + ClientEndpointConfig clientEndpointConfig = getOriginRequestHeaderConfig(origin); webSocketContainer = ContainerProvider.getWebSocketContainer(); // Attempt Connect - session = webSocketContainer.connectToServer(TerminalSocketTest.class, uri); + session = webSocketContainer.connectToServer( + TerminalSocketTest.class, clientEndpointConfig, webSocketConnectionUri); // Send Start terminal service message String terminalReadyCmd = String.format("{\"type\":\"TERMINAL_READY\"," + @@ -229,4 +248,118 @@ class TerminalInterpreterTest extends BaseInterpreterTest { } } } + + @Test + void testValidOrigin() { + Session session = null; + + // mock connect terminal + boolean running = terminal.terminalThreadIsRunning(); + assertTrue(running); + + URI webSocketConnectionUri = URI.create("ws://" + terminal.getTerminalHostIp() + + ":" + terminal.getTerminalPort() + "/terminal/"); + LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri); + String origin = "http://" + terminal.getTerminalHostIp() + ":" + terminal.getTerminalPort(); + LOGGER.info("origin: " + origin); + ClientEndpointConfig clientEndpointConfig = getOriginRequestHeaderConfig(origin); + WebSocketContainer webSocketContainer = ContainerProvider.getWebSocketContainer(); + + Throwable exception = null; + try { + // Attempt Connect + session = webSocketContainer.connectToServer( + TerminalSocketTest.class, clientEndpointConfig, webSocketConnectionUri); + } catch (DeploymentException e) { + exception = e; + } catch (IOException e) { + exception = e; + } finally { + if (session != null) { + try { + session.close(); + } catch (IOException e) { + LOGGER.error(e.getMessage(), e); + } + } + + // Force lifecycle stop when done with container. + // This is to free up threads and resources that the + // JSR-356 container allocates. But unfortunately + // the JSR-356 spec does not handle lifecycles (yet) + if (webSocketContainer instanceof LifeCycle) { + try { + ((LifeCycle) webSocketContainer).stop(); + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + } + } + } + + assertNull(exception); + } + + @Test + void testInvalidOrigin() { + Session session = null; + + // mock connect terminal + boolean running = terminal.terminalThreadIsRunning(); + assertTrue(running); + + URI webSocketConnectionUri = URI.create("ws://" + terminal.getTerminalHostIp() + + ":" + terminal.getTerminalPort() + "/terminal/"); + LOGGER.info("webSocketConnectionUri: " + webSocketConnectionUri); + String origin = "http://invalid-origin"; + LOGGER.info("origin: " + origin); + ClientEndpointConfig clientEndpointConfig = getOriginRequestHeaderConfig(origin); + WebSocketContainer webSocketContainer = ContainerProvider.getWebSocketContainer(); + + Throwable exception = null; + try { + // Attempt Connect + session = webSocketContainer.connectToServer( + TerminalSocketTest.class, clientEndpointConfig, webSocketConnectionUri); + } catch (DeploymentException e) { + exception = e; + } catch (IOException e) { + exception = e; + } finally { + if (session != null) { + try { + session.close(); + } catch (IOException e) { + LOGGER.error(e.getMessage(), e); + } + } + + // Force lifecycle stop when done with container. + // This is to free up threads and resources that the + // JSR-356 container allocates. But unfortunately + // the JSR-356 spec does not handle lifecycles (yet) + if (webSocketContainer instanceof LifeCycle) { + try { + ((LifeCycle) webSocketContainer).stop(); + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + } + } + } + + assertTrue(exception instanceof IOException); + assertEquals("Connect failure", exception.getMessage()); + } + + private static ClientEndpointConfig getOriginRequestHeaderConfig(String origin) { + Configurator configurator = new Configurator() { + @Override + public void beforeRequest(Map<String, List<String>> headers) { + headers.put("Origin", Arrays.asList(origin)); + } + }; + ClientEndpointConfig clientEndpointConfig = Builder.create() + .configurator(configurator) + .build(); + return clientEndpointConfig; + } } diff --git a/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java b/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java index 7861256462..051d73e569 100644 --- a/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java +++ b/shell/src/test/java/org/apache/zeppelin/shell/terminal/TerminalSocketTest.java @@ -17,49 +17,40 @@ package org.apache.zeppelin.shell.terminal; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.websocket.ClientEndpoint; -import javax.websocket.CloseReason; -import javax.websocket.OnClose; -import javax.websocket.OnError; -import javax.websocket.OnMessage; -import javax.websocket.OnOpen; -import javax.websocket.Session; -import javax.websocket.server.ServerEndpoint; import java.util.ArrayList; import java.util.List; +import javax.websocket.CloseReason; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.Session; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -@ClientEndpoint -@ServerEndpoint(value = "/") -public class TerminalSocketTest { +public class TerminalSocketTest extends Endpoint { private static final Logger LOGGER = LoggerFactory.getLogger(TerminalSocketTest.class); public static final List<String> ReceivedMsg = new ArrayList(); - @OnOpen - public void onWebSocketConnect(Session sess) - { - LOGGER.info("Socket Connected: " + sess); - } - - @OnMessage - public void onWebSocketText(String message) - { - LOGGER.info("Received TEXT message: " + message); - ReceivedMsg.add(message); + @Override + public void onOpen(Session session, EndpointConfig endpointConfig) { + LOGGER.info("Socket Connected: " + session); + + session.addMessageHandler(new javax.websocket.MessageHandler.Whole<String>() { + @Override + public void onMessage(String message) { + LOGGER.info("Received TEXT message: " + message); + ReceivedMsg.add(message); + } + }); } - @OnClose - public void onWebSocketClose(CloseReason reason) - { - LOGGER.info("Socket Closed: " + reason); + @Override + public void onClose(Session session, CloseReason closeReason) { + LOGGER.info("Socket Closed: " + closeReason); } - @OnError - public void onWebSocketError(Throwable cause) - { + @Override + public void onError(Session session, Throwable cause) { LOGGER.error(cause.getMessage(), cause); } }