Last active
July 21, 2021 01:57
-
-
Save hellovertex/b42be700300084a273d8507428e9ae7e to your computer and use it in GitHub Desktop.
tgcxz
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
# acronyms will be changed to lowercase for get-method calls in order to conform with pep-style | |
indicators = [ | |
'ABANDS', # Acceleration Bands | |
'AD', # Accumulation/Distribution | |
'ADX', # Average Directional Movement | |
'AMA', # Adaptive Moving Average | |
'APO', # Absolute Price Oscillator | |
'AR', # Aroon | |
'ARO', # Aroon Oscillator | |
'ATR', # Average True Range | |
'AVOL', # Volume on the Ask | |
'BAVOL', # Volume on the Bid and Ask | |
'BBANDS', # Bollinger Bands | |
'BVA', # Bar Value Area | |
'BVOL', # Bid Volume | |
'BW', # Band Width | |
'CCI', # Commodity Channel Index | |
'CMO', # Chande Momentum Oscillator | |
'DEMA', # Double Exponential Moving Average | |
'DMI', # Directional Movement Indicators | |
'EMA', # Exponential Moving Average | |
'FILL', # Fill Indicator | |
'ICH', # Ichimoku | |
'KC', # Keltner Channel | |
'LR', # Linear Regression | |
'LRA', # Linear Regression Angle | |
'LRI', # Linear Regression Intercept | |
'LRM', # Linear Regression Slope | |
'MACD', # Moving Average Convergence Divergence | |
'MAX', # Max | |
'MFI', # Money Flow Index | |
'MIDPNT' # Midpoint | |
'MIDPRI', # Midprice | |
'MIN', # Min | |
'MINMAX', # MinMax | |
'MOM', # Momentum | |
'NATR', # Normalized True Average | |
'OBV', # On Balance Volume | |
'PC', # Price Channel | |
'PLT', # PLOT | |
'PPO', # Percent Price Oscillator | |
'PVT', # Price Volume Trend | |
'ROC', # Rate of Change | |
'ROC100', # Rate of Change 100 | |
'ROCP', # Rate of Change Percentage | |
'ROCR', # Rate of Change Rate | |
'RSI', # Relative Strength Index | |
'S_VOL', # Session Volume | |
'SAR', # Parabolic Stop and Reverse | |
'SAVOL', # Session Cumulative Ask | |
'SBVOL', # Session Cumulative Bid | |
'SMA', # Simple Moving Average | |
'STDDEV', # Standard Deviation | |
'STOCH', # Stochastic | |
'StochF', # Stochastic Fast | |
'T3', # T3 | |
'TEMA', # Triple Exponential Moving Average | |
'TRIMA', # Triangular Moving Average | |
'TRIX', # Triple Exponential Moving Average Oscillator | |
'TSF', # Time Series Forecast | |
'TTCVD', # TT Cumulative Vol Delta | |
'ULTOSC', # Ultimate Oscillator | |
'VAP', # Volume at Price | |
'VOLUME', # Volume | |
'VolDel', # Volume Delta | |
'VWAP', # Volume Weighted Average Price | |
'WillR', # Williams % R | |
'WMA', # Weighted Moving Average | |
'WWS', # Welles Wilders Smoothing Average | |
] | |
# indicator default params | |
params_macd = {'fast_length': 12,'slow_length': 26,'smoothing': 9} | |
params_rsi = {'length': 14,'source': 'closep'} | |
params_sma = {'length': 9,'source': 'closep'} | |
params_stoch = {'k': 14,'d': 3,'smooth': 3} | |
params_willr = {'length': 14} | |
params_bbands = {'length': 20,'source': 'closep','stddev': 2.0,'offset': 0} | |
def get_rsi(ohlcv_data: pd.DataFrame, length: int = 14, source: str = 'closep') -> pd.DataFrame: | |
""" | |
Computes Relative Strength Indices, c.f. https://www.tradingview.com/wiki/Relative_Strength_Index_(RSI). | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have keys 'timestamp' and 'source'. The latter storing the prices. | |
:param length: window size for RSI calculation | |
:param source: Determines which price is used for computation. Default is close price. | |
:return: Dataframe with Index(['timestamp', 'date', f'RSI({length}, {source})'] | |
""" | |
# temporary | |
up = [] | |
down = [] | |
timestamps = ohlcv_data['timestamp'] | |
dts = ohlcv_data['date'] | |
avg_up = [] | |
avg_down = [] | |
rs = [] | |
rsis = [] | |
for i in range(1, len(ohlcv_data)): | |
# calculate current price change | |
curr_price = float(ohlcv_data[source][i]) | |
prev_price = float(ohlcv_data[source][i - 1]) | |
d = curr_price - prev_price | |
if d >= 0: | |
up.append(d) | |
down.append(0) | |
else: | |
down.append(-d) | |
up.append(0) | |
# append moving average | |
if i == length: | |
# First use simple moving average | |
avg_up.append(float(sum(up) / length)) | |
avg_down.append(float(sum(down) / length)) | |
if i > length: | |
# Then use previous moving average | |
# set indices | |
prev = i - length - 1 | |
curr = i - 1 | |
mvg_avg_up = (avg_up[prev] * (length - 1) + up[curr]) / length | |
mvg_avg_down = (avg_down[prev] * (length - 1) + down[curr]) / length | |
avg_up.append(mvg_avg_up) | |
avg_down.append(mvg_avg_down) | |
# calculate RSI | |
if mvg_avg_down == 0: | |
mvg_avg_down = 0.001 | |
r = mvg_avg_up / mvg_avg_down | |
rsi = 100 - (100 / (r + 1)) | |
rs.append(r) | |
rsis.append(rsi) | |
# to avoid the rsi to be set-off, we must cut the first length entries for which no rsi is available, due to how | |
# it is calculated | |
return pd.DataFrame(list(zip(timestamps[length+1:], dts[length+1:], rsis)), | |
columns=['timestamp', 'date', f'RSI({length}, {source})']) | |
def get_sma(ohlcv_data: pd.DataFrame, length: int = 9, source: str = 'closep', offset: int=0): | |
""" | |
Computes Simple Moving Averages: | |
c.f. https://www.tradingview.com/wiki/Moving_Average | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have keys 'timestamp' and 'source'. The latter storing the prices. | |
:param length: time window for moving average. Default is 9 | |
:param source: Determines which price is used for computation. Default is close price. | |
:param offset: Changing this number will move the Moving Average relative to the current market. 0 is the default. | |
:return: Dataframe with Index(['timestamp', 'date', f'SMA({length})'] | |
""" | |
ohlcv_data[f'SMA({length}, {source})'] = ohlcv_data.rolling(window=length)[source].mean() | |
return ohlcv_data[length:][['timestamp', 'date', f'SMA({length}, {source})']] | |
def get_willr(ohlcv_data: pd.DataFrame, length: int=14, high='highp', low='lowp', close='closep') -> pd.DataFrame: | |
""" | |
Computes Stochastic Oscillator Indicator: | |
c.f. https://www.tradingview.com/wiki/Williams_%25R_(%25R)#DEFINITION | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have key 'timestamp' | |
:param length: look-back period for highest highs and lowest lows | |
:param high: Name of the column containing high price | |
:param low: Name of the column containing low price | |
:param close: Name of the column containing close price | |
:return: Dataframe with Index(['date, timestamp', f'WILLIAMS%R({length})'] | |
""" | |
cp = pd.DataFrame() | |
cp['highest_high'] = ohlcv_data.rolling(window=length)[high].max() | |
cp['lowest_low'] = ohlcv_data.rolling(window=length)[low].min() | |
tmp1 = cp['highest_high'] - ohlcv_data[close] | |
tmp2 = cp['highest_high'] - cp['lowest_low'] | |
ohlcv_data[f'WILLIAMS%R({length})'] = (tmp1 / tmp2) * (-100) | |
cols = ['date', 'timestamp', f'WILLIAMS%R({length})'] | |
return ohlcv_data[length:][cols] | |
def get_ema(ohlcv_data: pd.DataFrame, length: int=9, source: str='closep', offset: int=0) -> pd.DataFrame: | |
""" | |
Computes Exponential Moving Averages: | |
c.f. https://www.tradingview.com/wiki/Moving_Average#Exponential_Moving_Average_.28EMA.29 | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have keys 'timestamp' and 'source'. The latter storing the prices. | |
:param length: time window for moving average | |
:param source: Determines which price is used for computation. Default is close price. | |
:param offset: Changing this number will move the Moving Average relative to the current market. 0 is the default. | |
:return: Dataframe with Index(['timestamp', 'date', f'EMA({length})'] | |
""" | |
# 0. return values | |
ema = [] | |
timestamps = ohlcv_data['timestamp'][length-1+offset:] | |
datetimes = ohlcv_data['date'][length-1+offset:] | |
# 1. Calculate SMA | |
sma = ohlcv_data[source][:length].mean() | |
ema.append(sma) | |
# 2. Calculate the Multiplier | |
mul = 2.0/(length+1) | |
# 3. Calculate the EMA | |
for i in range(0, len(ohlcv_data[length:])): | |
prev_ema = ema[i] | |
curr_p = ohlcv_data[source][length + i] | |
curr_ema = (curr_p - prev_ema) * mul + prev_ema | |
ema.append(curr_ema) | |
return pd.DataFrame(list(zip(timestamps, datetimes, ema)), columns=['timestamp', 'date', f'EMA({length})']) | |
def get_macd(ohlcv_data: pd.DataFrame, fast_length: int=12, slow_length=26, source: str='closep', smoothing: int=9) \ | |
-> pd.DataFrame: | |
""" | |
Computes Moving Average Convergence Divergence. | |
C.f. https://www.tradingview.com/wiki/MACD_(Moving_Average_Convergence/Divergence) | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have keys 'timestamp' and 'source'. The latter storing the prices. | |
:param fast_length: time window for fast moving average | |
:param slow_length: time window for slow moving average | |
:param source: Determines which price is used for computation. Default is close price. | |
:param smoothing: The time period for the EMA of the MACD Line (Signal Line). 9 Days is the default. | |
:return: Dataframe with | |
Index( | |
[ | |
'timestamp', | |
'date', | |
f'MACD({fast_length}, {slow_length}, {smoothing})', | |
f'Signal({smoothing})', | |
'MACD_Histogram' | |
] | |
""" | |
# calculate Moving Averages | |
fast_EMA = get_ema(ohlcv_data, length=fast_length, source=source) | |
slow_EMA = get_ema(ohlcv_data, length=slow_length, source=source) | |
tmp = pd.merge(ohlcv_data, fast_EMA) | |
merged = pd.merge(tmp, slow_EMA) | |
# calculate MACD | |
idx_fast = f'EMA({fast_length})' # see return interface of get_ema(...) | |
idx_slow = f'EMA({slow_length})' # see return interface of get_ema(...) | |
idx_MACD = f'MACD({fast_length}, {slow_length})' | |
merged[idx_MACD] = merged[idx_fast] - merged[idx_slow] | |
# calculate Signal | |
signal_raw = get_ema(merged, length=smoothing, source=f'MACD({fast_length}, {slow_length})') | |
signal = signal_raw.rename(index=str, columns={f'EMA({smoothing})': f'Signal({smoothing})'}) | |
result = pd.merge(merged, signal) | |
# calculate MACD-histogram | |
idx_MACD = f'MACD({fast_length}, {slow_length})' | |
idx_signal = f'Signal({smoothing})' | |
result['MACD_Histogram'] = result[idx_MACD] - result[idx_signal] | |
# remove unnecessary columns | |
cols = [ | |
'timestamp', | |
'date', | |
f'MACD({fast_length}, {slow_length})', | |
f'Signal({smoothing})', | |
'MACD_Histogram' | |
] | |
return result[cols] | |
def get_stoch(ohlcv_data: pd.DataFrame, k: int=14, d: int=3, smooth: int=3, high='highp', low='lowp', | |
close='closep') -> pd.DataFrame: | |
""" | |
Computes Stochastic Oscillator Indicator: | |
c.f. https://www.tradingview.com/wiki/Stochastic_(STOCH)#INPUTS | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have key 'timestamp' | |
:param k: time window for moving average | |
:param d: Determines which price is used for computation. Default is close price. | |
:param smooth: Changing this number will move the Moving Average relative to the current market. 0 is the default. | |
:param high: Name of the column containing high price | |
:param low: Name of the column containing low price | |
:param close: Name of the column containing close price | |
:return: Dataframe with Index(['timestamp', 'date', f'STOCH_K={k}', f'STOCH_D={d}'] | |
""" | |
cp = pd.DataFrame() | |
cp['highest_high'] = ohlcv_data.rolling(window=k)[high].max() | |
cp['lowest_low'] = ohlcv_data.rolling(window=k)[low].min() | |
# fast stochastics | |
cp['tmp1'] = ohlcv_data[close] - cp['lowest_low'] | |
cp['tmp2'] = cp['highest_high'] - cp['lowest_low'] | |
ohlcv_data['%K'] = (cp['tmp1'] / cp['tmp2']) * 100 | |
ohlcv_data[f'STOCH_K={k}'] = ohlcv_data.rolling(window=smooth)['%K'].mean() | |
# slow stochastics | |
ohlcv_data[f'STOCH_D={d}'] = ohlcv_data.rolling(window=d)[f'STOCH_K={k}'].mean() | |
cols = ['timestamp', 'date', f'STOCH_K={k}', f'STOCH_D={d}'] | |
return ohlcv_data[k+d:][cols] | |
def get_bbands(ohlcv_data: pd.DataFrame, length: int=20, source: str='closep', stddev: float=2.0, offset: int=0) -> \ | |
pd.DataFrame: | |
""" | |
Computes Bollinger Bands: | |
c.f. https://www.tradingview.com/wiki/Bollinger_Bands_(BB)#INPUTS | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have keys 'timestamp' and 'source'. The latter storing the prices. | |
:param length: time window for moving average | |
:param source: Determines which price is used for computation. Default is close price. | |
:param stddev: distance to the middle band. Default is 2. | |
:param offset: Changing this number will move the Moving Average relative to the current market. 0 is the default. | |
:return: Dataframe with Index(['timestamp', 'date', middleband, lowerband, upperband]) | |
""" | |
# compute deviation from mean in look back period | |
ohlcv_data[f'MIDDLE_BBAND({length}, {source}, {stddev})'] = ohlcv_data.rolling(window=20)[source].mean() | |
# compute standard deviation from mean in look back period | |
ohlcv_data['std_dev'] = stddev * ohlcv_data.rolling(window=20)[source].std() | |
# compute bollinger bands | |
ohlcv_data[f'UPPER_BBAND({length}, {source}, {stddev})'] = ohlcv_data[f'MIDDLE_BBAND({length}, {source}, ' \ | |
f'{stddev})'] + ohlcv_data['std_dev'] | |
ohlcv_data[f'LOWER_BBAND({length}, {source}, {stddev})'] = ohlcv_data[f'MIDDLE_BBAND({length}, {source}, ' \ | |
f'{stddev})'] - ohlcv_data['std_dev'] | |
cols = [ | |
'timestamp', | |
'date', | |
f'MIDDLE_BBAND({length}, {source}, {stddev})', | |
f'UPPER_BBAND({length}, {source}, {stddev})', | |
f'LOWER_BBAND({length}, {source}, {stddev})' | |
] | |
return ohlcv_data[length+offset:][cols] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import string | |
import logging | |
import argparse | |
import time | |
from collections import Counter | |
import sys | |
from telegram.client import Telegram | |
import csv | |
from datetime import datetime | |
import numpy as np | |
import os | |
def timestamp_to_date(ts: int) -> str: | |
return datetime.utcfromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S') | |
def setup_logging(level=logging.INFO): | |
root = logging.getLogger() | |
root.setLevel(level) | |
ch = logging.StreamHandler(sys.stdout) | |
ch.setLevel(logging.DEBUG) | |
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') | |
ch.setFormatter(formatter) | |
root.addHandler(ch) | |
if __name__ == '__main__': | |
setup_logging(level=logging.INFO) | |
parser = argparse.ArgumentParser() | |
parser.add_argument('api_id', help='API id', type=int) # https://my.telegram.org/apps | |
parser.add_argument('api_hash', help='API hash') | |
parser.add_argument('phone', help='Phone') | |
parser.add_argument('--lib_path', help='path to libtdjson.so.1.6.0', type=str, | |
default='/home/hellovertex/Documents/github.com/tdlib/td/tdlib/lib/libtdjson.so.1.6.0') | |
args = parser.parse_args() | |
print(f'api_id = {args.api_id}, type={type(args.api_id)}') | |
print(f'api_hash = {args.api_hash}, type={type(args.api_hash)}') | |
print(f'phone = {args.phone}, type={type(args.phone)}') | |
tg = Telegram( | |
api_id=args.api_id, | |
api_hash=args.api_hash, | |
phone=args.phone, | |
library_path=args.lib_path, | |
database_encryption_key='hanabi', | |
) | |
# you must call login method before others | |
tg.login() | |
def get_all_chat_ids(): | |
offset_order = 9223372036854775807 | |
offset_chat_id = 0 | |
chat_ids = list() | |
old_len = -1 | |
while True: | |
# get chat_ids from current offset_order, offset_chat_id | |
res = tg.get_chats(offset_order=offset_order, offset_chat_id=offset_chat_id, limit=100) | |
res.wait() | |
chat_ids += res.update['chat_ids'] | |
# remove duplicates from chat_ids | |
chat_ids = list(dict.fromkeys(chat_ids)) | |
# get id, order from last chat | |
res = tg.get_chat(chat_ids[-1]) | |
res.wait() | |
last_chat = res.update | |
# update offset_order, offset_chat_id | |
offset_chat_id = last_chat['id'] | |
offset_order = last_chat['order'] | |
new_len = len(chat_ids) | |
print(f'old_len = {old_len}, new_len = {new_len}') | |
if old_len == new_len: | |
print(f'Fetched {len(chat_ids)} chats.') | |
break | |
old_len = new_len | |
return chat_ids | |
def extract_supergroup_chats(chat_ids): | |
""" chat keys | |
dict_keys( | |
['@type', 'id', 'type', 'chat_list', 'title', 'photo', 'permissions', 'last_message', 'order', | |
'is_pinned', 'is_marked_as_unread', 'is_sponsored', 'has_scheduled_messages', | |
'can_be_deleted_only_for_self', 'can_be_deleted_for_all_users', 'can_be_reported', | |
'default_disable_notification', 'unread_count', 'last_read_inbox_message_id', | |
'last_read_outbox_message_id', 'unread_mention_count', 'notification_settings', | |
'pinned_message_id', 'reply_markup_message_id', 'client_data', '@extra']) | |
""" | |
crypto_chat_candidates = list() | |
for id in chat_ids: | |
res = tg.get_chat(id) | |
res.wait() | |
chat = res.update | |
if chat['type']['@type'] == 'chatTypeSupergroup': | |
if not chat['type']['is_channel']: | |
crypto_chat_candidates.append(chat) | |
return crypto_chat_candidates | |
def get_chat_history(chat_obj): | |
""" supergroup keys | |
dict_keys( | |
['@type', 'description', 'member_count', 'administrator_count', 'restricted_count', 'banned_count', | |
'linked_chat_id', 'slow_mode_delay', 'slow_mode_delay_expires_in', 'can_get_members', | |
'can_set_username', 'can_set_sticker_set', 'can_set_location', 'can_view_statistics', | |
'is_all_history_available', 'sticker_set_id', 'invite_link', 'upgraded_from_basic_group_id', | |
'upgraded_from_max_message_id', '@extra']) | |
""" | |
def _get_history(id): | |
from_message_id = 0 | |
offset = 0 | |
messages = list() | |
completed = False | |
old_ids = [] | |
# loop | |
while not completed: | |
res = tg.get_chat_history(id, limit=100, from_message_id=from_message_id, offset=0) | |
res.wait() | |
msgs = res.update['messages'] | |
if not msgs: | |
break | |
first = msgs[0] | |
last = msgs[-1] | |
from_message_id = last['id'] | |
print(f'number of messages in buffer = {len(msgs)}') | |
print(f'date 1st: {timestamp_to_date(first["date"])} -- date last = {timestamp_to_date(last["date"])}') | |
ids = np.array([msg['id'] for msg in msgs]) | |
have_duplicates = bool(int(np.sum(ids == old_ids))) | |
assert not have_duplicates | |
old_ids = ids | |
# todo consider writing directly to file instead | |
messages += msgs | |
print(f'total number of messages fetched = {len(messages)}') | |
time.sleep(.5) | |
return messages | |
# get supergroup id from chat_obj | |
supergroup_id = chat_obj['type']['supergroup_id'] | |
is_channel = chat_obj['type']['is_channel'] | |
# get supergroup info by supergroup id | |
res = tg.get_supergroup_full_info(supergroup_id) | |
res.wait() | |
group = res.update | |
hist = None | |
if group['is_all_history_available'] and group['member_count'] >= 500 and not is_channel: | |
# get chat_history | |
hist = _get_history(chat_obj['id']) | |
else: | |
skip_reason = f'group member count is too low: {group["member_count"]}' if group[ | |
'is_all_history_available'] else f'history is not available or group is channel' | |
print('...Skipping chat because ' + skip_reason) | |
return hist | |
chat_ids = get_all_chat_ids() | |
crypto_chats = extract_supergroup_chats(chat_ids) | |
for chat in crypto_chats: | |
# If no csv file for current history exists yet | |
if not os.path.isfile(f'{chat["title"]}.csv'): | |
print(f'getting history for chat: {chat["title"]}...') | |
hist = get_chat_history(chat) | |
if hist: | |
hist[0]['forward_info'] = '' | |
hist[0]['reply_markup'] = '' | |
keys = hist[0].keys() | |
# Create csv file for chat history | |
with open(f'{chat["title"]}.csv', 'w') as output_file: | |
print(f'Writing history of {chat["title"]} to .csv file...') | |
dict_writer = csv.DictWriter(output_file, keys) | |
dict_writer.writeheader() | |
dict_writer.writerows(hist) | |
else: | |
print(f'...History for chat {chat["title"]} is None') | |
else: | |
print(f'Skipping group {chat["title"]} because .csv file exists already') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
from collections import namedtuple | |
import sqlite3 | |
import pandas as pd | |
import numpy as np | |
import enum | |
from src import database | |
from src.indicators import get_rsi as rsi, get_willr as willr, get_macd as macd, get_stoch as stoch | |
from typing import Dict, List, Tuple, Optional # , TypedDict | |
from src import env_wrapper | |
STR_SYMBOL = str | |
DF_MARKET_DATA = pd.DataFrame # refers to only-OHLCV data | |
DF_X10DMARKET_DATA = pd.DataFrame # refers to OHLCV-data + Indicators | |
Trade = namedtuple('Trade', ['t_step', | |
'datestring', | |
'symbol', | |
'symbol_idx', | |
'type', | |
'buy_amt', | |
'sell_percent', | |
'price', | |
'net_value', | |
'info']) | |
class ActionType(enum.IntEnum): | |
NOOP = 0 | |
BUY = 1 | |
SELL = 2 | |
class ActionParser(object): | |
INSTRUMENT = Optional[int] | |
AMOUNT = int | |
""" Needs rework """ | |
def __init__(self, n_instruments_tradeable, | |
n_possible_buy_amts, | |
n_possible_sell_amts, | |
buy_amts, | |
sell_amts_percent): | |
""" | |
Implicit encoding: | |
action Instr in [A,B] Buys = 2 Sells = 3 Action % 2 Action % 3 floordiv2 floordiv3 | |
0 A 5 0 0 | |
1 A 10 1 | |
2 B 5 0 1 | |
3 B 10 1 | |
4 A 25 0 0 | |
5 A 50 1 | |
6 A 100 2 | |
7 B 25 0 1 | |
8 B 50 1 | |
9 B 100 2 | |
10 NOOP | |
""" | |
self.n_buy_actions = n_instruments_tradeable * n_possible_buy_amts | |
self.n_sell_actions = n_instruments_tradeable * n_possible_sell_amts | |
self.noop = self.n_buy_actions + self.n_sell_actions | |
self._max_actions = self.n_buy_actions + self.n_sell_actions + 1 | |
# prior | |
self.n_instruments_tradeable = n_instruments_tradeable | |
self.n_possible_buy_amts = n_possible_buy_amts | |
self.n_possible_sell_amts = n_possible_sell_amts | |
self.buy_amts = buy_amts | |
self.sell_amts_percents = sell_amts_percent | |
@property | |
def max_actions(self): | |
return self._max_actions | |
def _get_action_type(self, action: int) -> ActionType: | |
""" | |
""" | |
# one buy action per instrument and buy amount | |
if 0 <= action < self.n_buy_actions: | |
return ActionType.BUY | |
elif self.n_buy_actions <= action < self.noop: | |
return ActionType.SELL | |
elif action == self.noop: | |
return ActionType.NOOP | |
else: | |
raise ValueError | |
def _get_target_instr(self, action: int, action_type: int) -> INSTRUMENT: | |
if action_type == ActionType.BUY: | |
return action // self.n_possible_buy_amts | |
elif action_type == ActionType.SELL: | |
return (action - self.n_buy_actions) // self.n_possible_sell_amts | |
else: | |
raise ValueError | |
def _get_target_amount(self, action: int, a_type: int) -> AMOUNT: | |
if a_type == ActionType.BUY: | |
return self.buy_amts[action % self.n_possible_buy_amts] | |
elif a_type == ActionType.SELL: | |
return self.sell_amts_percents[(action - self.n_buy_actions) % self.n_possible_sell_amts] | |
else: | |
raise ValueError | |
def parse_int_action(self, action: int) -> Tuple[ActionType, INSTRUMENT, AMOUNT]: | |
""" | |
Implicit encoding: | |
action Instr in [A,B] Buys = 2 Sells = 3 ActionParser % 2 ActionParser % 3 floordiv2 floordiv3 | |
0 A 5 0 0 | |
1 A 10 1 | |
2 B 5 0 1 | |
3 B 10 1 | |
4 A 25 0 0 | |
5 A 50 1 | |
6 A 100 2 | |
7 B 25 0 1 | |
8 B 50 1 | |
9 B 100 2 | |
10 NOOP | |
""" | |
assert 0 <= action <= self.noop, f'Expected action between 0 and {self.noop}. Got {action} instead.' | |
a_type = self._get_action_type(action) | |
if a_type == ActionType.NOOP: | |
return ActionType.NOOP, None, 0 | |
else: | |
target_instr = self._get_target_instr(action, a_type) | |
target_amt = self._get_target_amount(action, a_type) | |
return a_type, target_instr, target_amt | |
class CreditInfo(object): | |
INITIAL_CREDIT = 1000 | |
INITIAL_TIMESTEP = 0 | |
# INDICES FOR METADATA | |
REMAINING_CREDIT = -2 | |
CURRENT_TIMESTEP = -1 | |
DEFAULT_BUY_AMT = 50 | |
class Instrument(object): | |
""" Container for trade history """ | |
def __init__(self, str_value): | |
self.str_value = str_value | |
self._trade_history = list() # of Trade tuples | |
self._balance = 0 | |
self._net_value = 0 | |
self._net_gain = 0 | |
@property | |
def balance(self): | |
return self._balance | |
@balance.setter | |
def balance(self, value): | |
self._balance = value | |
@property | |
def net_gain(self): | |
return self._net_gain | |
@net_gain.setter | |
def net_gain(self, value): | |
self._net_gain = value | |
@property | |
def net_value(self): | |
return self._net_value | |
@net_value.setter | |
def net_value(self, value): | |
self._net_value = value | |
@property | |
def trade_history(self): | |
return self._trade_history | |
class Credit(object): | |
def __init__(self, symbols_per_state, taker_fee): | |
self.symbols_per_state = symbols_per_state | |
self.columns = [f'symbol_{i}' for i in range(symbols_per_state)] \ | |
+ ['remaining_credit'] + ['current_timestep'] | |
# instruments are mapped from shuffled symbols for variance reduction | |
self.instruments = [Instrument(f'symbol_{i}') for i in range(symbols_per_state)] | |
self.metadata = [CreditInfo.INITIAL_CREDIT, CreditInfo.INITIAL_TIMESTEP] | |
self.fee = taker_fee # = 998/1000 = 0.2% | |
def reset(self): | |
self.instruments = [Instrument(f'symbol_{i}') for i in range(self.symbols_per_state)] | |
self.metadata = [CreditInfo.INITIAL_CREDIT, CreditInfo.INITIAL_TIMESTEP] | |
@property | |
def data(self): | |
return [pair.balance for pair in self.instruments] + \ | |
[pair.net_value for pair in self.instruments] + \ | |
[pair.net_gain for pair in self.instruments] + self.metadata | |
def has_cash(self, amount): | |
return self.metadata[CreditInfo.REMAINING_CREDIT] >= amount | |
def is_available(self, target_symbol: int): | |
return self.instruments[target_symbol].balance > 0 | |
def apply_trade_and_compute_reward(self, trade: Trade) -> int: | |
""" Computes reward using simple gain per balance metric. | |
More specifically, the reward is returned as the difference of the | |
gain per balance before and after the action taken, if the action was a SELL action and 0 otherwise. """ | |
assert trade.type in [ActionType.BUY, ActionType.SELL] | |
target = self.instruments[trade.symbol_idx] | |
target.trade_history.append(trade) # we dont make use of the history yet | |
self.metadata[CreditInfo.CURRENT_TIMESTEP] = trade.t_step | |
profit = 0 | |
if trade.type == ActionType.BUY: | |
self.metadata[CreditInfo.REMAINING_CREDIT] -= trade.net_value | |
# (50$ / 11000$) * fee | |
target.balance += trade.buy_amt * self.fee # = 998/1000 = 0.2% | |
target.net_gain -= trade.net_value | |
target.net_value = target.balance * trade.price | |
elif trade.type == ActionType.SELL: | |
# for now, use simple gain per balance metric | |
old_value = target.net_gain / target.balance | |
sell_amt = trade.sell_percent * target.balance | |
target.balance -= sell_amt | |
gain = (trade.price * sell_amt) * self.fee # = 998/1000 = 0.2% | |
self.metadata[CreditInfo.REMAINING_CREDIT] += gain | |
target.net_gain += gain | |
target.net_value = target.balance * trade.price | |
if trade.sell_percent != 1.: | |
new_value = target.net_gain / target.balance | |
# gain_per_balance_old - gain_per_balance_new [These are negative values] | |
profit = old_value - new_value | |
else: | |
# when everything is sold, the gain_per_balance is equal to the net_gain | |
profit = target.net_gain | |
else: | |
raise ValueError | |
# print(trade) | |
# print(f'pairs = {[(p.balance, p.net_gain, p.net_value) for p in self.pairs]}') | |
# print(f'pairshistory = {[p.trade_history for p in self.pairs]}') | |
# print(f'PROFIIIIIT = {profit}') | |
return profit | |
class DefaultPreprocessor(object): | |
def __init__(self, market_data: Dict[STR_SYMBOL, DF_MARKET_DATA], | |
trajectory_length: int, resolution: str, symbols: List[str]): | |
""" | |
Args: market_data: | |
{'BTC/USDT': timestamp date openp ... lowp closep volume | |
0 1514757600000 2017-12-31 23:00:00 13930.00 ... 13869.14 13898.98 130.5177 | |
1 1514759400000 2017-12-31 23:30:00 13899.00 ... 13782.00 13820.50 155.7665 | |
2 1514761200000 2018-01-01 00:00:00 13829.97 ... 13645.03 13724.00 146.3225 | |
3 1514763000000 2018-01-01 00:30:00 13721.05 ... 13651.00 13716.36 100.1306 | |
4 1514764800000 2018-01-01 01:00:00 13715.65 ... 13400.01 13521.12 221.7524 | |
.. ... ... ... ... ... ... ... | |
495 1515654000000 2018-01-11 08:00:00 13228.66 ... 13000.81 13520.01 759.8050 | |
496 1515655800000 2018-01-11 08:30:00 13525.78 ... 13364.28 13530.42 604.6083 | |
497 1515657600000 2018-01-11 09:00:00 13530.42 ... 13530.42 13699.98 523.9869 | |
498 1515659400000 2018-01-11 09:30:00 13699.98 ... 13548.01 13656.49 504.0101 | |
499 1515661200000 2018-01-11 10:00:00 13656.00 ... 13554.18 13660.38 427.2661 | |
[max_len rows x 7 columns], | |
... | |
'ZRX/USDT': timestamp date openp highp lowp closep volume | |
0 1523937600000 2018-04-17 06:00:00 0.26 0.29 0.26 0.26 5478159.15 | |
1 1523939400000 2018-04-17 06:30:00 0.26 0.27 0.26 0.27 2665534.08 | |
2 1523941200000 2018-04-17 07:00:00 0.27 0.28 0.27 0.27 3742666.10 | |
3 1523943000000 2018-04-17 07:30:00 0.27 0.27 0.26 0.26 4575257.51 | |
4 1523944800000 2018-04-17 08:00:00 0.26 0.26 0.25 0.25 5629959.80 | |
.. ... ... ... ... ... ... ... | |
495 1524828600000 2018-04-27 13:30:00 0.29 0.29 0.29 0.29 628279.53 | |
496 1524830400000 2018-04-27 14:00:00 0.29 0.30 0.29 0.30 1479472.82 | |
497 1524832200000 2018-04-27 14:30:00 0.30 0.31 0.30 0.30 5599690.80 | |
498 1524834000000 2018-04-27 15:00:00 0.30 0.30 0.29 0.30 4113774.34 | |
499 1524835800000 2018-04-27 15:30:00 0.30 0.30 0.29 0.30 1919637.17 | |
[max_len rows x 7 columns] | |
} | |
resolution in [ | |
# Intradays | |
'1m','3m','5m','15m','30m','1h','2h','3h','4h','6h','8h','12h', | |
# macroscale | |
'1d','3d','1w','2w','1M' | |
] | |
trajectory_length: | |
number of observations with given resolution per trajectory | |
""" | |
self.trajectory_length = trajectory_length | |
self.resolution = resolution | |
self.symbols = symbols | |
# extend market_data with indicators and scale the volume using min-max scaling | |
market_data_x10d = self.extend_tables_with_indicator_data(market_data) | |
market_data_x10d_scaled = self.min_max_scale_volume_columns(market_data_x10d) | |
# load preprocessed ohlcv-tables into memory, | |
# step() and reset() calls from environment sample trajectories from these tables | |
self.market_data = self.drop_nas(market_data_x10d_scaled) # can be very large, depending on db | |
assert self.market_data is not None | |
self.int_instrument_to_str_key_symbol = dict() | |
@staticmethod | |
def extend_tables_with_indicator_data(market_data: Dict[STR_SYMBOL, DF_MARKET_DATA]) \ | |
-> Dict[STR_SYMBOL, DF_X10DMARKET_DATA]: | |
""" | |
Args: | |
market_data: see database.load_tables() docstring | |
dictionary mapping p symbols to pd.DataFrame with Index | |
Index(['timestamp', 'date', 'openp', 'highp', 'lowp', 'closep', 'volume',dtype='object') | |
Returns: | |
dictionary mapping p symbols to pd.DataFrame with Index | |
Index(['timestamp', 'date', 'openp', 'highp', 'lowp', 'closep', 'volume', | |
# EXTENDED WITH THE FOLLOWING INDICATOR DATA: | |
'RSI(14, closep)', 'MACD(12, 26)', 'Signal(9)', 'MACD_Histogram', | |
'WILLIAMS%R(14)', '%K', 'STOCH_K=14', 'STOCH_D=3'],dtype='object') | |
see src.indicators for computation of the new columns | |
{'BTC/USDT': timestamp date ... STOCH_K=14 STOCH_D=3 | |
0 1514899800000 2018-01-02 14:30:00 ... 87.876188 85.171782 | |
1 1514901600000 2018-01-02 15:00:00 ... 88.487763 87.249437 | |
2 1514903400000 2018-01-02 15:30:00 ... 79.756318 85.373423 | |
3 1514905200000 2018-01-02 16:00:00 ... 63.003012 77.082365 | |
4 1514907000000 2018-01-02 16:30:00 ... 57.811091 66.856807 | |
.. ... ... ... ... ... | |
416 1515654000000 2018-01-11 08:00:00 ... 55.866712 51.026282 | |
417 1515655800000 2018-01-11 08:30:00 ... 59.529662 54.726684 | |
418 1515657600000 2018-01-11 09:00:00 ... 67.167808 60.854727 | |
419 1515659400000 2018-01-11 09:30:00 ... 70.589146 65.762205 | |
420 1515661200000 2018-01-11 10:00:00 ... 73.594852 70.450602 | |
[421 rows x 15 columns], | |
'ADA/USDT': timestamp date openp ... %K STOCH_K=14 STOCH_D=3 | |
3 1524085200000 2018-04-18 23:00:00 0.26 ... 100.0 100.000000 100.000000 | |
4 1524087000000 2018-04-18 23:30:00 0.26 ... 100.0 100.000000 100.000000 | |
5 1524088800000 2018-04-19 00:00:00 0.26 ... 100.0 100.000000 100.000000 | |
6 1524090600000 2018-04-19 00:30:00 0.26 ... 100.0 100.000000 100.000000 | |
7 1524092400000 2018-04-19 01:00:00 0.26 ... 100.0 100.000000 100.000000 | |
.. ... ... ... ... ... ... ... | |
416 1524828600000 2018-04-27 13:30:00 0.29 ... 100.0 100.000000 100.000000 | |
417 1524830400000 2018-04-27 14:00:00 0.29 ... 100.0 100.000000 100.000000 | |
418 1524832200000 2018-04-27 14:30:00 0.30 ... 50.0 83.333333 94.444444 | |
419 1524834000000 2018-04-27 15:00:00 0.30 ... 50.0 66.666667 83.333333 | |
420 1524835800000 2018-04-27 15:30:00 0.30 ... 50.0 50.000000 66.666667 | |
[400 rows x 15 columns] | |
} | |
""" | |
# assume indicators = ['rsi', 'macd', 'willr', 'stoch'] fixed for now | |
for sym, df_ohlcv in market_data.items(): | |
df_ohlcv = pd.merge(df_ohlcv, rsi(df_ohlcv)) | |
df_ohlcv = pd.merge(df_ohlcv, macd(df_ohlcv)) | |
df_ohlcv = pd.merge(df_ohlcv, willr(df_ohlcv)) | |
df_ohlcv = pd.merge(df_ohlcv, stoch(df_ohlcv)) | |
market_data[sym] = df_ohlcv.dropna() | |
return market_data | |
@staticmethod | |
def min_max_scale_volume_columns(market_data: Dict[STR_SYMBOL, DF_MARKET_DATA], volume: str = 'volume') \ | |
-> Dict[STR_SYMBOL, DF_MARKET_DATA]: | |
""" Performs min-max scaling of volume column for all tables in market_data. | |
C.f https://en.wikipedia.org/wiki/Feature_scaling """ | |
for _, table in market_data.items(): | |
v = table[volume] | |
table[volume] = ((v - v.min()) / (v.max() - v.min())) | |
return market_data | |
@staticmethod | |
def drop_indices(market_data: Dict[STR_SYMBOL, DF_MARKET_DATA]): | |
for _, table in market_data.items(): | |
table.reset_index(drop=True, inplace=True) | |
@staticmethod | |
def drop_nas(market_data: Dict[STR_SYMBOL, pd.DataFrame]): | |
for symbol, table in market_data.items(): | |
market_data[symbol] = table.dropna() | |
return market_data | |
class _DefaultPreprocessorSDMI(DefaultPreprocessor): | |
def __init__(self, market_data: Dict[STR_SYMBOL, DF_MARKET_DATA], | |
trajectory_length: int, resolution: str, symbols): | |
super().__init__(market_data, trajectory_length, resolution, symbols) | |
def precompute_trajectory_states(self) -> Tuple[Dict[STR_SYMBOL, DF_MARKET_DATA], np.ndarray, Dict]: | |
dict_states, mapping = self._sample_states_for_current_trajectory() | |
np_states = self._standardize_and_concat(dict_states) | |
return dict_states, np_states, mapping | |
def _sample_states_for_current_trajectory(self) -> Tuple[Dict[STR_SYMBOL, DF_MARKET_DATA], Dict]: | |
""" Randomly pick trajectory index and sample from database table using index as row """ | |
states = dict() | |
# rand_tables = np.random.randint(0, len(self.symbols), size=self.n_instruments_tradeable) | |
rand_tables = np.random.choice(len(self.symbols), 2, replace=False) | |
for i, table_index in enumerate(rand_tables): # e,g, enumerate([3,8]) | |
""" | |
self.symbols = ['BTC/USDT', 'ADA/USDT', 'BCH/USDT', 'BNB/USDT', 'EOS/USDT', 'ETC/USDT', 'ETH/USDT', 'IOTA/USDT', | |
'LTC/USDT', 'NEO/USDT', 'QTUM/USDT', 'XLM/USDT', 'XRP/USDT'] | |
""" | |
symbol = self.symbols[table_index] # e.g. table_index = 3 => symbol = 'BNB/USDT' | |
""" | |
self.market_data = { | |
'BTC/USDT': timestamp date ... STOCH_K=14 STOCH_D=3 | |
0 1514899800000 2018-01-02 14:30:00 ... 87.876188 85.171782 | |
1 1514901600000 2018-01-02 15:00:00 ... 88.487763 87.249437 | |
2 1514903400000 2018-01-02 15:30:00 ... 79.756318 85.373423 | |
.. ... ... ... ... ... | |
418 1515657600000 2018-01-11 09:00:00 ... 67.167808 60.854727 | |
419 1515659400000 2018-01-11 09:30:00 ... 70.589146 65.762205 | |
420 1515661200000 2018-01-11 10:00:00 ... 73.594852 70.450602 | |
[n rows x 15 columns], | |
... | |
'ZRX/USDT': timestamp date openp ... %K STOCH_K=14 STOCH_D=3 | |
3 1524085200000 2018-04-18 23:00:00 0.26 ... 100.0 100.000000 100.000000 | |
4 1524087000000 2018-04-18 23:30:00 0.26 ... 100.0 100.000000 100.000000 | |
5 1524088800000 2018-04-19 00:00:00 0.26 ... 100.0 100.000000 100.000000 | |
.. ... ... ... ... ... ... ... | |
418 1524832200000 2018-04-27 14:30:00 0.30 ... 50.0 83.333333 94.444444 | |
419 1524834000000 2018-04-27 15:00:00 0.30 ... 50.0 66.666667 83.333333 | |
420 1524835800000 2018-04-27 15:30:00 0.30 ... 50.0 50.000000 66.666667 | |
[n rows x 15 columns] | |
} | |
len(self.market_data.keys()) == len(self.symbols) | |
""" | |
max_row = self.market_data[symbol].shape[0] | |
# assert symbol is not None | |
# assert max_row > 2 * self.horizon, f'row assertion failed for symbol {symbol}' | |
# pick last point of trajectory from which we normalize backwards, for online evaluation | |
traj_end = np.random.randint(self.trajectory_length, max_row) | |
first = traj_end - self.trajectory_length | |
states_symbol_i = self.market_data[symbol][first:traj_end].reset_index(drop=True) | |
# assert states_symbol_i is not None | |
states[symbol] = states_symbol_i | |
# update integer mapping to string symbol | |
self.int_instrument_to_str_key_symbol[i] = symbol | |
#assert len( | |
# states.keys()) == self.n_instruments_tradeable, f'states.keys() = {states.keys()} i={i}, ' \ | |
# f'table_index={table_index}, ' \ | |
# f'symbols={[self.symbols[idx] for idx in rand_tables]}' | |
return states, self.int_instrument_to_str_key_symbol | |
def _standardize_and_concat(self, market_data: Dict[STR_SYMBOL, DF_MARKET_DATA]) -> np.ndarray: | |
""" | |
Standardizes market data for each symbol. Merges all symbols data at timestamp. | |
So each row has p * len(Index) many columns. | |
Returns np.ndarray containing the merged data. The data is used to make observations. | |
Returns df.values where: | |
df = openp_0 highp_0 lowp_0 ... %K_p STOCH_K=14_p STOCH_D=3_p | |
0 0.977301 0.987676 0.973552 ... 20.335293 45.119614 48.331698 | |
1 0.983994 1.005749 0.982991 ... 28.957109 35.882915 45.362960 | |
2 0.997554 0.999055 0.986003 ... 26.758110 25.350171 35.450900 | |
... | |
45 0.983995 0.991893 0.982856 ... 62.150538 63.021597 51.938320 | |
46 0.986069 1.000729 0.985333 ... 77.311828 70.525687 61.085642 | |
47 1.000000 1.000722 0.992799 ... 87.598566 75.686977 69.744754 | |
[48 rows x 26 columns] | |
""" | |
# (1) remove time, date (2) normalize backwards in online fashion | |
states = pd.DataFrame | |
to_drop = ['timestamp', 'date'] | |
n_symbols = n_cols = 0 | |
for symbol, table in market_data.items(): | |
assert table.shape[0] == self.trajectory_length | |
# state = table.dropna() na already dropped in __init__ | |
# table.drop(to_drop, axis=1, inplace=True) | |
state = table.drop(to_drop, axis=1) | |
n_cols = len(state.columns) | |
n_symbols += 1 | |
# standardize prices | |
refp = table['openp'].iloc[-1] | |
state['openp'] /= refp | |
state['highp'] /= refp | |
state['lowp'] /= refp | |
state['closep'] /= refp | |
# append column suffix | |
state.columns = [str(col) + f'_{symbol}' for col in state.columns] | |
# drop index for right join without nans | |
state.reset_index(drop=True, inplace=True) | |
# join multiple symbol data to one state observation | |
if states.empty: | |
states = state | |
#print(f'states was empty now {states}') | |
else: | |
#print(f'concatenating states = {states} and state={state}') | |
states = pd.concat([states, state], axis=1) | |
#print(f'yielded {states}') | |
assert len(states.columns) == n_cols * n_symbols | |
# return the states: | |
return states.values | |
class _DefaultPreprocessorMDMI(DefaultPreprocessor): | |
def __init__(self, market_data, trajectory_length, resolution, n_timeframes): | |
super().__init__(market_data, trajectory_length, resolution) | |
self.n_timeframes = n_timeframes | |
""" | |
resolution | |
n_timeframes | |
trajectory length | |
""" | |
def precompute_trajectory_states(self): | |
data_samples = self._sample_market_data() | |
data_samples_std = self._standardize_samples(data_samples) | |
dict_states = self._make_states(data_samples_std) # rolling merge | |
# np_states = self._clean_numpy(dict_states) # drop cols and merge dict to df and return values | |
return None | |
def _sample_market_data(self): | |
""" Randomly pick trajectory index and sample from database table using index as first row """ | |
# trajectory ranges from start to trajectory_length + n_timeframes | |
# each observation will contain n_timeframes | |
# but normalization will occur over all (trajectory_length + n_timeframes) data | |
# not as in SDMI over each state | |
states = dict() | |
# rand_tables = np.random.randint(0, len(self.symbols), size=self.n_instruments_tradeable) | |
rand_tables = np.random.choice(len(self.symbols), 2, replace=False) | |
for i, table_index in enumerate(rand_tables): # e,g, enumerate([3,8]) | |
symbol = self.symbols[table_index] # e.g. table_index = 3 => symbol = 'BNB/USDT' | |
max_row = self.market_data[symbol].shape[0] | |
# pick last point of trajectory from which we normalize backwards, for online evaluation | |
# e.g. trajectory_length = 48 | |
# n_timeframes = 12 | |
# min_row = 60 because for each of the 48 states you go 12 times back | |
# so at state s_0 you will go back the last 12 timeframes | |
backward_offset = self.trajectory_length + self.n_timeframes | |
traj_end = np.random.randint(backward_offset, max_row) | |
first = traj_end - backward_offset | |
states_symbol_i = self.market_data[symbol][first:traj_end].reset_index(drop=True) | |
# assert states_symbol_i is not None | |
states[symbol] = states_symbol_i | |
# update integer mapping to string symbol | |
self.int_instrument_to_str_key_symbol[i] = symbol | |
return states, self.int_instrument_to_str_key_symbol | |
def _standardize_samples(self, data_samples): | |
for symbol, table in data_samples.items(): | |
# standardize prices | |
refp = table['openp'].iloc[-1] | |
table['openp'] /= refp | |
table['highp'] /= refp | |
table['lowp'] /= refp | |
table['closep'] /= refp | |
return self.n_timeframes | |
def _make_states(self, dict_samples): | |
# now we have sampled 60 rows per symbol and want to make rolling | |
return self.n_timeframes | |
class SDMIEnvironment(object): | |
""" | |
SD: SingleData refers to the number timeframes per observation. | |
MI: MultipleInstruments refers to the number p of instruments per observation. | |
This p is equal to the number of instruments that can be traded via the action space. | |
""" | |
def __init__(self, config, market_data=None, preprocessor=_DefaultPreprocessorSDMI): | |
r""" | |
Creates an SingleDataMultipleInstruments environment with the given configuration. | |
SingleData refers to the number timeframes per observation. | |
MultipleInstruments refers to the number p of instruments per observation. | |
This p is equal to the number of instruments that can be traded via the action space. | |
Args: | |
config: e.g. | |
{ | |
'databases': { | |
'symbols': ['BTC/USDT', 'ADA/USDT', 'BCH/USDT', 'BNB/USDT', 'EOS/USDT', 'ETC/USDT', 'ETH/USDT', 'IOTA/USDT', | |
'LTC/USDT', 'NEO/USDT', 'QTUM/USDT', 'XLM/USDT', 'XRP/USDT'], | |
'db_files': ['BTC_USDT.db', 'ADA_USDT.db', 'BCH_USDT.db', 'BNB_USDT.db', 'EOS_USDT.db', 'ETC_USDT.db', | |
'ETH_USDT.db', 'IOTA_USDT.db', 'LTC_USDT.db', 'NEO_USDT.db', 'QTUM_USDT.db', 'XLM_USDT.db','XRP_USDT.db'] | |
'paths': ['./database/binance/BTC_USDT.db', './database/binance/ADA_USDT.db', ....] | |
} | |
'resolution': '30m', # resolution of data in each state | |
'trajectory_length': 96, # states per trajectory | |
'symbols_per_state': 2, # symbols corresponding to instruments observed at each timestep | |
'buy_amts': [50, 200], # absolute | |
'sell_amts_percent': [1.], # percent | |
'taker_fee': 998.0 / 1000.0 # 0.2 percent, | |
'n_possible_buy_amts': 2 # len(CONFIG['buy_amts']) | |
'n_possible_sell_amts_percent': 1 # len(CONFIG['sell_amts_percent']) | |
} | |
market_data: see database.load_tables() docstring | |
Mapping p symbols to pd.DataFrame (storing their database table for given resolution), | |
where p = len(symbols): | |
{'BTC/USDT': timestamp date openp ... lowp closep volume | |
0 1514757600000 2017-12-31 23:00:00 13930.00 ... 13869.14 13898.98 130.5177 | |
1 1514759400000 2017-12-31 23:30:00 13899.00 ... 13782.00 13820.50 155.7665 | |
2 1514761200000 2018-01-01 00:00:00 13829.97 ... 13645.03 13724.00 146.3225 | |
3 1514763000000 2018-01-01 00:30:00 13721.05 ... 13651.00 13716.36 100.1306 | |
4 1514764800000 2018-01-01 01:00:00 13715.65 ... 13400.01 13521.12 221.7524 | |
.. ... ... ... ... ... ... ... | |
495 1515654000000 2018-01-11 08:00:00 13228.66 ... 13000.81 13520.01 759.8050 | |
496 1515655800000 2018-01-11 08:30:00 13525.78 ... 13364.28 13530.42 604.6083 | |
497 1515657600000 2018-01-11 09:00:00 13530.42 ... 13530.42 13699.98 523.9869 | |
498 1515659400000 2018-01-11 09:30:00 13699.98 ... 13548.01 13656.49 504.0101 | |
499 1515661200000 2018-01-11 10:00:00 13656.00 ... 13554.18 13660.38 427.2661 | |
[max_len rows x 7 columns], | |
... | |
'ZRX/USDT': timestamp date openp highp lowp closep volume | |
0 1523937600000 2018-04-17 06:00:00 0.26 0.29 0.26 0.26 5478159.15 | |
1 1523939400000 2018-04-17 06:30:00 0.26 0.27 0.26 0.27 2665534.08 | |
2 1523941200000 2018-04-17 07:00:00 0.27 0.28 0.27 0.27 3742666.10 | |
3 1523943000000 2018-04-17 07:30:00 0.27 0.27 0.26 0.26 4575257.51 | |
4 1523944800000 2018-04-17 08:00:00 0.26 0.26 0.25 0.25 5629959.80 | |
.. ... ... ... ... ... ... ... | |
495 1524828600000 2018-04-27 13:30:00 0.29 0.29 0.29 0.29 628279.53 | |
496 1524830400000 2018-04-27 14:00:00 0.29 0.30 0.29 0.30 1479472.82 | |
497 1524832200000 2018-04-27 14:30:00 0.30 0.31 0.30 0.30 5599690.80 | |
498 1524834000000 2018-04-27 15:00:00 0.30 0.30 0.29 0.30 4113774.34 | |
499 1524835800000 2018-04-27 15:30:00 0.30 0.30 0.29 0.30 1919637.17 | |
[max_len rows x 7 columns] | |
} len(self.market_data.keys()) == len(self.symbols) | |
""" | |
assert isinstance(config, dict), "Expected config to be of type dict." | |
# config | |
self.config = config | |
self.paths = config['databases']['paths'] | |
self.symbols = config['databases']['symbols'] | |
self.resolution = config['resolution'] | |
self.trajectory_length = config['trajectory_length'] | |
self.n_instruments_tradeable = config['symbols_per_state'] | |
# data pipeline | |
if market_data is None: | |
market_data = database.load_tables(self.symbols, self.paths, self.resolution) | |
# computes normalization and feature scaling | |
self.preprocessor = preprocessor(market_data, self.trajectory_length, self.resolution, self.symbols) | |
# stores state information to compute balances and credit info for tradeable instruments | |
self.credit = Credit(self.n_instruments_tradeable, config['taker_fee']) # maybe swap with account | |
# market data observations | |
self.dict_states = dict() # Dict[STR_SYMBOL, DF_MARKET_DATA] | |
self.np_states = None | |
# used to parse integer actions | |
self.action_spec = ActionParser(self.n_instruments_tradeable, | |
config['n_possible_buy_amts'], | |
config['n_possible_sell_amts'], | |
config['buy_amts'], | |
config['sell_amts_percent']) | |
self._max_actions = self.action_spec.max_actions | |
self.dict_instr_to_symbol = dict() # Dict[int, STR_SYMBOL], e.g. {0: 'BNB/USDT', 1: 'XLM/USDT'} | |
# utils | |
self.t_step = 0 | |
self.t_traj = 0 | |
self._np_obs_shape = (34,) # self._get_np_obs_shape() | |
def _get_np_obs_shape(self): | |
_, stub_states, _ = self.preprocessor.precompute_trajectory_states() | |
# print(stub_states.shape, len(self.credit.data)) | |
stub_np_obs = np.append(stub_states[0], self.credit.data) | |
# print(stub_np_obs.shape) | |
del stub_states | |
return stub_np_obs.shape | |
@property | |
def max_actions(self): | |
return self._max_actions | |
@property | |
def np_obs_shape(self): | |
return self._np_obs_shape | |
def _make_observation(self) -> Dict: | |
# market data + credit date for current timestep | |
obs = np.append(self.np_states[self.t_step], self.credit.data).astype(np.float32) | |
assert obs.shape == self.np_obs_shape, f'expected shape={self.np_obs_shape}, got {obs.shape}, ' \ | |
f'dict_states={self.dict_states}, np_states = {self.np_states}, t_step = {self.t_step}' | |
# legal moves mask | |
legal_actions_as_int = self._legal_actions_as_int() | |
# observation as returned by environment reset() and step() calls | |
# observation = {'state': obs, 'mask': legal_actions_as_int, 'info': None} | |
self.last_mask = legal_actions_as_int | |
observation = {'state': obs, 'mask': legal_actions_as_int} | |
return observation | |
def _legal_actions_as_int(self): | |
""" Mask is boolean, not logits """ | |
legal_moves_as_int = list() | |
mask = np.full(self.max_actions, np.finfo(np.float32).min) | |
# mask = np.full(self.max_actions, 0) | |
for action in range(self.max_actions): | |
action_type, target_instr, target_amount = self.action_spec.parse_int_action(action) | |
if self._action_is_legal(action_type, target_instr, target_amount): | |
legal_moves_as_int.append(action) | |
mask[legal_moves_as_int] = 0 | |
# mask[legal_moves_as_int] = 1 | |
return mask | |
def reset(self, config=None): | |
""" Starts MC Rollout from OHLC-DB """ | |
# reset credit and states | |
self.credit.reset() | |
# reset step and increment traj counter | |
self.t_step = 0 | |
self.t_traj += 1 | |
# generate all states in the trajectory a priori [we assume BUY,SELL dont effect the future states] | |
dict_states, np_states, mapping = self.preprocessor.precompute_trajectory_states() | |
# meta information of market_data for creating Trade objects | |
self.dict_states = dict_states | |
# numpy values of market_data as in observation | |
self.np_states = np_states | |
# mapping from random integer samples to symbols in self.symbols | |
self.dict_instr_to_symbol = mapping | |
# observation = market_data + credit_information | |
observation = self._make_observation() | |
self.t_step += 1 | |
return observation | |
def _action_is_legal(self, action_type, target_symbol, target_amt): | |
# buy only when credit available | |
if action_type == ActionType.BUY: | |
return self.credit.has_cash(target_amt) | |
if action_type == ActionType.SELL: | |
return self.credit.is_available(target_symbol) | |
return True | |
def _build_trade(self, action: Dict[str, int], refp: str = 'closep') -> Trade: | |
# n_instruments_per_observation = self.symbols_per_state | |
# n_instruments_tradeable = n_instruments_per_observation | |
# will bound action range | |
# {0: 'BNB/USDT', 1: 'XLM/USDT'} -> 'XLM/USDT' | |
assert action['type'] != ActionType.NOOP | |
symbol = self.dict_instr_to_symbol[action['target']] | |
market_data = self.dict_states[symbol].iloc[self.t_step, :] # pd.Series for current timestep | |
buy_amt = None if action['type'] == ActionType.SELL else action['amount'] / market_data[refp] | |
sell_percent = None if action['type'] == ActionType.BUY else action['amount'] | |
net_val = None if action['type'] == ActionType.SELL else action['amount'] | |
return Trade(t_step=self.t_step, | |
datestring=market_data['date'], | |
symbol=symbol, | |
symbol_idx=action['target'], | |
type=action['type'], | |
buy_amt=buy_amt, | |
sell_percent=sell_percent, | |
price=market_data[refp], | |
net_value=net_val, | |
info=market_data) | |
def step(self, action): | |
# reset at end of trajectory | |
if self.t_step > self.trajectory_length - 1: | |
return self.reset(self.config) | |
# parse integer action | |
action_type, target_instr, target_amount = self.action_spec.parse_int_action(action) | |
assert self._action_is_legal(action_type, target_instr, target_amount), f'action_type={action_type} \n ' \ | |
f'target_instr={target_instr} \n' \ | |
f'target_amount={target_amount} \n' \ | |
f'action={action}\n' \ | |
f'credit.data = {self.credit.data} \n' \ | |
f'last_mask = {self.last_mask} \n' \ | |
f'self.t_step = {self.t_step}' | |
reward = 0 | |
# update credit balances and time | |
if action_type != ActionType.NOOP: | |
action_dict = {'type': action_type, 'target': target_instr, 'amount': target_amount} | |
trade = self._build_trade(action_dict) | |
reward = self.credit.apply_trade_and_compute_reward(trade=trade) | |
# observation is gotten from self.states regardless of action taken | |
observation = self._make_observation() | |
self.t_step += 1 | |
done = False | |
info = None | |
if self.t_step == self.trajectory_length: | |
done = True | |
# print(f'' | |
# f'action = {action}, \n ' | |
# f'action_type ={action_type}, \n' | |
# f'target_instr = {target_instr}, \n' | |
# f'target_amount = {target_amount}, \n' | |
# f'observation = {observation} ') | |
return observation, reward, done, info | |
def close(self): | |
""" Deletes self reference so that this instance will be garbage collected """ | |
del self | |
return | |
class MDMIEnvironment(object): | |
""" | |
MD: MultipleData refers to the number timeframes per observation. | |
Increasing this, we can try to catch cross-correlation patterns with RNNs, e.g. by making | |
one observation consist of a whole day of 30 minute data. | |
MI: MultipleInstruments refers to the number p of instruments per observation. | |
This p is equal to the number of instruments that can be traded via the action space. | |
""" | |
def __init__(self): | |
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sqlite3 | |
from pathlib import Path | |
import src.utils as utils | |
from typing import Optional | |
from sqlite3 import Connection | |
from typing import List, Dict | |
import pandas as pd | |
STR_SYMBOL = str | |
DF_MARKET_DATA = pd.DataFrame | |
def load_tables(symbols: List[str], paths: List[str], resolution: str, max_len=None) \ | |
-> Dict[STR_SYMBOL, DF_MARKET_DATA]: | |
""" | |
Returns: | |
dictionary mapping p symbols to pd.DataFrame (storing their database table for given resolution), | |
where p = len(symbols): | |
{'BTC/USDT': timestamp date openp ... lowp closep volume | |
0 1514757600000 2017-12-31 23:00:00 13930.00 ... 13869.14 13898.98 130.5177 | |
1 1514759400000 2017-12-31 23:30:00 13899.00 ... 13782.00 13820.50 155.7665 | |
2 1514761200000 2018-01-01 00:00:00 13829.97 ... 13645.03 13724.00 146.3225 | |
3 1514763000000 2018-01-01 00:30:00 13721.05 ... 13651.00 13716.36 100.1306 | |
4 1514764800000 2018-01-01 01:00:00 13715.65 ... 13400.01 13521.12 221.7524 | |
.. ... ... ... ... ... ... ... | |
495 1515654000000 2018-01-11 08:00:00 13228.66 ... 13000.81 13520.01 759.8050 | |
496 1515655800000 2018-01-11 08:30:00 13525.78 ... 13364.28 13530.42 604.6083 | |
497 1515657600000 2018-01-11 09:00:00 13530.42 ... 13530.42 13699.98 523.9869 | |
498 1515659400000 2018-01-11 09:30:00 13699.98 ... 13548.01 13656.49 504.0101 | |
499 1515661200000 2018-01-11 10:00:00 13656.00 ... 13554.18 13660.38 427.2661 | |
[max_len rows x 7 columns], | |
'ADA/USDT': timestamp date openp highp lowp closep volume | |
0 1523937600000 2018-04-17 06:00:00 0.26 0.29 0.26 0.26 5478159.15 | |
1 1523939400000 2018-04-17 06:30:00 0.26 0.27 0.26 0.27 2665534.08 | |
2 1523941200000 2018-04-17 07:00:00 0.27 0.28 0.27 0.27 3742666.10 | |
3 1523943000000 2018-04-17 07:30:00 0.27 0.27 0.26 0.26 4575257.51 | |
4 1523944800000 2018-04-17 08:00:00 0.26 0.26 0.25 0.25 5629959.80 | |
.. ... ... ... ... ... ... ... | |
495 1524828600000 2018-04-27 13:30:00 0.29 0.29 0.29 0.29 628279.53 | |
496 1524830400000 2018-04-27 14:00:00 0.29 0.30 0.29 0.30 1479472.82 | |
497 1524832200000 2018-04-27 14:30:00 0.30 0.31 0.30 0.30 5599690.80 | |
498 1524834000000 2018-04-27 15:00:00 0.30 0.30 0.29 0.30 4113774.34 | |
499 1524835800000 2018-04-27 15:30:00 0.30 0.30 0.29 0.30 1919637.17 | |
[max_len rows x 7 columns] | |
} | |
Required database table layout: | |
(timestamp INT NOT NULL PRIMARY KEY, | |
date TEXT, openp REAL, highp REAL, lowp REAL, closep REAL, volume REAL) | |
""" | |
tables = dict() | |
for sym, path in list(zip(symbols, paths)): | |
conn = sqlite3.connect(path) | |
df_ohlcv = pd.read_sql_query(f"SELECT * FROM {'ohlcv_' + resolution}", conn)[:max_len] # [:None] returns all | |
tables[sym] = df_ohlcv.dropna() | |
return tables | |
""" BEFORE REWORK """ | |
def create_database(path, db_filename): | |
""" | |
Creates a .db file | |
:param path: | |
:param db_filename: e.g. "BTC_USD.db" | |
:return: | |
""" | |
path = Path(path).joinpath(db_filename) | |
return create_connection(path) | |
def create_connection(path): | |
""" create a database connection to the SQLite database | |
specified by path | |
:param path: database file | |
:return: Connection object or None | |
""" | |
try: | |
conn = sqlite3.connect(str(path)) | |
return conn | |
except sqlite3.Error as e: | |
print(e) | |
return None | |
def create_table_trades(conn): | |
query = 'CREATE TABLE trades(id INT NOT NULL PRIMARY KEY, ' \ | |
'timestamp INT, ' \ | |
'datetime TEXT, ' \ | |
'symbol TEXT, ' \ | |
'side TEXT,' \ | |
'm BLOB,' \ | |
'price REAL,' \ | |
'amount REAL,' \ | |
'cost REAL)' | |
try: | |
c = conn.cursor() | |
c.execute(query) | |
# conn.commit() | |
except sqlite3.Error as e: | |
print("Error in function {}(...): {}".format(create_table_trades.__name__, e)) | |
def create_ohlcv_table_from_timeframe(conn, str_timeframe): | |
""" create a table corresponding to the given timeframe | |
:param conn: Connection object | |
:param str_timeframe: used to make CREATE TABLE statement | |
:return: | |
""" | |
# create_table_sql = queries.create_ohlcv_table(str_timeframe) | |
create_table_sql = _sql_ohlcv_table_from_timeframe(str_timeframe) | |
try: | |
c = conn.cursor() | |
c.execute(create_table_sql) | |
# conn.commit() | |
except sqlite3.Error as e: | |
print("Error in function {}(...): {}".format(create_ohlcv_table_from_timeframe.__name__, e)) | |
def create_indicator_table_from_timeframe(conn, str_indicator, str_timeframe): | |
""" create a table corresponding to the given timeframe | |
:param conn: Connection object | |
:param str_indicator: First part of table name | |
:param str_timeframe: Second part of table name | |
:return: Creates a table in the db given by conn | |
""" | |
create_table_sql = _sql_rsi_table_from_timeframe(str_timeframe) | |
try: | |
c = conn.cursor() | |
c.execute(create_table_sql) | |
except sqlite3.Error as e: | |
print("Error in function {}(...): {}".format(create_indicator_table_from_timeframe.__name__, e)) | |
def remove_table(conn, table): | |
cursor = conn.cursor() | |
sql = "DROP TABLE {}".format(table) | |
try: | |
cursor.executescript(sql) | |
conn.commit() | |
except sqlite3.Error as e: | |
print("Error in function {}(...): {}".format(remove_table.__name__, e)) | |
def get_tables(conn): | |
if conn is None: | |
return None | |
cursor = conn.cursor() | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
tables = cursor.fetchall() | |
tables_trimmed = [table[0] for table in tables] | |
return tables_trimmed | |
# returns true on success, false otherwise | |
def insert_data_ohlcv(conn, ohlcv_table, to_insert, mode): | |
""" | |
:param conn: database connection | |
:param ohlcv_table: table to operate on | |
:param to_insert: array of ohlcv-tuples | |
:param mode: specifies, whether insert replaces existing rows or skips them | |
:return: | |
""" | |
cursor = conn.cursor() | |
try: | |
if mode == "ignore": | |
cursor.executemany("INSERT OR IGNORE INTO {} VALUES (?,?,?,?,?,?,?)".format(ohlcv_table), to_insert) | |
if mode == "replace": | |
cursor.executemany("INSERT OR REPLACE INTO {} VALUES (?,?,?,?,?,?,?)".format(ohlcv_table), to_insert) | |
conn.commit() | |
return True | |
except sqlite3.Error as e: | |
print(e) | |
print('Function insert_data_ohlcv() returned False') | |
return False | |
def insert_data_rsi(conn, rsi_table, to_insert, mode): | |
""" | |
:param conn: database connection | |
:param rsi_table: indicator table to operate on | |
:param to_insert: array of rsi-tuples: [(avg_up, avg_down, rs, rsi),...] | |
:param mode: specifies, whether insert replaces existing rows or skips them | |
:return: | |
""" | |
cursor = conn.cursor() | |
try: | |
if mode == "ignore": | |
cursor.executemany("INSERT OR IGNORE INTO {} VALUES (?,?,?,?,?,?)".format(rsi_table), to_insert) | |
if mode == "replace": | |
cursor.executemany("INSERT OR REPLACE INTO {} VALUES (?,?,?,?,?,?)".format(rsi_table), to_insert) | |
conn.commit() | |
return True | |
except sqlite3.Error as e: | |
print(e) | |
print('Function insert_data_rsi() returned False') | |
return False | |
def get_all_rows(conn, table): | |
""" | |
:param conn: database connection | |
:param table: table name for which the data is returned | |
:return: returns all rows of the table | |
""" | |
cursor = conn.cursor() | |
cursor.execute("SELECT * FROM {}".format(table)) | |
data = cursor.fetchall() | |
return data | |
def get_all_rows_from_timestamp(conn, table, timestamp): | |
cursor = conn.cursor() | |
cursor.execute("SELECT * FROM {} WHERE timestamp >= {}".format(table, timestamp)) | |
data = cursor.fetchall() | |
return data | |
def get_rows_from_timestamp_to_timestamp(conn, table, from_timestamp, to_timestamp): | |
cursor = conn.cursor() | |
cursor.execute( | |
"SELECT * FROM {} WHERE timestamp >= {} AND timestamp <= {}".format(table, from_timestamp, to_timestamp)) | |
data = cursor.fetchall() | |
return data | |
def get_n_rows_by_from_timestamp(conn, table, timestamp, n): | |
""" | |
:param conn: | |
:param table: | |
:param timestamp: start of data | |
:param n: number of rows to be returned | |
:return: n consecutive rows, starting at timestamp (including) | |
""" | |
cursor = conn.cursor() | |
cursor.execute( | |
"select * from {} where timestamp>={} order by timestamp desc limit {}".format(table, timestamp, n)) | |
data = cursor.fetchall() | |
return data | |
def get_earliest_timestamp_as_int(conn: Connection, table: str) -> Optional[int]: | |
cursor = conn.cursor() | |
cursor.execute("SELECT MIN(timestamp) FROM {}".format(table)) | |
min_t = cursor.fetchall()[0] | |
return min_t[0] | |
# for debugging | |
def get_earliest_timestamp_as_str(conn: Connection, table: str) -> Optional[str]: | |
first_entry = get_earliest_timestamp_as_int(conn, table) | |
if first_entry is None: | |
return None | |
t = utils.msec_to_localtime(first_entry) | |
return t | |
def get_latest_timestamp_as_int(conn: Connection, table: str) -> Optional[int]: | |
cursor = conn.cursor() | |
cursor.execute("SELECT MAX(timestamp) FROM {}".format(table)) | |
max_t = cursor.fetchall()[0] | |
return max_t[0] | |
# for debugging | |
def get_latest_timestamp_as_str(conn: Connection, table: str) -> Optional[str]: | |
last_entry = get_latest_timestamp_as_int(conn, table) | |
if last_entry is None: | |
return None | |
t = utils.msec_to_localtime(last_entry) | |
return t | |
def table_is_empty(conn, table): | |
is_empty = True | |
cursor = conn.cursor() | |
cursor.execute("SELECT exists(select 1 from {})".format(table)) | |
# cursor.execute("SELECT count(*) from {}".format(tablename)) | |
# print(tablename, cursor.fetchall(), len(cursor.fetchall())) | |
num_rows = cursor.fetchall()[0][0] | |
if num_rows > 0: | |
is_empty = False | |
return is_empty | |
def has_row(conn, str_table, timestamp): | |
exists = False | |
cursor = conn.cursor() | |
cursor.execute('SELECT EXISTS(SELECT 1 FROM {} WHERE timestamp={})'.format(str_table, timestamp)) | |
if cursor.fetchall()[0][0] >= 1: | |
exists = True | |
return exists | |
def get_row_by_timestamp(conn, str_table, timestamp): | |
""" | |
:param conn: Database connection | |
:param str_table: Name of table | |
:param timestamp: Primary key INT milliseconds | |
:return: Row with given timestamp if it exists, else None. | |
""" | |
cursor = conn.cursor() | |
cursor.execute('SELECT * FROM {} WHERE timestamp={}'.format(str_table, timestamp)) | |
data = cursor.fetchall() | |
if len(data) > 0: | |
return data | |
return None | |
def get_pair(conn): | |
""" | |
:param conn: database | |
:return: returns string value of the pair that is associated with the database | |
""" | |
pair = None | |
cursor = conn.cursor() | |
cursor.execute('PRAGMA database_list') | |
db_path = cursor.fetchall()[0][2] | |
pair = str(Path(db_path).stem).replace('_', '/') | |
return pair | |
def get_exchange(conn): | |
""" | |
:param conn: database | |
:return: returns string value of the exchange that is associated with the database | |
""" | |
exchange = None | |
cursor = conn.cursor() | |
cursor.execute('PRAGMA database_list') | |
db_path = cursor.fetchall()[0][2] | |
exchange = str(Path(db_path).parent.name) | |
return exchange | |
def is_coherent_ohlcv_data(conn, ohlcv_table, db_path=None): | |
""" | |
:param conn: | |
:param ohlcv_table: | |
:param db_path: Can be passed when calling from another Thread as sqlite3 does not support passing | |
connection objects across Threads. | |
:return: True, if no candle is missing, False otherwise. | |
""" | |
assert ohlcv_table is not None | |
if conn is None: assert db_path is not None | |
if db_path is not None: | |
conn = create_connection(db_path) | |
is_coherent = True | |
count = 0 | |
# get timeframe | |
timeframe = ohlcv_table.split('_')[1] | |
# check for missing rows | |
current_timestamp = get_earliest_timestamp_as_int(conn, ohlcv_table) | |
latest_timestamp = get_latest_timestamp_as_int(conn, ohlcv_table) | |
while current_timestamp < latest_timestamp: | |
if not has_row(conn, ohlcv_table, current_timestamp): | |
is_coherent = False | |
count += 1 | |
current_timestamp += utils.timeframe_to_int(timeframe) * 60000 # minute in msec | |
print("Missing entries: {}".format(count)) | |
if db_path is not None: | |
conn.close() | |
return is_coherent | |
def clear_contents(conn, ohlcv_table): | |
try: | |
cursor = conn.cursor() | |
cursor.execute('DELETE FROM {}'.format(ohlcv_table)) | |
conn.commit() | |
except sqlite3.Error as e: | |
print(e) | |
# #################################################### # | |
# ###################### QUERIES ##################### # | |
# #################################################### # | |
def _sql_ohlcv_table_from_timeframe(timeframe): | |
sql = 'CREATE TABLE ohlcv_{}(timestamp INT NOT NULL PRIMARY KEY, date TEXT, ''openp REAL, ''highp REAL, ''lowp REAL, ''closep REAL, ''volume REAL)'.format( | |
timeframe) | |
return sql | |
def _sql_rsi_table_from_timeframe(timeframe): | |
""" | |
:param timeframe: timeframe of the candles used for calculation | |
:return: sql query for creation of table corresponding to params. | |
Call with cursor.execute(sql_query) | |
""" | |
sql_query = 'CREATE TABLE RSI_{}(timestamp INT NOT NULL PRIMARY KEY, date TEXT, avg_up REAL, avg_down REAL, rs REAL, rsi REAL)'.format( | |
timeframe) | |
return sql_query | |
def _sql_max_timestamp(table): | |
sql = "SELECT MAX(timestamp) FROM {}".format(table) | |
return sql | |
def _sql_min_timestamp(table): | |
sql = "SELECT MAX(timestamp) FROM {}".format(table) | |
return sql | |
def _sql_drop_table(table): | |
sql = "DROP TABLE {}".format(table) | |
return sql | |
def _sql_insert_ohlcv(table): | |
sql = 'INSERT INTO {} VALUES (?,?,?,?,?,?,?)'.format(table) | |
return sql | |
def _sql_exists_entry(table): | |
sql = "SELECT exists(select 1 from {})".format(table) | |
return sql |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment