Created
December 23, 2019 02:10
-
-
Save TrueGeek/960641d12e42d477bf8f35e7365a46de to your computer and use it in GitHub Desktop.
Custom TlsHandler for DotNetty. This is needed when using Matrix vNext with Xamarin. From https://github.com/Azure/DotNetty/pull/374
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright (c) Microsoft. All rights reserved. | |
// Licensed under the MIT license. See LICENSE file in the project root for full license information. | |
namespace CustomMonoTlsHandler | |
{ | |
using System.Security.Authentication; | |
using System.Security.Cryptography.X509Certificates; | |
using DotNetty.Handlers.Tls; | |
public sealed class ServerTlsSettings : TlsSettings | |
{ | |
public ServerTlsSettings(X509Certificate certificate) | |
: this(certificate, false) | |
{ | |
} | |
public ServerTlsSettings(X509Certificate certificate, bool negotiateClientCertificate) | |
: this(certificate, negotiateClientCertificate, false) | |
{ | |
} | |
public ServerTlsSettings(X509Certificate certificate, bool negotiateClientCertificate, bool checkCertificateRevocation) | |
: this(certificate, negotiateClientCertificate, checkCertificateRevocation, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12) | |
{ | |
} | |
public ServerTlsSettings(X509Certificate certificate, bool negotiateClientCertificate, bool checkCertificateRevocation, SslProtocols enabledProtocols) | |
: base(enabledProtocols, checkCertificateRevocation) | |
{ | |
this.Certificate = certificate; | |
this.NegotiateClientCertificate = negotiateClientCertificate; | |
} | |
public X509Certificate Certificate { get; } | |
public bool NegotiateClientCertificate { get; } | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright (c) Microsoft. All rights reserved. | |
// Licensed under the MIT license. See LICENSE file in the project root for full license information. | |
// from: https://github.com/Azure/DotNetty/pull/374 | |
namespace CustomMonoTlsHandler | |
{ | |
using System; | |
using System.Collections.Generic; | |
using System.Diagnostics; | |
using System.Diagnostics.Contracts; | |
using System.IO; | |
using System.Net.Security; | |
using System.Runtime.ExceptionServices; | |
using System.Security.Cryptography.X509Certificates; | |
using System.Threading; | |
using System.Threading.Tasks; | |
using DotNetty.Buffers; | |
using DotNetty.Codecs; | |
using DotNetty.Common.Concurrency; | |
using DotNetty.Common.Utilities; | |
using DotNetty.Handlers.Tls; | |
using DotNetty.Transport.Channels; | |
public sealed class TlsHandler : ByteToMessageDecoder | |
{ | |
readonly TlsSettings settings; | |
const int FallbackReadBufferSize = 256; | |
const int UnencryptedWriteBatchSize = 14 * 1024; | |
static readonly Exception ChannelClosedException = new IOException("Channel is closed"); | |
static readonly Action<Task, object> HandshakeCompletionCallback = new Action<Task, object>(HandleHandshakeCompleted); | |
static readonly Action<Task<int>, object> UnwrapCompletedCallback = new Action<Task<int>, object>(UnwrapCompleted); | |
readonly SslStream sslStream; | |
readonly MediationStream mediationStream; | |
readonly TaskCompletionSource closeFuture; | |
TlsHandlerState state; | |
int packetLength; | |
volatile IChannelHandlerContext capturedContext; | |
BatchingPendingWriteQueue pendingUnencryptedWrites; | |
Task lastContextWriteTask; | |
bool firedChannelRead; | |
volatile FlushMode flushMode = FlushMode.ForceFlush; | |
IByteBuffer pendingSslStreamReadBuffer; | |
int pendingSslStreamReadLength; | |
Task<int> pendingSslStreamReadFuture; | |
public TlsHandler(TlsSettings settings) | |
: this(stream => new SslStream(stream, true), settings) | |
{ | |
} | |
public TlsHandler(Func<Stream, SslStream> sslStreamFactory, TlsSettings settings) | |
{ | |
Contract.Requires(sslStreamFactory != null); | |
Contract.Requires(settings != null); | |
this.settings = settings; | |
this.closeFuture = new TaskCompletionSource(); | |
this.mediationStream = new MediationStream(this); | |
this.sslStream = sslStreamFactory(this.mediationStream); | |
} | |
public static TlsHandler Client(string targetHost) => new TlsHandler(new ClientTlsSettings(targetHost)); | |
public static TlsHandler Client(string targetHost, X509Certificate clientCertificate) => new TlsHandler(new ClientTlsSettings(targetHost, new List<X509Certificate>{ clientCertificate })); | |
public static TlsHandler Server(X509Certificate certificate) => new TlsHandler(new ServerTlsSettings(certificate)); | |
// using workaround mentioned here: https://github.com/dotnet/corefx/issues/4510 | |
public X509Certificate2 LocalCertificate => this.sslStream.LocalCertificate as X509Certificate2 ?? new X509Certificate2(this.sslStream.LocalCertificate?.Export(X509ContentType.Cert)); | |
public X509Certificate2 RemoteCertificate => this.sslStream.RemoteCertificate as X509Certificate2 ?? new X509Certificate2(this.sslStream.RemoteCertificate?.Export(X509ContentType.Cert)); | |
bool IsServer => this.settings is ServerTlsSettings; | |
public void Dispose() => this.sslStream?.Dispose(); | |
public override void ChannelActive(IChannelHandlerContext context) | |
{ | |
base.ChannelActive(context); | |
if (!this.IsServer) | |
{ | |
this.EnsureAuthenticated(); | |
} | |
} | |
public override void ChannelInactive(IChannelHandlerContext context) | |
{ | |
// Make sure to release SslStream, | |
// and notify the handshake future if the connection has been closed during handshake. | |
this.HandleFailure(ChannelClosedException); | |
base.ChannelInactive(context); | |
} | |
public override void ExceptionCaught(IChannelHandlerContext context, Exception exception) | |
{ | |
if (this.IgnoreException(exception)) | |
{ | |
// Close the connection explicitly just in case the transport | |
// did not close the connection automatically. | |
if (context.Channel.Active) | |
{ | |
context.CloseAsync(); | |
} | |
} | |
else | |
{ | |
base.ExceptionCaught(context, exception); | |
} | |
} | |
bool IgnoreException(Exception t) | |
{ | |
if (t is ObjectDisposedException && this.closeFuture.Task.IsCompleted) | |
{ | |
return true; | |
} | |
return false; | |
} | |
static void HandleHandshakeCompleted(Task task, object state) | |
{ | |
var self = (TlsHandler)state; | |
switch (task.Status) | |
{ | |
case TaskStatus.RanToCompletion: | |
{ | |
TlsHandlerState oldState = self.state; | |
Contract.Assert(!oldState.HasAny(TlsHandlerState.AuthenticationCompleted)); | |
self.state = (oldState | TlsHandlerState.Authenticated) & ~(TlsHandlerState.Authenticating | TlsHandlerState.FlushedBeforeHandshake); | |
self.capturedContext.FireUserEventTriggered(TlsHandshakeCompletionEvent.Success); | |
if (oldState.Has(TlsHandlerState.ReadRequestedBeforeAuthenticated) && !self.capturedContext.Channel.Configuration.AutoRead) | |
{ | |
self.capturedContext.Read(); | |
} | |
if (oldState.Has(TlsHandlerState.FlushedBeforeHandshake)) | |
{ | |
self.WrapAndFlush(self.capturedContext); | |
} | |
break; | |
} | |
case TaskStatus.Canceled: | |
case TaskStatus.Faulted: | |
{ | |
// ReSharper disable once AssignNullToNotNullAttribute -- task.Exception will be present as task is faulted | |
TlsHandlerState oldState = self.state; | |
Contract.Assert(!oldState.HasAny(TlsHandlerState.Authenticated)); | |
self.HandleFailure(task.Exception); | |
break; | |
} | |
default: | |
throw new ArgumentOutOfRangeException(nameof(task), "Unexpected task status: " + task.Status); | |
} | |
} | |
public override void HandlerAdded(IChannelHandlerContext context) | |
{ | |
base.HandlerAdded(context); | |
this.capturedContext = context; | |
this.pendingUnencryptedWrites = new BatchingPendingWriteQueue(context, UnencryptedWriteBatchSize); | |
if (context.Channel.Active && !this.IsServer) | |
{ | |
// todo: support delayed initialization on an existing/active channel if in client mode | |
this.EnsureAuthenticated(); | |
} | |
} | |
protected override void HandlerRemovedInternal(IChannelHandlerContext context) | |
{ | |
if (!this.pendingUnencryptedWrites.IsEmpty) | |
{ | |
// Check if queue is not empty first because create a new ChannelException is expensive | |
this.pendingUnencryptedWrites.RemoveAndFailAll(new ChannelException("Write has failed due to TlsHandler being removed from channel pipeline.")); | |
} | |
} | |
protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List<object> output) | |
{ | |
int startOffset = input.ReaderIndex; | |
int endOffset = input.WriterIndex; | |
int offset = startOffset; | |
int totalLength = 0; | |
List<int> packetLengths; | |
// if we calculated the length of the current SSL record before, use that information. | |
if (this.packetLength > 0) | |
{ | |
if (endOffset - startOffset < this.packetLength) | |
{ | |
// input does not contain a single complete SSL record | |
return; | |
} | |
else | |
{ | |
packetLengths = new List<int>(4); | |
packetLengths.Add(this.packetLength); | |
offset += this.packetLength; | |
totalLength = this.packetLength; | |
this.packetLength = 0; | |
} | |
} | |
else | |
{ | |
packetLengths = new List<int>(4); | |
} | |
bool nonSslRecord = false; | |
while (totalLength < TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH) | |
{ | |
int readableBytes = endOffset - offset; | |
if (readableBytes < TlsUtils.SSL_RECORD_HEADER_LENGTH) | |
{ | |
break; | |
} | |
int encryptedPacketLength = TlsUtils.GetEncryptedPacketLength(input, offset); | |
if (encryptedPacketLength == -1) | |
{ | |
nonSslRecord = true; | |
break; | |
} | |
Contract.Assert(encryptedPacketLength > 0); | |
if (encryptedPacketLength > readableBytes) | |
{ | |
// wait until the whole packet can be read | |
this.packetLength = encryptedPacketLength; | |
break; | |
} | |
int newTotalLength = totalLength + encryptedPacketLength; | |
if (newTotalLength > TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH) | |
{ | |
// Don't read too much. | |
break; | |
} | |
// 1. call unwrap with packet boundaries - call SslStream.ReadAsync only once. | |
// 2. once we're through all the whole packets, switch to reading out using fallback sized buffer | |
// We have a whole packet. | |
// Increment the offset to handle the next packet. | |
packetLengths.Add(encryptedPacketLength); | |
offset += encryptedPacketLength; | |
totalLength = newTotalLength; | |
} | |
if (totalLength > 0) | |
{ | |
// The buffer contains one or more full SSL records. | |
// Slice out the whole packet so unwrap will only be called with complete packets. | |
// Also directly reset the packetLength. This is needed as unwrap(..) may trigger | |
// decode(...) again via: | |
// 1) unwrap(..) is called | |
// 2) wrap(...) is called from within unwrap(...) | |
// 3) wrap(...) calls unwrapLater(...) | |
// 4) unwrapLater(...) calls decode(...) | |
// | |
// See https://github.com/netty/netty/issues/1534 | |
input.SkipBytes(totalLength); | |
this.Unwrap(context, input, startOffset, totalLength, packetLengths, output); | |
if (!this.firedChannelRead) | |
{ | |
// Check first if firedChannelRead is not set yet as it may have been set in a | |
// previous decode(...) call. | |
this.firedChannelRead = output.Count > 0; | |
} | |
} | |
if (nonSslRecord) | |
{ | |
// Not an SSL/TLS packet | |
var ex = new NotSslRecordException( | |
"not an SSL/TLS record: " + ByteBufferUtil.HexDump(input)); | |
input.SkipBytes(input.ReadableBytes); | |
context.FireExceptionCaught(ex); | |
this.HandleFailure(ex); | |
} | |
} | |
public override void ChannelReadComplete(IChannelHandlerContext ctx) | |
{ | |
// Discard bytes of the cumulation buffer if needed. | |
this.DiscardSomeReadBytes(); | |
this.ReadIfNeeded(ctx); | |
this.firedChannelRead = false; | |
ctx.FireChannelReadComplete(); | |
} | |
void ReadIfNeeded(IChannelHandlerContext ctx) | |
{ | |
// if handshake is not finished yet, we need more data | |
if (!ctx.Channel.Configuration.AutoRead && (!this.firedChannelRead || !this.state.HasAny(TlsHandlerState.AuthenticationCompleted))) | |
{ | |
// No auto-read used and no message was passed through the ChannelPipeline or the handshake was not completed | |
// yet, which means we need to trigger the read to ensure we will not stall | |
ctx.Read(); | |
} | |
} | |
/// <summary>Unwraps inbound SSL records.</summary> | |
void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int length, List<int> packetLengths, List<object> output) | |
{ | |
Contract.Requires(packetLengths.Count > 0); | |
//bool notifyClosure = false; // todo: netty/issues/137 | |
bool pending = false; | |
IByteBuffer outputBuffer = null; | |
try | |
{ | |
ArraySegment<byte> inputIoBuffer = packet.GetIoBuffer(offset, length); | |
this.mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset, ctx.Allocator); | |
int packetIndex = 0; | |
while (!this.EnsureAuthenticated()) | |
{ | |
this.mediationStream.ExpandSource(packetLengths[packetIndex]); | |
if (++packetIndex == packetLengths.Count) | |
{ | |
return; | |
} | |
} | |
Task<int> currentReadFuture = this.pendingSslStreamReadFuture; | |
int outputBufferLength; | |
if (currentReadFuture != null) | |
{ | |
// restoring context from previous read | |
Contract.Assert(this.pendingSslStreamReadBuffer != null); | |
outputBuffer = this.pendingSslStreamReadBuffer; | |
outputBufferLength = this.pendingSslStreamReadLength; | |
this.pendingSslStreamReadFuture = null; | |
this.pendingSslStreamReadBuffer = null; | |
this.pendingSslStreamReadLength = 0; | |
} | |
else | |
{ | |
outputBufferLength = 0; | |
} | |
// go through packets one by one (because SslStream does not consume more than 1 packet at a time) | |
for (; packetIndex < packetLengths.Count; packetIndex++) | |
{ | |
int currentPacketLength = packetLengths[packetIndex]; | |
this.mediationStream.ExpandSource(currentPacketLength); | |
while (true) | |
{ | |
int totalRead = 0; | |
if (currentReadFuture != null) | |
{ | |
// there was a read pending already, so we make sure we completed that first | |
if (!currentReadFuture.IsCompleted) | |
{ | |
// we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input | |
break; | |
} | |
int read = currentReadFuture.Result; | |
totalRead += read; | |
if (read == 0) | |
{ | |
//Stream closed | |
return; | |
} | |
// Now output the result of previous read and decide whether to do an extra read on the same source or move forward | |
AddBufferToOutput(outputBuffer, read, output); | |
currentReadFuture = null; | |
outputBuffer = null; | |
if (this.mediationStream.TotalReadableBytes == 0) | |
{ | |
// we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there | |
if (read < outputBufferLength) | |
{ | |
// SslStream returned non-full buffer and there's no more input to go through -> | |
// typically it means SslStream is done reading current frame so we skip | |
break; | |
} | |
// we've read out `read` bytes out of current packet to fulfil previously outstanding read | |
outputBufferLength = currentPacketLength - totalRead; | |
if (outputBufferLength <= 0) | |
{ | |
// after feeding to SslStream current frame it read out more bytes than current packet size | |
outputBufferLength = FallbackReadBufferSize; | |
} | |
} | |
else | |
{ | |
// SslStream did not get to reading current frame so it completed previous read sync | |
// and the next read will likely read out the new frame | |
outputBufferLength = currentPacketLength; | |
} | |
} | |
else | |
{ | |
// there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient | |
outputBufferLength = currentPacketLength; | |
} | |
outputBuffer = ctx.Allocator.Buffer(outputBufferLength); | |
currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, outputBufferLength); | |
} | |
} | |
if (currentReadFuture != null) | |
{ | |
pending = true; | |
this.pendingSslStreamReadBuffer = outputBuffer; | |
this.pendingSslStreamReadFuture = currentReadFuture; | |
this.pendingSslStreamReadLength = outputBufferLength; | |
//Can't use ExecuteSynchronously here for it may change the order of output if task is already completed here. | |
currentReadFuture.ContinueWith(UnwrapCompletedCallback, this, TaskContinuationOptions.None); | |
} | |
} | |
catch (Exception ex) | |
{ | |
this.HandleFailure(ex); | |
throw; | |
} | |
finally | |
{ | |
this.mediationStream.ResetSource(ctx.Allocator); | |
if (!pending && outputBuffer != null) | |
{ | |
if (outputBuffer.IsReadable()) | |
{ | |
output.Add(outputBuffer); | |
} | |
else | |
{ | |
outputBuffer.SafeRelease(); | |
} | |
} | |
} | |
} | |
static void UnwrapCompleted(Task<int> task, object state) | |
{ | |
// Mono(with legacy provider) finish ReadAsync in async, | |
// so extra check is needed to receive data in async | |
var self = (TlsHandler)state; | |
Debug.Assert(self.capturedContext.Executor.InEventLoop); | |
//Ignore task completed in Unwrap | |
if (task == self.pendingSslStreamReadFuture) | |
{ | |
IByteBuffer buf = self.pendingSslStreamReadBuffer; | |
int outputBufferLength = self.pendingSslStreamReadLength; | |
self.pendingSslStreamReadFuture = null; | |
self.pendingSslStreamReadBuffer = null; | |
self.pendingSslStreamReadLength = 0; | |
while (true) | |
{ | |
switch (task.Status) | |
{ | |
case TaskStatus.RanToCompletion: | |
{ | |
var read = task.Result; | |
//Stream Closed | |
if (read == 0) | |
return; | |
self.capturedContext.FireChannelRead(buf.SetWriterIndex(buf.WriterIndex + read)); | |
if (self.mediationStream.TotalReadableBytes == 0) | |
{ | |
self.capturedContext.FireChannelReadComplete(); | |
self.mediationStream.ResetSource(self.capturedContext.Allocator); | |
if (read < outputBufferLength) | |
{ | |
// SslStream returned non-full buffer and there's no more input to go through -> | |
// typically it means SslStream is done reading current frame so we skip | |
return; | |
} | |
} | |
outputBufferLength = self.mediationStream.TotalReadableBytes; | |
if (outputBufferLength <= 0) | |
outputBufferLength = FallbackReadBufferSize; | |
buf = self.capturedContext.Allocator.Buffer(outputBufferLength); | |
task = self.ReadFromSslStreamAsync(buf, outputBufferLength); | |
if (task.IsCompleted) | |
{ | |
continue; | |
} | |
self.pendingSslStreamReadFuture = task; | |
self.pendingSslStreamReadBuffer = buf; | |
self.pendingSslStreamReadLength = outputBufferLength; | |
task.ContinueWith(UnwrapCompletedCallback, self, TaskContinuationOptions.ExecuteSynchronously); | |
return; | |
} | |
case TaskStatus.Canceled: | |
case TaskStatus.Faulted: | |
{ | |
buf.SafeRelease(); | |
self.HandleFailure(task.Exception); | |
return; | |
} | |
default: | |
{ | |
buf.SafeRelease(); | |
throw new ArgumentOutOfRangeException(nameof(task), "Unexpected task status: " + task.Status); | |
} | |
} | |
} | |
} | |
} | |
static void AddBufferToOutput(IByteBuffer outputBuffer, int length, List<object> output) | |
{ | |
Contract.Assert(length > 0); | |
output.Add(outputBuffer.SetWriterIndex(outputBuffer.WriterIndex + length)); | |
} | |
Task<int> ReadFromSslStreamAsync(IByteBuffer outputBuffer, int outputBufferLength) | |
{ | |
ArraySegment<byte> outlet = outputBuffer.GetIoBuffer(outputBuffer.WriterIndex, outputBufferLength); | |
return this.sslStream.ReadAsync(outlet.Array, outlet.Offset, outlet.Count); | |
} | |
public override void Read(IChannelHandlerContext context) | |
{ | |
TlsHandlerState oldState = this.state; | |
if (!oldState.HasAny(TlsHandlerState.AuthenticationCompleted)) | |
{ | |
this.state = oldState | TlsHandlerState.ReadRequestedBeforeAuthenticated; | |
} | |
context.Read(); | |
} | |
bool EnsureAuthenticated() | |
{ | |
TlsHandlerState oldState = this.state; | |
if (!oldState.HasAny(TlsHandlerState.AuthenticationStarted)) | |
{ | |
this.state = oldState | TlsHandlerState.Authenticating; | |
if (this.IsServer) | |
{ | |
var serverSettings = (ServerTlsSettings)this.settings; | |
this.sslStream.AuthenticateAsServerAsync(serverSettings.Certificate, serverSettings.NegotiateClientCertificate, serverSettings.EnabledProtocols, serverSettings.CheckCertificateRevocation) | |
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously); | |
} | |
else | |
{ | |
var clientSettings = (ClientTlsSettings)this.settings; | |
this.sslStream.AuthenticateAsClientAsync(clientSettings.TargetHost, null, clientSettings.EnabledProtocols, clientSettings.CheckCertificateRevocation) | |
.ContinueWith(HandshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously); | |
} | |
return false; | |
} | |
return oldState.Has(TlsHandlerState.Authenticated); | |
} | |
public override Task WriteAsync(IChannelHandlerContext context, object message) | |
{ | |
if (!(message is IByteBuffer)) | |
{ | |
return TaskEx.FromException(new UnsupportedMessageTypeException(message, typeof(IByteBuffer))); | |
} | |
return this.pendingUnencryptedWrites.Add(message); | |
} | |
public override void Flush(IChannelHandlerContext context) | |
{ | |
if (this.pendingUnencryptedWrites.IsEmpty) | |
{ | |
this.pendingUnencryptedWrites.Add(Unpooled.Empty); | |
} | |
if (!this.EnsureAuthenticated()) | |
{ | |
this.state |= TlsHandlerState.FlushedBeforeHandshake; | |
return; | |
} | |
this.WrapAndFlush(context); | |
} | |
void WrapAndFlush(IChannelHandlerContext context) | |
{ | |
this.flushMode = FlushMode.NoFlush; | |
try | |
{ | |
this.Wrap(context); | |
} | |
finally | |
{ | |
// We may have written some parts of data before an exception was thrown so ensure we always flush. | |
if (this.flushMode == FlushMode.NoFlush) | |
{ | |
this.flushMode = FlushMode.ForceFlush; | |
context.Flush(); | |
} | |
else | |
{ | |
context.Executor.Execute((state) => { | |
var self = (TlsHandler)state; | |
self.flushMode = FlushMode.ForceFlush; | |
self.capturedContext.Flush(); | |
}, this); | |
} | |
} | |
} | |
void Wrap(IChannelHandlerContext context) | |
{ | |
Contract.Assert(context == this.capturedContext); | |
IByteBuffer buf = null; | |
try | |
{ | |
while (true) | |
{ | |
List<object> messages = this.pendingUnencryptedWrites.Current; | |
if (messages == null || messages.Count == 0) | |
{ | |
break; | |
} | |
if (messages.Count == 1) | |
{ | |
buf = (IByteBuffer)messages[0]; | |
} | |
else | |
{ | |
buf = context.Allocator.Buffer((int)this.pendingUnencryptedWrites.CurrentSize); | |
foreach (IByteBuffer buffer in messages) | |
{ | |
buffer.ReadBytes(buf, buffer.ReadableBytes); | |
buffer.Release(); | |
} | |
} | |
buf.ReadBytes(this.sslStream, buf.ReadableBytes); // this leads to FinishWrap being called 0+ times | |
buf.Release(); | |
TaskCompletionSource promise = this.pendingUnencryptedWrites.Remove(); | |
Task task = this.lastContextWriteTask; | |
if (task != null) | |
{ | |
task.LinkOutcome(promise); | |
this.lastContextWriteTask = null; | |
} | |
else | |
{ | |
promise.TryComplete(); | |
} | |
} | |
} | |
catch (Exception ex) | |
{ | |
buf.SafeRelease(); | |
this.HandleFailure(ex); | |
throw; | |
} | |
} | |
void FinishWrap(byte[] buffer, int offset, int count) | |
{ | |
// In Mono(with btls provider) on linux, and maybe also for apple provider, Write is called in another thread, | |
// so it will run after the call to Flush. | |
if (this.flushMode == FlushMode.NoFlush && !this.capturedContext.Executor.InEventLoop) | |
{ | |
this.flushMode = FlushMode.PendingFlush; | |
} | |
IByteBuffer output; | |
if (count == 0) | |
{ | |
output = Unpooled.Empty; | |
} | |
else | |
{ | |
output = this.capturedContext.Allocator.Buffer(count); | |
output.WriteBytes(buffer, offset, count); | |
} | |
this.lastContextWriteTask = (this.flushMode == FlushMode.ForceFlush) ? this.capturedContext.WriteAndFlushAsync(output) : this.capturedContext.WriteAsync(output); | |
} | |
Task FinishWrapNonAppDataAsync(byte[] buffer, int offset, int count) | |
{ | |
var future = this.capturedContext.WriteAndFlushAsync(Unpooled.WrappedBuffer(buffer, offset, count)); | |
this.ReadIfNeeded(this.capturedContext); | |
return future; | |
} | |
public override Task CloseAsync(IChannelHandlerContext context) | |
{ | |
this.closeFuture.TryComplete(); | |
this.sslStream.Dispose(); | |
return base.CloseAsync(context); | |
} | |
void HandleFailure(Exception cause) | |
{ | |
// Release all resources such as internal buffers that SSLEngine | |
// is managing. | |
this.mediationStream.Dispose(); | |
try | |
{ | |
this.sslStream.Dispose(); | |
} | |
catch (Exception) | |
{ | |
// todo: evaluate following: | |
// only log in Debug mode as it most likely harmless and latest chrome still trigger | |
// this all the time. | |
// | |
// See https://github.com/netty/netty/issues/1340 | |
//string msg = ex.Message; | |
//if (msg == null || !msg.contains("possible truncation attack")) | |
//{ | |
// //Logger.Debug("{} SSLEngine.closeInbound() raised an exception.", ctx.channel(), e); | |
//} | |
} | |
this.pendingSslStreamReadBuffer?.SafeRelease(); | |
this.pendingSslStreamReadBuffer = null; | |
this.pendingSslStreamReadFuture = null; | |
this.NotifyHandshakeFailure(cause); | |
this.pendingUnencryptedWrites.RemoveAndFailAll(cause); | |
} | |
void NotifyHandshakeFailure(Exception cause) | |
{ | |
if (!this.state.HasAny(TlsHandlerState.AuthenticationCompleted)) | |
{ | |
// handshake was not completed yet => TlsHandler react to failure by closing the channel | |
this.state = (this.state | TlsHandlerState.FailedAuthentication) & ~TlsHandlerState.Authenticating; | |
this.capturedContext.FireUserEventTriggered(new TlsHandshakeCompletionEvent(cause)); | |
this.CloseAsync(this.capturedContext); | |
} | |
} | |
enum FlushMode : byte | |
{ | |
/// <summary> | |
/// Do nothing with Flush. | |
/// </summary> | |
NoFlush = 0, | |
/// <summary> | |
/// An Flush is or will be posted to IEventExecutor. | |
/// </summary> | |
PendingFlush = 1, | |
/// <summary> | |
/// Force FinishWrap to call Flush. | |
/// </summary> | |
ForceFlush = 2, | |
} | |
sealed class MediationStream : Stream | |
{ | |
readonly TlsHandler owner; | |
object sourceLock = new object(); | |
IByteBuffer ownBuffer; | |
byte[] input; | |
int inputStartOffset; | |
int inputOffset; | |
int inputLength; | |
TaskCompletionSource<int> readCompletionSource; | |
ArraySegment<byte> sslOwnedBuffer; | |
#if NETSTANDARD1_3 | |
int readByteCount; | |
#else | |
SynchronousAsyncResult<int> syncReadResult; | |
AsyncCallback readCallback; | |
TaskCompletionSource writeCompletion; | |
AsyncCallback writeCallback; | |
#endif | |
public MediationStream(TlsHandler owner) | |
{ | |
this.owner = owner; | |
} | |
public int TotalReadableBytes => (this.ownBuffer?.ReadableBytes ?? 0) + SourceReadableBytes; | |
public int SourceReadableBytes => this.inputLength - this.inputOffset; | |
public void SetSource(byte[] source, int offset, IByteBufferAllocator alloc) | |
{ | |
lock (sourceLock) | |
{ | |
ResetSource(alloc); | |
this.input = source; | |
this.inputStartOffset = offset; | |
this.inputOffset = 0; | |
this.inputLength = 0; | |
} | |
} | |
public void ResetSource(IByteBufferAllocator alloc) | |
{ | |
//Mono will run BeginRead in async and it's running with ResetSource at the same time | |
lock (sourceLock) | |
{ | |
int leftLen = this.SourceReadableBytes; | |
IByteBuffer buf = this.ownBuffer; | |
if (leftLen > 0) | |
{ | |
if (buf != null) | |
{ | |
buf.DiscardSomeReadBytes(); | |
} | |
else | |
{ | |
buf = alloc.Buffer(leftLen); | |
this.ownBuffer = buf; | |
} | |
buf.WriteBytes(this.input, this.inputStartOffset + this.inputOffset, leftLen); | |
} | |
else if (buf != null) | |
{ | |
if (!buf.IsReadable()) | |
{ | |
buf.SafeRelease(); | |
this.ownBuffer = null; | |
} | |
else | |
{ | |
buf.DiscardSomeReadBytes(); | |
} | |
} | |
this.input = null; | |
this.inputStartOffset = 0; | |
this.inputOffset = 0; | |
this.inputLength = 0; | |
} | |
} | |
public void ExpandSource(int count) | |
{ | |
Contract.Assert(this.input != null); | |
lock (sourceLock) | |
{ | |
this.inputLength += count; | |
ArraySegment<byte> sslBuffer = this.sslOwnedBuffer; | |
if (sslBuffer.Array == null) | |
{ | |
// there is no pending read operation - keep for future | |
return; | |
} | |
this.sslOwnedBuffer = default(ArraySegment<byte>); | |
#if NETSTANDARD1_3 | |
this.readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); | |
// hack: this tricks SslStream's continuation to run synchronously instead of dispatching to TP. Remove once Begin/EndRead are available. | |
new Task( | |
ms => | |
{ | |
var self = (MediationStream)ms; | |
TaskCompletionSource<int> p = self.readCompletionSource; | |
self.readCompletionSource = null; | |
p.TrySetResult(self.readByteCount); | |
}, | |
this) | |
.RunSynchronously(TaskScheduler.Default); | |
#else | |
int read = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); | |
TaskCompletionSource<int> promise = this.readCompletionSource; | |
this.readCompletionSource = null; | |
promise.TrySetResult(read); | |
AsyncCallback callback = this.readCallback; | |
this.readCallback = null; | |
callback?.Invoke(promise.Task); | |
#endif | |
} | |
} | |
public override int Read(byte[] buffer, int offset, int count) => this.ReadAsync(buffer, offset, count).Result; | |
#if NETSTANDARD1_3 | |
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |
{ | |
if (this.TotalReadableBytes > 0) | |
{ | |
// we have the bytes available upfront - write out synchronously | |
int read = this.ReadFromInput(buffer, offset, count); | |
return Task.FromResult(read); | |
} | |
Contract.Assert(this.sslOwnedBuffer.Array == null); | |
// take note of buffer - we will pass bytes there once available | |
this.sslOwnedBuffer = new ArraySegment<byte>(buffer, offset, count); | |
this.readCompletionSource = new TaskCompletionSource<int>(); | |
return this.readCompletionSource.Task; | |
} | |
#else | |
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) | |
{ | |
if (this.TotalReadableBytes > 0) | |
{ | |
// we have the bytes available upfront - write out synchronously | |
int read = this.ReadFromInput(buffer, offset, count); | |
var res = this.PrepareSyncReadResult(read, state); | |
callback?.Invoke(res); | |
return res; | |
} | |
Contract.Assert(this.sslOwnedBuffer.Array == null); | |
// take note of buffer - we will pass bytes there once available | |
this.sslOwnedBuffer = new ArraySegment<byte>(buffer, offset, count); | |
this.readCompletionSource = new TaskCompletionSource<int>(state); | |
this.readCallback = callback; | |
return this.readCompletionSource.Task; | |
} | |
public override int EndRead(IAsyncResult asyncResult) | |
{ | |
SynchronousAsyncResult<int> syncResult = this.syncReadResult; | |
if (ReferenceEquals(asyncResult, syncResult)) | |
{ | |
return syncResult.Result; | |
} | |
Debug.Assert(this.readCompletionSource == null || this.readCompletionSource.Task == asyncResult); | |
Contract.Assert(!((Task<int>)asyncResult).IsCanceled); | |
try | |
{ | |
return ((Task<int>)asyncResult).Result; | |
} | |
catch (AggregateException ex) | |
{ | |
ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); | |
throw; // unreachable | |
} | |
} | |
IAsyncResult PrepareSyncReadResult(int readBytes, object state) | |
{ | |
// it is safe to reuse sync result object as it can't lead to leak (no way to attach to it via handle) | |
SynchronousAsyncResult<int> result = this.syncReadResult ?? (this.syncReadResult = new SynchronousAsyncResult<int>()); | |
result.Result = readBytes; | |
result.AsyncState = state; | |
return result; | |
} | |
#endif | |
public override void Write(byte[] buffer, int offset, int count) => this.owner.FinishWrap(buffer, offset, count); | |
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |
=> this.owner.FinishWrapNonAppDataAsync(buffer, offset, count); | |
#if !NETSTANDARD1_3 | |
static readonly Action<Task, object> WriteCompleteCallback = HandleChannelWriteComplete; | |
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) | |
{ | |
Task task = this.WriteAsync(buffer, offset, count); | |
switch (task.Status) | |
{ | |
case TaskStatus.RanToCompletion: | |
// write+flush completed synchronously (and successfully) | |
var result = new SynchronousAsyncResult<int>(); | |
result.AsyncState = state; | |
callback?.Invoke(result); | |
return result; | |
default: | |
this.writeCallback = callback; | |
Contract.Assert(this.writeCompletion == null); | |
var tcs = new TaskCompletionSource(state); | |
this.writeCompletion = tcs; | |
task.ContinueWith(WriteCompleteCallback, this, TaskContinuationOptions.ExecuteSynchronously); | |
return tcs.Task; | |
} | |
} | |
static void HandleChannelWriteComplete(Task writeTask, object state) | |
{ | |
var self = (MediationStream)state; | |
AsyncCallback callback = self.writeCallback; | |
self.writeCallback = null; | |
var promise = self.writeCompletion; | |
self.writeCompletion = null; | |
switch (writeTask.Status) | |
{ | |
case TaskStatus.RanToCompletion: | |
promise.TryComplete(); | |
break; | |
case TaskStatus.Canceled: | |
promise.TrySetCanceled(); | |
break; | |
case TaskStatus.Faulted: | |
promise.TrySetException(writeTask.Exception); | |
break; | |
default: | |
throw new ArgumentOutOfRangeException("Unexpected task status: " + writeTask.Status); | |
} | |
callback?.Invoke(promise.Task); | |
} | |
public override void EndWrite(IAsyncResult asyncResult) | |
{ | |
if (asyncResult is SynchronousAsyncResult<int>) | |
{ | |
return; | |
} | |
Debug.Assert(this.writeCompletion == null || this.writeCompletion.Task == asyncResult); | |
try | |
{ | |
((Task<int>)asyncResult).Wait(); | |
} | |
catch (AggregateException ex) | |
{ | |
ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); | |
throw; | |
} | |
} | |
#endif | |
int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapacity) | |
{ | |
Contract.Assert(destination != null); | |
lock (sourceLock) | |
{ | |
int length = 0; | |
do | |
{ | |
int readableBytes; | |
IByteBuffer buf = this.ownBuffer; | |
if (buf != null) | |
{ | |
readableBytes = buf.ReadableBytes; | |
if (readableBytes > 0) | |
{ | |
readableBytes = Math.Min(buf.ReadableBytes, destinationCapacity); | |
buf.ReadBytes(destination, destinationOffset, readableBytes); | |
length += readableBytes; | |
destinationCapacity -= readableBytes; | |
if (destinationCapacity == 0) | |
break; | |
} | |
} | |
byte[] source = this.input; | |
if (source != null) | |
{ | |
readableBytes = this.SourceReadableBytes; | |
if (readableBytes > 0) | |
{ | |
readableBytes = Math.Min(readableBytes, destinationCapacity); | |
Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, readableBytes); | |
length += readableBytes; | |
destinationCapacity -= readableBytes; | |
this.inputOffset += readableBytes; | |
} | |
} | |
} while (false); | |
return length; | |
} | |
} | |
public override void Flush() | |
{ | |
// NOOP: called on SslStream.Close | |
} | |
protected override void Dispose(bool disposing) | |
{ | |
base.Dispose(disposing); | |
if (disposing) | |
{ | |
TaskCompletionSource<int> p = this.readCompletionSource; | |
this.readCompletionSource = null; | |
p?.TrySetResult(0); | |
} | |
} | |
#region plumbing | |
public override long Seek(long offset, SeekOrigin origin) | |
{ | |
throw new NotSupportedException(); | |
} | |
public override void SetLength(long value) | |
{ | |
throw new NotSupportedException(); | |
} | |
public override bool CanRead => true; | |
public override bool CanSeek => false; | |
public override bool CanWrite => true; | |
public override long Length | |
{ | |
get { throw new NotSupportedException(); } | |
} | |
public override long Position | |
{ | |
get { throw new NotSupportedException(); } | |
set { throw new NotSupportedException(); } | |
} | |
#endregion | |
#region sync result | |
sealed class SynchronousAsyncResult<T> : IAsyncResult | |
{ | |
public T Result { get; set; } | |
public bool IsCompleted => true; | |
public WaitHandle AsyncWaitHandle | |
{ | |
get { throw new InvalidOperationException("Cannot wait on a synchronous result."); } | |
} | |
public object AsyncState { get; set; } | |
public bool CompletedSynchronously => true; | |
} | |
#endregion | |
} | |
} | |
[Flags] | |
enum TlsHandlerState | |
{ | |
Authenticating = 1, | |
Authenticated = 1 << 1, | |
FailedAuthentication = 1 << 2, | |
ReadRequestedBeforeAuthenticated = 1 << 3, | |
FlushedBeforeHandshake = 1 << 4, | |
AuthenticationStarted = Authenticating | Authenticated | FailedAuthentication, | |
AuthenticationCompleted = Authenticated | FailedAuthentication | |
} | |
static class TlsHandlerStateExtensions | |
{ | |
public static bool Has(this TlsHandlerState value, TlsHandlerState testValue) => (value & testValue) == testValue; | |
public static bool HasAny(this TlsHandlerState value, TlsHandlerState testValue) => (value & testValue) != 0; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright (c) Microsoft. All rights reserved. | |
// Licensed under the MIT license. See LICENSE file in the project root for full license information. | |
namespace CustomMonoTlsHandler | |
{ | |
using System; | |
using System.Diagnostics.Contracts; | |
public sealed class TlsHandshakeCompletionEvent | |
{ | |
public static readonly TlsHandshakeCompletionEvent Success = new TlsHandshakeCompletionEvent(); | |
readonly Exception exception; | |
/// <summary> | |
/// Creates a new event that indicates a successful handshake. | |
/// </summary> | |
TlsHandshakeCompletionEvent() | |
{ | |
this.exception = null; | |
} | |
/// <summary> | |
/// Creates a new event that indicates an unsuccessful handshake. | |
/// Use {@link #SUCCESS} to indicate a successful handshake. | |
/// </summary> | |
public TlsHandshakeCompletionEvent(Exception exception) | |
{ | |
Contract.Requires(exception != null); | |
this.exception = exception; | |
} | |
/// <summary> | |
/// Return {@code true} if the handshake was successful | |
/// </summary> | |
public bool IsSuccessful => this.exception == null; | |
/// <summary> | |
/// Return the {@link Throwable} if {@link #isSuccess()} returns {@code false} | |
/// and so the handshake failed. | |
/// </summary> | |
public Exception Exception => this.exception; | |
public override string ToString() | |
{ | |
Exception ex = this.Exception; | |
return ex == null ? "TlsHandshakeCompletionEvent(SUCCESS)" : $"TlsHandshakeCompletionEvent({ex})"; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright (c) Microsoft. All rights reserved. | |
// Licensed under the MIT license. See LICENSE file in the project root for full license information. | |
namespace CustomMonoTlsHandler | |
{ | |
using System; | |
using DotNetty.Buffers; | |
using DotNetty.Transport.Channels; | |
/// Utilities for TLS packets. | |
static class TlsUtils | |
{ | |
const int MAX_PLAINTEXT_LENGTH = 16 * 1024; // 2^14 | |
const int MAX_COMPRESSED_LENGTH = MAX_PLAINTEXT_LENGTH + 1024; | |
const int MAX_CIPHERTEXT_LENGTH = MAX_COMPRESSED_LENGTH + 1024; | |
// Header (5) + Data (2^14) + Compression (1024) + Encryption (1024) + MAC (20) + Padding (256) | |
public const int MAX_ENCRYPTED_PACKET_LENGTH = MAX_CIPHERTEXT_LENGTH + 5 + 20 + 256; | |
/// change cipher spec | |
public const int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20; | |
/// alert | |
public const int SSL_CONTENT_TYPE_ALERT = 21; | |
/// handshake | |
public const int SSL_CONTENT_TYPE_HANDSHAKE = 22; | |
/// application data | |
public const int SSL_CONTENT_TYPE_APPLICATION_DATA = 23; | |
/// the length of the ssl record header (in bytes) | |
public const int SSL_RECORD_HEADER_LENGTH = 5; | |
// Not enough data in buffer to parse the record length | |
public const int NOT_ENOUGH_DATA = -1; | |
// data is not encrypted | |
public const int NOT_ENCRYPTED = -2; | |
/// <summary> | |
/// Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase | |
/// the readerIndex of the given <see cref="IByteBuffer"/>. | |
/// </summary> | |
/// <param name="buffer"> | |
/// The <see cref="IByteBuffer"/> to read from. Be aware that it must have at least | |
/// <see cref="SSL_RECORD_HEADER_LENGTH"/> bytes to read, | |
/// otherwise it will throw an <see cref="ArgumentException"/>. | |
/// </param> | |
/// <param name="offset">Offset to record start.</param> | |
/// <returns> | |
/// The length of the encrypted packet that is included in the buffer. This will | |
/// return <c>-1</c> if the given <see cref="IByteBuffer"/> is not encrypted at all. | |
/// </returns> | |
public static int GetEncryptedPacketLength(IByteBuffer buffer, int offset) | |
{ | |
int packetLength = 0; | |
// SSLv3 or TLS - Check ContentType | |
switch (buffer.GetByte(offset)) | |
{ | |
case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: | |
case SSL_CONTENT_TYPE_ALERT: | |
case SSL_CONTENT_TYPE_HANDSHAKE: | |
case SSL_CONTENT_TYPE_APPLICATION_DATA: | |
break; | |
default: | |
// SSLv2 or bad data | |
return -1; | |
} | |
// SSLv3 or TLS - Check ProtocolVersion | |
int majorVersion = buffer.GetByte(offset + 1); | |
if (majorVersion == 3) | |
{ | |
// SSLv3 or TLS | |
packetLength = buffer.GetUnsignedShort(offset + 3) + SSL_RECORD_HEADER_LENGTH; | |
if (packetLength <= SSL_RECORD_HEADER_LENGTH) | |
{ | |
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) | |
return -1; | |
} | |
} | |
else | |
{ | |
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) | |
return -1; | |
} | |
return packetLength; | |
} | |
public static void NotifyHandshakeFailure(IChannelHandlerContext ctx, Exception cause) | |
{ | |
// We have may haven written some parts of data before an exception was thrown so ensure we always flush. | |
// See https://github.com/netty/netty/issues/3900#issuecomment-172481830 | |
ctx.Flush(); | |
ctx.FireUserEventTriggered(new TlsHandshakeCompletionEvent(cause)); | |
ctx.CloseAsync(); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment