Created
July 5, 2018 08:06
-
-
Save zhouchangxun/5750b4636cc070ac01385d89946e0a7b to your computer and use it in GitHub Desktop.
a simple loadbalancer implemention with python.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import socket | |
import select | |
import random | |
from itertools import cycle | |
# dumb netcat server, short tcp connection | |
# $ ~ while true ; do nc -l 8888 < server1.html; done | |
# $ ~ while true ; do nc -l 9999 < server2.html; done | |
SERVER_POOL = [('10.157.0.238', 8888)] | |
# dumb python socket echo server, long tcp connection | |
# $ ~ while python server.py | |
# SERVER_POOL = [('localhost', 6666)] | |
ITER = cycle(SERVER_POOL) | |
def round_robin(iter): | |
# round_robin([A, B, C, D]) --> A B C D A B C D A B C D ... | |
return next(iter) | |
class LoadBalancer(object): | |
""" Socket implementation of a load balancer. | |
Flow Diagram: | |
+---------------+ +-----------------------------------------+ +---------------+ | |
| client socket | <==> | client-side socket | server-side socket | <==> | server socket | | |
| <client> | | < load balancer > | | <server> | | |
+---------------+ +-----------------------------------------+ +---------------+ | |
Attributes: | |
ip (str): virtual server's ip; client-side socket's ip | |
port (int): virtual server's port; client-side socket's port | |
algorithm (str): algorithm used to select a server | |
flow_table (dict): mapping of client socket obj <==> server-side socket obj | |
sockets (list): current connected and open socket obj | |
""" | |
flow_table = dict() | |
sockets = list() | |
def __init__(self, ip, port, algorithm='random'): | |
self.ip = ip | |
self.port = port | |
self.algorithm = algorithm | |
# init a client-side socket | |
self.cs_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
# the SO_REUSEADDR flag tells the kernel to reuse a local socket in TIME_WAIT state, | |
# without waiting for its natural timeout to expire. | |
self.cs_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
self.cs_socket.bind((self.ip, self.port)) | |
print 'init client-side socket: %s' % (self.cs_socket.getsockname(),) | |
self.cs_socket.listen(10) # max connections | |
self.sockets.append(self.cs_socket) | |
def start(self): | |
while True: | |
read_list, write_list, exception_list = select.select(self.sockets, [], []) | |
for sock in read_list: | |
# new connection | |
if sock == self.cs_socket: | |
print '='*40+'flow start'+'='*39 | |
self.on_accept() | |
break | |
# incoming message from a client socket | |
else: | |
try: | |
# In Windows, sometimes when a TCP program closes abruptly, | |
# a "Connection reset by peer" exception will be thrown | |
data = sock.recv(4096) # buffer size: 2^n | |
if data: | |
self.on_recv(sock, data) | |
else: | |
self.on_close(sock) | |
break | |
except: | |
sock.on_close(sock) | |
break | |
def on_accept(self): | |
client_socket, client_addr = self.cs_socket.accept() | |
print 'client connected: %s <==> %s' % (client_addr, self.cs_socket.getsockname()) | |
# select a server that forwards packets to | |
server_ip, server_port = self.select_server(SERVER_POOL, self.algorithm) | |
# init a server-side socket | |
ss_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
try: | |
ss_socket.connect((server_ip, server_port)) | |
print 'init server-side socket: %s' % (ss_socket.getsockname(),) | |
print 'server connected: %s <==> %s' % (ss_socket.getsockname(),(socket.gethostbyname(server_ip), server_port)) | |
except: | |
print "Can't establish connection with remote server, err: %s" % sys.exc_info()[0] | |
print "Closing connection with client socket %s" % (client_addr,) | |
client_socket.close() | |
return | |
self.sockets.append(client_socket) | |
self.sockets.append(ss_socket) | |
self.flow_table[client_socket] = ss_socket | |
self.flow_table[ss_socket] = client_socket | |
def on_recv(self, sock, data): | |
print 'recving packets: %-20s ==> %-20s, data: %s' % (sock.getpeername(), sock.getsockname(), [data]) | |
# data can be modified before forwarding to server | |
# lots of add-on features can be added here | |
remote_socket = self.flow_table[sock] | |
remote_socket.send(data) | |
print 'sending packets: %-20s ==> %-20s, data: %s' % (remote_socket.getsockname(), remote_socket.getpeername(), [data]) | |
def on_close(self, sock): | |
print 'client %s has disconnected' % (sock.getpeername(),) | |
print '='*41+'flow end'+'='*40 | |
ss_socket = self.flow_table[sock] | |
self.sockets.remove(sock) | |
self.sockets.remove(ss_socket) | |
sock.close() # close connection with client | |
ss_socket.close() # close connection with server | |
del self.flow_table[sock] | |
del self.flow_table[ss_socket] | |
def select_server(self, server_list, algorithm): | |
if algorithm == 'random': | |
return random.choice(server_list) | |
elif algorithm == 'round robin': | |
return round_robin(ITER) | |
else: | |
raise Exception('unknown algorithm: %s' % algorithm) | |
if __name__ == '__main__': | |
try: | |
LoadBalancer('localhost', 5555, 'round robin').start() | |
except KeyboardInterrupt: | |
print "Ctrl C - Stopping load_balancer" | |
sys.exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, Im a university student and Im currently working on a software development project as an assignment related to networking features. I was looking at load balancer implemented in python and although your code is in python 2, it looks very good reference for me to implement it in python 3 including. It didn't say anything about referencing the code, so I wanted to inquire you if I could use partial code for my project