Last active
April 26, 2023 15:14
-
-
Save hellovertex/6974fef5c248d30fa5c1161a309cfd86 to your computer and use it in GitHub Desktop.
Faust App that redirects websocket traffic via Kafka broker + fin indicators
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
# acronyms will be changed to lowercase for get-method calls in order to conform with pep-style | |
indicators = [ | |
'ABANDS', # Acceleration Bands | |
'AD', # Accumulation/Distribution | |
'ADX', # Average Directional Movement | |
'AMA', # Adaptive Moving Average | |
'APO', # Absolute Price Oscillator | |
'AR', # Aroon | |
'ARO', # Aroon Oscillator | |
'ATR', # Average True Range | |
'AVOL', # Volume on the Ask | |
'BAVOL', # Volume on the Bid and Ask | |
'BBANDS', # Bollinger Bands | |
'BVA', # Bar Value Area | |
'BVOL', # Bid Volume | |
'BW', # Band Width | |
'CCI', # Commodity Channel Index | |
'CMO', # Chande Momentum Oscillator | |
'DEMA', # Double Exponential Moving Average | |
'DMI', # Directional Movement Indicators | |
'EMA', # Exponential Moving Average | |
'FILL', # Fill Indicator | |
'ICH', # Ichimoku | |
'KC', # Keltner Channel | |
'LR', # Linear Regression | |
'LRA', # Linear Regression Angle | |
'LRI', # Linear Regression Intercept | |
'LRM', # Linear Regression Slope | |
'MACD', # Moving Average Convergence Divergence | |
'MAX', # Max | |
'MFI', # Money Flow Index | |
'MIDPNT' # Midpoint | |
'MIDPRI', # Midprice | |
'MIN', # Min | |
'MINMAX', # MinMax | |
'MOM', # Momentum | |
'NATR', # Normalized True Average | |
'OBV', # On Balance Volume | |
'PC', # Price Channel | |
'PLT', # PLOT | |
'PPO', # Percent Price Oscillator | |
'PVT', # Price Volume Trend | |
'ROC', # Rate of Change | |
'ROC100', # Rate of Change 100 | |
'ROCP', # Rate of Change Percentage | |
'ROCR', # Rate of Change Rate | |
'RSI', # Relative Strength Index | |
'S_VOL', # Session Volume | |
'SAR', # Parabolic Stop and Reverse | |
'SAVOL', # Session Cumulative Ask | |
'SBVOL', # Session Cumulative Bid | |
'SMA', # Simple Moving Average | |
'STDDEV', # Standard Deviation | |
'STOCH', # Stochastic | |
'StochF', # Stochastic Fast | |
'T3', # T3 | |
'TEMA', # Triple Exponential Moving Average | |
'TRIMA', # Triangular Moving Average | |
'TRIX', # Triple Exponential Moving Average Oscillator | |
'TSF', # Time Series Forecast | |
'TTCVD', # TT Cumulative Vol Delta | |
'ULTOSC', # Ultimate Oscillator | |
'VAP', # Volume at Price | |
'VOLUME', # Volume | |
'VolDel', # Volume Delta | |
'VWAP', # Volume Weighted Average Price | |
'WillR', # Williams % R | |
'WMA', # Weighted Moving Average | |
'WWS', # Welles Wilders Smoothing Average | |
] | |
# indicator default params | |
params_macd = {'fast_length': 12,'slow_length': 26,'smoothing': 9} | |
params_rsi = {'length': 14,'source': 'closep'} | |
params_sma = {'length': 9,'source': 'closep'} | |
params_stoch = {'k': 14,'d': 3,'smooth': 3} | |
params_willr = {'length': 14} | |
params_bbands = {'length': 20,'source': 'closep','stddev': 2.0,'offset': 0} | |
def get_rsi(ohlcv_data: pd.DataFrame, length: int = 14, source: str = 'closep') -> pd.DataFrame: | |
""" | |
Computes Relative Strength Indices, c.f. https://www.tradingview.com/wiki/Relative_Strength_Index_(RSI). | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have keys 'timestamp' and 'source'. The latter storing the prices. | |
:param length: window size for RSI calculation | |
:param source: Determines which price is used for computation. Default is close price. | |
:return: Dataframe with Index(['timestamp', 'date', f'RSI({length}, {source})'] | |
""" | |
# temporary | |
up = [] | |
down = [] | |
timestamps = ohlcv_data['timestamp'] | |
dts = ohlcv_data['date'] | |
avg_up = [] | |
avg_down = [] | |
rs = [] | |
rsis = [] | |
for i in range(1, len(ohlcv_data)): | |
# calculate current price change | |
curr_price = float(ohlcv_data[source][i]) | |
prev_price = float(ohlcv_data[source][i - 1]) | |
d = curr_price - prev_price | |
if d >= 0: | |
up.append(d) | |
down.append(0) | |
else: | |
down.append(-d) | |
up.append(0) | |
# append moving average | |
if i == length: | |
# First use simple moving average | |
avg_up.append(float(sum(up) / length)) | |
avg_down.append(float(sum(down) / length)) | |
if i > length: | |
# Then use previous moving average | |
# set indices | |
prev = i - length - 1 | |
curr = i - 1 | |
mvg_avg_up = (avg_up[prev] * (length - 1) + up[curr]) / length | |
mvg_avg_down = (avg_down[prev] * (length - 1) + down[curr]) / length | |
avg_up.append(mvg_avg_up) | |
avg_down.append(mvg_avg_down) | |
# calculate RSI | |
if mvg_avg_down == 0: | |
mvg_avg_down = 0.001 | |
r = mvg_avg_up / mvg_avg_down | |
rsi = 100 - (100 / (r + 1)) | |
rs.append(r) | |
rsis.append(rsi) | |
# to avoid the rsi to be set-off, we must cut the first length entries for which no rsi is available, due to how | |
# it is calculated | |
return pd.DataFrame(list(zip(timestamps[length+1:], dts[length+1:], rsis)), | |
columns=['timestamp', 'date', f'RSI({length}, {source})']) | |
def get_sma(ohlcv_data: pd.DataFrame, length: int = 9, source: str = 'closep', offset: int=0): | |
""" | |
Computes Simple Moving Averages: | |
c.f. https://www.tradingview.com/wiki/Moving_Average | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have keys 'timestamp' and 'source'. The latter storing the prices. | |
:param length: time window for moving average. Default is 9 | |
:param source: Determines which price is used for computation. Default is close price. | |
:param offset: Changing this number will move the Moving Average relative to the current market. 0 is the default. | |
:return: Dataframe with Index(['timestamp', 'date', f'SMA({length})'] | |
""" | |
ohlcv_data[f'SMA({length}, {source})'] = ohlcv_data.rolling(window=length)[source].mean() | |
return ohlcv_data[length:][['timestamp', 'date', f'SMA({length}, {source})']] | |
def get_willr(ohlcv_data: pd.DataFrame, length: int=14, high='highp', low='lowp', close='closep') -> pd.DataFrame: | |
""" | |
Computes Stochastic Oscillator Indicator: | |
c.f. https://www.tradingview.com/wiki/Williams_%25R_(%25R)#DEFINITION | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have key 'timestamp' | |
:param length: look-back period for highest highs and lowest lows | |
:param high: Name of the column containing high price | |
:param low: Name of the column containing low price | |
:param close: Name of the column containing close price | |
:return: Dataframe with Index(['date, timestamp', f'WILLIAMS%R({length})'] | |
""" | |
cp = pd.DataFrame() | |
cp['highest_high'] = ohlcv_data.rolling(window=length)[high].max() | |
cp['lowest_low'] = ohlcv_data.rolling(window=length)[low].min() | |
tmp1 = cp['highest_high'] - ohlcv_data[close] | |
tmp2 = cp['highest_high'] - cp['lowest_low'] | |
ohlcv_data[f'WILLIAMS%R({length})'] = (tmp1 / tmp2) * (-100) | |
cols = ['date', 'timestamp', f'WILLIAMS%R({length})'] | |
return ohlcv_data[length:][cols] | |
def get_ema(ohlcv_data: pd.DataFrame, length: int=9, source: str='closep', offset: int=0) -> pd.DataFrame: | |
""" | |
Computes Exponential Moving Averages: | |
c.f. https://www.tradingview.com/wiki/Moving_Average#Exponential_Moving_Average_.28EMA.29 | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have keys 'timestamp' and 'source'. The latter storing the prices. | |
:param length: time window for moving average | |
:param source: Determines which price is used for computation. Default is close price. | |
:param offset: Changing this number will move the Moving Average relative to the current market. 0 is the default. | |
:return: Dataframe with Index(['timestamp', 'date', f'EMA({length})'] | |
""" | |
# 0. return values | |
ema = [] | |
timestamps = ohlcv_data['timestamp'][length-1+offset:] | |
datetimes = ohlcv_data['date'][length-1+offset:] | |
# 1. Calculate SMA | |
sma = ohlcv_data[source][:length].mean() | |
ema.append(sma) | |
# 2. Calculate the Multiplier | |
mul = 2.0/(length+1) | |
# 3. Calculate the EMA | |
for i in range(0, len(ohlcv_data[length:])): | |
prev_ema = ema[i] | |
curr_p = ohlcv_data[source][length + i] | |
curr_ema = (curr_p - prev_ema) * mul + prev_ema | |
ema.append(curr_ema) | |
return pd.DataFrame(list(zip(timestamps, datetimes, ema)), columns=['timestamp', 'date', f'EMA({length})']) | |
def get_macd(ohlcv_data: pd.DataFrame, fast_length: int=12, slow_length=26, source: str='closep', smoothing: int=9) \ | |
-> pd.DataFrame: | |
""" | |
Computes Moving Average Convergence Divergence. | |
C.f. https://www.tradingview.com/wiki/MACD_(Moving_Average_Convergence/Divergence) | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have keys 'timestamp' and 'source'. The latter storing the prices. | |
:param fast_length: time window for fast moving average | |
:param slow_length: time window for slow moving average | |
:param source: Determines which price is used for computation. Default is close price. | |
:param smoothing: The time period for the EMA of the MACD Line (Signal Line). 9 Days is the default. | |
:return: Dataframe with | |
Index( | |
[ | |
'timestamp', | |
'date', | |
f'MACD({fast_length}, {slow_length}, {smoothing})', | |
f'Signal({smoothing})', | |
'MACD_Histogram' | |
] | |
""" | |
# calculate Moving Averages | |
fast_EMA = get_ema(ohlcv_data, length=fast_length, source=source) | |
slow_EMA = get_ema(ohlcv_data, length=slow_length, source=source) | |
tmp = pd.merge(ohlcv_data, fast_EMA) | |
merged = pd.merge(tmp, slow_EMA) | |
# calculate MACD | |
idx_fast = f'EMA({fast_length})' # see return interface of get_ema(...) | |
idx_slow = f'EMA({slow_length})' # see return interface of get_ema(...) | |
idx_MACD = f'MACD({fast_length}, {slow_length})' | |
merged[idx_MACD] = merged[idx_fast] - merged[idx_slow] | |
# calculate Signal | |
signal_raw = get_ema(merged, length=smoothing, source=f'MACD({fast_length}, {slow_length})') | |
signal = signal_raw.rename(index=str, columns={f'EMA({smoothing})': f'Signal({smoothing})'}) | |
result = pd.merge(merged, signal) | |
# calculate MACD-histogram | |
idx_MACD = f'MACD({fast_length}, {slow_length})' | |
idx_signal = f'Signal({smoothing})' | |
result['MACD_Histogram'] = result[idx_MACD] - result[idx_signal] | |
# remove unnecessary columns | |
cols = [ | |
'timestamp', | |
'date', | |
f'MACD({fast_length}, {slow_length})', | |
f'Signal({smoothing})', | |
'MACD_Histogram' | |
] | |
return result[cols] | |
def get_stoch(ohlcv_data: pd.DataFrame, k: int=14, d: int=3, smooth: int=3, high='highp', low='lowp', | |
close='closep') -> pd.DataFrame: | |
""" | |
Computes Stochastic Oscillator Indicator: | |
c.f. https://www.tradingview.com/wiki/Stochastic_(STOCH)#INPUTS | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have key 'timestamp' | |
:param k: time window for moving average | |
:param d: Determines which price is used for computation. Default is close price. | |
:param smooth: Changing this number will move the Moving Average relative to the current market. 0 is the default. | |
:param high: Name of the column containing high price | |
:param low: Name of the column containing low price | |
:param close: Name of the column containing close price | |
:return: Dataframe with Index(['timestamp', 'date', f'STOCH_K={k}', f'STOCH_D={d}'] | |
""" | |
cp = pd.DataFrame() | |
cp['highest_high'] = ohlcv_data.rolling(window=k)[high].max() | |
cp['lowest_low'] = ohlcv_data.rolling(window=k)[low].min() | |
# fast stochastics | |
cp['tmp1'] = ohlcv_data[close] - cp['lowest_low'] | |
cp['tmp2'] = cp['highest_high'] - cp['lowest_low'] | |
ohlcv_data['%K'] = (cp['tmp1'] / cp['tmp2']) * 100 | |
ohlcv_data[f'STOCH_K={k}'] = ohlcv_data.rolling(window=smooth)['%K'].mean() | |
# slow stochastics | |
ohlcv_data[f'STOCH_D={d}'] = ohlcv_data.rolling(window=d)[f'STOCH_K={k}'].mean() | |
cols = ['timestamp', 'date', f'STOCH_K={k}', f'STOCH_D={d}'] | |
return ohlcv_data[k+d:][cols] | |
def get_bbands(ohlcv_data: pd.DataFrame, length: int=20, source: str='closep', stddev: float=2.0, offset: int=0) -> \ | |
pd.DataFrame: | |
""" | |
Computes Bollinger Bands: | |
c.f. https://www.tradingview.com/wiki/Bollinger_Bands_(BB)#INPUTS | |
Requires a DataFrame with a key equal to {source} and keys 'timestamp' and 'date'. Using a MultiIndexed df, | |
will cause a KeyError. Assumes the dataframe to be sorted by timestamp, with the highest index | |
corresponding to the latest entry. | |
:param ohlcv_data: pandas Dataframe that needs to have keys 'timestamp' and 'source'. The latter storing the prices. | |
:param length: time window for moving average | |
:param source: Determines which price is used for computation. Default is close price. | |
:param stddev: distance to the middle band. Default is 2. | |
:param offset: Changing this number will move the Moving Average relative to the current market. 0 is the default. | |
:return: Dataframe with Index(['timestamp', 'date', middleband, lowerband, upperband]) | |
""" | |
# compute deviation from mean in look back period | |
ohlcv_data[f'MIDDLE_BBAND({length}, {source}, {stddev})'] = ohlcv_data.rolling(window=20)[source].mean() | |
# compute standard deviation from mean in look back period | |
ohlcv_data['std_dev'] = stddev * ohlcv_data.rolling(window=20)[source].std() | |
# compute bollinger bands | |
ohlcv_data[f'UPPER_BBAND({length}, {source}, {stddev})'] = ohlcv_data[f'MIDDLE_BBAND({length}, {source}, ' \ | |
f'{stddev})'] + ohlcv_data['std_dev'] | |
ohlcv_data[f'LOWER_BBAND({length}, {source}, {stddev})'] = ohlcv_data[f'MIDDLE_BBAND({length}, {source}, ' \ | |
f'{stddev})'] - ohlcv_data['std_dev'] | |
cols = [ | |
'timestamp', | |
'date', | |
f'MIDDLE_BBAND({length}, {source}, {stddev})', | |
f'UPPER_BBAND({length}, {source}, {stddev})', | |
f'LOWER_BBAND({length}, {source}, {stddev})' | |
] | |
return ohlcv_data[length+offset:][cols] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import string | |
import logging | |
import argparse | |
import time | |
from collections import Counter | |
import sys | |
from telegram.client import Telegram | |
import csv | |
from datetime import datetime | |
import numpy as np | |
import os | |
def timestamp_to_date(ts: int) -> str: | |
return datetime.utcfromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S') | |
def setup_logging(level=logging.INFO): | |
root = logging.getLogger() | |
root.setLevel(level) | |
ch = logging.StreamHandler(sys.stdout) | |
ch.setLevel(logging.DEBUG) | |
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') | |
ch.setFormatter(formatter) | |
root.addHandler(ch) | |
if __name__ == '__main__': | |
setup_logging(level=logging.INFO) | |
parser = argparse.ArgumentParser() | |
parser.add_argument('api_id', help='API id', type=int) # https://my.telegram.org/apps | |
parser.add_argument('api_hash', help='API hash') | |
parser.add_argument('phone', help='Phone') | |
parser.add_argument('--lib_path', help='path to libtdjson.so.1.6.0', type=str, | |
default='/home/hellovertex/Documents/github.com/tdlib/td/tdlib/lib/libtdjson.so.1.6.0') | |
args = parser.parse_args() | |
print(f'api_id = {args.api_id}, type={type(args.api_id)}') | |
print(f'api_hash = {args.api_hash}, type={type(args.api_hash)}') | |
print(f'phone = {args.phone}, type={type(args.phone)}') | |
tg = Telegram( | |
api_id=args.api_id, | |
api_hash=args.api_hash, | |
phone=args.phone, | |
library_path=args.lib_path, | |
database_encryption_key='hanabi', | |
) | |
# you must call login method before others | |
tg.login() | |
def get_all_chat_ids(): | |
offset_order = 9223372036854775807 | |
offset_chat_id = 0 | |
chat_ids = list() | |
old_len = -1 | |
while True: | |
# get chat_ids from current offset_order, offset_chat_id | |
res = tg.get_chats(offset_order=offset_order, offset_chat_id=offset_chat_id, limit=100) | |
res.wait() | |
chat_ids += res.update['chat_ids'] | |
# remove duplicates from chat_ids | |
chat_ids = list(dict.fromkeys(chat_ids)) | |
# get id, order from last chat | |
res = tg.get_chat(chat_ids[-1]) | |
res.wait() | |
last_chat = res.update | |
# update offset_order, offset_chat_id | |
offset_chat_id = last_chat['id'] | |
offset_order = last_chat['order'] | |
new_len = len(chat_ids) | |
print(f'old_len = {old_len}, new_len = {new_len}') | |
if old_len == new_len: | |
print(f'Fetched {len(chat_ids)} chats.') | |
break | |
old_len = new_len | |
return chat_ids | |
def extract_supergroup_chats(chat_ids): | |
""" chat keys | |
dict_keys( | |
['@type', 'id', 'type', 'chat_list', 'title', 'photo', 'permissions', 'last_message', 'order', | |
'is_pinned', 'is_marked_as_unread', 'is_sponsored', 'has_scheduled_messages', | |
'can_be_deleted_only_for_self', 'can_be_deleted_for_all_users', 'can_be_reported', | |
'default_disable_notification', 'unread_count', 'last_read_inbox_message_id', | |
'last_read_outbox_message_id', 'unread_mention_count', 'notification_settings', | |
'pinned_message_id', 'reply_markup_message_id', 'client_data', '@extra']) | |
""" | |
crypto_chat_candidates = list() | |
for id in chat_ids: | |
res = tg.get_chat(id) | |
res.wait() | |
chat = res.update | |
if chat['type']['@type'] == 'chatTypeSupergroup': | |
if not chat['type']['is_channel']: | |
crypto_chat_candidates.append(chat) | |
return crypto_chat_candidates | |
def get_chat_history(chat_obj): | |
""" supergroup keys | |
dict_keys( | |
['@type', 'description', 'member_count', 'administrator_count', 'restricted_count', 'banned_count', | |
'linked_chat_id', 'slow_mode_delay', 'slow_mode_delay_expires_in', 'can_get_members', | |
'can_set_username', 'can_set_sticker_set', 'can_set_location', 'can_view_statistics', | |
'is_all_history_available', 'sticker_set_id', 'invite_link', 'upgraded_from_basic_group_id', | |
'upgraded_from_max_message_id', '@extra']) | |
""" | |
def _get_history(id): | |
from_message_id = 0 | |
offset = 0 | |
messages = list() | |
completed = False | |
old_ids = [] | |
# loop | |
while not completed: | |
res = tg.get_chat_history(id, limit=100, from_message_id=from_message_id, offset=0) | |
res.wait() | |
msgs = res.update['messages'] | |
if not msgs: | |
break | |
first = msgs[0] | |
last = msgs[-1] | |
from_message_id = last['id'] | |
print(f'number of messages in buffer = {len(msgs)}') | |
print(f'date 1st: {timestamp_to_date(first["date"])} -- date last = {timestamp_to_date(last["date"])}') | |
ids = np.array([msg['id'] for msg in msgs]) | |
have_duplicates = bool(int(np.sum(ids == old_ids))) | |
assert not have_duplicates | |
old_ids = ids | |
# todo consider writing directly to file instead | |
messages += msgs | |
print(f'total number of messages fetched = {len(messages)}') | |
time.sleep(.5) | |
return messages | |
# get supergroup id from chat_obj | |
supergroup_id = chat_obj['type']['supergroup_id'] | |
is_channel = chat_obj['type']['is_channel'] | |
# get supergroup info by supergroup id | |
res = tg.get_supergroup_full_info(supergroup_id) | |
res.wait() | |
group = res.update | |
hist = None | |
if group['is_all_history_available'] and group['member_count'] >= 500 and not is_channel: | |
# get chat_history | |
hist = _get_history(chat_obj['id']) | |
else: | |
skip_reason = f'group member count is too low: {group["member_count"]}' if group[ | |
'is_all_history_available'] else f'history is not available or group is channel' | |
print('...Skipping chat because ' + skip_reason) | |
return hist | |
chat_ids = get_all_chat_ids() | |
crypto_chats = extract_supergroup_chats(chat_ids) | |
for chat in crypto_chats: | |
# If no csv file for current history exists yet | |
if not os.path.isfile(f'{chat["title"]}.csv'): | |
print(f'getting history for chat: {chat["title"]}...') | |
hist = get_chat_history(chat) | |
if hist: | |
hist[0]['forward_info'] = '' | |
hist[0]['reply_markup'] = '' | |
keys = hist[0].keys() | |
# Create csv file for chat history | |
with open(f'{chat["title"]}.csv', 'w') as output_file: | |
print(f'Writing history of {chat["title"]} to .csv file...') | |
dict_writer = csv.DictWriter(output_file, keys) | |
dict_writer.writeheader() | |
dict_writer.writerows(hist) | |
else: | |
print(f'...History for chat {chat["title"]} is None') | |
else: | |
print(f'Skipping group {chat["title"]} because .csv file exists already') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
1. Launch Kafka: | |
- $KAFKA_HOME/bin/zookeeper-server-start.sh $KAFKA_HOME/config/zookeeper.properties | |
- $KAFKA_HOME/bin/kafka-server-start.sh $KAFKA_HOME/config/server.properties | |
run via faust -A <filename> worker -l info | |
2. Faust library: | |
faust.App.agent: - main processing actor in Faust App | |
""" | |
import faust | |
from database import DatabaseHandler | |
from config import DB_PATH | |
# establish connection to database at db_path, create it if does not exist | |
dbh = DatabaseHandler(DB_PATH) | |
# create table named 'trades' [only if it does not exist] | |
dbh.create_table_trades() | |
# todo: serialize incoming data and write to database | |
app = faust.App( | |
'client', | |
broker='kafka://localhost:9092', | |
value_serializer='raw', | |
autodiscover=False, | |
# topic_partitions=3, | |
# broker_commit_every=100, | |
# stream_buffer_maxsize=65536, | |
) | |
topic = app.topic('ticks') | |
@app.agent(topic) | |
async def echo(ticks: faust.Stream): | |
async for tick in ticks: | |
data_mock = [] | |
dbh.insert_many('trades', data_mock) | |
print('echo from app 1 that wrote to db') | |
print(f'also {tick} has been written') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sqlite3 | |
from typing import Optional | |
import sys, getopt, os | |
class DatabaseHandler: | |
def __init__(self, abs_path: str): | |
""" Creates or uses existing database located in {abs_path} and stores its connection """ | |
try: | |
self.conn = sqlite3.connect(str(abs_path)) | |
assert self.conn is not None | |
except sqlite3.Error as e: | |
# database could not be created (or opened) -> abort | |
print(e) | |
sys.exit(1) | |
def create_table_trades(self) -> None: | |
try: | |
sql = 'CREATE TABLE trades(event_type TEXT, ' \ | |
'event_time INT, ' \ | |
'symbol TEXT, ' \ | |
'trade_id INT NOT NULL PRIMARY KEY,' \ | |
'price REAL,' \ | |
'quantity REAL,' \ | |
'buyer_order_id INT,' \ | |
'seller_order_id INT,' \ | |
'is_market_maker BOOLEAN CHECK (is_market_maker IN (0, 1)),' \ | |
'ignore BOOLEAN CHECK (ignore IN (0, 1)))' | |
self.conn.cursor().execute(sql) | |
# conn.commit() | |
except sqlite3.Error as e: | |
print('creation failed:') | |
print(e) | |
def remove_table(self, tablename: str) -> None: | |
sql = "DROP TABLE {}".format(tablename) | |
try: | |
self.conn.cursor().executescript(sql) | |
self.conn.commit() | |
except sqlite3.Error as e: | |
print('removing failed:') | |
print(e) | |
def insert_many(self, tablename, data, mode='replace') -> bool: | |
assert mode in ['ignore', 'replace'] | |
try: | |
if mode == "ignore": | |
self.conn.cursor().executemany( | |
"INSERT OR IGNORE INTO {} VALUES (?,?,?,?,?,?,?,?,?,?)".format(tablename), data) | |
elif mode == "replace": | |
self.conn.cursor().executemany( | |
"INSERT OR REPLACE INTO {} VALUES (?,?,?,?,?,?,?,?,?,?)".format(tablename), data) | |
self.conn.commit() | |
return True | |
except sqlite3.Error as e: | |
print('insertion failed:') | |
print(e) | |
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
1. Launch Kafka: | |
- $KAFKA_HOME/bin/zookeeper-server-start.sh $KAFKA_HOME/config/zookeeper.properties | |
- $KAFKA_HOME/bin/kafka-server-start.sh $KAFKA_HOME/config/server.properties | |
run via faust -A <filename> worker -l info | |
consumer1 run via faust -A <filename> worker -l info -p 6066 | |
consumer2 run via faust -A <filename> worker -l info -p 6067 | |
etc... | |
2. Faust library: | |
faust.App.agent: - main processing actor in Faust App | |
- unary async function - receives stream as its argument | |
faust.Stream: - async python generator | |
- abstractions over a kafka topic | |
- can apply operations on the stream (filter(), take(5)) | |
faust.Record: - data transfer object: Represents events via python classes inheriting it | |
- serialization, deserialization | |
From: https://faust.readthedocs.io/en/latest/userguide/settings.html#guide-settings | |
app configuration: | |
broker: - only supported production transport is kafka://, | |
- uses the aiokafka client under the hood, for consuming and producing messages | |
- can specify multiple hosts, e.g. broker='kafka://kafka1.example.com:9092;kafka2.example.com:9092' [fault tolerance] | |
store: - default is memory:// | |
- production should use rocksdb:// | |
processing_guarantee: “at_least_once” (default) and “exactly_once”. | |
Note that if exactly-once processing is enabled consumers are configured with isolation.level="read_committed" | |
and producers are configured with retries=Integer.MAX_VALUE and enable.idempotence=true per default. | |
Note that by default exactly-once processing requires a cluster of at least three brokers what is the recommended setting for production. | |
For development you can change this, by adjusting broker setting transaction.state.log.replication.factor to the number of brokers you want to use. | |
autodiscover: set to false, see https://faust.readthedocs.io/en/latest/userguide/settings.html#autodiscover | |
""" | |
from kafka import KafkaProducer | |
from websockets.client import WebSocketClientProtocol | |
from websockets.exceptions import ConnectionClosed | |
from mode import Service | |
import websockets | |
import faust | |
from faust import App | |
import random | |
import numpy as np | |
import asyncio | |
import websocket | |
import time | |
import logging | |
sock_addr = "wss://stream.binance.com:9443/ws/btcusdt@trade" | |
class WebSocketClient(): | |
def __init__(self, sock_addr, **kwargs): | |
self.producer = KafkaProducer() | |
self.sock_addr = sock_addr | |
# need to implement as service when we want to gracefully shutdown | |
async def connect(self): | |
''' | |
returns a WebSocketClientProtocol, used to send and receive messages | |
''' | |
self.connection = await websockets.client.connect(self.sock_addr) | |
if self.connection.open: | |
print('Connection stablished. Client correcly connected') | |
return self.connection | |
async def sendMessage(self, message): | |
await self.connection.send(message) | |
async def receiveMessage(self, connection: WebSocketClientProtocol): | |
while True: | |
try: | |
msg = await connection.recv() | |
self.producer.send(topic='ticks', value=f'{msg}'.encode()) | |
except websockets.exceptions.ConnectionClosed: | |
print('Connection with server closed') | |
#todo: handle 24h disconnects | |
break | |
if __name__ == '__main__': | |
client = WebSocketClient("wss://stream.binance.com:9443/ws/btcusdt@trade") | |
loop = asyncio.get_event_loop() | |
# Start connection and get client connection protocol | |
connection = loop.run_until_complete(client.connect()) | |
# Start listener | |
tasks = [ | |
asyncio.ensure_future(client.receiveMessage(connection)), | |
] | |
loop.run_until_complete(asyncio.wait(tasks)) | |
# producer = KafkaProducer() | |
# i = 0 | |
# while True: | |
# producer.send(topic='ticks', value=str(i).encode()) | |
# producer.send(topic='ticks_ml', value=str(i).encode()) | |
# print(f'sent both msgs at iteration {i}') | |
# i+=1 | |
# time.sleep(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
from collections import namedtuple | |
import sqlite3 | |
import pandas as pd | |
import numpy as np | |
import enum | |
from src import database | |
from src.indicators import get_rsi as rsi, get_willr as willr, get_macd as macd, get_stoch as stoch | |
from typing import Dict, List, Tuple, Optional # , TypedDict | |
from src import env_wrapper | |
STR_SYMBOL = str | |
DF_MARKET_DATA = pd.DataFrame # refers to only-OHLCV data | |
DF_X10DMARKET_DATA = pd.DataFrame # refers to OHLCV-data + Indicators | |
Trade = namedtuple('Trade', ['t_step', | |
'datestring', | |
'symbol', | |
'symbol_idx', | |
'type', | |
'buy_amt', | |
'sell_percent', | |
'price', | |
'net_value', | |
'info']) | |
class ActionType(enum.IntEnum): | |
NOOP = 0 | |
BUY = 1 | |
SELL = 2 | |
class ActionParser(object): | |
INSTRUMENT = Optional[int] | |
AMOUNT = int | |
""" Needs rework """ | |
def __init__(self, n_instruments_tradeable, | |
n_possible_buy_amts, | |
n_possible_sell_amts, | |
buy_amts, | |
sell_amts_percent): | |
""" | |
Implicit encoding: | |
action Instr in [A,B] Buys = 2 Sells = 3 Action % 2 Action % 3 floordiv2 floordiv3 | |
0 A 5 0 0 | |
1 A 10 1 | |
2 B 5 0 1 | |
3 B 10 1 | |
4 A 25 0 0 | |
5 A 50 1 | |
6 A 100 2 | |
7 B 25 0 1 | |
8 B 50 1 | |
9 B 100 2 | |
10 NOOP | |
""" | |
self.n_buy_actions = n_instruments_tradeable * n_possible_buy_amts | |
self.n_sell_actions = n_instruments_tradeable * n_possible_sell_amts | |
self.noop = self.n_buy_actions + self.n_sell_actions | |
self._max_actions = self.n_buy_actions + self.n_sell_actions + 1 | |
# prior | |
self.n_instruments_tradeable = n_instruments_tradeable | |
self.n_possible_buy_amts = n_possible_buy_amts | |
self.n_possible_sell_amts = n_possible_sell_amts | |
self.buy_amts = buy_amts | |
self.sell_amts_percents = sell_amts_percent | |
@property | |
def max_actions(self): | |
return self._max_actions | |
def _get_action_type(self, action: int) -> ActionType: | |
""" | |
""" | |
# one buy action per instrument and buy amount | |
if 0 <= action < self.n_buy_actions: | |
return ActionType.BUY | |
elif self.n_buy_actions <= action < self.noop: | |
return ActionType.SELL | |
elif action == self.noop: | |
return ActionType.NOOP | |
else: | |
raise ValueError | |
def _get_target_instr(self, action: int, action_type: int) -> INSTRUMENT: | |
if action_type == ActionType.BUY: | |
return action // self.n_possible_buy_amts | |
elif action_type == ActionType.SELL: | |
return (action - self.n_buy_actions) // self.n_possible_sell_amts | |
else: | |
raise ValueError | |
def _get_target_amount(self, action: int, a_type: int) -> AMOUNT: | |
if a_type == ActionType.BUY: | |
return self.buy_amts[action % self.n_possible_buy_amts] | |
elif a_type == ActionType.SELL: | |
return self.sell_amts_percents[(action - self.n_buy_actions) % self.n_possible_sell_amts] | |
else: | |
raise ValueError | |
def parse_int_action(self, action: int) -> Tuple[ActionType, INSTRUMENT, AMOUNT]: | |
""" | |
Implicit encoding: | |
action Instr in [A,B] Buys = 2 Sells = 3 ActionParser % 2 ActionParser % 3 floordiv2 floordiv3 | |
0 A 5 0 0 | |
1 A 10 1 | |
2 B 5 0 1 | |
3 B 10 1 | |
4 A 25 0 0 | |
5 A 50 1 | |
6 A 100 2 | |
7 B 25 0 1 | |
8 B 50 1 | |
9 B 100 2 | |
10 NOOP | |
""" | |
assert 0 <= action <= self.noop, f'Expected action between 0 and {self.noop}. Got {action} instead.' | |
a_type = self._get_action_type(action) | |
if a_type == ActionType.NOOP: | |
return ActionType.NOOP, None, 0 | |
else: | |
target_instr = self._get_target_instr(action, a_type) | |
target_amt = self._get_target_amount(action, a_type) | |
return a_type, target_instr, target_amt | |
class CreditInfo(object): | |
INITIAL_CREDIT = 1000 | |
INITIAL_TIMESTEP = 0 | |
# INDICES FOR METADATA | |
REMAINING_CREDIT = -2 | |
CURRENT_TIMESTEP = -1 | |
DEFAULT_BUY_AMT = 50 | |
class Instrument(object): | |
""" Container for trade history """ | |
def __init__(self, str_value): | |
self.str_value = str_value | |
self._trade_history = list() # of Trade tuples | |
self._balance = 0 | |
self._net_value = 0 | |
self._net_gain = 0 | |
@property | |
def balance(self): | |
return self._balance | |
@balance.setter | |
def balance(self, value): | |
self._balance = value | |
@property | |
def net_gain(self): | |
return self._net_gain | |
@net_gain.setter | |
def net_gain(self, value): | |
self._net_gain = value | |
@property | |
def net_value(self): | |
return self._net_value | |
@net_value.setter | |
def net_value(self, value): | |
self._net_value = value | |
@property | |
def trade_history(self): | |
return self._trade_history | |
class Credit(object): | |
def __init__(self, symbols_per_state, taker_fee): | |
self.symbols_per_state = symbols_per_state | |
self.columns = [f'symbol_{i}' for i in range(symbols_per_state)] \ | |
+ ['remaining_credit'] + ['current_timestep'] | |
# instruments are mapped from shuffled symbols for variance reduction | |
self.instruments = [Instrument(f'symbol_{i}') for i in range(symbols_per_state)] | |
self.metadata = [CreditInfo.INITIAL_CREDIT, CreditInfo.INITIAL_TIMESTEP] | |
self.fee = taker_fee # = 998/1000 = 0.2% | |
def reset(self): | |
self.instruments = [Instrument(f'symbol_{i}') for i in range(self.symbols_per_state)] | |
self.metadata = [CreditInfo.INITIAL_CREDIT, CreditInfo.INITIAL_TIMESTEP] | |
@property | |
def data(self): | |
return [pair.balance for pair in self.instruments] + \ | |
[pair.net_value for pair in self.instruments] + \ | |
[pair.net_gain for pair in self.instruments] + self.metadata | |
def has_cash(self, amount): | |
return self.metadata[CreditInfo.REMAINING_CREDIT] >= amount | |
def is_available(self, target_symbol: int): | |
return self.instruments[target_symbol].balance > 0 | |
def apply_trade_and_compute_reward(self, trade: Trade) -> int: | |
""" Computes reward using simple gain per balance metric. | |
More specifically, the reward is returned as the difference of the | |
gain per balance before and after the action taken, if the action was a SELL action and 0 otherwise. """ | |
assert trade.type in [ActionType.BUY, ActionType.SELL] | |
target = self.instruments[trade.symbol_idx] | |
target.trade_history.append(trade) # we dont make use of the history yet | |
self.metadata[CreditInfo.CURRENT_TIMESTEP] = trade.t_step | |
profit = 0 | |
if trade.type == ActionType.BUY: | |
self.metadata[CreditInfo.REMAINING_CREDIT] -= trade.net_value | |
# (50$ / 11000$) * fee | |
target.balance += trade.buy_amt * self.fee # = 998/1000 = 0.2% | |
target.net_gain -= trade.net_value | |
target.net_value = target.balance * trade.price | |
elif trade.type == ActionType.SELL: | |
# for now, use simple gain per balance metric | |
old_value = target.net_gain / target.balance | |
sell_amt = trade.sell_percent * target.balance | |
target.balance -= sell_amt | |
gain = (trade.price * sell_amt) * self.fee # = 998/1000 = 0.2% | |
self.metadata[CreditInfo.REMAINING_CREDIT] += gain | |
target.net_gain += gain | |
target.net_value = target.balance * trade.price | |
if trade.sell_percent != 1.: | |
new_value = target.net_gain / target.balance | |
# gain_per_balance_old - gain_per_balance_new [These are negative values] | |
profit = old_value - new_value | |
else: | |
# when everything is sold, the gain_per_balance is equal to the net_gain | |
profit = target.net_gain | |
else: | |
raise ValueError | |
# print(trade) | |
# print(f'pairs = {[(p.balance, p.net_gain, p.net_value) for p in self.pairs]}') | |
# print(f'pairshistory = {[p.trade_history for p in self.pairs]}') | |
# print(f'PROFIIIIIT = {profit}') | |
return profit | |
class DefaultPreprocessor(object): | |
def __init__(self, market_data: Dict[STR_SYMBOL, DF_MARKET_DATA], | |
trajectory_length: int, resolution: str, symbols: List[str]): | |
""" | |
Args: market_data: | |
{'BTC/USDT': timestamp date openp ... lowp closep volume | |
0 1514757600000 2017-12-31 23:00:00 13930.00 ... 13869.14 13898.98 130.5177 | |
1 1514759400000 2017-12-31 23:30:00 13899.00 ... 13782.00 13820.50 155.7665 | |
2 1514761200000 2018-01-01 00:00:00 13829.97 ... 13645.03 13724.00 146.3225 | |
3 1514763000000 2018-01-01 00:30:00 13721.05 ... 13651.00 13716.36 100.1306 | |
4 1514764800000 2018-01-01 01:00:00 13715.65 ... 13400.01 13521.12 221.7524 | |
.. ... ... ... ... ... ... ... | |
495 1515654000000 2018-01-11 08:00:00 13228.66 ... 13000.81 13520.01 759.8050 | |
496 1515655800000 2018-01-11 08:30:00 13525.78 ... 13364.28 13530.42 604.6083 | |
497 1515657600000 2018-01-11 09:00:00 13530.42 ... 13530.42 13699.98 523.9869 | |
498 1515659400000 2018-01-11 09:30:00 13699.98 ... 13548.01 13656.49 504.0101 | |
499 1515661200000 2018-01-11 10:00:00 13656.00 ... 13554.18 13660.38 427.2661 | |
[max_len rows x 7 columns], | |
... | |
'ZRX/USDT': timestamp date openp highp lowp closep volume | |
0 1523937600000 2018-04-17 06:00:00 0.26 0.29 0.26 0.26 5478159.15 | |
1 1523939400000 2018-04-17 06:30:00 0.26 0.27 0.26 0.27 2665534.08 | |
2 1523941200000 2018-04-17 07:00:00 0.27 0.28 0.27 0.27 3742666.10 | |
3 1523943000000 2018-04-17 07:30:00 0.27 0.27 0.26 0.26 4575257.51 | |
4 1523944800000 2018-04-17 08:00:00 0.26 0.26 0.25 0.25 5629959.80 | |
.. ... ... ... ... ... ... ... | |
495 1524828600000 2018-04-27 13:30:00 0.29 0.29 0.29 0.29 628279.53 | |
496 1524830400000 2018-04-27 14:00:00 0.29 0.30 0.29 0.30 1479472.82 | |
497 1524832200000 2018-04-27 14:30:00 0.30 0.31 0.30 0.30 5599690.80 | |
498 1524834000000 2018-04-27 15:00:00 0.30 0.30 0.29 0.30 4113774.34 | |
499 1524835800000 2018-04-27 15:30:00 0.30 0.30 0.29 0.30 1919637.17 | |
[max_len rows x 7 columns] | |
} | |
resolution in [ | |
# Intradays | |
'1m','3m','5m','15m','30m','1h','2h','3h','4h','6h','8h','12h', | |
# macroscale | |
'1d','3d','1w','2w','1M' | |
] | |
trajectory_length: | |
number of observations with given resolution per trajectory | |
""" | |
self.trajectory_length = trajectory_length | |
self.resolution = resolution | |
self.symbols = symbols | |
# extend market_data with indicators and scale the volume using min-max scaling | |
market_data_x10d = self.extend_tables_with_indicator_data(market_data) | |
market_data_x10d_scaled = self.min_max_scale_volume_columns(market_data_x10d) | |
# load preprocessed ohlcv-tables into memory, | |
# step() and reset() calls from environment sample trajectories from these tables | |
self.market_data = self.drop_nas(market_data_x10d_scaled) # can be very large, depending on db | |
assert self.market_data is not None | |
self.int_instrument_to_str_key_symbol = dict() | |
@staticmethod | |
def extend_tables_with_indicator_data(market_data: Dict[STR_SYMBOL, DF_MARKET_DATA]) \ | |
-> Dict[STR_SYMBOL, DF_X10DMARKET_DATA]: | |
""" | |
Args: | |
market_data: see database.load_tables() docstring | |
dictionary mapping p symbols to pd.DataFrame with Index | |
Index(['timestamp', 'date', 'openp', 'highp', 'lowp', 'closep', 'volume',dtype='object') | |
Returns: | |
dictionary mapping p symbols to pd.DataFrame with Index | |
Index(['timestamp', 'date', 'openp', 'highp', 'lowp', 'closep', 'volume', | |
# EXTENDED WITH THE FOLLOWING INDICATOR DATA: | |
'RSI(14, closep)', 'MACD(12, 26)', 'Signal(9)', 'MACD_Histogram', | |
'WILLIAMS%R(14)', '%K', 'STOCH_K=14', 'STOCH_D=3'],dtype='object') | |
see src.indicators for computation of the new columns | |
{'BTC/USDT': timestamp date ... STOCH_K=14 STOCH_D=3 | |
0 1514899800000 2018-01-02 14:30:00 ... 87.876188 85.171782 | |
1 1514901600000 2018-01-02 15:00:00 ... 88.487763 87.249437 | |
2 1514903400000 2018-01-02 15:30:00 ... 79.756318 85.373423 | |
3 1514905200000 2018-01-02 16:00:00 ... 63.003012 77.082365 | |
4 1514907000000 2018-01-02 16:30:00 ... 57.811091 66.856807 | |
.. ... ... ... ... ... | |
416 1515654000000 2018-01-11 08:00:00 ... 55.866712 51.026282 | |
417 1515655800000 2018-01-11 08:30:00 ... 59.529662 54.726684 | |
418 1515657600000 2018-01-11 09:00:00 ... 67.167808 60.854727 | |
419 1515659400000 2018-01-11 09:30:00 ... 70.589146 65.762205 | |
420 1515661200000 2018-01-11 10:00:00 ... 73.594852 70.450602 | |
[421 rows x 15 columns], | |
'ADA/USDT': timestamp date openp ... %K STOCH_K=14 STOCH_D=3 | |
3 1524085200000 2018-04-18 23:00:00 0.26 ... 100.0 100.000000 100.000000 | |
4 1524087000000 2018-04-18 23:30:00 0.26 ... 100.0 100.000000 100.000000 | |
5 1524088800000 2018-04-19 00:00:00 0.26 ... 100.0 100.000000 100.000000 | |
6 1524090600000 2018-04-19 00:30:00 0.26 ... 100.0 100.000000 100.000000 | |
7 1524092400000 2018-04-19 01:00:00 0.26 ... 100.0 100.000000 100.000000 | |
.. ... ... ... ... ... ... ... | |
416 1524828600000 2018-04-27 13:30:00 0.29 ... 100.0 100.000000 100.000000 | |
417 1524830400000 2018-04-27 14:00:00 0.29 ... 100.0 100.000000 100.000000 | |
418 1524832200000 2018-04-27 14:30:00 0.30 ... 50.0 83.333333 94.444444 | |
419 1524834000000 2018-04-27 15:00:00 0.30 ... 50.0 66.666667 83.333333 | |
420 1524835800000 2018-04-27 15:30:00 0.30 ... 50.0 50.000000 66.666667 | |
[400 rows x 15 columns] | |
} | |
""" | |
# assume indicators = ['rsi', 'macd', 'willr', 'stoch'] fixed for now | |
for sym, df_ohlcv in market_data.items(): | |
df_ohlcv = pd.merge(df_ohlcv, rsi(df_ohlcv)) | |
df_ohlcv = pd.merge(df_ohlcv, macd(df_ohlcv)) | |
df_ohlcv = pd.merge(df_ohlcv, willr(df_ohlcv)) | |
df_ohlcv = pd.merge(df_ohlcv, stoch(df_ohlcv)) | |
market_data[sym] = df_ohlcv.dropna() | |
return market_data | |
@staticmethod | |
def min_max_scale_volume_columns(market_data: Dict[STR_SYMBOL, DF_MARKET_DATA], volume: str = 'volume') \ | |
-> Dict[STR_SYMBOL, DF_MARKET_DATA]: | |
""" Performs min-max scaling of volume column for all tables in market_data. | |
C.f https://en.wikipedia.org/wiki/Feature_scaling """ | |
for _, table in market_data.items(): | |
v = table[volume] | |
table[volume] = ((v - v.min()) / (v.max() - v.min())) | |
return market_data | |
@staticmethod | |
def drop_indices(market_data: Dict[STR_SYMBOL, DF_MARKET_DATA]): | |
for _, table in market_data.items(): | |
table.reset_index(drop=True, inplace=True) | |
@staticmethod | |
def drop_nas(market_data: Dict[STR_SYMBOL, pd.DataFrame]): | |
for symbol, table in market_data.items(): | |
market_data[symbol] = table.dropna() | |
return market_data | |
class _DefaultPreprocessorSDMI(DefaultPreprocessor): | |
def __init__(self, market_data: Dict[STR_SYMBOL, DF_MARKET_DATA], | |
trajectory_length: int, resolution: str, symbols): | |
super().__init__(market_data, trajectory_length, resolution, symbols) | |
def precompute_trajectory_states(self) -> Tuple[Dict[STR_SYMBOL, DF_MARKET_DATA], np.ndarray, Dict]: | |
dict_states, mapping = self._sample_states_for_current_trajectory() | |
np_states = self._standardize_and_concat(dict_states) | |
return dict_states, np_states, mapping | |
def _sample_states_for_current_trajectory(self) -> Tuple[Dict[STR_SYMBOL, DF_MARKET_DATA], Dict]: | |
""" Randomly pick trajectory index and sample from database table using index as row """ | |
states = dict() | |
# rand_tables = np.random.randint(0, len(self.symbols), size=self.n_instruments_tradeable) | |
rand_tables = np.random.choice(len(self.symbols), 2, replace=False) | |
for i, table_index in enumerate(rand_tables): # e,g, enumerate([3,8]) | |
""" | |
self.symbols = ['BTC/USDT', 'ADA/USDT', 'BCH/USDT', 'BNB/USDT', 'EOS/USDT', 'ETC/USDT', 'ETH/USDT', 'IOTA/USDT', | |
'LTC/USDT', 'NEO/USDT', 'QTUM/USDT', 'XLM/USDT', 'XRP/USDT'] | |
""" | |
symbol = self.symbols[table_index] # e.g. table_index = 3 => symbol = 'BNB/USDT' | |
""" | |
self.market_data = { | |
'BTC/USDT': timestamp date ... STOCH_K=14 STOCH_D=3 | |
0 1514899800000 2018-01-02 14:30:00 ... 87.876188 85.171782 | |
1 1514901600000 2018-01-02 15:00:00 ... 88.487763 87.249437 | |
2 1514903400000 2018-01-02 15:30:00 ... 79.756318 85.373423 | |
.. ... ... ... ... ... | |
418 1515657600000 2018-01-11 09:00:00 ... 67.167808 60.854727 | |
419 1515659400000 2018-01-11 09:30:00 ... 70.589146 65.762205 | |
420 1515661200000 2018-01-11 10:00:00 ... 73.594852 70.450602 | |
[n rows x 15 columns], | |
... | |
'ZRX/USDT': timestamp date openp ... %K STOCH_K=14 STOCH_D=3 | |
3 1524085200000 2018-04-18 23:00:00 0.26 ... 100.0 100.000000 100.000000 | |
4 1524087000000 2018-04-18 23:30:00 0.26 ... 100.0 100.000000 100.000000 | |
5 1524088800000 2018-04-19 00:00:00 0.26 ... 100.0 100.000000 100.000000 | |
.. ... ... ... ... ... ... ... | |
418 1524832200000 2018-04-27 14:30:00 0.30 ... 50.0 83.333333 94.444444 | |
419 1524834000000 2018-04-27 15:00:00 0.30 ... 50.0 66.666667 83.333333 | |
420 1524835800000 2018-04-27 15:30:00 0.30 ... 50.0 50.000000 66.666667 | |
[n rows x 15 columns] | |
} | |
len(self.market_data.keys()) == len(self.symbols) | |
""" | |
max_row = self.market_data[symbol].shape[0] | |
# assert symbol is not None | |
# assert max_row > 2 * self.horizon, f'row assertion failed for symbol {symbol}' | |
# pick last point of trajectory from which we normalize backwards, for online evaluation | |
traj_end = np.random.randint(self.trajectory_length, max_row) | |
first = traj_end - self.trajectory_length | |
states_symbol_i = self.market_data[symbol][first:traj_end].reset_index(drop=True) | |
# assert states_symbol_i is not None | |
states[symbol] = states_symbol_i | |
# update integer mapping to string symbol | |
self.int_instrument_to_str_key_symbol[i] = symbol | |
#assert len( | |
# states.keys()) == self.n_instruments_tradeable, f'states.keys() = {states.keys()} i={i}, ' \ | |
# f'table_index={table_index}, ' \ | |
# f'symbols={[self.symbols[idx] for idx in rand_tables]}' | |
return states, self.int_instrument_to_str_key_symbol | |
def _standardize_and_concat(self, market_data: Dict[STR_SYMBOL, DF_MARKET_DATA]) -> np.ndarray: | |
""" | |
Standardizes market data for each symbol. Merges all symbols data at timestamp. | |
So each row has p * len(Index) many columns. | |
Returns np.ndarray containing the merged data. The data is used to make observations. | |
Returns df.values where: | |
df = openp_0 highp_0 lowp_0 ... %K_p STOCH_K=14_p STOCH_D=3_p | |
0 0.977301 0.987676 0.973552 ... 20.335293 45.119614 48.331698 | |
1 0.983994 1.005749 0.982991 ... 28.957109 35.882915 45.362960 | |
2 0.997554 0.999055 0.986003 ... 26.758110 25.350171 35.450900 | |
... | |
45 0.983995 0.991893 0.982856 ... 62.150538 63.021597 51.938320 | |
46 0.986069 1.000729 0.985333 ... 77.311828 70.525687 61.085642 | |
47 1.000000 1.000722 0.992799 ... 87.598566 75.686977 69.744754 | |
[48 rows x 26 columns] | |
""" | |
# (1) remove time, date (2) normalize backwards in online fashion | |
states = pd.DataFrame | |
to_drop = ['timestamp', 'date'] | |
n_symbols = n_cols = 0 | |
for symbol, table in market_data.items(): | |
assert table.shape[0] == self.trajectory_length | |
# state = table.dropna() na already dropped in __init__ | |
# table.drop(to_drop, axis=1, inplace=True) | |
state = table.drop(to_drop, axis=1) | |
n_cols = len(state.columns) | |
n_symbols += 1 | |
# standardize prices | |
refp = table['openp'].iloc[-1] | |
state['openp'] /= refp | |
state['highp'] /= refp | |
state['lowp'] /= refp | |
state['closep'] /= refp | |
# append column suffix | |
state.columns = [str(col) + f'_{symbol}' for col in state.columns] | |
# drop index for right join without nans | |
state.reset_index(drop=True, inplace=True) | |
# join multiple symbol data to one state observation | |
if states.empty: | |
states = state | |
#print(f'states was empty now {states}') | |
else: | |
#print(f'concatenating states = {states} and state={state}') | |
states = pd.concat([states, state], axis=1) | |
#print(f'yielded {states}') | |
assert len(states.columns) == n_cols * n_symbols | |
# return the states: | |
return states.values | |
class _DefaultPreprocessorMDMI(DefaultPreprocessor): | |
def __init__(self, market_data, trajectory_length, resolution, n_timeframes): | |
super().__init__(market_data, trajectory_length, resolution) | |
self.n_timeframes = n_timeframes | |
""" | |
resolution | |
n_timeframes | |
trajectory length | |
""" | |
def precompute_trajectory_states(self): | |
data_samples = self._sample_market_data() | |
data_samples_std = self._standardize_samples(data_samples) | |
dict_states = self._make_states(data_samples_std) # rolling merge | |
# np_states = self._clean_numpy(dict_states) # drop cols and merge dict to df and return values | |
return None | |
def _sample_market_data(self): | |
""" Randomly pick trajectory index and sample from database table using index as first row """ | |
# trajectory ranges from start to trajectory_length + n_timeframes | |
# each observation will contain n_timeframes | |
# but normalization will occur over all (trajectory_length + n_timeframes) data | |
# not as in SDMI over each state | |
states = dict() | |
# rand_tables = np.random.randint(0, len(self.symbols), size=self.n_instruments_tradeable) | |
rand_tables = np.random.choice(len(self.symbols), 2, replace=False) | |
for i, table_index in enumerate(rand_tables): # e,g, enumerate([3,8]) | |
symbol = self.symbols[table_index] # e.g. table_index = 3 => symbol = 'BNB/USDT' | |
max_row = self.market_data[symbol].shape[0] | |
# pick last point of trajectory from which we normalize backwards, for online evaluation | |
# e.g. trajectory_length = 48 | |
# n_timeframes = 12 | |
# min_row = 60 because for each of the 48 states you go 12 times back | |
# so at state s_0 you will go back the last 12 timeframes | |
backward_offset = self.trajectory_length + self.n_timeframes | |
traj_end = np.random.randint(backward_offset, max_row) | |
first = traj_end - backward_offset | |
states_symbol_i = self.market_data[symbol][first:traj_end].reset_index(drop=True) | |
# assert states_symbol_i is not None | |
states[symbol] = states_symbol_i | |
# update integer mapping to string symbol | |
self.int_instrument_to_str_key_symbol[i] = symbol | |
return states, self.int_instrument_to_str_key_symbol | |
def _standardize_samples(self, data_samples): | |
for symbol, table in data_samples.items(): | |
# standardize prices | |
refp = table['openp'].iloc[-1] | |
table['openp'] /= refp | |
table['highp'] /= refp | |
table['lowp'] /= refp | |
table['closep'] /= refp | |
return self.n_timeframes | |
def _make_states(self, dict_samples): | |
# now we have sampled 60 rows per symbol and want to make rolling | |
return self.n_timeframes | |
class SDMIEnvironment(object): | |
""" | |
SD: SingleData refers to the number timeframes per observation. | |
MI: MultipleInstruments refers to the number p of instruments per observation. | |
This p is equal to the number of instruments that can be traded via the action space. | |
""" | |
def __init__(self, config, market_data=None, preprocessor=_DefaultPreprocessorSDMI): | |
r""" | |
Creates an SingleDataMultipleInstruments environment with the given configuration. | |
SingleData refers to the number timeframes per observation. | |
MultipleInstruments refers to the number p of instruments per observation. | |
This p is equal to the number of instruments that can be traded via the action space. | |
Args: | |
config: e.g. | |
{ | |
'databases': { | |
'symbols': ['BTC/USDT', 'ADA/USDT', 'BCH/USDT', 'BNB/USDT', 'EOS/USDT', 'ETC/USDT', 'ETH/USDT', 'IOTA/USDT', | |
'LTC/USDT', 'NEO/USDT', 'QTUM/USDT', 'XLM/USDT', 'XRP/USDT'], | |
'db_files': ['BTC_USDT.db', 'ADA_USDT.db', 'BCH_USDT.db', 'BNB_USDT.db', 'EOS_USDT.db', 'ETC_USDT.db', | |
'ETH_USDT.db', 'IOTA_USDT.db', 'LTC_USDT.db', 'NEO_USDT.db', 'QTUM_USDT.db', 'XLM_USDT.db','XRP_USDT.db'] | |
'paths': ['./database/binance/BTC_USDT.db', './database/binance/ADA_USDT.db', ....] | |
} | |
'resolution': '30m', # resolution of data in each state | |
'trajectory_length': 96, # states per trajectory | |
'symbols_per_state': 2, # symbols corresponding to instruments observed at each timestep | |
'buy_amts': [50, 200], # absolute | |
'sell_amts_percent': [1.], # percent | |
'taker_fee': 998.0 / 1000.0 # 0.2 percent, | |
'n_possible_buy_amts': 2 # len(CONFIG['buy_amts']) | |
'n_possible_sell_amts_percent': 1 # len(CONFIG['sell_amts_percent']) | |
} | |
market_data: see database.load_tables() docstring | |
Mapping p symbols to pd.DataFrame (storing their database table for given resolution), | |
where p = len(symbols): | |
{'BTC/USDT': timestamp date openp ... lowp closep volume | |
0 1514757600000 2017-12-31 23:00:00 13930.00 ... 13869.14 13898.98 130.5177 | |
1 1514759400000 2017-12-31 23:30:00 13899.00 ... 13782.00 13820.50 155.7665 | |
2 1514761200000 2018-01-01 00:00:00 13829.97 ... 13645.03 13724.00 146.3225 | |
3 1514763000000 2018-01-01 00:30:00 13721.05 ... 13651.00 13716.36 100.1306 | |
4 1514764800000 2018-01-01 01:00:00 13715.65 ... 13400.01 13521.12 221.7524 | |
.. ... ... ... ... ... ... ... | |
495 1515654000000 2018-01-11 08:00:00 13228.66 ... 13000.81 13520.01 759.8050 | |
496 1515655800000 2018-01-11 08:30:00 13525.78 ... 13364.28 13530.42 604.6083 | |
497 1515657600000 2018-01-11 09:00:00 13530.42 ... 13530.42 13699.98 523.9869 | |
498 1515659400000 2018-01-11 09:30:00 13699.98 ... 13548.01 13656.49 504.0101 | |
499 1515661200000 2018-01-11 10:00:00 13656.00 ... 13554.18 13660.38 427.2661 | |
[max_len rows x 7 columns], | |
... | |
'ZRX/USDT': timestamp date openp highp lowp closep volume | |
0 1523937600000 2018-04-17 06:00:00 0.26 0.29 0.26 0.26 5478159.15 | |
1 1523939400000 2018-04-17 06:30:00 0.26 0.27 0.26 0.27 2665534.08 | |
2 1523941200000 2018-04-17 07:00:00 0.27 0.28 0.27 0.27 3742666.10 | |
3 1523943000000 2018-04-17 07:30:00 0.27 0.27 0.26 0.26 4575257.51 | |
4 1523944800000 2018-04-17 08:00:00 0.26 0.26 0.25 0.25 5629959.80 | |
.. ... ... ... ... ... ... ... | |
495 1524828600000 2018-04-27 13:30:00 0.29 0.29 0.29 0.29 628279.53 | |
496 1524830400000 2018-04-27 14:00:00 0.29 0.30 0.29 0.30 1479472.82 | |
497 1524832200000 2018-04-27 14:30:00 0.30 0.31 0.30 0.30 5599690.80 | |
498 1524834000000 2018-04-27 15:00:00 0.30 0.30 0.29 0.30 4113774.34 | |
499 1524835800000 2018-04-27 15:30:00 0.30 0.30 0.29 0.30 1919637.17 | |
[max_len rows x 7 columns] | |
} len(self.market_data.keys()) == len(self.symbols) | |
""" | |
assert isinstance(config, dict), "Expected config to be of type dict." | |
# config | |
self.config = config | |
self.paths = config['databases']['paths'] | |
self.symbols = config['databases']['symbols'] | |
self.resolution = config['resolution'] | |
self.trajectory_length = config['trajectory_length'] | |
self.n_instruments_tradeable = config['symbols_per_state'] | |
# data pipeline | |
if market_data is None: | |
market_data = database.load_tables(self.symbols, self.paths, self.resolution) | |
# computes normalization and feature scaling | |
self.preprocessor = preprocessor(market_data, self.trajectory_length, self.resolution, self.symbols) | |
# stores state information to compute balances and credit info for tradeable instruments | |
self.credit = Credit(self.n_instruments_tradeable, config['taker_fee']) # maybe swap with account | |
# market data observations | |
self.dict_states = dict() # Dict[STR_SYMBOL, DF_MARKET_DATA] | |
self.np_states = None | |
# used to parse integer actions | |
self.action_spec = ActionParser(self.n_instruments_tradeable, | |
config['n_possible_buy_amts'], | |
config['n_possible_sell_amts'], | |
config['buy_amts'], | |
config['sell_amts_percent']) | |
self._max_actions = self.action_spec.max_actions | |
self.dict_instr_to_symbol = dict() # Dict[int, STR_SYMBOL], e.g. {0: 'BNB/USDT', 1: 'XLM/USDT'} | |
# utils | |
self.t_step = 0 | |
self.t_traj = 0 | |
self._np_obs_shape = (34,) # self._get_np_obs_shape() | |
def _get_np_obs_shape(self): | |
_, stub_states, _ = self.preprocessor.precompute_trajectory_states() | |
# print(stub_states.shape, len(self.credit.data)) | |
stub_np_obs = np.append(stub_states[0], self.credit.data) | |
# print(stub_np_obs.shape) | |
del stub_states | |
return stub_np_obs.shape | |
@property | |
def max_actions(self): | |
return self._max_actions | |
@property | |
def np_obs_shape(self): | |
return self._np_obs_shape | |
def _make_observation(self) -> Dict: | |
# market data + credit date for current timestep | |
obs = np.append(self.np_states[self.t_step], self.credit.data).astype(np.float32) | |
assert obs.shape == self.np_obs_shape, f'expected shape={self.np_obs_shape}, got {obs.shape}, ' \ | |
f'dict_states={self.dict_states}, np_states = {self.np_states}, t_step = {self.t_step}' | |
# legal moves mask | |
legal_actions_as_int = self._legal_actions_as_int() | |
# observation as returned by environment reset() and step() calls | |
# observation = {'state': obs, 'mask': legal_actions_as_int, 'info': None} | |
self.last_mask = legal_actions_as_int | |
observation = {'state': obs, 'mask': legal_actions_as_int} | |
return observation | |
def _legal_actions_as_int(self): | |
""" Mask is boolean, not logits """ | |
legal_moves_as_int = list() | |
mask = np.full(self.max_actions, np.finfo(np.float32).min) | |
# mask = np.full(self.max_actions, 0) | |
for action in range(self.max_actions): | |
action_type, target_instr, target_amount = self.action_spec.parse_int_action(action) | |
if self._action_is_legal(action_type, target_instr, target_amount): | |
legal_moves_as_int.append(action) | |
mask[legal_moves_as_int] = 0 | |
# mask[legal_moves_as_int] = 1 | |
return mask | |
def reset(self, config=None): | |
""" Starts MC Rollout from OHLC-DB """ | |
# reset credit and states | |
self.credit.reset() | |
# reset step and increment traj counter | |
self.t_step = 0 | |
self.t_traj += 1 | |
# generate all states in the trajectory a priori [we assume BUY,SELL dont effect the future states] | |
dict_states, np_states, mapping = self.preprocessor.precompute_trajectory_states() | |
# meta information of market_data for creating Trade objects | |
self.dict_states = dict_states | |
# numpy values of market_data as in observation | |
self.np_states = np_states | |
# mapping from random integer samples to symbols in self.symbols | |
self.dict_instr_to_symbol = mapping | |
# observation = market_data + credit_information | |
observation = self._make_observation() | |
self.t_step += 1 | |
return observation | |
def _action_is_legal(self, action_type, target_symbol, target_amt): | |
# buy only when credit available | |
if action_type == ActionType.BUY: | |
return self.credit.has_cash(target_amt) | |
if action_type == ActionType.SELL: | |
return self.credit.is_available(target_symbol) | |
return True | |
def _build_trade(self, action: Dict[str, int], refp: str = 'closep') -> Trade: | |
# n_instruments_per_observation = self.symbols_per_state | |
# n_instruments_tradeable = n_instruments_per_observation | |
# will bound action range | |
# {0: 'BNB/USDT', 1: 'XLM/USDT'} -> 'XLM/USDT' | |
assert action['type'] != ActionType.NOOP | |
symbol = self.dict_instr_to_symbol[action['target']] | |
market_data = self.dict_states[symbol].iloc[self.t_step, :] # pd.Series for current timestep | |
buy_amt = None if action['type'] == ActionType.SELL else action['amount'] / market_data[refp] | |
sell_percent = None if action['type'] == ActionType.BUY else action['amount'] | |
net_val = None if action['type'] == ActionType.SELL else action['amount'] | |
return Trade(t_step=self.t_step, | |
datestring=market_data['date'], | |
symbol=symbol, | |
symbol_idx=action['target'], | |
type=action['type'], | |
buy_amt=buy_amt, | |
sell_percent=sell_percent, | |
price=market_data[refp], | |
net_value=net_val, | |
info=market_data) | |
def step(self, action): | |
# reset at end of trajectory | |
if self.t_step > self.trajectory_length - 1: | |
return self.reset(self.config) | |
# parse integer action | |
action_type, target_instr, target_amount = self.action_spec.parse_int_action(action) | |
assert self._action_is_legal(action_type, target_instr, target_amount), f'action_type={action_type} \n ' \ | |
f'target_instr={target_instr} \n' \ | |
f'target_amount={target_amount} \n' \ | |
f'action={action}\n' \ | |
f'credit.data = {self.credit.data} \n' \ | |
f'last_mask = {self.last_mask} \n' \ | |
f'self.t_step = {self.t_step}' | |
reward = 0 | |
# update credit balances and time | |
if action_type != ActionType.NOOP: | |
action_dict = {'type': action_type, 'target': target_instr, 'amount': target_amount} | |
trade = self._build_trade(action_dict) | |
reward = self.credit.apply_trade_and_compute_reward(trade=trade) | |
# observation is gotten from self.states regardless of action taken | |
observation = self._make_observation() | |
self.t_step += 1 | |
done = False | |
info = None | |
if self.t_step == self.trajectory_length: | |
done = True | |
# print(f'' | |
# f'action = {action}, \n ' | |
# f'action_type ={action_type}, \n' | |
# f'target_instr = {target_instr}, \n' | |
# f'target_amount = {target_amount}, \n' | |
# f'observation = {observation} ') | |
return observation, reward, done, info | |
def close(self): | |
""" Deletes self reference so that this instance will be garbage collected """ | |
del self | |
return | |
class MDMIEnvironment(object): | |
""" | |
MD: MultipleData refers to the number timeframes per observation. | |
Increasing this, we can try to catch cross-correlation patterns with RNNs, e.g. by making | |
one observation consist of a whole day of 30 minute data. | |
MI: MultipleInstruments refers to the number p of instruments per observation. | |
This p is equal to the number of instruments that can be traded via the action space. | |
""" | |
def __init__(self): | |
raise NotImplementedError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment