Skip to content

Instantly share code, notes, and snippets.

Created May 5, 2022 17:12
Show Gist options
  • Save davidglassborow/86cd15bca3ac908667824e347e355b00 to your computer and use it in GitHub Desktop.
Save davidglassborow/86cd15bca3ac908667824e347e355b00 to your computer and use it in GitHub Desktop.
Generic TCP proxy from MS SMO project
using System;
using System.Data.SqlClient;
using System.Net.Sockets;
using System.Net;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.SqlServer.SmoSamples
/// <summary>
/// Provides an in-memory proxy with callbacks that allow tests to run code before transmission and after receipt of
/// data on the wire
/// </summary>
class GenericSqlProxy : IDisposable
// We pick a buffer size that's large enough to hold most single replies so we don't over-inject latency
private const int BufferSizeBytes = 128 * 1024;
readonly string connectionString;
volatile bool disposed;
private TcpListener listener = null;
private readonly CancellationTokenSource tokenSource = new CancellationTokenSource();
/// <summary>
/// Constructs a GenericSqlProxy for the local default sql instance
/// </summary>
public GenericSqlProxy() : this(".")
/// <summary>
/// Construct a new GenericSqlProxy for the given connection string
/// </summary>
/// <param name="connectionString"></param>
public GenericSqlProxy(string connectionString)
this.connectionString = connectionString;
public int Port { get; private set; }
/// <summary>
/// Initializes the proxy by opening the TCP listener and copying data between client and server
/// </summary>
/// <param name="localPort">local port number to use. 0 will use a random port</param>
/// <returns>The connection string to use for the SqlConnection</returns>
public string Initialize(int localPort = 0)
var builder = new SqlConnectionStringBuilder(connectionString);
GetTcpInfoFromDataSource(builder.DataSource, out string hostName, out int port);
listener = new TcpListener(IPAddress.Loopback, localPort);
listener.Server.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.DontLinger, true);
listener.Server.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
Port = ((IPEndPoint) listener.LocalEndpoint).Port;
Trace.TraceInformation($"Starting TcpListener on port {Port}");
Task.Factory.StartNew(() => { AsyncInit(listener, hostName, port); });
return new SqlConnectionStringBuilder(builder.ConnectionString)
DataSource = $"tcp:,{Port}"
private void AsyncInit(TcpListener tcpListener, string hostName, int port)
while (!disposed)
var accept = tcpListener.AcceptTcpClientAsync();
if (accept.Wait(1000, tokenSource.Token) && !tokenSource.IsCancellationRequested)
var localClient = accept.GetAwaiter().GetResult();
OnConnect?.Invoke(this, new ProxyConnectionEventArgs(localClient));
var remoteClient = new TcpClient() {NoDelay = true};
tokenSource.Token.Register(() =>
remoteClient.ConnectAsync(hostName, port).Wait(tokenSource.Token);
if (!tokenSource.IsCancellationRequested)
Task.Factory.StartNew(() => { ForwardToSql(localClient, remoteClient); });
Task.Factory.StartNew(() => { ForwardToClient(localClient, remoteClient); });
Trace.TraceInformation("AsyncInit aborted due to cancellation token set");
/// <summary>
/// Fires before the proxy writes a buffer to the host
/// </summary>
public event EventHandler<StreamWriteEventArgs> OnWriteHost;
/// <summary>
/// Fires before the proxy writes a buffer to the client
/// </summary>
public event EventHandler<StreamWriteEventArgs> OnWriteClient;
/// <summary>
/// Fires when a new connection to the proxy's port is accepted
/// </summary>
public event EventHandler<ProxyConnectionEventArgs> OnConnect;
private void ForwardToSql(TcpClient ourClient, TcpClient sqlClient)
long index = 0;
while (!disposed)
byte[] buffer = new byte[BufferSizeBytes];
int bytesRead = ourClient.GetStream().ReadAsync(buffer, 0, buffer.Length, tokenSource.Token).Result;
if (!tokenSource.Token.IsCancellationRequested)
OnWriteHost?.Invoke(this, new StreamWriteEventArgs(index++, buffer, bytesRead));
sqlClient.GetStream().Write(buffer, 0, bytesRead);
catch (Exception)
if (!disposed)
Trace.TraceInformation("ForwardToSql exiting");
private void ForwardToClient(TcpClient ourClient, TcpClient sqlClient)
long index = 0;
while (!disposed)
byte[] buffer = new byte[BufferSizeBytes];
int bytesRead = sqlClient.GetStream().ReadAsync(buffer, 0, buffer.Length, tokenSource.Token).Result;
if (!tokenSource.Token.IsCancellationRequested)
OnWriteClient?.Invoke(this, new StreamWriteEventArgs(index++, buffer, bytesRead));
ourClient.GetStream().Write(buffer, 0, bytesRead);
catch (Exception)
if (!disposed)
Trace.TraceInformation("ForwardToClient exiting");
private static void GetTcpInfoFromDataSource(string dataSource, out string hostName, out int port)
string[] dataSourceParts = dataSource.Split(',');
if (dataSourceParts.Length == 1)
hostName = dataSourceParts[0].Replace("tcp:", "");
port = 1433;
else if (dataSourceParts.Length == 2)
hostName = dataSourceParts[0].Replace("tcp:", "");
port = int.Parse(dataSourceParts[1]);
throw new InvalidOperationException("TCP Connection String not in correct format!");
public void Dispose()
disposed = true;
Trace.TraceInformation("Disposing TcpListener on port {0}", Port);
public class StreamWriteEventArgs : EventArgs
public StreamWriteEventArgs(long index, byte[]buffer, int bytesWritten)
Index = index;
Buffer = buffer;
BytesWritten = bytesWritten;
public long Index;
public byte[] Buffer;
public int BytesWritten;
public class ProxyConnectionEventArgs : EventArgs
public ProxyConnectionEventArgs(TcpClient client)
Client = client;
public TcpClient Client;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment