Author: markt Date: Wed Dec 19 20:02:04 2012 New Revision: 1424066 URL: http://svn.apache.org/viewvc?rev=1424066&view=rev Log: WebSocket 1.0 implementation part 17 of many Improve the handling of fragmented messages
Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsFrame.java Modified: tomcat/trunk/java/org/apache/tomcat/websocket/WsFrame.java URL: http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/tomcat/websocket/WsFrame.java?rev=1424066&r1=1424065&r2=1424066&view=diff ============================================================================== --- tomcat/trunk/java/org/apache/tomcat/websocket/WsFrame.java (original) +++ tomcat/trunk/java/org/apache/tomcat/websocket/WsFrame.java Wed Dec 19 20:02:04 2012 @@ -19,7 +19,6 @@ package org.apache.tomcat.websocket; import java.io.EOFException; import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.charset.Charset; import javax.servlet.ServletInputStream; import javax.websocket.MessageHandler; @@ -33,28 +32,40 @@ import org.apache.tomcat.util.res.String */ public class WsFrame { - private static StringManager sm = StringManager.getManager(Constants.PACKAGE_NAME); + private static StringManager sm = + StringManager.getManager(Constants.PACKAGE_NAME); + + // Connection level attributes private final ServletInputStream sis; private final WsSession wsSession; private final byte[] inputBuffer; - private int pos = 0; - private State state = State.NEW_FRAME; - private int headerLength = 0; - private boolean continutationExpected = false; + + // Attributes of the current message + private final ByteBuffer messageBuffer; + private boolean continuationExpected = false; private boolean textMessage = false; - private long payloadSent = 0; - private long payloadLength = 0; - private boolean fin; - private int rsv; - private byte opCode; + + // Attributes of the current frame + private boolean fin = false; + private int rsv = 0; + private byte opCode = 0; + private int frameStart = 0; + private int headerLength = 0; private byte[] mask = new byte[4]; - int maskIndex = 0; + private int maskIndex = 0; + private long payloadLength = 0; + private int payloadRead = 0; + private long payloadWritten = 0; + // Attributes tracking state + private State state = State.NEW_FRAME; + private int writePos = 0; public WsFrame(ServletInputStream sis, WsSession wsSession) { this.sis = sis; this.wsSession = wsSession; inputBuffer = new byte[8192]; + messageBuffer = ByteBuffer.allocate(8192); } @@ -64,14 +75,15 @@ public class WsFrame { public void onDataAvailable() throws IOException { while (sis.isReady()) { // Fill up the input buffer with as much data as we can - int read = sis.read(inputBuffer, pos, inputBuffer.length - pos); + int read = sis.read(inputBuffer, writePos, + inputBuffer.length - writePos); if (read == 0) { return; } if (read == -1) { throw new EOFException(); } - pos += read; + writePos += read; while (true) { if (state == State.NEW_FRAME) { if (!processInitialHeader()) { @@ -99,15 +111,15 @@ public class WsFrame { */ private boolean processInitialHeader() throws IOException { // Need at least two bytes of data to do this - if (pos < 2) { + if (writePos - frameStart < 2) { return false; } - int b = inputBuffer[0]; + int b = inputBuffer[frameStart]; fin = (b & 0x80) > 0; rsv = (b & 0x70) >>> 4; opCode = (byte) (b & 0x0F); if (!isControl()) { - if (continutationExpected) { + if (continuationExpected) { if (opCode != Constants.OPCODE_CONTINUATION) { // TODO i18n throw new IllegalStateException(); @@ -122,9 +134,9 @@ public class WsFrame { throw new UnsupportedOperationException(); } } - continutationExpected = !fin; + continuationExpected = !fin; } - b = inputBuffer[1]; + b = inputBuffer[frameStart + 1]; // Client data must be masked if ((b & 0x80) == 0) { throw new IOException(sm.getString("wsFrame.notMasked")); @@ -148,7 +160,7 @@ public class WsFrame { } else if (payloadLength == 127) { headerLength += 8; } - if (pos < headerLength) { + if (writePos - frameStart < headerLength) { return false; } // Calculate new payload length if necessary @@ -167,57 +179,69 @@ public class WsFrame { throw new IOException("wsFrame.controlNoFin"); } } - System.arraycopy(inputBuffer, headerLength - 4, mask, 0, 4); + System.arraycopy(inputBuffer, frameStart + headerLength - 4, mask, 0, 4); state = State.DATA; + payloadRead = frameStart + headerLength; return true; } private boolean processData() throws IOException { + checkRoomPayload(); + appendPayloadToMessage(); if (isControl()) { - if (!isPayloadComplete()) { + if (writePos < frameStart + headerLength + payloadLength) { return false; } if (opCode == Constants.OPCODE_CLOSE) { wsSession.close(); } else if (opCode == Constants.OPCODE_PING) { - wsSession.getRemote().sendPong(getPayloadBinary()); + messageBuffer.flip(); + wsSession.getRemote().sendPong(messageBuffer); } else if (opCode == Constants.OPCODE_PONG) { MessageHandler.Basic<PongMessage> mhPong = wsSession.getPongMessageHandler(); if (mhPong != null) { - mhPong.onMessage(new WsPongMessage(getPayloadBinary())); + messageBuffer.flip(); + mhPong.onMessage(new WsPongMessage(messageBuffer)); } } else { // TODO i18n throw new UnsupportedOperationException(); } + newMessage(); return true; } - if (!isPayloadComplete()) { - if (usePartial()) { - sendPayload(false); - return false; - } else { - if (inputBuffer.length - pos > 0) { - return false; + if (payloadWritten == payloadLength) { + if (continuationExpected) { + if (usePartial()) { + messageBuffer.flip(); + sendMessage(false); + messageBuffer.clear(); } - throw new UnsupportedOperationException(); + newFrame(); + return true; + } else { + messageBuffer.flip(); + sendMessage(true); + newMessage(); + return true; } } else { - sendPayload(true); + if (usePartial()) { + messageBuffer.flip(); + sendMessage(false); + messageBuffer.clear(); + } + return false; } - state = State.NEW_FRAME; - payloadLength = 0; - payloadSent = 0; - maskIndex = 0; - return true; } @SuppressWarnings("unchecked") - private void sendPayload(boolean last) { + private void sendMessage(boolean last) { if (textMessage) { - String payload = getPayloadText(); + String payload = + new String(messageBuffer.array(), 0, messageBuffer.limit()); MessageHandler mh = wsSession.getTextMessageHandler(); if (mh != null) { if (mh instanceof MessageHandler.Async<?>) { @@ -227,27 +251,83 @@ public class WsFrame { } } } else { - ByteBuffer payload = getPayloadBinary(); MessageHandler mh = wsSession.getBinaryMessageHandler(); if (mh != null) { if (mh instanceof MessageHandler.Async<?>) { - ((MessageHandler.Async<ByteBuffer>) mh).onMessage(payload, - last); + ((MessageHandler.Async<ByteBuffer>) mh).onMessage( + messageBuffer, last); } else { - ((MessageHandler.Basic<ByteBuffer>) mh).onMessage(payload); + ((MessageHandler.Basic<ByteBuffer>) mh).onMessage( + messageBuffer); } } } } - private boolean isControl() { - return (opCode & 0x08) > 0; + private void newMessage() { + messageBuffer.clear(); + continuationExpected = false; + newFrame(); + } + + + private void newFrame() { + if (frameStart + headerLength + payloadLength == writePos) { + frameStart = 0; + writePos = 0; + } else { + frameStart = frameStart + headerLength + (int) payloadLength; + } + + // These get reset in processInitialHeader() + // fin, rsv, opCode, headerLength, payloadLength, mask + maskIndex = 0; + payloadRead = 0; + payloadWritten = 0; + state = State.NEW_FRAME; + checkRoomHeaders(); + } + + + private void checkRoomHeaders() { + // Is the start of the current frame too near the end of the input + // buffer? + if (inputBuffer.length - frameStart < 131) { + // Limit based on a control frame with a full payload + makeRoom(); + } + } + + + private void checkRoomPayload() throws IOException { + long frameSize = headerLength + payloadLength; + if (inputBuffer.length - frameStart - frameSize < 0) { + if (isControl()) { + makeRoom(); + return; + } + // Might not be enough room + if (usePartial()) { + // Not a problem - can use partial messages + return; + } + if (inputBuffer.length < frameSize) { + // Never going to work + // TODO i18n - buffer too small + throw new IOException(); + } + makeRoom(); + } } - private boolean isPayloadComplete() { - return (payloadSent + pos - headerLength) >= payloadLength; + private void makeRoom() { + System.arraycopy(inputBuffer, frameStart, inputBuffer, 0, + writePos - frameStart); + writePos = writePos - frameStart; + payloadRead = payloadRead - frameStart; + frameStart = 0; } @@ -271,32 +351,17 @@ public class WsFrame { } - private ByteBuffer getPayloadBinary() { - int end; - if (isPayloadComplete()) { - end = (int) (payloadLength - payloadSent) + headerLength; - } else { - end = pos; - } - ByteBuffer result = ByteBuffer.allocate(end - headerLength); - for (int i = headerLength; i < end; i++) { - result.put(i - headerLength, - (byte) ((inputBuffer[i] ^ mask[maskIndex]) & 0xFF)); + private void appendPayloadToMessage() { + while (payloadWritten < payloadLength && payloadRead < writePos) { + byte b = (byte) ((inputBuffer[payloadRead] ^ mask[maskIndex]) & 0xFF); maskIndex++; if (maskIndex == 4) { maskIndex = 0; } + payloadRead++; + payloadWritten++; + messageBuffer.put(b); } - // May have read past end of current frame into next - pos = 0; - headerLength = 0; - return result; - } - - - private String getPayloadText() { - ByteBuffer bb = getPayloadBinary(); - return new String(bb.array(), Charset.forName("UTF-8")); } @@ -315,6 +380,12 @@ public class WsFrame { return result; } + + private boolean isControl() { + return (opCode & 0x08) > 0; + } + + private static enum State { NEW_FRAME, PARTIAL_HEADER, DATA } --------------------------------------------------------------------- To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org For additional commands, e-mail: dev-h...@tomcat.apache.org