Skip to content

Instantly share code, notes, and snippets.

@dustyfresh
Created June 7, 2019 16:26
Show Gist options
  • Save dustyfresh/49fc6cf6c2b8824f09d9f231430e983b to your computer and use it in GitHub Desktop.
Save dustyfresh/49fc6cf6c2b8824f09d9f231430e983b to your computer and use it in GitHub Desktop.
DNS nameserver implemented in python
#!/usr/bin/env python3
import sys
from datetime import datetime
import time
from time import sleep
from dnslib import DNSLabel, QTYPE, RD, RR, RCODE
from dnslib import A, AAAA, CNAME, MX, NS, SOA, TXT
from dnslib.server import DNSServer
EPOCH = datetime(1970, 1, 1)
TIMESTAMP = datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S')
SERIAL = int((datetime.utcnow() - EPOCH).total_seconds())
TYPE_LOOKUP = {
A: QTYPE.A,
AAAA: QTYPE.AAAA,
CNAME: QTYPE.CNAME,
MX: QTYPE.MX,
NS: QTYPE.NS,
SOA: QTYPE.SOA,
TXT: QTYPE.TXT,
}
SUSPICIOUS_RECORDS = {
'ftp': 'FTP ',
'cpanel': 'cPanel hosting control panel',
'admin': 'common record for admin control panels',
'ssh': 'SSH service',
'wordpress': 'common blogging platform',
'store': 'common record for stores',
'staging': 'common record for staging environments',
'mail': 'common record for email',
}
class Record:
def __init__(self, rdata_type, *args, rtype=None, rname=None, ttl=None, **kwargs):
if isinstance(rdata_type, RD):
# actually an instance, not a type
self._rtype = TYPE_LOOKUP[rdata_type.__class__]
rdata = rdata_type
else:
self._rtype = TYPE_LOOKUP[rdata_type]
if rdata_type == SOA and len(args) == 2:
# add sensible times to SOA
args += ((
SERIAL, # serial number
60 * 60 * 1, # refresh
60 * 60 * 3, # retry
60 * 60 * 24, # expire
60 * 60 * 1, # minimum
),)
rdata = rdata_type(*args)
if rtype:
self._rtype = rtype
self._rname = rname
self.kwargs = dict(
rdata=rdata,
ttl=self.sensible_ttl() if ttl is None else ttl,
**kwargs
)
def try_rr(self, q):
if q.qtype == QTYPE.ANY or q.qtype == self._rtype:
return self.as_rr(q.qname)
def as_rr(self, alt_rname):
return RR(rname=self._rname or alt_rname, rtype=self._rtype, **self.kwargs)
def sensible_ttl(self):
if self._rtype in (QTYPE.NS, QTYPE.SOA):
return 60 * 60 * 24
else:
return 300
@property
def is_soa(self):
return self._rtype == QTYPE.SOA
def __str__(self):
return '{} {}'.format(QTYPE[self._rtype], self.kwargs)
ZONES = {
'example.com': [
Record(A, '1.2.3.4'),
Record(CNAME, 'www.example.com.'),
Record(MX, 'whatever.com.', 5),
Record(MX, 'mx2.whatever.com.', 10),
Record(MX, 'mx3.whatever.com.', 20),
Record(NS, 'mx2.whatever.com.'),
Record(NS, 'mx3.whatever.com.'),
Record(TXT, 'hack the planet!!!'),
Record(SOA, 'ns1.example.com', 'dns.example.com'),
],
'testing.com': [
Record(A, '185.92.221.142'),
Record(TXT, 'hack the planet!!!'),
Record(NS, 'ns1.lol.systems'),
Record(NS, 'ns2.lol.systems'),
Record(SOA, 'ns1.lol.systems', 'ns2.lol.systems'),
],
'www.testing.com': [
Record(CNAME, 'testing.com.'),
],
'derp.testing.com': [
Record(TXT, 'rekt lmao'),
]
}
class Resolver:
def __init__(self):
self.zones = {DNSLabel(k): v for k, v in ZONES.items()}
def resolve(self, request, handler):
reply = request.reply()
zone = self.zones.get(request.q.qname)
if zone is not None:
for zone_records in zone:
rr = zone_records.try_rr(request.q)
rr and reply.add_answer(rr)
else:
# CATCHALL CODE STARTS HERE
if QTYPE[reply.q.qtype] == 'CNAME':
new_record = Record(CNAME, 'testing.com.')
elif QTYPE[reply.q.qtype] == 'A':
new_record = Record(A, '1.3.3.7')
reply.add_answer(new_record.try_rr(request.q))
print('{} -- {} requested an invalid {} record: {}'.format(TIMESTAMP, handler.client_address[0], QTYPE[reply.q.qtype], reply.q.qname))
'''
# no direct zone so look for an SOA record for a higher level zone
for zone_label, zone_records in self.zones.items():
if request.q.qname.matchSuffix(zone_label):
try:
soa_record = next(r for r in zone_records if r.is_soa)
except StopIteration:
continue
else:
reply.add_answer(soa_record.as_rr(zone_label))
break
'''
return reply
class DNSLogger:
def __init__(self,log="",prefix=True):
default = ["request","reply","truncated","error"]
log = log.split(",") if log else []
enabled = set([ s for s in log if s[0] not in '+-'] or default)
[ enabled.add(l[1:]) for l in log if l.startswith('+') ]
[ enabled.discard(l[1:]) for l in log if l.startswith('-') ]
for l in ['log_recv','log_send','log_request','log_reply',
'log_truncated','log_error','log_data']:
if l[4:] not in enabled:
setattr(self,l,self.log_pass)
self.prefix = prefix
def log_pass(self,*args):
pass
def log_prefix(self,handler):
if self.prefix:
return "%s [%s:%s] " % (time.strftime("%Y-%m-%d %X"),
handler.__class__.__name__,
handler.server.resolver.__class__.__name__)
else:
return ""
def log_recv(self,handler,data):
print("%sReceived: [%s:%d] (%s) <%d> : %s" % (
self.log_prefix(handler),
handler.client_address[0],
handler.client_address[1],
handler.protocol,
len(data),
binascii.hexlify(data)))
def log_send(self,handler,data):
print("%sSent: [%s:%d] (%s) <%d> : %s" % (
self.log_prefix(handler),
handler.client_address[0],
handler.client_address[1],
handler.protocol,
len(data),
binascii.hexlify(data)))
def log_request(self,handler,request):
self.log_data(request)
def log_reply(self,handler,reply):
if reply.header.rcode == RCODE.NOERROR:
test = reply.rr
for i in test:
if QTYPE[i.rtype] != 'SOA':
print('{} requested a {} record: {}'.format(handler.client_address[0], QTYPE[reply.q.qtype], reply.q.qname))
self.log_data(reply)
def log_truncated(self,handler,reply):
print("%sTruncated Reply: [%s:%d] (%s) / '%s' (%s) / RRs: %s" % (
self.log_prefix(handler),
handler.client_address[0],
handler.client_address[1],
handler.protocol,
reply.q.qname,
QTYPE[reply.q.qtype],
",".join([QTYPE[a.rtype] for a in reply.rr])))
self.log_data(reply)
def log_error(self,handler,e):
print("%sInvalid Request: [%s:%d] (%s) :: %s" % (
self.log_prefix(handler),
handler.client_address[0],
handler.client_address[1],
handler.protocol,
e))
def log_data(self,dnsobj):
print("\n",dnsobj.toZone(" "),"\n",sep="")
resolver = Resolver()
logger = DNSLogger()
servers = [
DNSServer(resolver, port=53, address='0.0.0.0', tcp=True, logger=logger),
DNSServer(resolver, port=53, address='0.0.0.0', tcp=False, logger=logger),
]
if __name__ == '__main__':
for s in servers:
s.start_thread()
try:
while 1:
sleep(0.1)
except KeyboardInterrupt:
pass
finally:
for s in servers:
s.stop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment