Skip to content

Instantly share code, notes, and snippets.

@yalayabeeb
Last active May 2, 2017 04:59
Show Gist options
  • Save yalayabeeb/3492c67a5d9e9de48093 to your computer and use it in GitHub Desktop.
Save yalayabeeb/3492c67a5d9e9de48093 to your computer and use it in GitHub Desktop.
A simple wrapper for the Socket class with an example of usage. Updated with encryption & compression/decompression. Fixed crashing when sending data too quickly.
using System;
using System.IO;
using System.IO.Compression;
using System.Net;
using System.Net.Sockets;
using System.Runtime.Serialization.Formatters.Binary;
using System.Security.Cryptography;
using System.Text;
public class Sock
{
#region Server
public class Server
{
#region Delegates
public delegate void ClientAcceptedEventHandler(Sock.Client client);
public delegate void DataReceivedEventHandler(Sock.Client client, object[] data);
public delegate void ClientDisconnectedEventHandler(Sock.Client client);
#endregion
#region Events
public event ClientAcceptedEventHandler OnClientAccepted;
public event DataReceivedEventHandler OnDataReceived;
public event ClientDisconnectedEventHandler OnClientDisconnect;
#endregion
public Socket ServerSocket { get; set; }
private byte[] DataBuffer { get; set; }
public SocketEncryption EncryptionSettings { get; set; } = new SocketEncryption();
public string EncryptionKey { get; set; } = "key";
public bool UseEncryption { get; set; } = false;
public object Tag { get; set; }
public int BufferSize
{
get
{
return DataBuffer.Length;
}
set
{
if (!ServerSocket.Connected)
DataBuffer = new byte[value];
}
}
public Server()
{
ServerSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
DataBuffer = new byte[1024];
}
public Server(Socket sock)
{
ServerSocket = sock;
DataBuffer = new byte[1024];
}
public Server(Socket sock, int bufferSize)
{
ServerSocket = sock;
DataBuffer = new byte[bufferSize];
}
#region Socket Functions
public void Start(int port)
{
if (!ServerSocket.Connected)
{
ServerSocket.Bind(new IPEndPoint(IPAddress.Any, port));
ServerSocket.Listen(10);
ServerSocket.BeginAccept(new AsyncCallback(OnAccept), ServerSocket);
}
}
public void Start(int port, int backlog)
{
if (!ServerSocket.Connected)
{
ServerSocket.Bind(new IPEndPoint(IPAddress.Any, port));
ServerSocket.Listen(backlog);
ServerSocket.BeginAccept(new AsyncCallback(OnAccept), ServerSocket);
}
}
#endregion
#region CallBacks
private void OnAccept(IAsyncResult ar)
{
Socket sock = ServerSocket.EndAccept(ar);
Sock.Client client = new Sock.Client(sock, DataBuffer.Length);
client.EncryptionKey = EncryptionKey;
client.EncryptionSettings = EncryptionSettings;
client.UseEncryption = UseEncryption;
if (OnClientAccepted != null)
OnClientAccepted(client);
client.ClientSocket.BeginReceive(DataBuffer, 0, DataBuffer.Length, SocketFlags.None, new AsyncCallback(OnReceive), client);
ServerSocket.BeginAccept(new AsyncCallback(OnAccept), sock);
}
private void OnReceive(IAsyncResult ar)
{
Sock.Client sock = ar.AsyncState as Sock.Client;
try
{
int receivedLength = sock.ClientSocket.EndReceive(ar);
if (receivedLength != 0)
{
byte[] dataPacket = new byte[receivedLength];
Buffer.BlockCopy(DataBuffer, 0, dataPacket, 0, receivedLength);
using (var ms = new MemoryStream(dataPacket))
{
using (var br = new BinaryReader(ms))
{
while (ms.Position < ms.Length)
{
int packetLength = br.ReadInt32();
byte[] finalPacket = null;
if (packetLength > ms.Length - ms.Position)
{
using (var ms2 = new MemoryStream())
{
byte[] buffer = new byte[ms.Length - ms.Position];
buffer = br.ReadBytes(packetLength);
ms2.Write(buffer, 0, buffer.Length);
while (ms.Position != packetLength)
{
packetLength = sock.ClientSocket.Receive(DataBuffer);
buffer = new byte[packetLength];
Buffer.BlockCopy(DataBuffer, 0, buffer, 0, packetLength);
ms2.Write(buffer, 0, buffer.Length);
}
finalPacket = ms2.ToArray();
}
}
else
{
finalPacket = br.ReadBytes(packetLength);
}
if (UseEncryption)
{
dataPacket = EncryptionSettings.Decrypt(finalPacket, EncryptionKey);
}
if (OnDataReceived != null)
OnDataReceived(sock, DataFormatter.Deserialize<object[]>(Decompress(finalPacket)));
sock.ClientSocket.BeginReceive(DataBuffer, 0, DataBuffer.Length, SocketFlags.None, new AsyncCallback(OnReceive), sock);
}
}
}
}
}
catch (Exception ex)
{
if (string.Equals(ex.Message, "An existing connection was forcibly closed by the remote host"))
{
if (OnClientDisconnect != null)
OnClientDisconnect(sock);
}
}
}
private void OnSend(IAsyncResult ar)
{
Socket sock = ar.AsyncState as Socket;
sock.EndSend(ar);
}
#endregion
}
#endregion
#region Client
public class Client
{
#region Delegates
public delegate void ConnectedEventHandler(bool connected);
public delegate void DataReceivedEventHandler(Sock.Server sender, object[] data);
public delegate void ServerDisconnectedEventHanlder();
#endregion
#region Events
public event ConnectedEventHandler OnConnect;
public event DataReceivedEventHandler OnDataReceived;
public event ServerDisconnectedEventHanlder OnServerDisconnect;
#endregion
public Socket ClientSocket { get; set; }
private byte[] DataBuffer { get; set; }
public SocketEncryption EncryptionSettings { get; set; } = new SocketEncryption();
public string EncryptionKey { get; set; } = "key";
public bool UseEncryption { get; set; } = false;
public object Tag { get; set; }
public int BufferSize
{
get
{
return DataBuffer.Length;
}
set
{
if (!ClientSocket.Connected)
DataBuffer = new byte[value];
}
}
public Client()
{
ClientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
DataBuffer = new byte[1024];
}
public Client(Socket sock)
{
ClientSocket = sock;
DataBuffer = new byte[1024];
}
public Client(Socket sock, int bufferSize)
{
ClientSocket = sock;
DataBuffer = new byte[bufferSize];
}
#region Socket Functions
public void Connect(IPAddress ip, int port)
{
try
{
ClientSocket.Connect(new IPEndPoint(ip, port));
}
catch (SocketException sockEx)
{
if (string.Equals(sockEx.Message, "No connection could be made because the target machine actively refused it"))
{
if (OnConnect != null)
OnConnect(false);
}
}
if (ClientSocket.Connected)
{
var sock = new Sock.Server(ClientSocket, DataBuffer.Length);
sock.ServerSocket.BeginReceive(DataBuffer, 0, DataBuffer.Length, SocketFlags.None, new AsyncCallback(OnReceive), sock);
}
if (OnConnect != null)
OnConnect(ClientSocket.Connected);
}
public bool Send(object[] data)
{
if (ClientSocket.Connected)
{
byte[] dataPacket = Compress(DataFormatter.Serialize(data));
if (UseEncryption)
{
dataPacket = EncryptionSettings.Encrypt(dataPacket, EncryptionKey);
}
byte[] finalPacket = new byte[dataPacket.Length + 4];
using (var ms = new MemoryStream())
{
using (var bw = new BinaryWriter(ms))
{
bw.Write(dataPacket.Length);
bw.Write(dataPacket);
}
finalPacket = ms.ToArray();
}
ClientSocket.BeginSend(finalPacket, 0, finalPacket.Length, SocketFlags.None, new AsyncCallback(OnSend), ClientSocket);
return true;
}
return false;
}
#endregion
#region CallBacks
private void OnReceive(IAsyncResult ar)
{
try
{
Sock.Server sock = ar.AsyncState as Sock.Server;
int receivedLength = sock.ServerSocket.EndReceive(ar);
if (receivedLength != 0)
{
byte[] dataPacket = new byte[receivedLength];
Buffer.BlockCopy(DataBuffer, 0, dataPacket, 0, receivedLength);
using (var ms = new MemoryStream(dataPacket))
{
using (var br = new BinaryReader(ms))
{
while (ms.Position < ms.Length)
{
int packetLength = br.ReadInt32();
byte[] finalPacket = null;
if (packetLength > ms.Length - ms.Position)
{
using (var ms2 = new MemoryStream())
{
byte[] buffer = new byte[ms.Length - ms.Position];
buffer = br.ReadBytes(packetLength);
ms2.Write(buffer, 0, buffer.Length);
while (ms.Position != packetLength)
{
packetLength = sock.ServerSocket.Receive(DataBuffer);
buffer = new byte[packetLength];
Buffer.BlockCopy(DataBuffer, 0, buffer, 0, packetLength);
ms2.Write(buffer, 0, buffer.Length);
}
finalPacket = ms2.ToArray();
}
}
else
{
finalPacket = br.ReadBytes(packetLength);
}
if (UseEncryption)
{
dataPacket = EncryptionSettings.Decrypt(finalPacket, EncryptionKey);
}
if (OnDataReceived != null)
OnDataReceived(sock, DataFormatter.Deserialize<object[]>(Decompress(finalPacket)));
sock.ServerSocket.BeginReceive(DataBuffer, 0, DataBuffer.Length, SocketFlags.None, new AsyncCallback(OnReceive), sock);
}
}
}
}
//Sock.Server sock = ar.AsyncState as Sock.Server;
//int receivedLength = sock.ServerSocket.EndReceive(ar);
//if (receivedLength != 0)
//{
// byte[] dataPacket = new byte[receivedLength];
// Buffer.BlockCopy(DataBuffer, 0, dataPacket, 0, receivedLength);
// byte[] finalPacket = null;
// using (var ms = new MemoryStream(dataPacket))
// {
// using (var br = new BinaryReader(ms))
// {
// int length = br.ReadInt32();
// finalPacket = br.ReadBytes(length);
// }
// }
// if (UseEncryption)
// {
// dataPacket = EncryptionSettings.Decrypt(dataPacket, EncryptionKey);
// }
// if (OnDataReceived != null)
// OnDataReceived(sock, DataFormatter.Deserialize<object[]>(Decompress(dataPacket)));
// sock.ServerSocket.BeginReceive(DataBuffer, 0, DataBuffer.Length, SocketFlags.None, new AsyncCallback(OnReceive), sock);
//}
}
catch (Exception ex)
{
if (string.Equals(ex.Message, "An existing connection was forcibly closed by the remote host"))
{
if (OnServerDisconnect != null)
OnServerDisconnect();
}
}
}
private void OnSend(IAsyncResult ar)
{
Socket sock = ar.AsyncState as Socket;
sock.EndSend(ar);
}
#endregion
}
#endregion
#region Data Formatting
private class DataFormatter
{
public static byte[] Serialize(object obj)
{
var bf = new BinaryFormatter();
using (var ms = new MemoryStream())
{
bf.Serialize(ms, obj);
return ms.ToArray();
}
}
public static T Deserialize<T>(byte[] arrBytes)
{
var bf = new BinaryFormatter();
using (var ms = new MemoryStream())
{
ms.Write(arrBytes, 0, arrBytes.Length);
ms.Seek(0, SeekOrigin.Begin);
return (T)bf.Deserialize(ms);
}
}
}
#endregion
#region Encryption
public class SocketEncryption
{
public EncryptionMethod Method { get; set; }
public SocketEncryption()
{
Method = new DefaultEncryption();
}
public SocketEncryption(EncryptionMethod method)
{
Method = method;
}
public string GenerateKey()
{
return Guid.NewGuid().ToString();
}
public byte[] Encrypt(byte[] input, string key)
{
return Method.Encrypt(input, key);
}
public byte[] Decrypt(byte[] input, string key)
{
return Method.Decrypt(input, key);
}
}
public interface EncryptionMethod
{
byte[] Encrypt(byte[] input, string key);
byte[] Decrypt(byte[] input, string key);
}
private class DefaultEncryption : EncryptionMethod
{
public byte[] Encrypt(byte[] input, string key)
{
using (var ms = new MemoryStream())
{
using (var md5 = new MD5CryptoServiceProvider())
{
byte[] keyBytes = Encoding.UTF8.GetBytes(key);
byte[] rijndaelKey = md5.ComputeHash(keyBytes, 0, keyBytes.Length);
using (var r = new RijndaelManaged())
{
r.Key = rijndaelKey;
r.IV = rijndaelKey;
r.Mode = CipherMode.CBC;
r.Padding = PaddingMode.PKCS7;
using (var cs = new CryptoStream(ms, r.CreateEncryptor(), CryptoStreamMode.Write))
cs.Write(input, 0, input.Length);
return ms.ToArray();
}
}
}
}
public byte[] Decrypt(byte[] input, string key)
{
using (var ms = new MemoryStream())
{
using (var md5 = new MD5CryptoServiceProvider())
{
byte[] keyBytes = Encoding.UTF8.GetBytes(key);
byte[] rijndaelKey = md5.ComputeHash(keyBytes, 0, keyBytes.Length);
using (var r = new RijndaelManaged())
{
r.Key = rijndaelKey;
r.IV = rijndaelKey;
r.Mode = CipherMode.CBC;
r.Padding = PaddingMode.PKCS7;
using (var cs = new CryptoStream(ms, r.CreateDecryptor(), CryptoStreamMode.Write))
cs.Write(input, 0, input.Length);
return ms.ToArray();
}
}
}
}
}
#endregion
#region Compression / Decompression
public static byte[] Compress(byte[] input)
{
using (MemoryStream ms = new MemoryStream())
{
using (GZipStream _gz = new GZipStream(ms, CompressionMode.Compress))
{
_gz.Write(input, 0, input.Length);
}
return ms.ToArray();
}
}
public static byte[] Decompress(byte[] input)
{
using (var ms = new MemoryStream(input))
{
using (var ms2 = new MemoryStream())
{
using (GZipStream gzip = new GZipStream(ms, CompressionMode.Decompress))
{
byte[] payload = new byte[100000];
int count = 0;
while ((count = gzip.Read(payload, 0, payload.Length)) > 0)
{
ms2.Write(payload, 0, count);
}
return ms2.ToArray();
}
}
}
}
#endregion
}
using System;
using System.Collections.Generic;
using System.Net;
namespace Client
{
class Program
{
static void Main(string[] args)
{
SocketSetup();
Console.ReadLine();
}
private static void SocketSetup()
{
Sock.Client client = new Sock.Client();
client.OnConnect += Client_OnConnect;
client.OnDataReceived += Client_OnDataReceived;
client.OnServerDisconnect += Client_OnServerDisconnect;
client.Connect(IPAddress.Parse("IP ADDRESS HERE (e.g. 192.168.1.5)"), 100);
string message = "Hello from Client";
if (client.Send(new object[] { message }))
Console.WriteLine("Message Sent: {0}\n", message);
}
private static void Client_OnConnect(bool connected)
{
Console.WriteLine("Connected: {0}\n", connected);
if (!connected)
{
Console.WriteLine("Do you want to attempt to reconnect? (y/n)");
string input = Console.ReadLine();
if (string.Equals(input, "y"))
SocketSetup();
}
}
private static void Client_OnDataReceived(Sock.Server sender, object[] data)
{
Console.WriteLine("Data Received: {0}", data[0]);
Console.WriteLine("From Server: {0}\n", sender.ServerSocket.RemoteEndPoint.ToString());
}
private static void Client_OnServerDisconnect()
{
Console.WriteLine("Server Disconnected");
}
}
}
using System;
namespace Server
{
class Program
{
static void Main(string[] args)
{
Sock.Server server = new Sock.Server();
server.OnClientAccepted += Server_ClientAccepted;
server.OnDataReceived += Server_DataReceived;
server.OnClientDisconnect += Server_OnClientDisconnect;
server.Start(100);
Console.ReadLine();
}
private static void Server_ClientAccepted(Sock.Client client)
{
Console.WriteLine("Connected: {0}\n", client.ClientSocket.Connected);
}
private static void Server_DataReceived(Sock.Client client, object[] data)
{
Console.WriteLine("Data Received: {0}", data[0]);
Console.WriteLine("From Client: {0}\n", client.ClientSocket.RemoteEndPoint.ToString());
string messageToSend = "Reply from Server";
client.Send(new object[] { messageToSend });
Console.WriteLine("Message Sent: {0}", messageToSend);
}
private static void Server_OnClientDisconnect(Sock.Client client)
{
Console.WriteLine("\nClient Disconnected");
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment