Skip to content

Instantly share code, notes, and snippets.

@gianm
Created January 3, 2018 01:52
Show Gist options
  • Save gianm/43d8bae20311877cd429f19676c9abb8 to your computer and use it in GitHub Desktop.
Save gianm/43d8bae20311877cd429f19676c9abb8 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# Copyright 2017 Imply Data, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import base64
import collections
import csv
import errno
import json
import numbers
import re
import ssl
import sys
import time
import unicodedata
import urllib2
class DruidSqlException(Exception):
def write_to(self, f):
f.write('\x1b[31m')
f.write(self.message if self.message else "Query failed")
f.write('\x1b[0m')
f.write('\n')
f.flush()
def do_query(url, sql, context, timeout, user, password, ignore_ssl_verification, ca_file, ca_path):
json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict)
try:
sql_json = json.dumps({'query' : sql, 'context' : context})
# SSL stuff
ssl_context = None;
if (ignore_ssl_verification or ca_file != None or ca_path != None):
ssl_context = ssl.create_default_context()
if (ignore_ssl_verification):
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
else:
ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path)
req = urllib2.Request(url, sql_json, {'Content-Type' : 'application/json'})
if timeout <= 0:
timeout = None
if (user and password):
basicAuthEncoding = base64.b64encode('%s:%s' % (user, password))
req.add_header("Authorization", "Basic %s" % basicAuthEncoding)
response = urllib2.urlopen(req, None, timeout, context=ssl_context)
first_chunk = True
eof = False
buf = ''
while not eof or len(buf) > 0:
while True:
try:
# Remove starting ','
buf = buf.lstrip(',')
obj, sz = json_decoder.raw_decode(buf)
yield obj
buf = buf[sz:]
except ValueError as e:
# Maybe invalid JSON, maybe partial object; it's hard to tell with this library.
if eof and buf.rstrip() == ']':
# Stream done and all objects read.
buf = ''
break
elif eof or len(buf) > 256 * 1024:
# If we read more than 256KB or if it's eof then report the parse error.
raise e
else:
# Stop reading objects, get more from the stream instead.
break
# Read more from the http stream
if not eof:
chunk = response.read(8192)
if chunk:
buf = buf + chunk
if first_chunk:
# Remove starting '['
buf = buf.lstrip('[')
else:
# Stream done. Keep reading objects out of buf though.
eof = True
except urllib2.URLError as e:
raise_friendly_error(e)
def raise_friendly_error(e):
if isinstance(e, urllib2.HTTPError):
text = e.read().strip()
error_obj = {}
try:
error_obj = dict(json.loads(text))
except:
pass
if e.code == 500 and 'errorMessage' in error_obj:
error_text = ''
if error_obj['error'] != 'Unknown exception':
error_text = error_text + error_obj['error'] + ': '
if error_obj['errorClass']:
error_text = error_text + str(error_obj['errorClass']) + ': '
error_text = error_text + str(error_obj['errorMessage'])
if error_obj['host']:
error_text = error_text + ' (' + str(error_obj['host']) + ')'
raise DruidSqlException(error_text)
else:
raise DruidSqlException("HTTP Error {0}: {1}\n{2}".format(e.code, e.reason, text))
else:
raise DruidSqlException(str(e))
def to_utf8(value):
if value is None:
return ""
elif isinstance(value, unicode):
return value.encode("utf-8")
else:
return str(value)
def to_tsv(values, delimiter):
return delimiter.join(to_utf8(v).replace(delimiter, '') for v in values)
def print_csv(rows, header):
csv_writer = csv.writer(sys.stdout)
first = True
for row in rows:
if first and header:
csv_writer.writerow(list(to_utf8(k) for k in row.keys()))
first = False
values = []
for key, value in row.iteritems():
values.append(to_utf8(value))
csv_writer.writerow(values)
def print_tsv(rows, header, tsv_delimiter):
first = True
for row in rows:
if first and header:
print(to_tsv(row.keys(), tsv_delimiter))
first = False
values = []
for key, value in row.iteritems():
values.append(value)
print(to_tsv(values, tsv_delimiter))
def print_json(rows):
for row in rows:
print(json.dumps(row))
def table_to_printable_value(value):
# Unicode string, trimmed with control characters removed
if value is None:
return u"NULL"
else:
return to_utf8(value).strip().decode('utf-8').translate(dict.fromkeys(range(32)))
def table_compute_string_width(v):
normalized = unicodedata.normalize('NFC', v)
width = 0
for c in normalized:
ccategory = unicodedata.category(c)
cwidth = unicodedata.east_asian_width(c)
if ccategory == 'Cf':
# Formatting control, zero width
pass
elif cwidth == 'F' or cwidth == 'W':
# Double-wide character, prints in two columns
width = width + 2
else:
# All other characters
width = width + 1
return width
def table_compute_column_widths(row_buffer):
widths = None
for values in row_buffer:
values_widths = [table_compute_string_width(v) for v in values]
if not widths:
widths = values_widths
else:
i = 0
for v in values:
widths[i] = max(widths[i], values_widths[i])
i = i + 1
return widths
def table_print_row(values, column_widths, column_types):
vertical_line = u'\u2502'.encode('utf-8')
for i in xrange(0, len(values)):
padding = ' ' * max(0, column_widths[i] - table_compute_string_width(values[i]))
if column_types and column_types[i] == 'n':
print(vertical_line + ' ' + padding + values[i].encode('utf-8') + ' ', end="")
else:
print(vertical_line + ' ' + values[i].encode('utf-8') + padding + ' ', end="")
print(vertical_line)
def table_print_header(values, column_widths):
# Line 1
left_corner = u'\u250C'.encode('utf-8')
horizontal_line = u'\u2500'.encode('utf-8')
top_tee = u'\u252C'.encode('utf-8')
right_corner = u'\u2510'.encode('utf-8')
print(left_corner, end="")
for i in xrange(0, len(column_widths)):
print(horizontal_line * max(0, column_widths[i] + 2), end="")
if i + 1 < len(column_widths):
print(top_tee, end="")
print(right_corner)
# Line 2
table_print_row(values, column_widths, None)
# Line 3
left_tee = u'\u251C'.encode('utf-8')
cross = u'\u253C'.encode('utf-8')
right_tee = u'\u2524'.encode('utf-8')
print(left_tee, end="")
for i in xrange(0, len(column_widths)):
print(horizontal_line * max(0, column_widths[i] + 2), end="")
if i + 1 < len(column_widths):
print(cross, end="")
print(right_tee)
def table_print_bottom(column_widths):
left_corner = u'\u2514'.encode('utf-8')
right_corner = u'\u2518'.encode('utf-8')
bottom_tee = u'\u2534'.encode('utf-8')
horizontal_line = u'\u2500'.encode('utf-8')
print(left_corner, end="")
for i in xrange(0, len(column_widths)):
print(horizontal_line * max(0, column_widths[i] + 2), end="")
if i + 1 < len(column_widths):
print(bottom_tee, end="")
print(right_corner)
def table_print_row_buffer(row_buffer, column_widths, column_types):
first = True
for values in row_buffer:
if first:
table_print_header(values, column_widths)
first = False
else:
table_print_row(values, column_widths, column_types)
def print_table(rows):
start = time.time()
nrows = 0
first = True
# Buffer some rows before printing.
rows_to_buffer = 500
row_buffer = []
column_types = []
column_widths = None
for row in rows:
nrows = nrows + 1
if first:
row_buffer.append([table_to_printable_value(k) for k in row.keys()])
for k in row.keys():
if isinstance(row[k], numbers.Number):
column_types.append('n')
else:
column_types.append('s')
first = False
values = [table_to_printable_value(v) for k, v in row.iteritems()]
if rows_to_buffer > 0:
row_buffer.append(values)
rows_to_buffer = rows_to_buffer - 1
else:
if row_buffer:
column_widths = table_compute_column_widths(row_buffer)
table_print_row_buffer(row_buffer, column_widths, column_types)
del row_buffer[:]
table_print_row(values, column_widths, column_types)
if row_buffer:
column_widths = table_compute_column_widths(row_buffer)
table_print_row_buffer(row_buffer, column_widths, column_types)
if column_widths:
table_print_bottom(column_widths)
print("Retrieved {0:,d} row{1:s} in {2:.2f}s.".format(nrows, 's' if nrows != 1 else '', time.time() - start))
print("")
def display_query(url, sql, context, args):
rows = do_query(url, sql, context, args.timeout, args.user, args.password, args.ignore_ssl_verification, args.cafile, args.capath)
if args.format == 'csv':
print_csv(rows, args.header)
elif args.format == 'tsv':
print_tsv(rows, args.header, args.tsv_delimiter)
elif args.format == 'json':
print_json(rows)
elif args.format == 'table':
print_table(rows)
def sql_escape(s):
if s is None:
return "''"
elif isinstance(s, unicode):
ustr = s
else:
ustr = str(s).decode('utf-8')
escaped = [u"U&'"]
for c in ustr:
ccategory = unicodedata.category(c)
if ccategory.startswith('L') or ccategory.startswith('N') or c == ' ':
escaped.append(c)
else:
escaped.append(u'\\')
escaped.append('%04x' % ord(c))
escaped.append("'")
return ''.join(escaped)
def main():
parser = argparse.ArgumentParser(description='Druid SQL command-line client.')
parser.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Broker host or url')
parser.add_argument('--timeout', type=int, default=0, help='Timeout in seconds, 0 for no timeout')
parser.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format')
parser.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"')
parser.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"')
parser.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection')
parser.add_argument('--execute', '-e', type=str, help='Execute single SQL query')
parser.add_argument('--user', '-u', type=str, help='Username for HTTP basic auth')
parser.add_argument('--password', '-p', type=str, help='Password for HTTP basic auth')
parser.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.')
parser.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
parser.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.')
args = parser.parse_args()
# Build broker URL
url = args.host.rstrip('/') + '/druid/v2/sql/'
if not url.startswith('http:') and not url.startswith('https:'):
url = 'http://' + url
# Build context
context = {}
if args.context_option:
for opt in args.context_option:
kv = opt.split("=", 1)
if len(kv) != 2:
raise ValueError('Invalid context option, should be key=value: ' + opt)
if re.match(r"^\d+$", kv[1]):
context[kv[0]] = long(kv[1])
else:
context[kv[0]] = kv[1]
if args.execute:
display_query(url, args.execute, context, args)
else:
# interactive mode
print("Welcome to dsql, the command-line client for Druid SQL.")
print("Type \"\h\" for help.")
while True:
sql = ''
while not sql.endswith(';'):
prompt = "dsql> " if sql == '' else 'more> '
try:
more_sql = raw_input(prompt)
except EOFError:
sys.stdout.write('\n')
sys.exit(1)
if sql == '' and more_sql.startswith('\\'):
# backslash command
dmatch = re.match(r'^\\d(S?)(\+?)(\s+.*?|)\s*$', more_sql)
if dmatch:
include_system = dmatch.group(1)
extra_info = dmatch.group(2)
arg = dmatch.group(3).strip()
if arg:
sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_escape(arg)
if not include_system:
sql = sql + " AND TABLE_SCHEMA = 'druid'"
# break to execute sql
break
else:
sql = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES";
if not include_system:
sql = sql + " WHERE TABLE_SCHEMA = 'druid'"
# break to execute sql
break
hmatch = re.match(r'^\\h\s*$', more_sql)
if hmatch:
print("Commands:")
print(" \d show tables")
print(" \dS show tables, including system tables")
print(" \d table_name describe table")
print(" \h show this help")
print(" \q exit this program")
print("Or enter a SQL query ending with a semicolon (;).")
continue
qmatch = re.match(r'^\\q\s*$', more_sql)
if qmatch:
sys.exit(0)
print("No such command: " + more_sql)
else:
sql = (sql + ' ' + more_sql).strip()
try:
display_query(url, sql.rstrip(';'), context, args)
except DruidSqlException as e:
e.write_to(sys.stdout)
except KeyboardInterrupt:
sys.stdout.write("Query interrupted\n")
sys.stdout.flush()
try:
main()
except DruidSqlException as e:
e.write_to(sys.stderr)
sys.exit(1)
except KeyboardInterrupt:
sys.exit(1)
except IOError as e:
if e.errno == errno.EPIPE:
sys.exit(1)
else:
raise e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment