[SSHD-792] Using more fine-grained decision whether to close a tunnel session 
gracefully or not


Project: http://git-wip-us.apache.org/repos/asf/mina-sshd/repo
Commit: http://git-wip-us.apache.org/repos/asf/mina-sshd/commit/169ff4e4
Tree: http://git-wip-us.apache.org/repos/asf/mina-sshd/tree/169ff4e4
Diff: http://git-wip-us.apache.org/repos/asf/mina-sshd/diff/169ff4e4

Branch: refs/heads/master
Commit: 169ff4e43dde1b939ed193fc89e5b615457e0fc2
Parents: deb2445
Author: Goldstein Lyor <l...@c-b4.com>
Authored: Wed Feb 28 15:18:36 2018 +0200
Committer: Lyor Goldstein <lyor.goldst...@gmail.com>
Committed: Wed Feb 28 19:47:04 2018 +0200

----------------------------------------------------------------------
 .../common/forward/DefaultForwardingFilter.java | 49 +++++++++++---------
 1 file changed, 26 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/169ff4e4/sshd-core/src/main/java/org/apache/sshd/common/forward/DefaultForwardingFilter.java
----------------------------------------------------------------------
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/forward/DefaultForwardingFilter.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/forward/DefaultForwardingFilter.java
index a6f46fa..21de6dc 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/forward/DefaultForwardingFilter.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/forward/DefaultForwardingFilter.java
@@ -34,6 +34,7 @@ import java.util.Set;
 import java.util.TreeMap;
 import java.util.concurrent.CopyOnWriteArraySet;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
 
 import org.apache.sshd.client.channel.ClientChannelEvent;
 import org.apache.sshd.common.Closeable;
@@ -915,18 +916,17 @@ public class DefaultForwardingFilter
         return getClass().getSimpleName() + "[" + getSession() + "]";
     }
 
-    //
-    // Static IoHandler implementation
-    //
-
+    @SuppressWarnings("synthetic-access")
     class StaticIoHandler implements IoHandler {
+        private final AtomicLong messagesCounter = new AtomicLong(0L);
+        private final boolean traceEnabled = log.isTraceEnabled();
+
         StaticIoHandler() {
             super();
         }
 
         @Override
-        @SuppressWarnings("synthetic-access")
-        public void sessionCreated(final IoSession session) throws Exception {
+        public void sessionCreated(IoSession session) throws Exception {
             InetSocketAddress local = (InetSocketAddress) 
session.getLocalAddress();
             int localPort = local.getPort();
             SshdSocketAddress remote = localToRemote.get(localPort);
@@ -934,12 +934,8 @@ public class DefaultForwardingFilter
                 log.debug("sessionCreated({}) remote={}", session, remote);
             }
 
-            TcpipClientChannel channel;
-            if (remote != null) {
-                channel = new 
TcpipClientChannel(TcpipClientChannel.Type.Direct, session, remote);
-            } else {
-                channel = new 
TcpipClientChannel(TcpipClientChannel.Type.Forwarded, session, null);
-            }
+            TcpipClientChannel.Type channelType = (remote == null) ? 
TcpipClientChannel.Type.Forwarded : TcpipClientChannel.Type.Direct;
+            TcpipClientChannel channel = new TcpipClientChannel(channelType, 
session, remote);
             session.setAttribute(TcpipClientChannel.class, channel);
 
             service.registerChannel(channel);
@@ -958,28 +954,35 @@ public class DefaultForwardingFilter
         }
 
         @Override
-        @SuppressWarnings("synthetic-access")
         public void sessionClosed(IoSession session) throws Exception {
-            TcpipClientChannel channel = (TcpipClientChannel) 
session.getAttribute(TcpipClientChannel.class);
+            TcpipClientChannel channel = (TcpipClientChannel) 
session.removeAttribute(TcpipClientChannel.class);
             if (channel != null) {
+                Throwable cause = (Throwable) 
session.getAttribute(Throwable.class);
                 if (log.isDebugEnabled()) {
-                    log.debug("sessionClosed({}) closing channel={}", session, 
channel);
+                    log.debug("sessionClosed({}) closing channel={} after {} 
messages - cause={}",
+                            session, channel, messagesCounter, (cause == null) 
? null : cause.getClass().getSimpleName());
                 }
-                channel.close(false);
+                // If exception signaled then close channel immediately
+                channel.close(cause != null);
             }
         }
 
         @Override
-        @SuppressWarnings("synthetic-access")
         public void messageReceived(IoSession session, Readable message) 
throws Exception {
             TcpipClientChannel channel = (TcpipClientChannel) 
session.getAttribute(TcpipClientChannel.class);
+            long totalMessages = messagesCounter.incrementAndGet();
             Buffer buffer = new ByteArrayBuffer(message.available() + 
Long.SIZE, false);
             buffer.putBuffer(message);
 
+            if (traceEnabled) {
+                log.trace("messageReceived({}) channel={}, count={}, handle 
len={}",
+                          session, channel, totalMessages, 
message.available());
+            }
+
             Collection<ClientChannelEvent> result = 
channel.waitFor(STATIC_IO_MSG_RECEIVED_EVENTS, Long.MAX_VALUE);
-            if (log.isTraceEnabled()) {
-                log.trace("messageReceived({}) channel={}, len={} wait result: 
{}",
-                          session, channel, result, buffer.array());
+            if (traceEnabled) {
+                log.trace("messageReceived({}) channel={}, count={}, len={} 
wait result: {}",
+                          session, channel, totalMessages, 
message.available(), result);
             }
 
             OutputStream outputStream = channel.getInvertedIn();
@@ -988,15 +991,15 @@ public class DefaultForwardingFilter
         }
 
         @Override
-        @SuppressWarnings("synthetic-access")
         public void exceptionCaught(IoSession session, Throwable cause) throws 
Exception {
+            session.setAttribute(Throwable.class, cause);
             if (log.isDebugEnabled()) {
                 log.debug("exceptionCaught({}) {}: {}", session, 
cause.getClass().getSimpleName(), cause.getMessage());
             }
-            if (log.isTraceEnabled()) {
+            if (traceEnabled) {
                 log.trace("exceptionCaught(" + session + ") caught exception 
details", cause);
             }
-            session.close(false);
+            session.close(true);
         }
     }
 }

Reply via email to