Created
May 8, 2023 13:05
-
-
Save GrovesD2/6f821b3dea1d0c66e14ef2cbb66d82a0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import yfinance as yf | |
from dtaidistance import dtw | |
from plotly import graph_objects as go | |
from plotly.subplots import make_subplots | |
import plotly.io as pio | |
pio.renderers.default='svg' | |
# Tickers to get the rolling score for, must be a list of length 2 | |
TICKERS = ['AAPL', 'MSFT'] | |
ROLLING_SCORE_LENGTH = 30 | |
DAYS_TO_PLOT = 100 | |
def get_stock_data(ticker: str) -> pd.DataFrame: | |
''' | |
Download the stock data for a specified ticker symbol and return a DataFrame | |
containing the date and closing price for that ticker. | |
Parameters | |
---------- | |
ticker : str | |
The ticker symbol to download data for. | |
Returns | |
------- | |
pd.DataFrame | |
A DataFrame containing the date and closing price for the specified ticker. | |
''' | |
df = yf.download(ticker).reset_index() | |
return df[['Date', 'Close']].rename(columns={'Close': ticker}) | |
def z_scale(ts: np.array) -> np.array: | |
''' | |
Normalize a time series using z-score normalization. | |
Parameters | |
---------- | |
ts : np.array | |
The time series to normalize. | |
Returns | |
------- | |
np.array | |
The normalized time series. | |
''' | |
return (ts - np.mean(ts))/np.std(ts) | |
def rolling_dtw_score(ts1: np.array, ts2: np.array, length: int) -> np.array: | |
''' | |
Compute the rolling DTW score between two time series. | |
Parameters | |
---------- | |
ts1 : np.array | |
The first time series. | |
ts2 : np.array | |
The second time series. | |
length : int | |
The length of the rolling window for computing DTW scores. | |
Returns | |
------- | |
np.array | |
An array of DTW scores for each date in the time series. | |
''' | |
scores = np.full(ts1.shape, np.nan) | |
for n in range(length, ts1.shape[0]): | |
scores[n] = dtw.distance_fast( | |
z_scale(ts1[n-length:n]), | |
z_scale(ts2[n-length:n]), | |
) | |
return scores | |
if __name__ == '__main__': | |
# Download and merge stock data for specified ticker symbols | |
df = pd.merge( | |
get_stock_data(TICKERS[0]), | |
get_stock_data(TICKERS[1]), | |
on='Date', | |
how='inner', | |
) | |
# Calculate rolling DTW score between the two time series | |
df.loc[:, 'dtw_score'] = rolling_dtw_score( | |
ts1=df[TICKERS[0]].values, | |
ts2=df[TICKERS[1]].values, | |
length=ROLLING_SCORE_LENGTH, | |
) | |
# Select most recent days of data to plot | |
df = df[-DAYS_TO_PLOT:] | |
# Create the chart | |
fig = make_subplots(rows=3, cols=1, shared_xaxes=True) | |
fig.add_trace( | |
go.Scatter( | |
x=df['Date'], | |
y=df[TICKERS[0]], | |
name=TICKERS[0] | |
), | |
row=1, col=1, | |
) | |
# Add trace for ticker 2 to top subplot | |
fig.add_trace( | |
go.Scatter( | |
x=df['Date'], | |
y=df[TICKERS[1]], | |
name=TICKERS[1], | |
), | |
row=2, col=1, | |
) | |
# Add trace for rolling DTW score to bottom subplot | |
fig.add_trace( | |
go.Scatter( | |
x=df['Date'], | |
y=df['dtw_score'], | |
name='Rolling DTW Score', | |
), | |
row=3, col=1 | |
) | |
# Set layout for subplot chart | |
fig.update_layout( | |
height=800, | |
width=800, | |
legend = {'x': 0, 'y': -0.1, 'orientation': 'h'}, | |
margin = {'l': 50, 'r': 50, 'b': 50, 't': 50}, | |
yaxis1={'title': TICKERS[0]}, | |
yaxis2={'title': TICKERS[1]}, | |
yaxis3={'title': 'Dynamic Time Warping Score'}, | |
xaxis3={ | |
'tickformat': '%Y-%m-%d', # Add a date format | |
'tickangle': 20, # Rotate the tick labels by 45 degrees | |
'title': 'Date' | |
}, | |
) | |
# Show subplot chart | |
fig.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment