Created
January 3, 2018 01:52
-
-
Save gianm/43d8bae20311877cd429f19676c9abb8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# Copyright 2017 Imply Data, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import print_function | |
import argparse | |
import base64 | |
import collections | |
import csv | |
import errno | |
import json | |
import numbers | |
import re | |
import ssl | |
import sys | |
import time | |
import unicodedata | |
import urllib2 | |
class DruidSqlException(Exception): | |
def write_to(self, f): | |
f.write('\x1b[31m') | |
f.write(self.message if self.message else "Query failed") | |
f.write('\x1b[0m') | |
f.write('\n') | |
f.flush() | |
def do_query(url, sql, context, timeout, user, password, ignore_ssl_verification, ca_file, ca_path): | |
json_decoder = json.JSONDecoder(object_pairs_hook=collections.OrderedDict) | |
try: | |
sql_json = json.dumps({'query' : sql, 'context' : context}) | |
# SSL stuff | |
ssl_context = None; | |
if (ignore_ssl_verification or ca_file != None or ca_path != None): | |
ssl_context = ssl.create_default_context() | |
if (ignore_ssl_verification): | |
ssl_context.check_hostname = False | |
ssl_context.verify_mode = ssl.CERT_NONE | |
else: | |
ssl_context.load_verify_locations(cafile=ca_file, capath=ca_path) | |
req = urllib2.Request(url, sql_json, {'Content-Type' : 'application/json'}) | |
if timeout <= 0: | |
timeout = None | |
if (user and password): | |
basicAuthEncoding = base64.b64encode('%s:%s' % (user, password)) | |
req.add_header("Authorization", "Basic %s" % basicAuthEncoding) | |
response = urllib2.urlopen(req, None, timeout, context=ssl_context) | |
first_chunk = True | |
eof = False | |
buf = '' | |
while not eof or len(buf) > 0: | |
while True: | |
try: | |
# Remove starting ',' | |
buf = buf.lstrip(',') | |
obj, sz = json_decoder.raw_decode(buf) | |
yield obj | |
buf = buf[sz:] | |
except ValueError as e: | |
# Maybe invalid JSON, maybe partial object; it's hard to tell with this library. | |
if eof and buf.rstrip() == ']': | |
# Stream done and all objects read. | |
buf = '' | |
break | |
elif eof or len(buf) > 256 * 1024: | |
# If we read more than 256KB or if it's eof then report the parse error. | |
raise e | |
else: | |
# Stop reading objects, get more from the stream instead. | |
break | |
# Read more from the http stream | |
if not eof: | |
chunk = response.read(8192) | |
if chunk: | |
buf = buf + chunk | |
if first_chunk: | |
# Remove starting '[' | |
buf = buf.lstrip('[') | |
else: | |
# Stream done. Keep reading objects out of buf though. | |
eof = True | |
except urllib2.URLError as e: | |
raise_friendly_error(e) | |
def raise_friendly_error(e): | |
if isinstance(e, urllib2.HTTPError): | |
text = e.read().strip() | |
error_obj = {} | |
try: | |
error_obj = dict(json.loads(text)) | |
except: | |
pass | |
if e.code == 500 and 'errorMessage' in error_obj: | |
error_text = '' | |
if error_obj['error'] != 'Unknown exception': | |
error_text = error_text + error_obj['error'] + ': ' | |
if error_obj['errorClass']: | |
error_text = error_text + str(error_obj['errorClass']) + ': ' | |
error_text = error_text + str(error_obj['errorMessage']) | |
if error_obj['host']: | |
error_text = error_text + ' (' + str(error_obj['host']) + ')' | |
raise DruidSqlException(error_text) | |
else: | |
raise DruidSqlException("HTTP Error {0}: {1}\n{2}".format(e.code, e.reason, text)) | |
else: | |
raise DruidSqlException(str(e)) | |
def to_utf8(value): | |
if value is None: | |
return "" | |
elif isinstance(value, unicode): | |
return value.encode("utf-8") | |
else: | |
return str(value) | |
def to_tsv(values, delimiter): | |
return delimiter.join(to_utf8(v).replace(delimiter, '') for v in values) | |
def print_csv(rows, header): | |
csv_writer = csv.writer(sys.stdout) | |
first = True | |
for row in rows: | |
if first and header: | |
csv_writer.writerow(list(to_utf8(k) for k in row.keys())) | |
first = False | |
values = [] | |
for key, value in row.iteritems(): | |
values.append(to_utf8(value)) | |
csv_writer.writerow(values) | |
def print_tsv(rows, header, tsv_delimiter): | |
first = True | |
for row in rows: | |
if first and header: | |
print(to_tsv(row.keys(), tsv_delimiter)) | |
first = False | |
values = [] | |
for key, value in row.iteritems(): | |
values.append(value) | |
print(to_tsv(values, tsv_delimiter)) | |
def print_json(rows): | |
for row in rows: | |
print(json.dumps(row)) | |
def table_to_printable_value(value): | |
# Unicode string, trimmed with control characters removed | |
if value is None: | |
return u"NULL" | |
else: | |
return to_utf8(value).strip().decode('utf-8').translate(dict.fromkeys(range(32))) | |
def table_compute_string_width(v): | |
normalized = unicodedata.normalize('NFC', v) | |
width = 0 | |
for c in normalized: | |
ccategory = unicodedata.category(c) | |
cwidth = unicodedata.east_asian_width(c) | |
if ccategory == 'Cf': | |
# Formatting control, zero width | |
pass | |
elif cwidth == 'F' or cwidth == 'W': | |
# Double-wide character, prints in two columns | |
width = width + 2 | |
else: | |
# All other characters | |
width = width + 1 | |
return width | |
def table_compute_column_widths(row_buffer): | |
widths = None | |
for values in row_buffer: | |
values_widths = [table_compute_string_width(v) for v in values] | |
if not widths: | |
widths = values_widths | |
else: | |
i = 0 | |
for v in values: | |
widths[i] = max(widths[i], values_widths[i]) | |
i = i + 1 | |
return widths | |
def table_print_row(values, column_widths, column_types): | |
vertical_line = u'\u2502'.encode('utf-8') | |
for i in xrange(0, len(values)): | |
padding = ' ' * max(0, column_widths[i] - table_compute_string_width(values[i])) | |
if column_types and column_types[i] == 'n': | |
print(vertical_line + ' ' + padding + values[i].encode('utf-8') + ' ', end="") | |
else: | |
print(vertical_line + ' ' + values[i].encode('utf-8') + padding + ' ', end="") | |
print(vertical_line) | |
def table_print_header(values, column_widths): | |
# Line 1 | |
left_corner = u'\u250C'.encode('utf-8') | |
horizontal_line = u'\u2500'.encode('utf-8') | |
top_tee = u'\u252C'.encode('utf-8') | |
right_corner = u'\u2510'.encode('utf-8') | |
print(left_corner, end="") | |
for i in xrange(0, len(column_widths)): | |
print(horizontal_line * max(0, column_widths[i] + 2), end="") | |
if i + 1 < len(column_widths): | |
print(top_tee, end="") | |
print(right_corner) | |
# Line 2 | |
table_print_row(values, column_widths, None) | |
# Line 3 | |
left_tee = u'\u251C'.encode('utf-8') | |
cross = u'\u253C'.encode('utf-8') | |
right_tee = u'\u2524'.encode('utf-8') | |
print(left_tee, end="") | |
for i in xrange(0, len(column_widths)): | |
print(horizontal_line * max(0, column_widths[i] + 2), end="") | |
if i + 1 < len(column_widths): | |
print(cross, end="") | |
print(right_tee) | |
def table_print_bottom(column_widths): | |
left_corner = u'\u2514'.encode('utf-8') | |
right_corner = u'\u2518'.encode('utf-8') | |
bottom_tee = u'\u2534'.encode('utf-8') | |
horizontal_line = u'\u2500'.encode('utf-8') | |
print(left_corner, end="") | |
for i in xrange(0, len(column_widths)): | |
print(horizontal_line * max(0, column_widths[i] + 2), end="") | |
if i + 1 < len(column_widths): | |
print(bottom_tee, end="") | |
print(right_corner) | |
def table_print_row_buffer(row_buffer, column_widths, column_types): | |
first = True | |
for values in row_buffer: | |
if first: | |
table_print_header(values, column_widths) | |
first = False | |
else: | |
table_print_row(values, column_widths, column_types) | |
def print_table(rows): | |
start = time.time() | |
nrows = 0 | |
first = True | |
# Buffer some rows before printing. | |
rows_to_buffer = 500 | |
row_buffer = [] | |
column_types = [] | |
column_widths = None | |
for row in rows: | |
nrows = nrows + 1 | |
if first: | |
row_buffer.append([table_to_printable_value(k) for k in row.keys()]) | |
for k in row.keys(): | |
if isinstance(row[k], numbers.Number): | |
column_types.append('n') | |
else: | |
column_types.append('s') | |
first = False | |
values = [table_to_printable_value(v) for k, v in row.iteritems()] | |
if rows_to_buffer > 0: | |
row_buffer.append(values) | |
rows_to_buffer = rows_to_buffer - 1 | |
else: | |
if row_buffer: | |
column_widths = table_compute_column_widths(row_buffer) | |
table_print_row_buffer(row_buffer, column_widths, column_types) | |
del row_buffer[:] | |
table_print_row(values, column_widths, column_types) | |
if row_buffer: | |
column_widths = table_compute_column_widths(row_buffer) | |
table_print_row_buffer(row_buffer, column_widths, column_types) | |
if column_widths: | |
table_print_bottom(column_widths) | |
print("Retrieved {0:,d} row{1:s} in {2:.2f}s.".format(nrows, 's' if nrows != 1 else '', time.time() - start)) | |
print("") | |
def display_query(url, sql, context, args): | |
rows = do_query(url, sql, context, args.timeout, args.user, args.password, args.ignore_ssl_verification, args.cafile, args.capath) | |
if args.format == 'csv': | |
print_csv(rows, args.header) | |
elif args.format == 'tsv': | |
print_tsv(rows, args.header, args.tsv_delimiter) | |
elif args.format == 'json': | |
print_json(rows) | |
elif args.format == 'table': | |
print_table(rows) | |
def sql_escape(s): | |
if s is None: | |
return "''" | |
elif isinstance(s, unicode): | |
ustr = s | |
else: | |
ustr = str(s).decode('utf-8') | |
escaped = [u"U&'"] | |
for c in ustr: | |
ccategory = unicodedata.category(c) | |
if ccategory.startswith('L') or ccategory.startswith('N') or c == ' ': | |
escaped.append(c) | |
else: | |
escaped.append(u'\\') | |
escaped.append('%04x' % ord(c)) | |
escaped.append("'") | |
return ''.join(escaped) | |
def main(): | |
parser = argparse.ArgumentParser(description='Druid SQL command-line client.') | |
parser.add_argument('--host', '-H', type=str, default='http://localhost:8082/', help='Broker host or url') | |
parser.add_argument('--timeout', type=int, default=0, help='Timeout in seconds, 0 for no timeout') | |
parser.add_argument('--format', type=str, default='table', choices=('csv', 'tsv', 'json', 'table'), help='Result format') | |
parser.add_argument('--header', action='store_true', help='Include header row for formats "csv" and "tsv"') | |
parser.add_argument('--tsv-delimiter', type=str, default='\t', help='Delimiter for format "tsv"') | |
parser.add_argument('--context-option', '-c', type=str, action='append', help='Set context option for this connection') | |
parser.add_argument('--execute', '-e', type=str, help='Execute single SQL query') | |
parser.add_argument('--user', '-u', type=str, help='Username for HTTP basic auth') | |
parser.add_argument('--password', '-p', type=str, help='Password for HTTP basic auth') | |
parser.add_argument('--ignore-ssl-verification', '-k', action='store_true', default=False, help='Skip verification of SSL certificates.') | |
parser.add_argument('--cafile', type=str, help='Path to SSL CA file for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') | |
parser.add_argument('--capath', type=str, help='SSL CA path for validating server certificates. See load_verify_locations() in https://docs.python.org/2/library/ssl.html#ssl.SSLContext.') | |
args = parser.parse_args() | |
# Build broker URL | |
url = args.host.rstrip('/') + '/druid/v2/sql/' | |
if not url.startswith('http:') and not url.startswith('https:'): | |
url = 'http://' + url | |
# Build context | |
context = {} | |
if args.context_option: | |
for opt in args.context_option: | |
kv = opt.split("=", 1) | |
if len(kv) != 2: | |
raise ValueError('Invalid context option, should be key=value: ' + opt) | |
if re.match(r"^\d+$", kv[1]): | |
context[kv[0]] = long(kv[1]) | |
else: | |
context[kv[0]] = kv[1] | |
if args.execute: | |
display_query(url, args.execute, context, args) | |
else: | |
# interactive mode | |
print("Welcome to dsql, the command-line client for Druid SQL.") | |
print("Type \"\h\" for help.") | |
while True: | |
sql = '' | |
while not sql.endswith(';'): | |
prompt = "dsql> " if sql == '' else 'more> ' | |
try: | |
more_sql = raw_input(prompt) | |
except EOFError: | |
sys.stdout.write('\n') | |
sys.exit(1) | |
if sql == '' and more_sql.startswith('\\'): | |
# backslash command | |
dmatch = re.match(r'^\\d(S?)(\+?)(\s+.*?|)\s*$', more_sql) | |
if dmatch: | |
include_system = dmatch.group(1) | |
extra_info = dmatch.group(2) | |
arg = dmatch.group(3).strip() | |
if arg: | |
sql = "SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = " + sql_escape(arg) | |
if not include_system: | |
sql = sql + " AND TABLE_SCHEMA = 'druid'" | |
# break to execute sql | |
break | |
else: | |
sql = "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES"; | |
if not include_system: | |
sql = sql + " WHERE TABLE_SCHEMA = 'druid'" | |
# break to execute sql | |
break | |
hmatch = re.match(r'^\\h\s*$', more_sql) | |
if hmatch: | |
print("Commands:") | |
print(" \d show tables") | |
print(" \dS show tables, including system tables") | |
print(" \d table_name describe table") | |
print(" \h show this help") | |
print(" \q exit this program") | |
print("Or enter a SQL query ending with a semicolon (;).") | |
continue | |
qmatch = re.match(r'^\\q\s*$', more_sql) | |
if qmatch: | |
sys.exit(0) | |
print("No such command: " + more_sql) | |
else: | |
sql = (sql + ' ' + more_sql).strip() | |
try: | |
display_query(url, sql.rstrip(';'), context, args) | |
except DruidSqlException as e: | |
e.write_to(sys.stdout) | |
except KeyboardInterrupt: | |
sys.stdout.write("Query interrupted\n") | |
sys.stdout.flush() | |
try: | |
main() | |
except DruidSqlException as e: | |
e.write_to(sys.stderr) | |
sys.exit(1) | |
except KeyboardInterrupt: | |
sys.exit(1) | |
except IOError as e: | |
if e.errno == errno.EPIPE: | |
sys.exit(1) | |
else: | |
raise e |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment