Author: markt Date: Thu Jan 24 22:54:41 2019 New Revision: 1852079 URL: http://svn.apache.org/viewvc?rev=1852079&view=rev Log: Fix https://bz.apache.org/bugzilla/show_bug.cgi?id=57974 Re-work code that supports Session.getOpenSessions() to ensure that both client-side and server-side calls behave as the EG intended
Added: tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainerGetOpenSessions.java (with props) Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java tomcat/trunk/java/org/apache/tomcat/websocket/WsWebSocketContainer.java tomcat/trunk/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java tomcat/trunk/java/org/apache/tomcat/websocket/server/WsServerContainer.java tomcat/trunk/webapps/docs/changelog.xml Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java?rev=1852079&r1=1852078&r2=1852079&view=diff ============================================================================== --- tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java (original) +++ tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java Thu Jan 24 22:54:41 2019 @@ -45,6 +45,7 @@ import javax.websocket.RemoteEndpoint; import javax.websocket.SendResult; import javax.websocket.Session; import javax.websocket.WebSocketContainer; +import javax.websocket.server.ServerEndpointConfig; import org.apache.juli.logging.Log; import org.apache.juli.logging.LogFactory; @@ -416,7 +417,7 @@ public class WsSession implements Sessio @Override public Set<Session> getOpenSessions() { checkState(); - return webSocketContainer.getOpenSessions(localEndpoint); + return webSocketContainer.getOpenSessions(getSessionMapKey()); } @@ -605,11 +606,21 @@ public class WsSession implements Sessio localEndpoint.onError(this, e); } } finally { - webSocketContainer.unregisterSession(localEndpoint, this); + webSocketContainer.unregisterSession(getSessionMapKey(), this); } } + private Object getSessionMapKey() { + if (endpointConfig instanceof ServerEndpointConfig) { + // Server + return ((ServerEndpointConfig) endpointConfig).getPath(); + } else { + // Client + return localEndpoint; + } + } + /** * Use protected so unit tests can access this method directly. * @param msg The message Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsWebSocketContainer.java URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/WsWebSocketContainer.java?rev=1852079&r1=1852078&r2=1852079&view=diff ============================================================================== --- tomcat/trunk/java/org/apache/tomcat/websocket/WsWebSocketContainer.java (original) +++ tomcat/trunk/java/org/apache/tomcat/websocket/WsWebSocketContainer.java Thu Jan 24 22:54:41 2019 @@ -89,8 +89,9 @@ public class WsWebSocketContainer implem private final Object asynchronousChannelGroupLock = new Object(); private final Log log = LogFactory.getLog(WsWebSocketContainer.class); // must not be static - private final Map<Endpoint, Set<WsSession>> endpointSessionMap = - new HashMap<>(); + // Server side uses the endpoint path as the key + // Client side uses the client endpoint instance + private final Map<Object, Set<WsSession>> endpointSessionMap = new HashMap<>(); private final Map<WsSession,WsSession> sessions = new ConcurrentHashMap<>(); private final Object endPointSessionMapLock = new Object(); @@ -578,7 +579,7 @@ public class WsWebSocketContainer implem return ByteBuffer.wrap(bytes); } - protected void registerSession(Endpoint endpoint, WsSession wsSession) { + protected void registerSession(Object key, WsSession wsSession) { if (!wsSession.isOpen()) { // The session was closed during onOpen. No need to register it. @@ -588,10 +589,10 @@ public class WsWebSocketContainer implem if (endpointSessionMap.size() == 0) { BackgroundProcessManager.getInstance().register(this); } - Set<WsSession> wsSessions = endpointSessionMap.get(endpoint); + Set<WsSession> wsSessions = endpointSessionMap.get(key); if (wsSessions == null) { wsSessions = new HashSet<>(); - endpointSessionMap.put(endpoint, wsSessions); + endpointSessionMap.put(key, wsSessions); } wsSessions.add(wsSession); } @@ -599,14 +600,14 @@ public class WsWebSocketContainer implem } - protected void unregisterSession(Endpoint endpoint, WsSession wsSession) { + protected void unregisterSession(Object key, WsSession wsSession) { synchronized (endPointSessionMapLock) { - Set<WsSession> wsSessions = endpointSessionMap.get(endpoint); + Set<WsSession> wsSessions = endpointSessionMap.get(key); if (wsSessions != null) { wsSessions.remove(wsSession); if (wsSessions.size() == 0) { - endpointSessionMap.remove(endpoint); + endpointSessionMap.remove(key); } } if (endpointSessionMap.size() == 0) { @@ -617,10 +618,10 @@ public class WsWebSocketContainer implem } - Set<Session> getOpenSessions(Endpoint endpoint) { + Set<Session> getOpenSessions(Object key) { HashSet<Session> result = new HashSet<>(); synchronized (endPointSessionMapLock) { - Set<WsSession> sessions = endpointSessionMap.get(endpoint); + Set<WsSession> sessions = endpointSessionMap.get(key); if (sessions != null) { result.addAll(sessions); } Modified: tomcat/trunk/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java?rev=1852079&r1=1852078&r2=1852079&view=diff ============================================================================== --- tomcat/trunk/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java (original) +++ tomcat/trunk/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java Thu Jan 24 22:54:41 2019 @@ -26,8 +26,8 @@ import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; import javax.websocket.DeploymentException; import javax.websocket.Endpoint; -import javax.websocket.EndpointConfig; import javax.websocket.Extension; +import javax.websocket.server.ServerEndpointConfig; import org.apache.coyote.http11.upgrade.InternalHttpUpgradeHandler; import org.apache.juli.logging.Log; @@ -54,7 +54,7 @@ public class WsHttpUpgradeHandler implem private SocketWrapperBase<?> socketWrapper; private Endpoint ep; - private EndpointConfig endpointConfig; + private ServerEndpointConfig serverEndpointConfig; private WsServerContainer webSocketContainer; private WsHandshakeRequest handshakeRequest; private List<Extension> negotiatedExtensions; @@ -80,13 +80,13 @@ public class WsHttpUpgradeHandler implem } - public void preInit(Endpoint ep, EndpointConfig endpointConfig, + public void preInit(Endpoint ep, ServerEndpointConfig serverEndpointConfig, WsServerContainer wsc, WsHandshakeRequest handshakeRequest, List<Extension> negotiatedExtensionsPhase2, String subProtocol, Transformation transformation, Map<String,String> pathParameters, boolean secure) { this.ep = ep; - this.endpointConfig = endpointConfig; + this.serverEndpointConfig = serverEndpointConfig; this.webSocketContainer = wsc; this.handshakeRequest = handshakeRequest; this.negotiatedExtensions = negotiatedExtensionsPhase2; @@ -124,14 +124,14 @@ public class WsHttpUpgradeHandler implem handshakeRequest.getQueryString(), handshakeRequest.getUserPrincipal(), httpSessionId, negotiatedExtensions, subProtocol, pathParameters, secure, - endpointConfig); + serverEndpointConfig); wsFrame = new WsFrameServer(socketWrapper, wsSession, transformation, applicationClassLoader); // WsFrame adds the necessary final transformations. Copy the // completed transformation chain to the remote end point. wsRemoteEndpointServer.setTransformation(wsFrame.getTransformation()); - ep.onOpen(wsSession, endpointConfig); - webSocketContainer.registerSession(ep, wsSession); + ep.onOpen(wsSession, serverEndpointConfig); + webSocketContainer.registerSession(serverEndpointConfig.getPath(), wsSession); } catch (DeploymentException e) { throw new IllegalArgumentException(e); } finally { Modified: tomcat/trunk/java/org/apache/tomcat/websocket/server/WsServerContainer.java URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/server/WsServerContainer.java?rev=1852079&r1=1852078&r2=1852079&view=diff ============================================================================== --- tomcat/trunk/java/org/apache/tomcat/websocket/server/WsServerContainer.java (original) +++ tomcat/trunk/java/org/apache/tomcat/websocket/server/WsServerContainer.java Thu Jan 24 22:54:41 2019 @@ -37,7 +37,6 @@ import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; import javax.websocket.DeploymentException; import javax.websocket.Encoder; -import javax.websocket.Endpoint; import javax.websocket.server.ServerContainer; import javax.websocket.server.ServerEndpoint; import javax.websocket.server.ServerEndpointConfig; @@ -341,8 +340,8 @@ public class WsServerContainer extends W * Overridden to make it visible to other classes in this package. */ @Override - protected void registerSession(Endpoint endpoint, WsSession wsSession) { - super.registerSession(endpoint, wsSession); + protected void registerSession(Object key, WsSession wsSession) { + super.registerSession(key, wsSession); if (wsSession.isOpen() && wsSession.getUserPrincipal() != null && wsSession.getHttpSessionId() != null) { @@ -358,13 +357,13 @@ public class WsServerContainer extends W * Overridden to make it visible to other classes in this package. */ @Override - protected void unregisterSession(Endpoint endpoint, WsSession wsSession) { + protected void unregisterSession(Object key, WsSession wsSession) { if (wsSession.getUserPrincipal() != null && wsSession.getHttpSessionId() != null) { unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId()); } - super.unregisterSession(endpoint, wsSession); + super.unregisterSession(key, wsSession); } Added: tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainerGetOpenSessions.java URL: http://svn.apache.org/viewvc/tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainerGetOpenSessions.java?rev=1852079&view=auto ============================================================================== --- tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainerGetOpenSessions.java (added) +++ tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainerGetOpenSessions.java Thu Jan 24 22:54:41 2019 @@ -0,0 +1,388 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tomcat.websocket; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.servlet.ServletContextEvent; +import javax.websocket.ClientEndpointConfig; +import javax.websocket.CloseReason; +import javax.websocket.ContainerProvider; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.MessageHandler; +import javax.websocket.OnMessage; +import javax.websocket.Session; +import javax.websocket.WebSocketContainer; +import javax.websocket.server.ServerContainer; +import javax.websocket.server.ServerEndpoint; +import javax.websocket.server.ServerEndpointConfig; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.catalina.Context; +import org.apache.catalina.servlets.DefaultServlet; +import org.apache.catalina.startup.Tomcat; +import org.apache.tomcat.websocket.server.Constants; +import org.apache.tomcat.websocket.server.WsContextListener; + +/* + * This method is split out into a separate class to make it easier to track the + * various permutations and combinations of client and server endpoints. + * + * Each test uses 2 client endpoint and 2 server endpoints with each client + * connecting to each server for a total of four connections (note sometimes + * the two clients and/or the two servers will be the sam)e. + */ +public class TestWsWebSocketContainerGetOpenSessions extends WebSocketBaseTest { + + @Test + public void testClientAClientAPojoAPojoA() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointA(); + + doTest(client1, client2, "/pojoA", "/pojoA", 2, 2, 4, 4); + } + + + @Test + public void testClientAClientBPojoAPojoA() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointB(); + + doTest(client1, client2, "/pojoA", "/pojoA", 2, 2, 4, 4); + } + + + @Test + public void testClientAClientAPojoAPojoB() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointA(); + + doTest(client1, client2, "/pojoA", "/pojoB", 2, 2, 2, 2); + } + + + @Test + public void testClientAClientBPojoAPojoB() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointB(); + + doTest(client1, client2, "/pojoA", "/pojoB", 2, 2, 2, 2); + } + + + @Test + public void testClientAClientAProgAProgA() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointA(); + + doTest(client1, client2, "/progA", "/progA", 2, 2, 4, 4); + } + + + @Test + public void testClientAClientBProgAProgA() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointB(); + + doTest(client1, client2, "/progA", "/progA", 2, 2, 4, 4); + } + + + @Test + public void testClientAClientAProgAProgB() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointA(); + + doTest(client1, client2, "/progA", "/progB", 2, 2, 2, 2); + } + + + @Test + public void testClientAClientBProgAProgB() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointB(); + + doTest(client1, client2, "/progA", "/progB", 2, 2, 2, 2); + } + + + @Test + public void testClientAClientAPojoAProgA() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointA(); + + doTest(client1, client2, "/pojoA", "/progA", 2, 2, 2, 2); + } + + + @Test + public void testClientAClientBPojoAProgA() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointB(); + + doTest(client1, client2, "/pojoA", "/progA", 2, 2, 2, 2); + } + + + @Test + public void testClientAClientAPojoAProgB() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointA(); + + doTest(client1, client2, "/pojoA", "/progB", 2, 2, 2, 2); + } + + + @Test + public void testClientAClientBPojoAProgB() throws Exception { + Endpoint client1 = new ClientEndpointA(); + Endpoint client2 = new ClientEndpointB(); + + doTest(client1, client2, "/pojoA", "/progB", 2, 2, 2, 2); + } + + + private void doTest(Endpoint client1, Endpoint client2, String server1, String server2, + int client1Count, int client2Count, int server1Count, int server2Count) throws Exception { + Tracker.reset(); + Tomcat tomcat = getTomcatInstance(); + // No file system docBase required + Context ctx = tomcat.addContext("", null); + ctx.addApplicationListener(Config.class.getName()); + Tomcat.addServlet(ctx, "default", new DefaultServlet()); + ctx.addServletMappingDecoded("/", "default"); + + tomcat.start(); + + WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer(); + + Session sClient1Server1 = createSession(wsContainer, client1, "client1", server1); + Session sClient1Server2 = createSession(wsContainer, client1, "client1", server2); + Session sClient2Server1 = createSession(wsContainer, client2, "client2", server1); + Session sClient2Server2 = createSession(wsContainer, client2, "client2", server2); + + + int delayCount = 0; + // Wait for up to 20s for this to complete. It should be a lot faster + // but some CI systems get be slow at times. + while (Tracker.getUpdateCount() < 8 && delayCount < 400) { + Thread.sleep(50); + delayCount++; + } + + Assert.assertTrue(Tracker.checkRecord("client1", client1Count)); + Assert.assertTrue(Tracker.checkRecord("client2", client2Count)); + // Note: need to strip leading '/' from path + Assert.assertTrue(Tracker.checkRecord(server1.substring(1), server1Count)); + Assert.assertTrue(Tracker.checkRecord(server2.substring(1), server2Count)); + + sClient1Server1.close(); + sClient1Server2.close(); + sClient2Server1.close(); + sClient2Server2.close(); + } + + + private Session createSession(WebSocketContainer wsContainer, Endpoint client, + String clientName, String server) + throws DeploymentException, IOException, URISyntaxException { + + Session s = wsContainer.connectToServer(client, + ClientEndpointConfig.Builder.create().build(), + new URI("ws://localhost:" + getPort() + server)); + Tracker.addRecord(clientName, s.getOpenSessions().size()); + s.getBasicRemote().sendText("X"); + return s; + } + + + public static class Config extends WsContextListener { + + @Override + public void contextInitialized(ServletContextEvent sce) { + super.contextInitialized(sce); + ServerContainer sc = + (ServerContainer) sce.getServletContext().getAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + + try { + sc.addEndpoint(PojoEndpointA.class); + sc.addEndpoint(PojoEndpointB.class); + sc.addEndpoint(ServerEndpointConfig.Builder.create( + ServerEndpointA.class, "/progA").build()); + sc.addEndpoint(ServerEndpointConfig.Builder.create( + ServerEndpointB.class, "/progB").build()); + } catch (DeploymentException e) { + throw new IllegalStateException(e); + } + } + } + + + public abstract static class PojoEndpointBase { + + @OnMessage + public void onMessage(@SuppressWarnings("unused") String msg, Session session) { + Tracker.addRecord(getTrackingName(), session.getOpenSessions().size()); + } + + protected abstract String getTrackingName(); + } + + + @ServerEndpoint("/pojoA") + public static class PojoEndpointA extends PojoEndpointBase { + + @Override + protected String getTrackingName() { + return "pojoA"; + } + } + + + @ServerEndpoint("/pojoB") + public static class PojoEndpointB extends PojoEndpointBase { + + @Override + protected String getTrackingName() { + return "pojoB"; + } + } + + + public abstract static class ServerEndpointBase extends Endpoint{ + + @Override + public void onOpen(Session session, EndpointConfig config) { + session.addMessageHandler(new TrackerMessageHandler(session, getTrackingName())); + } + + protected abstract String getTrackingName(); + } + + + public static final class ServerEndpointA extends ServerEndpointBase { + + @Override + protected String getTrackingName() { + return "progA"; + } + } + + + public static final class ServerEndpointB extends ServerEndpointBase { + + @Override + protected String getTrackingName() { + return "progB"; + } + } + + + public static final class TrackerMessageHandler implements MessageHandler.Whole<String> { + + private final Session session; + private final String trackingName; + + public TrackerMessageHandler(Session session, String trackingName) { + this.session = session; + this.trackingName = trackingName; + } + + @Override + public void onMessage(String message) { + Tracker.addRecord(trackingName, session.getOpenSessions().size()); + } + } + + + public abstract static class ClientEndpointBase extends Endpoint { + + @Override + public void onOpen(Session session, EndpointConfig config) { + // NO-OP + } + + @Override + public void onClose(Session session, CloseReason closeReason) { + // NO-OP + } + + protected abstract String getTrackingName(); + } + + + public static final class ClientEndpointA extends ClientEndpointBase { + + @Override + protected String getTrackingName() { + return "clientA"; + } + } + + + public static final class ClientEndpointB extends ClientEndpointBase { + + @Override + protected String getTrackingName() { + return "clientB"; + } + } + + + public static class Tracker { + + private static final Map<String, Integer> records = new ConcurrentHashMap<>(); + private static final AtomicInteger updateCount = new AtomicInteger(0); + + public static void addRecord(String key, int count) { + records.put(key, Integer.valueOf(count)); + updateCount.incrementAndGet(); + } + + public static boolean checkRecord(String key, int expectedCount) { + Integer actualCount = records.get(key); + if (actualCount == null) { + if (expectedCount == 0) { + return true; + } else { + return false; + } + } else { + return actualCount.intValue() == expectedCount; + } + } + + public static int getUpdateCount() { + return updateCount.intValue(); + } + + public static void reset() { + records.clear(); + updateCount.set(0); + } + } +} Propchange: tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainerGetOpenSessions.java ------------------------------------------------------------------------------ svn:eol-style = native Modified: tomcat/trunk/webapps/docs/changelog.xml URL: http://svn.apache.org/viewvc/tomcat/trunk/webapps/docs/changelog.xml?rev=1852079&r1=1852078&r2=1852079&view=diff ============================================================================== --- tomcat/trunk/webapps/docs/changelog.xml (original) +++ tomcat/trunk/webapps/docs/changelog.xml Thu Jan 24 22:54:41 2019 @@ -167,6 +167,11 @@ <subsection name="WebSocket"> <changelog> <fix> + <bug>57974</bug>: Ensure implementation of + <code>Session.getOpenSessions()</code> returns correct value for both + client-side and server-side calls. (markt) + </fix> + <fix> <bug>63019</bug>: Use payload remaining bytes rather than limit when writing. Submitted by Benoit Courtilly. (remm) </fix> --------------------------------------------------------------------- To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org For additional commands, e-mail: dev-h...@tomcat.apache.org