diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.Handshake.cs b/src/DotNetty.Handlers/Tls/TlsHandler.Handshake.cs index 666885c2..b0a98566 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.Handshake.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.Handshake.cs @@ -43,6 +43,7 @@ namespace DotNetty.Handlers.Tls partial class TlsHandler { private static readonly Action s_handshakeCompletionCallback = (t, s) => HandleHandshakeCompleted((Task)t, (TlsHandler)s); + private static readonly TaskCanceledException s_taskCanceledException = new TaskCanceledException(); public static readonly AttributeKey SslStreamAttrKey = AttributeKey.ValueOf("SSLSTREAM"); private bool EnsureAuthenticated(IChannelHandlerContext ctx) @@ -215,9 +216,9 @@ private static void HandleHandshakeCompleted(Task task, TlsHandler self) var cause = taskExc.Unwrap(); try { - if (self._handshakePromise.TrySetException(taskExc)) + if (task.IsFaulted ? self._handshakePromise.TrySetException(taskExc) : self._handshakePromise.TrySetCanceled()) { - TlsUtils.NotifyHandshakeFailure(capturedContext, cause, true); + TlsUtils.NotifyHandshakeFailure(capturedContext, task.IsFaulted ? cause : s_taskCanceledException, true); } } finally diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetCore.cs b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetCore.cs index 7bf2f59b..387ae186 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetCore.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetCore.cs @@ -177,10 +177,10 @@ private int ReadFromInput(Memory destination) // byte[] destination, int d } public override void Write(ReadOnlySpan buffer) - => _owner.FinishWrap(buffer, _owner._lastContextWritePromise); + => _owner.FinishWrap(buffer, _owner._lastContextWritePromise ?? _owner.CapturedContext.NewPromise()); public override void Write(byte[] buffer, int offset, int count) - => _owner.FinishWrap(buffer, offset, count, _owner._lastContextWritePromise); + => _owner.FinishWrap(buffer, offset, count, _owner._lastContextWritePromise ?? _owner.CapturedContext.NewPromise()); public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetFx.cs b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetFx.cs index c6c81e6f..18494090 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetFx.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetFx.cs @@ -145,7 +145,7 @@ private int ReadFromInput(byte[] destination, int destinationOffset, int destina return length; } - public override void Write(byte[] buffer, int offset, int count) => _owner.FinishWrap(buffer, offset, count, _owner._lastContextWritePromise); + public override void Write(byte[] buffer, int offset, int count) => _owner.FinishWrap(buffer, offset, count, _owner._lastContextWritePromise ?? _owner.CapturedContext.NewPromise()); public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => _owner.FinishWrapNonAppDataAsync(buffer, offset, count, _owner.CapturedContext.NewPromise()); diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetStandard20.cs b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetStandard20.cs index 2889bf3f..e95c0ef8 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetStandard20.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.MediationStream.NetStandard20.cs @@ -108,7 +108,7 @@ private int ReadFromInput(byte[] destination, int destinationOffset, int destina return length; } - public override void Write(byte[] buffer, int offset, int count) => _owner.FinishWrap(buffer, offset, count, _owner._lastContextWritePromise); + public override void Write(byte[] buffer, int offset, int count) => _owner.FinishWrap(buffer, offset, count, _owner._lastContextWritePromise ?? _owner.CapturedContext.NewPromise()); public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => _owner.FinishWrapNonAppDataAsync(buffer, offset, count, _owner.CapturedContext.NewPromise());