This is an automated email from the ASF dual-hosted git repository.

lgoldstein pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git

commit 0bbdc772e2b36a71af3c1bf4caf98cef1ae27819
Author: Lyor Goldstein <lgoldst...@apache.org>
AuthorDate: Wed Dec 9 21:53:16 2020 +0200

    [SSHD-1085] Added more notifications related to channel state change for 
detecting channel closing or closed earlier
---
 CHANGES.md                                         |   2 +
 .../org/apache/sshd/cli/client/SshClientMain.java  |   2 +-
 .../sshd/common/util/logging/LoggingUtils.java     |  40 +++---
 .../sshd/client/channel/AbstractClientChannel.java |   2 +-
 .../sshd/common/channel/AbstractChannel.java       | 144 +++++++++++++--------
 .../org/apache/sshd/common/channel/Channel.java    |  18 +++
 .../session/helpers/AbstractConnectionService.java |  24 ++--
 7 files changed, 142 insertions(+), 90 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index c51a3ca..fc90adf 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -17,3 +17,5 @@
 ## Minor code helpers
 
 ## Behavioral changes and enhancements
+
+* [SSHD-1085](https://issues.apache.org/jira/browse/SSHD-1085) Added more 
notifications related to channel state change for detecting channel closing or 
closed earlier.
diff --git 
a/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientMain.java 
b/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientMain.java
index 7c1265e..c71f4fa 100644
--- a/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientMain.java
+++ b/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientMain.java
@@ -108,7 +108,7 @@ public class SshClientMain extends SshClientCliSupport {
                     error = true;
                     break;
                 }
-                if (GenericUtils.isEmpty(command) && target == null) {
+                if (GenericUtils.isEmpty(command) && (target == null)) {
                     target = argName;
                 } else {
                     if (command == null) {
diff --git 
a/sshd-common/src/main/java/org/apache/sshd/common/util/logging/LoggingUtils.java
 
b/sshd-common/src/main/java/org/apache/sshd/common/util/logging/LoggingUtils.java
index 08c90d7..19aa8e9 100644
--- 
a/sshd-common/src/main/java/org/apache/sshd/common/util/logging/LoggingUtils.java
+++ 
b/sshd-common/src/main/java/org/apache/sshd/common/util/logging/LoggingUtils.java
@@ -556,7 +556,7 @@ public final class LoggingUtils {
     }
 
     public static void debug(Logger log, String message, Object o1, Object o2, 
Throwable t) {
-        if (log.isTraceEnabled() && t != null) {
+        if (log.isTraceEnabled() && (t != null)) {
             log.debug(message, o1, o2, t);
         } else if (log.isDebugEnabled()) {
             log.debug(message, o1, o2);
@@ -564,7 +564,7 @@ public final class LoggingUtils {
     }
 
     public static void debug(Logger log, String message, Object o1, Object o2, 
Object o3, Throwable t) {
-        if (log.isTraceEnabled() && t != null) {
+        if (log.isTraceEnabled() && (t != null)) {
             log.debug(message, o1, o2, o3, t);
         } else if (log.isDebugEnabled()) {
             log.debug(message, o1, o2, o3);
@@ -572,7 +572,7 @@ public final class LoggingUtils {
     }
 
     public static void debug(Logger log, String message, Object o1, Object o2, 
Object o3, Object o4, Throwable t) {
-        if (log.isTraceEnabled() && t != null) {
+        if (log.isTraceEnabled() && (t != null)) {
             log.debug(message, o1, o2, o3, o4, t);
         } else if (log.isDebugEnabled()) {
             log.debug(message, o1, o2, o3, o4);
@@ -580,7 +580,7 @@ public final class LoggingUtils {
     }
 
     public static void debug(Logger log, String message, Object o1, Object o2, 
Object o3, Object o4, Object o5, Throwable t) {
-        if (log.isTraceEnabled() && t != null) {
+        if (log.isTraceEnabled() && (t != null)) {
             log.debug(message, o1, o2, o3, o4, o5, t);
         } else if (log.isDebugEnabled()) {
             log.debug(message, o1, o2, o3, o4, o5);
@@ -590,7 +590,7 @@ public final class LoggingUtils {
     @SuppressWarnings("all")
     public static void debug(
             Logger log, String message, Object o1, Object o2, Object o3, 
Object o4, Object o5, Object o6, Throwable t) {
-        if (log.isTraceEnabled() && t != null) {
+        if (log.isTraceEnabled() && (t != null)) {
             log.debug(message, o1, o2, o3, o4, o5, o6, t);
         } else if (log.isDebugEnabled()) {
             log.debug(message, o1, o2, o3, o4, o5, o6);
@@ -598,7 +598,7 @@ public final class LoggingUtils {
     }
 
     public static void info(Logger log, String message, Object o1, Object o2, 
Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.info(message, o1, o2, t);
         } else {
             log.info(message, o1, o2);
@@ -606,7 +606,7 @@ public final class LoggingUtils {
     }
 
     public static void info(Logger log, String message, Object o1, Object o2, 
Object o3, Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.info(message, o1, o2, o3, t);
         } else {
             log.info(message, o1, o2, o3);
@@ -614,7 +614,7 @@ public final class LoggingUtils {
     }
 
     public static void warn(Logger log, String message, Object o1, Object o2, 
Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.warn(message, o1, o2, t);
         } else {
             log.warn(message, o1, o2);
@@ -622,7 +622,7 @@ public final class LoggingUtils {
     }
 
     public static void warn(Logger log, String message, Object o1, Object o2, 
Object o3, Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.warn(message, o1, o2, o3, t);
         } else {
             log.warn(message, o1, o2, o3);
@@ -630,7 +630,7 @@ public final class LoggingUtils {
     }
 
     public static void warn(Logger log, String message, Object o1, Object o2, 
Object o3, Object o4, Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.warn(message, o1, o2, o3, o4, t);
         } else if (log.isDebugEnabled()) {
             log.warn(message, o1, o2, o3, o4);
@@ -638,7 +638,7 @@ public final class LoggingUtils {
     }
 
     public static void warn(Logger log, String message, Object o1, Object o2, 
Object o3, Object o4, Object o5, Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.warn(message, o1, o2, o3, o4, o5, t);
         } else {
             log.warn(message, o1, o2, o3, o4, o5);
@@ -648,7 +648,7 @@ public final class LoggingUtils {
     @SuppressWarnings("all")
     public static void warn(
             Logger log, String message, Object o1, Object o2, Object o3, 
Object o4, Object o5, Object o6, Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.warn(message, o1, o2, o3, o4, o5, o6, t);
         } else {
             log.warn(message, o1, o2, o3, o4, o5, o6);
@@ -659,7 +659,7 @@ public final class LoggingUtils {
     public static void warn(
             Logger log, String message, Object o1, Object o2, Object o3, 
Object o4, Object o5, Object o6, Object o7,
             Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.warn(message, o1, o2, o3, o4, o5, o6, o7, t);
         } else {
             log.warn(message, o1, o2, o3, o4, o5, o6, o7);
@@ -670,7 +670,7 @@ public final class LoggingUtils {
     public static void warn(
             Logger log, String message, Object o1, Object o2, Object o3, 
Object o4, Object o5, Object o6, Object o7, Object o8,
             Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.warn(message, o1, o2, o3, o4, o5, o6, o7, o8, t);
         } else {
             log.warn(message, o1, o2, o3, o4, o5, o6, o7, o8);
@@ -681,7 +681,7 @@ public final class LoggingUtils {
     public static void warn(
             Logger log, String message, Object o1, Object o2, Object o3, 
Object o4, Object o5, Object o6, Object o7, Object o8,
             Object o9, Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.warn(message, o1, o2, o3, o4, o5, o6, o7, o8, o9, t);
         } else {
             log.warn(message, o1, o2, o3, o4, o5, o6, o7, o8, o9);
@@ -689,7 +689,7 @@ public final class LoggingUtils {
     }
 
     public static void error(Logger log, String message, Object o1, Object o2, 
Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.error(message, o1, o2, t);
         } else {
             log.error(message, o1, o2);
@@ -697,7 +697,7 @@ public final class LoggingUtils {
     }
 
     public static void error(Logger log, String message, Object o1, Object o2, 
Object o3, Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.error(message, o1, o2, o3, t);
         } else {
             log.error(message, o1, o2, o3);
@@ -705,7 +705,7 @@ public final class LoggingUtils {
     }
 
     public static void error(Logger log, String message, Object o1, Object o2, 
Object o3, Object o4, Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.error(message, o1, o2, o3, o4, t);
         } else if (log.isDebugEnabled()) {
             log.error(message, o1, o2, o3, o4);
@@ -713,7 +713,7 @@ public final class LoggingUtils {
     }
 
     public static void error(Logger log, String message, Object o1, Object o2, 
Object o3, Object o4, Object o5, Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.error(message, o1, o2, o3, o4, o5, t);
         } else {
             log.error(message, o1, o2, o3, o4, o5);
@@ -723,7 +723,7 @@ public final class LoggingUtils {
     @SuppressWarnings("all")
     public static void error(
             Logger log, String message, Object o1, Object o2, Object o3, 
Object o4, Object o5, Object o6, Throwable t) {
-        if (log.isDebugEnabled() && t != null) {
+        if (log.isDebugEnabled() && (t != null)) {
             log.error(message, o1, o2, o3, o4, o5, o6, t);
         } else {
             log.error(message, o1, o2, o3, o4, o5, o6);
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java
 
b/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java
index 3857313..c82c404 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java
@@ -297,7 +297,7 @@ public abstract class AbstractClientChannel extends 
AbstractChannel implements C
         if ((openFuture != null) && openFuture.isOpened()) {
             state.add(ClientChannelEvent.OPENED);
         }
-        if (closeFuture.isClosed()) {
+        if (closeFuture.isClosed() || closeSignaled.get() || 
unregisterSignaled.get() || isClosed()) {
             state.add(ClientChannelEvent.CLOSED);
         }
         if (isEofSignalled()) {
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java 
b/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java
index 8bb4d99..f888d74 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java
@@ -69,9 +69,7 @@ import org.apache.sshd.core.CoreModuleProperties;
  *
  * @author <a href="mailto:d...@mina.apache.org";>Apache MINA SSHD Project</a>
  */
-public abstract class AbstractChannel
-        extends AbstractInnerCloseable
-        implements Channel, ExecutorServiceCarrier {
+public abstract class AbstractChannel extends AbstractInnerCloseable 
implements Channel, ExecutorServiceCarrier {
 
     /**
      * Default growth factor function used to resize response buffers
@@ -89,6 +87,8 @@ public abstract class AbstractChannel
     protected final AtomicBoolean initialized = new AtomicBoolean(false);
     protected final AtomicBoolean eofReceived = new AtomicBoolean(false);
     protected final AtomicBoolean eofSent = new AtomicBoolean(false);
+    protected final AtomicBoolean unregisterSignaled = new 
AtomicBoolean(false);
+    protected final AtomicBoolean closeSignaled = new AtomicBoolean(false);
     protected AtomicReference<GracefulState> gracefulState = new 
AtomicReference<>(GracefulState.Opened);
     protected final DefaultCloseFuture gracefulFuture;
     /**
@@ -274,17 +274,17 @@ public abstract class AbstractChannel
             try {
                 result = handler.process(this, req, wantReply, buffer);
             } catch (Throwable e) {
-                debug("handleRequest({}) {} while 
{}#process({})[want-reply={}]: {}",
-                        this, e.getClass().getSimpleName(), 
handler.getClass().getSimpleName(),
-                        req, wantReply, e.getMessage(), e);
+                debug("handleRequest({}) {} while 
{}#process({})[want-reply={}]: {}", this,
+                        e.getClass().getSimpleName(), 
handler.getClass().getSimpleName(), req, wantReply,
+                        e.getMessage(), e);
                 result = RequestHandler.Result.ReplyFailure;
             }
 
             // if Unsupported then check the next handler in line
             if (RequestHandler.Result.Unsupported.equals(result)) {
                 if (traceEnabled) {
-                    
log.trace("handleRequest({})[{}#process({})[want-reply={}]]: {}",
-                            this, handler.getClass().getSimpleName(), req, 
wantReply, result);
+                    
log.trace("handleRequest({})[{}#process({})[want-reply={}]]: {}", this,
+                            handler.getClass().getSimpleName(), req, 
wantReply, result);
                 }
             } else {
                 sendResponse(buffer, req, result, wantReply);
@@ -305,11 +305,11 @@ public abstract class AbstractChannel
      * @throws IOException If failed to send the response (if needed)
      * @see                #handleInternalRequest(String, boolean, Buffer)
      */
-    protected void handleUnknownChannelRequest(String req, boolean wantReply, 
Buffer buffer)
-            throws IOException {
+    protected void handleUnknownChannelRequest(String req, boolean wantReply, 
Buffer buffer) throws IOException {
         RequestHandler.Result r = handleInternalRequest(req, wantReply, 
buffer);
         if ((r == null) || RequestHandler.Result.Unsupported.equals(r)) {
-            log.warn("handleUnknownChannelRequest({}) Unknown channel request: 
{}[want-reply={}]", this, req, wantReply);
+            log.warn("handleUnknownChannelRequest({}) Unknown channel request: 
{}[want-reply={}]", this, req,
+                    wantReply);
             sendResponse(buffer, req, RequestHandler.Result.Unsupported, 
wantReply);
         } else {
             sendResponse(buffer, req, r, wantReply);
@@ -335,8 +335,7 @@ public abstract class AbstractChannel
         return RequestHandler.Result.Unsupported;
     }
 
-    protected IoWriteFuture sendResponse(
-            Buffer buffer, String req, RequestHandler.Result result, boolean 
wantReply)
+    protected IoWriteFuture sendResponse(Buffer buffer, String req, 
RequestHandler.Result result, boolean wantReply)
             throws IOException {
         if (log.isDebugEnabled()) {
             log.debug("sendResponse({}) request={} result={}, want-reply={}", 
this, req, result, wantReply);
@@ -379,6 +378,8 @@ public abstract class AbstractChannel
                 signalChannelInitialized(l);
                 return null;
             });
+
+            notifyStateChanged("init");
         } catch (Throwable err) {
             Throwable e = GenericUtils.peelException(err);
             if (e instanceof IOException) {
@@ -387,8 +388,8 @@ public abstract class AbstractChannel
                 throw (RuntimeException) e;
             } else {
                 throw new IOException(
-                        "Failed (" + e.getClass().getSimpleName() + ") to 
notify channel " + this + " initialization: "
-                                      + e.getMessage(),
+                        "Failed (" + e.getClass().getSimpleName() + ") to 
notify channel " + this
+                                      + " initialization: " + e.getMessage(),
                         e);
             }
         }
@@ -432,6 +433,21 @@ public abstract class AbstractChannel
         return initialized.get();
     }
 
+    @Override
+    public void handleChannelRegistrationResult(
+            ConnectionService service, Session session, int channelId,
+            boolean registered) {
+        notifyStateChanged("registered=" + registered);
+        if (registered) {
+            return;
+        }
+
+        RuntimeException reason = new IllegalStateException(
+                "Channel id=" + channelId + " not registered because session 
is being closed: " + this);
+        signalChannelClosed(reason);
+        throw reason;
+    }
+
     protected void signalChannelOpenFailure(Throwable reason) {
         try {
             invokeChannelSignaller(l -> {
@@ -440,8 +456,9 @@ public abstract class AbstractChannel
             });
         } catch (Throwable err) {
             Throwable ignored = GenericUtils.peelException(err);
-            debug("signalChannelOpenFailure({}) failed ({}) to inform listener 
of open failure={}: {}",
-                    this, ignored.getClass().getSimpleName(), 
reason.getClass().getSimpleName(), ignored.getMessage(), ignored);
+            debug("signalChannelOpenFailure({}) failed ({}) to inform listener 
of open failure={}: {}", this,
+                    ignored.getClass().getSimpleName(), 
reason.getClass().getSimpleName(), ignored.getMessage(),
+                    ignored);
         }
     }
 
@@ -461,8 +478,8 @@ public abstract class AbstractChannel
             });
         } catch (Throwable err) {
             Throwable e = GenericUtils.peelException(err);
-            debug("notifyStateChanged({})[{}] {} while signal channel state 
change: {}",
-                    this, hint, e.getClass().getSimpleName(), e.getMessage(), 
e);
+            debug("notifyStateChanged({})[{}] {} while signal channel state 
change: {}", this, hint,
+                    e.getClass().getSimpleName(), e.getMessage(), e);
         } finally {
             synchronized (futureLock) {
                 futureLock.notifyAll();
@@ -528,29 +545,31 @@ public abstract class AbstractChannel
             log.debug("handleClose({}) SSH_MSG_CHANNEL_CLOSE", this);
         }
 
-        if (!isEofSent()) {
-            if (debugEnabled) {
-                log.debug("handleClose({}) prevent sending EOF", this);
+        try {
+            if (!isEofSent()) {
+                if (debugEnabled) {
+                    log.debug("handleClose({}) prevent sending EOF", this);
+                }
             }
-        }
 
-        if (gracefulState.compareAndSet(GracefulState.Opened, 
GracefulState.CloseReceived)) {
-            close(false);
-        } else if (gracefulState.compareAndSet(GracefulState.CloseSent, 
GracefulState.Closed)) {
-            gracefulFuture.setClosed();
+            if (gracefulState.compareAndSet(GracefulState.Opened, 
GracefulState.CloseReceived)) {
+                close(false);
+            } else if (gracefulState.compareAndSet(GracefulState.CloseSent, 
GracefulState.Closed)) {
+                gracefulFuture.setClosed();
+            }
+        } finally {
+            notifyStateChanged("SSH_MSG_CHANNEL_CLOSE");
         }
     }
 
     @Override
     protected Closeable getInnerCloseable() {
-        Closeable closer = builder()
-                .sequential(new GracefulChannelCloseable(), 
getExecutorService())
+        Closeable closer = builder().sequential(new 
GracefulChannelCloseable(), getExecutorService())
                 .run(toString(), () -> {
                     if (service != null) {
                         service.unregisterChannel(AbstractChannel.this);
                     }
-                })
-                .build();
+                }).build();
         closer.addCloseFutureListener(future -> clearAttributes());
         return closer;
     }
@@ -675,15 +694,33 @@ public abstract class AbstractChannel
 
         IOException err = IoUtils.closeQuietly(getLocalWindow(), 
getRemoteWindow());
         if (err != null) {
-            debug("Failed ({}) to pre-close window(s) of {}: {}",
-                    err.getClass().getSimpleName(), this, err.getMessage(), 
err);
+            debug("Failed ({}) to pre-close window(s) of {}: {}", 
err.getClass().getSimpleName(), this,
+                    err.getMessage(), err);
         }
 
         super.preClose();
     }
 
+    @Override
+    public void handleChannelUnregistration(ConnectionService service) {
+        if (!unregisterSignaled.getAndSet(true)) {
+            if (log.isTraceEnabled()) {
+                log.trace("handleChannelUnregistration({}) via service={}", 
this, service);
+            }
+        }
+
+        notifyStateChanged("unregistered");
+    }
+
     public void signalChannelClosed(Throwable reason) {
+        String event = (reason == null) ? "signalChannelClosed" : 
reason.getClass().getSimpleName();
         try {
+            if (!closeSignaled.getAndSet(true)) {
+                if (log.isTraceEnabled()) {
+                    log.trace("signalChannelClosed({})[{}]", this, event);
+                }
+            }
+
             invokeChannelSignaller(l -> {
                 signalChannelClosed(l, reason);
                 return null;
@@ -692,6 +729,8 @@ public abstract class AbstractChannel
             Throwable e = GenericUtils.peelException(err);
             debug("signalChannelClosed({}) {} while signal channel closed: 
{}", this, e.getClass().getSimpleName(),
                     e.getMessage(), e);
+        } finally {
+            notifyStateChanged(event);
         }
     }
 
@@ -708,9 +747,7 @@ public abstract class AbstractChannel
         FactoryManager manager = (session == null) ? null : 
session.getFactoryManager();
         ChannelListener[] listeners = {
                 (manager == null) ? null : manager.getChannelListenerProxy(),
-                (session == null) ? null : session.getChannelListenerProxy(),
-                getChannelListenerProxy()
-        };
+                (session == null) ? null : session.getChannelListenerProxy(), 
getChannelListenerProxy() };
 
         Throwable err = null;
         for (ChannelListener l : listeners) {
@@ -783,8 +820,9 @@ public abstract class AbstractChannel
             log.debug("handleExtendedData({}) SSH_MSG_CHANNEL_EXTENDED_DATA 
len={}", this, len);
         }
         if (log.isTraceEnabled()) {
-            BufferUtils.dumpHex(getSimplifiedLogger(), 
BufferUtils.DEFAULT_HEXDUMP_LEVEL, "handleExtendedData(" + this + ")",
-                    this, BufferUtils.DEFAULT_HEX_SEPARATOR, buffer.array(), 
buffer.rpos(), (int) len);
+            BufferUtils.dumpHex(getSimplifiedLogger(), 
BufferUtils.DEFAULT_HEXDUMP_LEVEL,
+                    "handleExtendedData(" + this + ")", this, 
BufferUtils.DEFAULT_HEX_SEPARATOR, buffer.array(),
+                    buffer.rpos(), (int) len);
         }
         if (isEofSignalled()) {
             // TODO consider throwing an exception
@@ -793,7 +831,9 @@ public abstract class AbstractChannel
         doWriteExtendedData(buffer.array(), buffer.rpos(), len);
     }
 
-    protected long validateIncomingDataSize(int cmd, long len /* actually a 
uint32 */) {
+    protected long validateIncomingDataSize(
+            int cmd,
+            long len /* actually a uint32 */) {
         if (!BufferUtils.isValidUint32Value(len)) {
             throw new IllegalArgumentException(
                     "Non UINT32 length (" + len + ") for command=" + 
SshConstants.getCommandMessageName(cmd));
@@ -802,23 +842,23 @@ public abstract class AbstractChannel
         /*
          * According to RFC 4254 section 5.1
          *
-         * The 'maximum packet size' specifies the maximum size of an 
individual data packet that can be sent to the
-         * sender
+         * The 'maximum packet size' specifies the maximum size of an 
individual
+         * data packet that can be sent to the sender
          *
-         * The local window reflects our preference - i.e., how much our peer 
should send at most
+         * The local window reflects our preference - i.e., how much our peer
+         * should send at most
          */
         Window wLocal = getLocalWindow();
         long maxLocalSize = wLocal.getPacketSize();
 
         /*
-         * The reason for the +4 is that there seems to be some confusion 
whether the max. packet size includes the
-         * length field or not
+         * The reason for the +4 is that there seems to be some confusion
+         * whether the max. packet size includes the length field or not
          */
         if (len > (maxLocalSize + 4L)) {
             throw new IllegalStateException(
-                    "Bad length (" + len + ") "
-                                            + " for cmd=" + 
SshConstants.getCommandMessageName(cmd)
-                                            + " - max. allowed=" + 
maxLocalSize);
+                    "Bad length (" + len + ") " + " for cmd="
+                                            + 
SshConstants.getCommandMessageName(cmd) + " - max. allowed=" + maxLocalSize);
         }
 
         return len;
@@ -905,7 +945,8 @@ public abstract class AbstractChannel
         Buffer buffer = s.createBuffer(SshConstants.SSH_MSG_CHANNEL_EOF, 
Short.SIZE);
         buffer.putInt(getRecipient());
         /*
-         * The default "writePacket" does not send packets if state is not 
open so we need to bypass it.
+         * The default "writePacket" does not send packets if state is not open
+         * so we need to bypass it.
          */
         return s.writePacket(buffer);
     }
@@ -946,9 +987,7 @@ public abstract class AbstractChannel
     @Override
     @SuppressWarnings("unchecked")
     public <T> T setAttribute(AttributeRepository.AttributeKey<T> key, T 
value) {
-        return (T) attributes.put(
-                Objects.requireNonNull(key, "No key"),
-                Objects.requireNonNull(value, "No value"));
+        return (T) attributes.put(Objects.requireNonNull(key, "No key"), 
Objects.requireNonNull(value, "No value"));
     }
 
     @Override
@@ -979,6 +1018,7 @@ public abstract class AbstractChannel
 
     @Override
     public String toString() {
-        return getClass().getSimpleName() + "[id=" + getId() + ", recipient=" 
+ getRecipient() + "]" + "-" + getSession();
+        return getClass().getSimpleName() + "[id=" + getId() + ", recipient=" 
+ getRecipient() + "]" + "-"
+               + getSession();
     }
 }
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/channel/Channel.java 
b/sshd-core/src/main/java/org/apache/sshd/common/channel/Channel.java
index dae4aa1..2bfd095 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/channel/Channel.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/Channel.java
@@ -153,6 +153,24 @@ public interface Channel
     void init(ConnectionService service, Session session, int id) throws 
IOException;
 
     /**
+     * Invoked after being successfully registered by the connection service - 
should throw a {@link RuntimeException}
+     * if not registered
+     *
+     * @param service    The {@link ConnectionService} through which the 
channel is registered
+     * @param session    The {@link Session} associated with the channel
+     * @param id         The locally assigned channel identifier
+     * @param registered Whether registration was successful or not
+     */
+    void handleChannelRegistrationResult(ConnectionService service, Session 
session, int id, boolean registered);
+
+    /**
+     * Called by the connection service to inform the channel that it has bee 
unregistered.
+     *
+     * @param service The {@link ConnectionService} through which the channel 
is unregistered
+     */
+    void handleChannelUnregistration(ConnectionService service);
+
+    /**
      * @return {@code true} if call to {@link #init(ConnectionService, 
Session, int)} was successfully completed
      */
     boolean isInitialized();
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractConnectionService.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractConnectionService.java
index 7a9eb07..d1f3ac0 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractConnectionService.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractConnectionService.java
@@ -43,7 +43,6 @@ import org.apache.sshd.client.future.OpenFuture;
 import org.apache.sshd.common.Closeable;
 import org.apache.sshd.common.FactoryManager;
 import org.apache.sshd.common.SshConstants;
-import org.apache.sshd.common.channel.AbstractChannel;
 import org.apache.sshd.common.channel.Channel;
 import org.apache.sshd.common.channel.ChannelFactory;
 import org.apache.sshd.common.channel.RequestHandler;
@@ -389,7 +388,7 @@ public abstract class AbstractConnectionService
     protected Closeable getInnerCloseable() {
         return builder()
                 .sequential(forwarderHolder.get(), agentForwardHolder.get(), 
x11ForwardHolder.get())
-                .parallel(toString(), channels.values())
+                .parallel(toString(), getChannels())
                 .build();
     }
 
@@ -417,23 +416,12 @@ public abstract class AbstractConnectionService
             }
         }
 
-        if (!registered) {
-            handleChannelRegistrationFailure(channel, channelId);
-        }
-
         if (log.isDebugEnabled()) {
-            log.debug("registerChannel({})[id={}] {}", this, channelId, 
channel);
+            log.debug("registerChannel({})[id={}, registered={}] {}", this, 
channelId, registered, channel);
         }
-        return channelId;
-    }
 
-    protected void handleChannelRegistrationFailure(Channel channel, int 
channelId) throws IOException {
-        RuntimeException reason = new IllegalStateException(
-                "Channel id=" + channelId + " not registered because session 
is being closed: " + this);
-        AbstractChannel notifier
-                = ValidateUtils.checkInstanceOf(channel, 
AbstractChannel.class, "Non abstract channel for id=%d", channelId);
-        notifier.signalChannelClosed(reason);
-        throw reason;
+        channel.handleChannelRegistrationResult(this, session, channelId, 
registered);
+        return channelId;
     }
 
     /**
@@ -452,6 +440,10 @@ public abstract class AbstractConnectionService
         if (log.isDebugEnabled()) {
             log.debug("unregisterChannel({}) result={}", channel, result);
         }
+
+        if (result != null) {
+            result.handleChannelUnregistration(this);
+        }
     }
 
     @Override

Reply via email to