This is an automated email from the ASF dual-hosted git repository. lgoldstein pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/mina-sshd.git
commit 0bbdc772e2b36a71af3c1bf4caf98cef1ae27819 Author: Lyor Goldstein <lgoldst...@apache.org> AuthorDate: Wed Dec 9 21:53:16 2020 +0200 [SSHD-1085] Added more notifications related to channel state change for detecting channel closing or closed earlier --- CHANGES.md | 2 + .../org/apache/sshd/cli/client/SshClientMain.java | 2 +- .../sshd/common/util/logging/LoggingUtils.java | 40 +++--- .../sshd/client/channel/AbstractClientChannel.java | 2 +- .../sshd/common/channel/AbstractChannel.java | 144 +++++++++++++-------- .../org/apache/sshd/common/channel/Channel.java | 18 +++ .../session/helpers/AbstractConnectionService.java | 24 ++-- 7 files changed, 142 insertions(+), 90 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index c51a3ca..fc90adf 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -17,3 +17,5 @@ ## Minor code helpers ## Behavioral changes and enhancements + +* [SSHD-1085](https://issues.apache.org/jira/browse/SSHD-1085) Added more notifications related to channel state change for detecting channel closing or closed earlier. diff --git a/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientMain.java b/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientMain.java index 7c1265e..c71f4fa 100644 --- a/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientMain.java +++ b/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientMain.java @@ -108,7 +108,7 @@ public class SshClientMain extends SshClientCliSupport { error = true; break; } - if (GenericUtils.isEmpty(command) && target == null) { + if (GenericUtils.isEmpty(command) && (target == null)) { target = argName; } else { if (command == null) { diff --git a/sshd-common/src/main/java/org/apache/sshd/common/util/logging/LoggingUtils.java b/sshd-common/src/main/java/org/apache/sshd/common/util/logging/LoggingUtils.java index 08c90d7..19aa8e9 100644 --- a/sshd-common/src/main/java/org/apache/sshd/common/util/logging/LoggingUtils.java +++ b/sshd-common/src/main/java/org/apache/sshd/common/util/logging/LoggingUtils.java @@ -556,7 +556,7 @@ public final class LoggingUtils { } public static void debug(Logger log, String message, Object o1, Object o2, Throwable t) { - if (log.isTraceEnabled() && t != null) { + if (log.isTraceEnabled() && (t != null)) { log.debug(message, o1, o2, t); } else if (log.isDebugEnabled()) { log.debug(message, o1, o2); @@ -564,7 +564,7 @@ public final class LoggingUtils { } public static void debug(Logger log, String message, Object o1, Object o2, Object o3, Throwable t) { - if (log.isTraceEnabled() && t != null) { + if (log.isTraceEnabled() && (t != null)) { log.debug(message, o1, o2, o3, t); } else if (log.isDebugEnabled()) { log.debug(message, o1, o2, o3); @@ -572,7 +572,7 @@ public final class LoggingUtils { } public static void debug(Logger log, String message, Object o1, Object o2, Object o3, Object o4, Throwable t) { - if (log.isTraceEnabled() && t != null) { + if (log.isTraceEnabled() && (t != null)) { log.debug(message, o1, o2, o3, o4, t); } else if (log.isDebugEnabled()) { log.debug(message, o1, o2, o3, o4); @@ -580,7 +580,7 @@ public final class LoggingUtils { } public static void debug(Logger log, String message, Object o1, Object o2, Object o3, Object o4, Object o5, Throwable t) { - if (log.isTraceEnabled() && t != null) { + if (log.isTraceEnabled() && (t != null)) { log.debug(message, o1, o2, o3, o4, o5, t); } else if (log.isDebugEnabled()) { log.debug(message, o1, o2, o3, o4, o5); @@ -590,7 +590,7 @@ public final class LoggingUtils { @SuppressWarnings("all") public static void debug( Logger log, String message, Object o1, Object o2, Object o3, Object o4, Object o5, Object o6, Throwable t) { - if (log.isTraceEnabled() && t != null) { + if (log.isTraceEnabled() && (t != null)) { log.debug(message, o1, o2, o3, o4, o5, o6, t); } else if (log.isDebugEnabled()) { log.debug(message, o1, o2, o3, o4, o5, o6); @@ -598,7 +598,7 @@ public final class LoggingUtils { } public static void info(Logger log, String message, Object o1, Object o2, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.info(message, o1, o2, t); } else { log.info(message, o1, o2); @@ -606,7 +606,7 @@ public final class LoggingUtils { } public static void info(Logger log, String message, Object o1, Object o2, Object o3, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.info(message, o1, o2, o3, t); } else { log.info(message, o1, o2, o3); @@ -614,7 +614,7 @@ public final class LoggingUtils { } public static void warn(Logger log, String message, Object o1, Object o2, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.warn(message, o1, o2, t); } else { log.warn(message, o1, o2); @@ -622,7 +622,7 @@ public final class LoggingUtils { } public static void warn(Logger log, String message, Object o1, Object o2, Object o3, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.warn(message, o1, o2, o3, t); } else { log.warn(message, o1, o2, o3); @@ -630,7 +630,7 @@ public final class LoggingUtils { } public static void warn(Logger log, String message, Object o1, Object o2, Object o3, Object o4, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.warn(message, o1, o2, o3, o4, t); } else if (log.isDebugEnabled()) { log.warn(message, o1, o2, o3, o4); @@ -638,7 +638,7 @@ public final class LoggingUtils { } public static void warn(Logger log, String message, Object o1, Object o2, Object o3, Object o4, Object o5, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.warn(message, o1, o2, o3, o4, o5, t); } else { log.warn(message, o1, o2, o3, o4, o5); @@ -648,7 +648,7 @@ public final class LoggingUtils { @SuppressWarnings("all") public static void warn( Logger log, String message, Object o1, Object o2, Object o3, Object o4, Object o5, Object o6, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.warn(message, o1, o2, o3, o4, o5, o6, t); } else { log.warn(message, o1, o2, o3, o4, o5, o6); @@ -659,7 +659,7 @@ public final class LoggingUtils { public static void warn( Logger log, String message, Object o1, Object o2, Object o3, Object o4, Object o5, Object o6, Object o7, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.warn(message, o1, o2, o3, o4, o5, o6, o7, t); } else { log.warn(message, o1, o2, o3, o4, o5, o6, o7); @@ -670,7 +670,7 @@ public final class LoggingUtils { public static void warn( Logger log, String message, Object o1, Object o2, Object o3, Object o4, Object o5, Object o6, Object o7, Object o8, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.warn(message, o1, o2, o3, o4, o5, o6, o7, o8, t); } else { log.warn(message, o1, o2, o3, o4, o5, o6, o7, o8); @@ -681,7 +681,7 @@ public final class LoggingUtils { public static void warn( Logger log, String message, Object o1, Object o2, Object o3, Object o4, Object o5, Object o6, Object o7, Object o8, Object o9, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.warn(message, o1, o2, o3, o4, o5, o6, o7, o8, o9, t); } else { log.warn(message, o1, o2, o3, o4, o5, o6, o7, o8, o9); @@ -689,7 +689,7 @@ public final class LoggingUtils { } public static void error(Logger log, String message, Object o1, Object o2, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.error(message, o1, o2, t); } else { log.error(message, o1, o2); @@ -697,7 +697,7 @@ public final class LoggingUtils { } public static void error(Logger log, String message, Object o1, Object o2, Object o3, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.error(message, o1, o2, o3, t); } else { log.error(message, o1, o2, o3); @@ -705,7 +705,7 @@ public final class LoggingUtils { } public static void error(Logger log, String message, Object o1, Object o2, Object o3, Object o4, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.error(message, o1, o2, o3, o4, t); } else if (log.isDebugEnabled()) { log.error(message, o1, o2, o3, o4); @@ -713,7 +713,7 @@ public final class LoggingUtils { } public static void error(Logger log, String message, Object o1, Object o2, Object o3, Object o4, Object o5, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.error(message, o1, o2, o3, o4, o5, t); } else { log.error(message, o1, o2, o3, o4, o5); @@ -723,7 +723,7 @@ public final class LoggingUtils { @SuppressWarnings("all") public static void error( Logger log, String message, Object o1, Object o2, Object o3, Object o4, Object o5, Object o6, Throwable t) { - if (log.isDebugEnabled() && t != null) { + if (log.isDebugEnabled() && (t != null)) { log.error(message, o1, o2, o3, o4, o5, o6, t); } else { log.error(message, o1, o2, o3, o4, o5, o6); diff --git a/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java b/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java index 3857313..c82c404 100644 --- a/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java +++ b/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java @@ -297,7 +297,7 @@ public abstract class AbstractClientChannel extends AbstractChannel implements C if ((openFuture != null) && openFuture.isOpened()) { state.add(ClientChannelEvent.OPENED); } - if (closeFuture.isClosed()) { + if (closeFuture.isClosed() || closeSignaled.get() || unregisterSignaled.get() || isClosed()) { state.add(ClientChannelEvent.CLOSED); } if (isEofSignalled()) { diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java index 8bb4d99..f888d74 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java @@ -69,9 +69,7 @@ import org.apache.sshd.core.CoreModuleProperties; * * @author <a href="mailto:d...@mina.apache.org">Apache MINA SSHD Project</a> */ -public abstract class AbstractChannel - extends AbstractInnerCloseable - implements Channel, ExecutorServiceCarrier { +public abstract class AbstractChannel extends AbstractInnerCloseable implements Channel, ExecutorServiceCarrier { /** * Default growth factor function used to resize response buffers @@ -89,6 +87,8 @@ public abstract class AbstractChannel protected final AtomicBoolean initialized = new AtomicBoolean(false); protected final AtomicBoolean eofReceived = new AtomicBoolean(false); protected final AtomicBoolean eofSent = new AtomicBoolean(false); + protected final AtomicBoolean unregisterSignaled = new AtomicBoolean(false); + protected final AtomicBoolean closeSignaled = new AtomicBoolean(false); protected AtomicReference<GracefulState> gracefulState = new AtomicReference<>(GracefulState.Opened); protected final DefaultCloseFuture gracefulFuture; /** @@ -274,17 +274,17 @@ public abstract class AbstractChannel try { result = handler.process(this, req, wantReply, buffer); } catch (Throwable e) { - debug("handleRequest({}) {} while {}#process({})[want-reply={}]: {}", - this, e.getClass().getSimpleName(), handler.getClass().getSimpleName(), - req, wantReply, e.getMessage(), e); + debug("handleRequest({}) {} while {}#process({})[want-reply={}]: {}", this, + e.getClass().getSimpleName(), handler.getClass().getSimpleName(), req, wantReply, + e.getMessage(), e); result = RequestHandler.Result.ReplyFailure; } // if Unsupported then check the next handler in line if (RequestHandler.Result.Unsupported.equals(result)) { if (traceEnabled) { - log.trace("handleRequest({})[{}#process({})[want-reply={}]]: {}", - this, handler.getClass().getSimpleName(), req, wantReply, result); + log.trace("handleRequest({})[{}#process({})[want-reply={}]]: {}", this, + handler.getClass().getSimpleName(), req, wantReply, result); } } else { sendResponse(buffer, req, result, wantReply); @@ -305,11 +305,11 @@ public abstract class AbstractChannel * @throws IOException If failed to send the response (if needed) * @see #handleInternalRequest(String, boolean, Buffer) */ - protected void handleUnknownChannelRequest(String req, boolean wantReply, Buffer buffer) - throws IOException { + protected void handleUnknownChannelRequest(String req, boolean wantReply, Buffer buffer) throws IOException { RequestHandler.Result r = handleInternalRequest(req, wantReply, buffer); if ((r == null) || RequestHandler.Result.Unsupported.equals(r)) { - log.warn("handleUnknownChannelRequest({}) Unknown channel request: {}[want-reply={}]", this, req, wantReply); + log.warn("handleUnknownChannelRequest({}) Unknown channel request: {}[want-reply={}]", this, req, + wantReply); sendResponse(buffer, req, RequestHandler.Result.Unsupported, wantReply); } else { sendResponse(buffer, req, r, wantReply); @@ -335,8 +335,7 @@ public abstract class AbstractChannel return RequestHandler.Result.Unsupported; } - protected IoWriteFuture sendResponse( - Buffer buffer, String req, RequestHandler.Result result, boolean wantReply) + protected IoWriteFuture sendResponse(Buffer buffer, String req, RequestHandler.Result result, boolean wantReply) throws IOException { if (log.isDebugEnabled()) { log.debug("sendResponse({}) request={} result={}, want-reply={}", this, req, result, wantReply); @@ -379,6 +378,8 @@ public abstract class AbstractChannel signalChannelInitialized(l); return null; }); + + notifyStateChanged("init"); } catch (Throwable err) { Throwable e = GenericUtils.peelException(err); if (e instanceof IOException) { @@ -387,8 +388,8 @@ public abstract class AbstractChannel throw (RuntimeException) e; } else { throw new IOException( - "Failed (" + e.getClass().getSimpleName() + ") to notify channel " + this + " initialization: " - + e.getMessage(), + "Failed (" + e.getClass().getSimpleName() + ") to notify channel " + this + + " initialization: " + e.getMessage(), e); } } @@ -432,6 +433,21 @@ public abstract class AbstractChannel return initialized.get(); } + @Override + public void handleChannelRegistrationResult( + ConnectionService service, Session session, int channelId, + boolean registered) { + notifyStateChanged("registered=" + registered); + if (registered) { + return; + } + + RuntimeException reason = new IllegalStateException( + "Channel id=" + channelId + " not registered because session is being closed: " + this); + signalChannelClosed(reason); + throw reason; + } + protected void signalChannelOpenFailure(Throwable reason) { try { invokeChannelSignaller(l -> { @@ -440,8 +456,9 @@ public abstract class AbstractChannel }); } catch (Throwable err) { Throwable ignored = GenericUtils.peelException(err); - debug("signalChannelOpenFailure({}) failed ({}) to inform listener of open failure={}: {}", - this, ignored.getClass().getSimpleName(), reason.getClass().getSimpleName(), ignored.getMessage(), ignored); + debug("signalChannelOpenFailure({}) failed ({}) to inform listener of open failure={}: {}", this, + ignored.getClass().getSimpleName(), reason.getClass().getSimpleName(), ignored.getMessage(), + ignored); } } @@ -461,8 +478,8 @@ public abstract class AbstractChannel }); } catch (Throwable err) { Throwable e = GenericUtils.peelException(err); - debug("notifyStateChanged({})[{}] {} while signal channel state change: {}", - this, hint, e.getClass().getSimpleName(), e.getMessage(), e); + debug("notifyStateChanged({})[{}] {} while signal channel state change: {}", this, hint, + e.getClass().getSimpleName(), e.getMessage(), e); } finally { synchronized (futureLock) { futureLock.notifyAll(); @@ -528,29 +545,31 @@ public abstract class AbstractChannel log.debug("handleClose({}) SSH_MSG_CHANNEL_CLOSE", this); } - if (!isEofSent()) { - if (debugEnabled) { - log.debug("handleClose({}) prevent sending EOF", this); + try { + if (!isEofSent()) { + if (debugEnabled) { + log.debug("handleClose({}) prevent sending EOF", this); + } } - } - if (gracefulState.compareAndSet(GracefulState.Opened, GracefulState.CloseReceived)) { - close(false); - } else if (gracefulState.compareAndSet(GracefulState.CloseSent, GracefulState.Closed)) { - gracefulFuture.setClosed(); + if (gracefulState.compareAndSet(GracefulState.Opened, GracefulState.CloseReceived)) { + close(false); + } else if (gracefulState.compareAndSet(GracefulState.CloseSent, GracefulState.Closed)) { + gracefulFuture.setClosed(); + } + } finally { + notifyStateChanged("SSH_MSG_CHANNEL_CLOSE"); } } @Override protected Closeable getInnerCloseable() { - Closeable closer = builder() - .sequential(new GracefulChannelCloseable(), getExecutorService()) + Closeable closer = builder().sequential(new GracefulChannelCloseable(), getExecutorService()) .run(toString(), () -> { if (service != null) { service.unregisterChannel(AbstractChannel.this); } - }) - .build(); + }).build(); closer.addCloseFutureListener(future -> clearAttributes()); return closer; } @@ -675,15 +694,33 @@ public abstract class AbstractChannel IOException err = IoUtils.closeQuietly(getLocalWindow(), getRemoteWindow()); if (err != null) { - debug("Failed ({}) to pre-close window(s) of {}: {}", - err.getClass().getSimpleName(), this, err.getMessage(), err); + debug("Failed ({}) to pre-close window(s) of {}: {}", err.getClass().getSimpleName(), this, + err.getMessage(), err); } super.preClose(); } + @Override + public void handleChannelUnregistration(ConnectionService service) { + if (!unregisterSignaled.getAndSet(true)) { + if (log.isTraceEnabled()) { + log.trace("handleChannelUnregistration({}) via service={}", this, service); + } + } + + notifyStateChanged("unregistered"); + } + public void signalChannelClosed(Throwable reason) { + String event = (reason == null) ? "signalChannelClosed" : reason.getClass().getSimpleName(); try { + if (!closeSignaled.getAndSet(true)) { + if (log.isTraceEnabled()) { + log.trace("signalChannelClosed({})[{}]", this, event); + } + } + invokeChannelSignaller(l -> { signalChannelClosed(l, reason); return null; @@ -692,6 +729,8 @@ public abstract class AbstractChannel Throwable e = GenericUtils.peelException(err); debug("signalChannelClosed({}) {} while signal channel closed: {}", this, e.getClass().getSimpleName(), e.getMessage(), e); + } finally { + notifyStateChanged(event); } } @@ -708,9 +747,7 @@ public abstract class AbstractChannel FactoryManager manager = (session == null) ? null : session.getFactoryManager(); ChannelListener[] listeners = { (manager == null) ? null : manager.getChannelListenerProxy(), - (session == null) ? null : session.getChannelListenerProxy(), - getChannelListenerProxy() - }; + (session == null) ? null : session.getChannelListenerProxy(), getChannelListenerProxy() }; Throwable err = null; for (ChannelListener l : listeners) { @@ -783,8 +820,9 @@ public abstract class AbstractChannel log.debug("handleExtendedData({}) SSH_MSG_CHANNEL_EXTENDED_DATA len={}", this, len); } if (log.isTraceEnabled()) { - BufferUtils.dumpHex(getSimplifiedLogger(), BufferUtils.DEFAULT_HEXDUMP_LEVEL, "handleExtendedData(" + this + ")", - this, BufferUtils.DEFAULT_HEX_SEPARATOR, buffer.array(), buffer.rpos(), (int) len); + BufferUtils.dumpHex(getSimplifiedLogger(), BufferUtils.DEFAULT_HEXDUMP_LEVEL, + "handleExtendedData(" + this + ")", this, BufferUtils.DEFAULT_HEX_SEPARATOR, buffer.array(), + buffer.rpos(), (int) len); } if (isEofSignalled()) { // TODO consider throwing an exception @@ -793,7 +831,9 @@ public abstract class AbstractChannel doWriteExtendedData(buffer.array(), buffer.rpos(), len); } - protected long validateIncomingDataSize(int cmd, long len /* actually a uint32 */) { + protected long validateIncomingDataSize( + int cmd, + long len /* actually a uint32 */) { if (!BufferUtils.isValidUint32Value(len)) { throw new IllegalArgumentException( "Non UINT32 length (" + len + ") for command=" + SshConstants.getCommandMessageName(cmd)); @@ -802,23 +842,23 @@ public abstract class AbstractChannel /* * According to RFC 4254 section 5.1 * - * The 'maximum packet size' specifies the maximum size of an individual data packet that can be sent to the - * sender + * The 'maximum packet size' specifies the maximum size of an individual + * data packet that can be sent to the sender * - * The local window reflects our preference - i.e., how much our peer should send at most + * The local window reflects our preference - i.e., how much our peer + * should send at most */ Window wLocal = getLocalWindow(); long maxLocalSize = wLocal.getPacketSize(); /* - * The reason for the +4 is that there seems to be some confusion whether the max. packet size includes the - * length field or not + * The reason for the +4 is that there seems to be some confusion + * whether the max. packet size includes the length field or not */ if (len > (maxLocalSize + 4L)) { throw new IllegalStateException( - "Bad length (" + len + ") " - + " for cmd=" + SshConstants.getCommandMessageName(cmd) - + " - max. allowed=" + maxLocalSize); + "Bad length (" + len + ") " + " for cmd=" + + SshConstants.getCommandMessageName(cmd) + " - max. allowed=" + maxLocalSize); } return len; @@ -905,7 +945,8 @@ public abstract class AbstractChannel Buffer buffer = s.createBuffer(SshConstants.SSH_MSG_CHANNEL_EOF, Short.SIZE); buffer.putInt(getRecipient()); /* - * The default "writePacket" does not send packets if state is not open so we need to bypass it. + * The default "writePacket" does not send packets if state is not open + * so we need to bypass it. */ return s.writePacket(buffer); } @@ -946,9 +987,7 @@ public abstract class AbstractChannel @Override @SuppressWarnings("unchecked") public <T> T setAttribute(AttributeRepository.AttributeKey<T> key, T value) { - return (T) attributes.put( - Objects.requireNonNull(key, "No key"), - Objects.requireNonNull(value, "No value")); + return (T) attributes.put(Objects.requireNonNull(key, "No key"), Objects.requireNonNull(value, "No value")); } @Override @@ -979,6 +1018,7 @@ public abstract class AbstractChannel @Override public String toString() { - return getClass().getSimpleName() + "[id=" + getId() + ", recipient=" + getRecipient() + "]" + "-" + getSession(); + return getClass().getSimpleName() + "[id=" + getId() + ", recipient=" + getRecipient() + "]" + "-" + + getSession(); } } diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/Channel.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/Channel.java index dae4aa1..2bfd095 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/channel/Channel.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/Channel.java @@ -153,6 +153,24 @@ public interface Channel void init(ConnectionService service, Session session, int id) throws IOException; /** + * Invoked after being successfully registered by the connection service - should throw a {@link RuntimeException} + * if not registered + * + * @param service The {@link ConnectionService} through which the channel is registered + * @param session The {@link Session} associated with the channel + * @param id The locally assigned channel identifier + * @param registered Whether registration was successful or not + */ + void handleChannelRegistrationResult(ConnectionService service, Session session, int id, boolean registered); + + /** + * Called by the connection service to inform the channel that it has bee unregistered. + * + * @param service The {@link ConnectionService} through which the channel is unregistered + */ + void handleChannelUnregistration(ConnectionService service); + + /** * @return {@code true} if call to {@link #init(ConnectionService, Session, int)} was successfully completed */ boolean isInitialized(); diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractConnectionService.java b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractConnectionService.java index 7a9eb07..d1f3ac0 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractConnectionService.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractConnectionService.java @@ -43,7 +43,6 @@ import org.apache.sshd.client.future.OpenFuture; import org.apache.sshd.common.Closeable; import org.apache.sshd.common.FactoryManager; import org.apache.sshd.common.SshConstants; -import org.apache.sshd.common.channel.AbstractChannel; import org.apache.sshd.common.channel.Channel; import org.apache.sshd.common.channel.ChannelFactory; import org.apache.sshd.common.channel.RequestHandler; @@ -389,7 +388,7 @@ public abstract class AbstractConnectionService protected Closeable getInnerCloseable() { return builder() .sequential(forwarderHolder.get(), agentForwardHolder.get(), x11ForwardHolder.get()) - .parallel(toString(), channels.values()) + .parallel(toString(), getChannels()) .build(); } @@ -417,23 +416,12 @@ public abstract class AbstractConnectionService } } - if (!registered) { - handleChannelRegistrationFailure(channel, channelId); - } - if (log.isDebugEnabled()) { - log.debug("registerChannel({})[id={}] {}", this, channelId, channel); + log.debug("registerChannel({})[id={}, registered={}] {}", this, channelId, registered, channel); } - return channelId; - } - protected void handleChannelRegistrationFailure(Channel channel, int channelId) throws IOException { - RuntimeException reason = new IllegalStateException( - "Channel id=" + channelId + " not registered because session is being closed: " + this); - AbstractChannel notifier - = ValidateUtils.checkInstanceOf(channel, AbstractChannel.class, "Non abstract channel for id=%d", channelId); - notifier.signalChannelClosed(reason); - throw reason; + channel.handleChannelRegistrationResult(this, session, channelId, registered); + return channelId; } /** @@ -452,6 +440,10 @@ public abstract class AbstractConnectionService if (log.isDebugEnabled()) { log.debug("unregisterChannel({}) result={}", channel, result); } + + if (result != null) { + result.handleChannelUnregistration(this); + } } @Override