Author: markt Date: Wed Feb 6 19:06:56 2013 New Revision: 1443135 URL: http://svn.apache.org/viewvc?rev=1443135&view=rev Log: Refactor the RemoteEndpoint implementation. - Add support for masking client data - Add support batching (a.k.a. buffering) messages - Provide building blocks for Stream, Writer, etc. support
Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointBase.java tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointClient.java tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java tomcat/trunk/java/org/apache/tomcat/websocket/server/WsRemoteEndpointServer.java tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainer.java Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointBase.java URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointBase.java?rev=1443135&r1=1443134&r2=1443135&view=diff ============================================================================== --- tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointBase.java (original) +++ tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointBase.java Wed Feb 6 19:06:56 2013 @@ -21,7 +21,6 @@ import java.io.OutputStream; import java.io.Writer; import java.nio.ByteBuffer; import java.nio.CharBuffer; -import java.nio.channels.CompletionHandler; import java.nio.charset.Charset; import java.nio.charset.CharsetEncoder; import java.nio.charset.CoderResult; @@ -31,6 +30,8 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; import javax.websocket.EncodeException; import javax.websocket.RemoteEndpoint; @@ -44,17 +45,18 @@ public abstract class WsRemoteEndpointBa private static final StringManager sm = StringManager.getManager(Constants.PACKAGE_NAME); - // TODO Make the size of these buffers configurable - private final ByteBuffer intermediateBuffer = ByteBuffer.allocate(8192); - protected final ByteBuffer outputBuffer = ByteBuffer.allocate(8192); - private final AtomicBoolean charToByteInProgress = new AtomicBoolean(false); - private final CharsetEncoder encoder = Charset.forName("UTF8").newEncoder(); + private final ReentrantLock writeLock = new ReentrantLock(); + private final Condition notInProgress = writeLock.newCondition(); + // Must hold writeLock above to modify state private final MessageSendStateMachine state = new MessageSendStateMachine(); - + // Max size of WebSocket header is 14 bytes + private final ByteBuffer headerBuffer = ByteBuffer.allocate(14); + private final ByteBuffer outputBuffer = ByteBuffer.allocate(8192); + private final CharsetEncoder encoder = Charset.forName("UTF8").newEncoder(); + private final ByteBuffer encoderBuffer = ByteBuffer.allocate(8192); + private AtomicBoolean batchingAllowed = new AtomicBoolean(false); private volatile long asyncSendTimeout = -1; - protected ByteBuffer payload = null; - @Override public long getAsyncSendTimeout() { @@ -70,66 +72,79 @@ public abstract class WsRemoteEndpointBa @Override public void setBatchingAllowed(boolean batchingAllowed) { - // TODO Auto-generated method stub + boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed); + if (oldValue && !batchingAllowed) { + // Just disabled batched. Must flush. + flushBatch(); + } } @Override public boolean getBatchingAllowed() { - // TODO Auto-generated method stub - return false; + return batchingAllowed.get(); } @Override public void flushBatch() { - // TODO Auto-generated method stub - + // Have to hold lock to flush output buffer + writeLock.lock(); + try { + while (state.isInProgress()) { + notInProgress.await(); + } + FutureToSendHandler f2sh = new FutureToSendHandler(); + doWrite(f2sh, outputBuffer); + f2sh.get(); + } catch (InterruptedException | ExecutionException e) { + // TODO Log this? Runtime exception? Something else? + } finally { + writeLock.unlock(); + } } @Override - public final void sendString(String text) throws IOException { - sendPartialString(text, true); + public void sendBytes(ByteBuffer data) throws IOException { + Future<SendResult> f = sendBytesByFuture(data); + try { + SendResult sr = f.get(); + if (!sr.isOK()) { + if (sr.getException() == null) { + throw new IOException(); + } else { + throw new IOException(sr.getException()); + } + } + } catch (InterruptedException | ExecutionException e) { + throw new IOException(e); + } } @Override - public final void sendBytes(ByteBuffer data) throws IOException { - sendPartialBytes(data, true); + public Future<SendResult> sendBytesByFuture(ByteBuffer data) { + FutureToSendHandler f2sh = new FutureToSendHandler(); + sendBytesByCompletion(data, f2sh); + return f2sh; } @Override - public void sendPartialString(String fragment, boolean isLast) - throws IOException { - - // The toBytes buffer needs to be protected from multiple threads and - // the state check happens to late. - if (!charToByteInProgress.compareAndSet(false, true)) { - throw new IllegalStateException(sm.getString( - "wsRemoteEndpoint.concurrentMessageSend")); + public void sendBytesByCompletion(ByteBuffer data, SendHandler completion) { + boolean locked = writeLock.tryLock(); + if (!locked) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.concurrentMessageSend")); } - try { - encoder.reset(); - intermediateBuffer.clear(); - CharBuffer cb = CharBuffer.wrap(fragment); - CoderResult cr = encoder.encode(cb, intermediateBuffer, true); - intermediateBuffer.flip(); - while (cr.isOverflow()) { - sendMessageBlocking( - Constants.OPCODE_TEXT, intermediateBuffer, false); - intermediateBuffer.clear(); - cr = encoder.encode(cb, intermediateBuffer, true); - intermediateBuffer.flip(); - } - sendMessageBlocking( - Constants.OPCODE_TEXT, intermediateBuffer, isLast); + byte opCode = Constants.OPCODE_BINARY; + boolean isLast = true; + sendMessage(opCode, data, isLast, completion); } finally { - // Make sure flag is reset before method exists - charToByteInProgress.set(false); + writeLock.unlock(); } } @@ -137,130 +152,181 @@ public abstract class WsRemoteEndpointBa @Override public void sendPartialBytes(ByteBuffer partialByte, boolean isLast) throws IOException { - sendMessageBlocking(Constants.OPCODE_BINARY, partialByte, isLast); + boolean locked = writeLock.tryLock(); + if (!locked) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.concurrentMessageSend")); + } + try { + byte opCode = Constants.OPCODE_BINARY; + FutureToSendHandler f2sh = new FutureToSendHandler(); + sendMessage(opCode, partialByte, isLast, f2sh); + f2sh.get(); + } catch (InterruptedException | ExecutionException e) { + throw new IOException(e); + } finally { + writeLock.unlock(); + } } @Override - public void sendPing(ByteBuffer applicationData) throws IOException { - sendMessageBlocking(Constants.OPCODE_PING, applicationData, true); + public void sendPing(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + sendControlMessage(Constants.OPCODE_PING, applicationData); } @Override - public void sendPong(ByteBuffer applicationData) throws IOException { - sendMessageBlocking(Constants.OPCODE_PONG, applicationData, true); + public void sendPong(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + sendControlMessage(Constants.OPCODE_PONG, applicationData); } @Override - public Future<SendResult> sendBytesByFuture(ByteBuffer data) { - this.payload = data; - return sendMessageByFuture(Constants.OPCODE_BINARY, true); + public void sendString(String text) throws IOException { + Future<SendResult> f = sendStringByFuture(text); + try { + SendResult sr = f.get(); + if (!sr.isOK()) { + if (sr.getException() == null) { + throw new IOException(); + } else { + throw new IOException(sr.getException()); + } + } + } catch (InterruptedException | ExecutionException e) { + throw new IOException(e); + } } @Override - public void sendBytesByCompletion(ByteBuffer data, SendHandler completion) { - this.payload = data; - sendMessageByCompletion(Constants.OPCODE_BINARY, true, - new WsCompletionHandler(this, completion, state, false)); + public Future<SendResult> sendStringByFuture(String text) { + FutureToSendHandler f2sh = new FutureToSendHandler(); + sendStringByCompletion(text, f2sh); + return f2sh; } + @Override + public void sendStringByCompletion(String text, SendHandler completion) { + boolean locked = writeLock.tryLock(); + if (!locked) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.concurrentMessageSend")); + } + try { + TextMessageSendHandler tmsh = new TextMessageSendHandler( + completion, text, true, encoder, encoderBuffer, this); + tmsh.write(); + } finally { + writeLock.unlock(); + } + } - - - - protected void sendMessageBlocking(byte opCode, ByteBuffer payload, - boolean isLast) throws IOException { - - this.payload = payload; - - Future<SendResult> f = sendMessageByFuture(opCode, isLast); - SendResult sr = null; + @Override + public void sendPartialString(String fragment, boolean isLast) + throws IOException { + boolean locked = writeLock.tryLock(); + if (!locked) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.concurrentMessageSend")); + } try { - sr = f.get(); + FutureToSendHandler f2sh = new FutureToSendHandler(); + TextMessageSendHandler tmsh = new TextMessageSendHandler( + f2sh, fragment, isLast, encoder, encoderBuffer, this); + tmsh.write(); + f2sh.get(); } catch (InterruptedException | ExecutionException e) { throw new IOException(e); - } - - if (!sr.isOK()) { - throw new IOException(sr.getException()); + } finally { + writeLock.unlock(); } } - private Future<SendResult> sendMessageByFuture(byte opCode, - boolean isLast) { - WsCompletionHandler wsCompletionHandler = new WsCompletionHandler( - this, state, opCode == Constants.OPCODE_CLOSE); - sendMessageByCompletion(opCode, isLast, wsCompletionHandler); - return wsCompletionHandler; - } + /** + * Sends a control message, blocking until the message is sent. + */ + void sendControlMessage(byte opCode, ByteBuffer payload) + throws IOException{ + // Close needs to be sent so disable batching. This will flush any + // messages in the buffer + if (opCode == Constants.OPCODE_CLOSE) { + setBatchingAllowed(false); + } - private void sendMessageByCompletion(byte opCode, boolean isLast, - WsCompletionHandler handler) { + writeLock.lock(); + try { + if (state.isInProgress()) { + notInProgress.await(); + } + FutureToSendHandler f2sh = new FutureToSendHandler(); + sendMessage(opCode, payload, true, f2sh); + f2sh.get(); + } catch (InterruptedException | ExecutionException e) { + throw new IOException(e); + } finally { + notInProgress.signal(); + writeLock.unlock(); + } + } - boolean isFirst = state.startMessage(opCode, isLast); - outputBuffer.clear(); - byte first = 0; + private void sendMessage(byte opCode, ByteBuffer payload, boolean last, + SendHandler completion) { - if (isLast) { - // Set the fin bit - first = -128; + if (!writeLock.isHeldByCurrentThread()) { + // Coding problem + throw new IllegalStateException( + "Must hold writeLock before calling this method"); } - if (isFirst) { - // This is the first fragment of this message - first = (byte) (first + opCode); - } - // If not the first fragment, it is a continuation with opCode of zero + state.startMessage(opCode, last); - outputBuffer.put(first); + SendMessageSendHandler smsh = + new SendMessageSendHandler(state, completion, this); - byte masked = getMasked(); + byte[] mask; - // Next write the mask && length length - if (payload.limit() < 126) { - outputBuffer.put((byte) (payload.limit() | masked)); - } else if (payload.limit() < 65536) { - outputBuffer.put((byte) (126 | masked)); - outputBuffer.put((byte) (payload.limit() >>> 8)); - outputBuffer.put((byte) (payload.limit() & 0xFF)); + if (isMasked()) { + mask = Util.generateMask(); } else { - // Will never be more than 2^31-1 - outputBuffer.put((byte) (127 | masked)); - outputBuffer.put((byte) 0); - outputBuffer.put((byte) 0); - outputBuffer.put((byte) 0); - outputBuffer.put((byte) 0); - outputBuffer.put((byte) (payload.limit() >>> 24)); - outputBuffer.put((byte) (payload.limit() >>> 16)); - outputBuffer.put((byte) (payload.limit() >>> 8)); - outputBuffer.put((byte) (payload.limit() & 0xFF)); - } - if (masked != 0) { - // TODO Mask the data properly - outputBuffer.put((byte) 0); - outputBuffer.put((byte) 0); - outputBuffer.put((byte) 0); - outputBuffer.put((byte) 0); + mask = null; } - outputBuffer.flip(); - sendMessage(handler); + headerBuffer.clear(); + writeHeader(headerBuffer, opCode, payload, state.isFirst(), last, + isMasked(), mask); + headerBuffer.flip(); + + if (getBatchingAllowed() || isMasked()) { + // Need to write via output buffer + OutputBufferSendHandler obsh = new OutputBufferSendHandler( + smsh, headerBuffer, payload, mask, outputBuffer, + !getBatchingAllowed(), this); + obsh.write(); + } else { + // Can write directly + doWrite(smsh, headerBuffer, payload); + } } - protected abstract byte getMasked(); - - protected abstract void sendMessage(WsCompletionHandler handler); - protected abstract void close(); + private void endMessage() { + writeLock.lock(); + try { + notInProgress.signal(); + } finally { + writeLock.unlock(); + } + } @@ -276,146 +342,88 @@ public abstract class WsRemoteEndpointBa return null; } - @Override public Writer getSendWriter() throws IOException { // TODO Auto-generated method stub return null; } - - @Override - public Future<SendResult> sendStringByFuture(String text) { - // TODO Auto-generated method stub - return null; - } - - @Override public void sendObject(Object o) throws IOException, EncodeException { // TODO Auto-generated method stub - } - - @Override - public void sendStringByCompletion(String text, SendHandler completion) { - // TODO Auto-generated method stub } - @Override public Future<SendResult> sendObjectByFuture(Object obj) { // TODO Auto-generated method stub return null; } - @Override public void sendObjectByCompletion(Object obj, SendHandler completion) { // TODO Auto-generated method stub - } - - - - - - - - - protected static class WsCompletionHandler implements Future<SendResult>, - CompletionHandler<Long,Void> { - - private final WsRemoteEndpointBase wsRemoteEndpoint; - private final MessageSendStateMachine state; - private final SendHandler sendHandler; - private final boolean close; - private final CountDownLatch latch = new CountDownLatch(1); - private volatile SendResult result = null; - - - public WsCompletionHandler(WsRemoteEndpointBase wsRemoteEndpoint, - MessageSendStateMachine state, boolean close) { - this(wsRemoteEndpoint, null, state, close); - } - - - public WsCompletionHandler(WsRemoteEndpointBase wsRemoteEndpoint, - SendHandler sendHandler, MessageSendStateMachine state, - boolean close) { - this.wsRemoteEndpoint = wsRemoteEndpoint; - this.sendHandler = sendHandler; - this.state = state; - this.close = close; - } + } - // ------------------------------------------- CompletionHandler methods - @Override - public void completed(Long result, Void attachment) { - state.endMessage(); - if (close) { - wsRemoteEndpoint.close(); - } - this.result = new SendResult(); - latch.countDown(); - if (sendHandler != null) { - sendHandler.setResult(this.result); - } - } - @Override - public void failed(Throwable exc, Void attachment) { - state.endMessage(); - if (close) { - wsRemoteEndpoint.close(); - } - this.result = new SendResult(exc); - latch.countDown(); - if (sendHandler != null) { - sendHandler.setResult(this.result); - } - } + protected abstract void doWrite(SendHandler handler, ByteBuffer... data); + protected abstract boolean isMasked(); + protected abstract void close(); - // ------------------------------------------------------ Future methods + private static void writeHeader(ByteBuffer headerBuffer, byte opCode, + ByteBuffer payload, boolean first, boolean last, boolean masked, + byte[] mask) { - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - // Cancelling the task is not supported - return false; - } + byte b = 0; - - @Override - public boolean isCancelled() { - // Cancelling the task is not supported - return false; + if (last) { + // Set the fin bit + b = -128; } - - @Override - public boolean isDone() { - return latch.getCount() == 0; + if (first) { + // This is the first fragment of this message + b = (byte) (b + opCode); } + // If not the first fragment, it is a continuation with opCode of zero + headerBuffer.put(b); - @Override - public SendResult get() throws InterruptedException, ExecutionException { - latch.await(); - return result; + if (masked) { + b = (byte) 0x80; + } else { + b = 0; } - - @Override - public SendResult get(long timeout, TimeUnit unit) - throws InterruptedException, ExecutionException, - TimeoutException { - - latch.await(timeout, unit); - return result; + // Next write the mask && length length + if (payload.limit() < 126) { + headerBuffer.put((byte) (payload.limit() | b)); + } else if (payload.limit() < 65536) { + headerBuffer.put((byte) (126 | b)); + headerBuffer.put((byte) (payload.limit() >>> 8)); + headerBuffer.put((byte) (payload.limit() & 0xFF)); + } else { + // Will never be more than 2^31-1 + headerBuffer.put((byte) (127 | b)); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) (payload.limit() >>> 24)); + headerBuffer.put((byte) (payload.limit() >>> 16)); + headerBuffer.put((byte) (payload.limit() >>> 8)); + headerBuffer.put((byte) (payload.limit() & 0xFF)); + } + if (masked) { + headerBuffer.put(mask[0]); + headerBuffer.put(mask[1]); + headerBuffer.put(mask[2]); + headerBuffer.put(mask[3]); } } @@ -425,11 +433,12 @@ public abstract class WsRemoteEndpointBa private boolean inProgress = false; private boolean fragmented = false; private boolean text = false; + private boolean first = false; private boolean nextFragmented = false; private boolean nextText = false; - public synchronized boolean startMessage(byte opCode, boolean isLast) { + public synchronized void startMessage(byte opCode, boolean isLast) { if (closed) { throw new IllegalStateException( @@ -451,7 +460,8 @@ public abstract class WsRemoteEndpointBa if (opCode == Constants.OPCODE_CLOSE) { closed = true; } - return true; + first = true; + return; } boolean isText = Util.isText(opCode); @@ -464,7 +474,7 @@ public abstract class WsRemoteEndpointBa } nextText = text; nextFragmented = !isLast; - return false; + first = false; } else { // Wasn't fragmented. Might be now if (isLast) { @@ -473,7 +483,7 @@ public abstract class WsRemoteEndpointBa nextFragmented = true; nextText = isText; } - return true; + first = true; } } @@ -482,5 +492,219 @@ public abstract class WsRemoteEndpointBa fragmented = nextFragmented; text = nextText; } + + public synchronized boolean isInProgress() { + return inProgress; + } + + public synchronized boolean isFirst() { + return first; + } + } + + + private static class TextMessageSendHandler implements SendHandler { + + private final SendHandler handler; + private final CharBuffer message; + private final boolean isLast; + private final CharsetEncoder encoder; + private final ByteBuffer buffer; + private final WsRemoteEndpointBase endpoint; + private volatile boolean isDone = false; + + public TextMessageSendHandler(SendHandler handler, String message, + boolean isLast, CharsetEncoder encoder, + ByteBuffer encoderBuffer, WsRemoteEndpointBase endpoint) { + this.handler = handler; + this.message = CharBuffer.wrap(message); + this.isLast = isLast; + this.encoder = encoder.reset(); + this.buffer = encoderBuffer; + this.endpoint = endpoint; + } + + public void write() { + buffer.clear(); + CoderResult cr = encoder.encode(message, buffer, true); + if (cr.isError()) { + throw new IllegalArgumentException(cr.toString()); + } + isDone = !cr.isOverflow(); + buffer.flip(); + endpoint.sendMessage(Constants.OPCODE_TEXT, buffer, + isDone && isLast, this); + } + + @Override + public void setResult(SendResult result) { + if (isDone || !result.isOK()) { + handler.setResult(result); + } else { + write(); + } + } + } + + + /** + * Wraps user provided {@link SendHandler} so that state is updated when + * the message completes. + */ + private static class SendMessageSendHandler implements SendHandler { + + private final MessageSendStateMachine state; + private final SendHandler handler; + private final WsRemoteEndpointBase endpoint; + + public SendMessageSendHandler(MessageSendStateMachine state, + SendHandler handler, WsRemoteEndpointBase endpoint) { + this.state = state; + this.handler = handler; + this.endpoint = endpoint; + } + + @Override + public void setResult(SendResult result) { + state.endMessage(); + if (state.closed) { + endpoint.close(); + } + handler.setResult(result); + endpoint.endMessage(); + } + } + + + /** + * Used to write data to the output buffer, flushing the buffer if it fills + * up. + */ + private static class OutputBufferSendHandler implements SendHandler { + + private final SendHandler handler; + private final ByteBuffer headerBuffer; + private final ByteBuffer payload; + private final byte[] mask; + private final ByteBuffer outputBuffer; + private volatile boolean flushRequired; + private final WsRemoteEndpointBase endpoint; + private int maskIndex = 0; + + public OutputBufferSendHandler(SendHandler completion, + ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask, + ByteBuffer outputBuffer, boolean flushRequired, + WsRemoteEndpointBase endpoint) { + this.handler = completion; + this.headerBuffer = headerBuffer; + this.payload = payload; + this.mask = mask; + this.outputBuffer = outputBuffer; + this.flushRequired = flushRequired; + this.endpoint = endpoint; + } + + public void write() { + // Write the header + while (headerBuffer.hasRemaining() && outputBuffer.hasRemaining()) { + outputBuffer.put(headerBuffer.get()); + } + if (headerBuffer.hasRemaining()) { + // Still more headers to write, need to flush + flushRequired = true; + outputBuffer.flip(); + endpoint.doWrite(this, outputBuffer); + return; + } + + // Write the payload + while (payload.hasRemaining() && outputBuffer.hasRemaining()) { + outputBuffer.put( + (byte) (payload.get() ^ (mask[maskIndex++] & 0xFF))); + if (maskIndex > 3) { + maskIndex = 0; + } + } + if (payload.hasRemaining()) { + // Still more headers to write, need to flush + flushRequired = true; + outputBuffer.flip(); + endpoint.doWrite(this, outputBuffer); + return; + } + + if (flushRequired) { + outputBuffer.flip(); + endpoint.doWrite(this, outputBuffer); + flushRequired = false; + return; + } else { + handler.setResult(new SendResult()); + } + } + + // ------------------------------------------------- SendHandler methods + @Override + public void setResult(SendResult result) { + outputBuffer.clear(); + if (result.isOK()) { + write(); + } else { + handler.setResult(result); + } + } + } + + /** + * Converts a Future to a SendHandler. + */ + private static class FutureToSendHandler + implements Future<SendResult>, SendHandler { + + private final CountDownLatch latch = new CountDownLatch(1); + private volatile SendResult result = null; + + // --------------------------------------------------------- SendHandler + + @Override + public void setResult(SendResult result) { + this.result = result; + latch.countDown(); + } + + + // -------------------------------------------------------------- Future + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + // Cancelling the task is not supported + return false; + } + + @Override + public boolean isCancelled() { + // Cancelling the task is not supported + return false; + } + + @Override + public boolean isDone() { + return latch.getCount() == 0; + } + + @Override + public SendResult get() throws InterruptedException, + ExecutionException { + latch.await(); + return result; + } + + @Override + public SendResult get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + latch.await(timeout, unit); + return result; + } } } Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointClient.java URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointClient.java?rev=1443135&r1=1443134&r2=1443135&view=diff ============================================================================== --- tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointClient.java (original) +++ tomcat/trunk/java/org/apache/tomcat/websocket/WsRemoteEndpointClient.java Wed Feb 6 19:06:56 2013 @@ -19,8 +19,12 @@ package org.apache.tomcat.websocket; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousSocketChannel; +import java.nio.channels.CompletionHandler; import java.util.concurrent.TimeUnit; +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + public class WsRemoteEndpointClient extends WsRemoteEndpointBase { private final AsynchronousSocketChannel channel; @@ -31,20 +35,22 @@ public class WsRemoteEndpointClient exte @Override - protected byte getMasked() { - return (byte) 0x80; + protected boolean isMasked() { + return true; } @Override - protected void sendMessage(WsCompletionHandler handler) { + protected void doWrite(SendHandler handler, ByteBuffer... data) { long timeout = getAsyncSendTimeout(); if (timeout < 1) { timeout = Long.MAX_VALUE; } - channel.write(new ByteBuffer[] {outputBuffer, payload}, 0, 2, - getAsyncSendTimeout(), TimeUnit.MILLISECONDS, null, handler); + SendHandlerToCompletionHandler sh2ch = + new SendHandlerToCompletionHandler(handler); + channel.write(data, 0, data.length, getAsyncSendTimeout(), + TimeUnit.MILLISECONDS, null, sh2ch); } @Override @@ -55,4 +61,25 @@ public class WsRemoteEndpointClient exte // Ignore } } + + + private static class SendHandlerToCompletionHandler + implements CompletionHandler<Long,Void> { + + private SendHandler handler; + + public SendHandlerToCompletionHandler(SendHandler handler) { + this.handler = handler; + } + + @Override + public void completed(Long result, Void attachment) { + handler.setResult(new SendResult()); + } + + @Override + public void failed(Throwable exc, Void attachment) { + handler.setResult(new SendResult(exc)); + } + } } Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java?rev=1443135&r1=1443134&r2=1443135&view=diff ============================================================================== --- tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java (original) +++ tomcat/trunk/java/org/apache/tomcat/websocket/WsSession.java Wed Feb 6 19:06:56 2013 @@ -254,8 +254,8 @@ public class WsSession implements Sessio } msg.flip(); try { - wsRemoteEndpoint.sendMessageBlocking( - Constants.OPCODE_CLOSE, msg, true); + wsRemoteEndpoint.sendControlMessage( + Constants.OPCODE_CLOSE, msg); } catch (IOException ioe) { // Unable to send close message. // TODO - Ignore? Modified: tomcat/trunk/java/org/apache/tomcat/websocket/server/WsRemoteEndpointServer.java URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/server/WsRemoteEndpointServer.java?rev=1443135&r1=1443134&r2=1443135&view=diff ============================================================================== --- tomcat/trunk/java/org/apache/tomcat/websocket/server/WsRemoteEndpointServer.java (original) +++ tomcat/trunk/java/org/apache/tomcat/websocket/server/WsRemoteEndpointServer.java Wed Feb 6 19:06:56 2013 @@ -18,8 +18,11 @@ package org.apache.tomcat.websocket.serv import java.io.IOException; import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; import javax.servlet.ServletOutputStream; +import javax.websocket.SendHandler; +import javax.websocket.SendResult; import org.apache.juli.logging.Log; import org.apache.juli.logging.LogFactory; @@ -40,12 +43,11 @@ public class WsRemoteEndpointServer exte private final ServletOutputStream sos; private final WsTimeout wsTimeout; - private volatile WsCompletionHandler handler = null; + private volatile SendHandler handler = null; + private volatile ByteBuffer[] buffers = null; + private volatile long timeoutExpiry = -1; private volatile boolean close; - private volatile Long size = null; - private volatile boolean headerWritten = false; - private volatile boolean payloadWritten = false; public WsRemoteEndpointServer(ServletOutputStream sos, @@ -56,50 +58,59 @@ public class WsRemoteEndpointServer exte @Override - protected byte getMasked() { - // Messages from the server are not masked - return 0; + protected final boolean isMasked() { + return false; } @Override - protected void sendMessage(WsCompletionHandler handler) { + protected void doWrite(SendHandler handler, ByteBuffer... buffers) { this.handler = handler; + this.buffers = buffers; onWritePossible(); } public void onWritePossible() { + boolean complete = true; try { // If this is false there will be a call back when it is true while (sos.canWrite()) { - if (!headerWritten) { - headerWritten = true; - size = Long.valueOf( - outputBuffer.remaining() + payload.remaining()); - sos.write(outputBuffer.array(), outputBuffer.arrayOffset(), - outputBuffer.limit()); - } else if (!payloadWritten) { - payloadWritten = true; - sos.write(payload.array(), payload.arrayOffset(), - payload.limit()); - } else { + complete = true; + for (ByteBuffer buffer : buffers) { + if (buffer.hasRemaining()) { + complete = false; + sos.write(buffer.array(), buffer.arrayOffset(), + buffer.limit()); + buffer.position(buffer.limit()); + break; + } + } + if (complete) { wsTimeout.unregister(this); if (close) { close(); } - handler.completed(size, null); - nextWrite(); + // Setting the result marks this (partial) message as + // complete which means the next one may be sent which + // could update the value of the handler. Therefore, keep a + // local copy before signalling the end of the (partial) + // message. + SendHandler sh = handler; + handler = null; + sh.setResult(new SendResult()); break; } } + } catch (IOException ioe) { wsTimeout.unregister(this); close(); - handler.failed(ioe, null); - nextWrite(); + SendHandler sh = handler; + handler = null; + sh.setResult(new SendResult(ioe)); } - if (handler != null) { + if (!complete) { // Async write is in progress long timeout = getAsyncSendTimeout(); @@ -132,15 +143,7 @@ public class WsRemoteEndpointServer exte protected void onTimeout() { close(); - handler.failed(new SocketTimeoutException(), null); - nextWrite(); - } - - - private void nextWrite() { + handler.setResult(new SendResult(new SocketTimeoutException())); handler = null; - size = null; - headerWritten = false; - payloadWritten = false; } } Modified: tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainer.java URL: http://svn.apache.org/viewvc/tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainer.java?rev=1443135&r1=1443134&r2=1443135&view=diff ============================================================================== --- tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainer.java (original) +++ tomcat/trunk/test/org/apache/tomcat/websocket/TestWsWebSocketContainer.java Wed Feb 6 19:06:56 2013 @@ -160,7 +160,7 @@ public class TestWsWebSocketContainer ex @Test public void testSmallBinaryBufferClientTextMessage() throws Exception { - doBufferTest(false, false, true, false); + doBufferTest(false, false, true, true); } @@ -172,7 +172,7 @@ public class TestWsWebSocketContainer ex @Test public void testSmallBinaryBufferServerTextMessage() throws Exception { - doBufferTest(false, true, true, false); + doBufferTest(false, true, true, true); } @@ -382,7 +382,6 @@ public class TestWsWebSocketContainer ex // Check nothing really bad happened Assert.assertNull(ConstantTxEndpoint.getException()); - System.out.println(ConstantTxEndpoint.getTimeout()); // Check correct time passed Assert.assertTrue(ConstantTxEndpoint.getTimeout() >= TIMEOUT_MS); --------------------------------------------------------------------- To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org For additional commands, e-mail: dev-h...@tomcat.apache.org