Skip to content

Instantly share code, notes, and snippets.

@bsweger
Created March 14, 2024 22:21
Show Gist options
  • Save bsweger/f2da4cf8a0e1adb45fca7cb5019fb807 to your computer and use it in GitHub Desktop.
Save bsweger/f2da4cf8a0e1adb45fca7cb5019fb807 to your computer and use it in GitHub Desktop.
generate variant hub sample output data
from itertools import product
import pandas as pd
import numpy as np
def make_sample(
n_samples: int = 2,
n_horizons: int = 3,
n_variants: int = 3,
n_locations: int = 2,
samples_joint_across: list[str] = None
) -> pd.DataFrame:
samples_joint_across = [] if not samples_joint_across else samples_joint_across
if len(samples_joint_across) > 3:
raise ValueError('Too many samples_joint_across!')
if len(samples_joint_across) != len(set(samples_joint_across)):
raise ValueError('Duplicate samples_joint_across!')
for item in samples_joint_across:
if item not in ['horizon', 'variant', 'location']:
raise ValueError('Sample joint across must be horizon, variant, or location')
# assume days for this exercise
temp_scale = 'd'
samples = np.arange(n_samples).astype('str')
horizons = np.arange(-1, n_horizons-1).astype('str')
variants = np.arange(0, n_variants).astype('str')
locations = np.arange(0, n_locations).astype('str')
df = pd.DataFrame.from_records(
list(product(
samples,
horizons,
locations,
variants,
)),
columns = ['sample', 'horizon', 'location', 'variant'],
)
if len(samples_joint_across) == 0:
df.sort_values(['sample', 'horizon', 'location', 'variant'], inplace=True)
df['traj'] = range(1, len(df) + 1)
else:
df.sort_values(['sample'] + samples_joint_across, inplace=True)
df['traj'] = df.groupby(samples_joint_across).cumcount()
# Generate the output_type_id
df['output_type_id'] = 'S' + df['traj'].astype('str')
for item in samples_joint_across:
df['output_type_id'] = df['output_type_id'] + '_' + item.upper()[0] + df[item]
# Add the constant columns
df['output_type'] = 'sample'
df['nowcast_date'] = pd.to_datetime('2024-01-26')
df['horizon'] = df['horizon'].astype('int') # horizon needs to be int fot date math
df['target_date'] = df['nowcast_date'] + pd.to_timedelta(df['horizon'], temp_scale)
# Deal with values later
# df['value'] = np.random.standard_normal(df.shape[0]) # current state 180
return df
df_none = make_sample()
print(df_none.shape)
print(df_none.sort_values(['sample', 'traj']))
print()
df_none.to_csv('sja_none.csv', index=False)
sja = ['horizon']
df_h = make_sample(samples_joint_across=sja)
print(sja)
print(df_h.shape)
print(df_h.sort_values(['sample', 'traj']))
print()
df_h.to_csv('sja_horizon.csv', index=False)
sja = ['horizon', 'variant']
df_hv = make_sample(samples_joint_across=sja)
print(sja)
print(df_hv.shape)
print(df_hv.sort_values(['sample', 'traj']))
print()
df_hv.to_csv('sja_horizon_variant.csv', index=False)
sja = ['horizon', 'variant', 'location']
df_hvl = make_sample(samples_joint_across=sja)
print(sja)
print(df_hvl.shape)
print(df_hvl.sort_values(['sample', 'traj']))
print()
df_hvl.to_csv('sja_horizon_variant_location.csv', index=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment