Skip to content

Instantly share code, notes, and snippets.

@Taytay
Created June 20, 2023 12:35
Show Gist options
  • Save Taytay/3afaea621f6fdc0f420357f0caabd1f6 to your computer and use it in GitHub Desktop.
Save Taytay/3afaea621f6fdc0f420357f0caabd1f6 to your computer and use it in GitHub Desktop.
params:
# Default params file:
# This actually gets overwritten by hydra at runtime
- params.yaml
stages:
# Note that we set the DVC_STAGE_NAME so it knows it's running in dvc.
train:
cmd:
- DVC_STAGE_NAME=train python scripts/fine_tune/fine_tune_and_evaluate.py --hoist_params_path="trainer_task.train"
params:
- trainer_task.train
deps:
- scripts/fine_tune/fine_tune_and_evaluate.py
evaluate:
cmd:
- DVC_STAGE_NAME=evaluate python scripts/fine_tune/fine_tune_and_evaluate.py --hoist_params_path="trainer_task.eval"
params:
- trainer_task.eval
deps:
- scripts/fine_tune/fine_tune_and_evaluate.py
def get_dvc_stage_name():
# Get this from the DVC_STAGE_NAME environment variable:
dvc_stage_name = os.getenv("DVC_STAGE_NAME")
# if it's empty, make it None:
if dvc_stage_name is not None and dvc_stage_name.strip() == "":
dvc_stage_name = None
return dvc_stage_name
def hoist_params(params_dictionary, hoist_params_path):
hoist_params_path_parts = hoist_params_path.split(".")
for hoist_params_path_part in hoist_params_path_parts:
params_dictionary = params_dictionary[hoist_params_path_part]
if params_dictionary is None:
# Print warning:
logger.warning(
f"hoist_params_path = {hoist_params_path} is invalid. It's None at {hoist_params_path_part}"
)
break
return params_dictionary
def get_dvc_params_if_in_dvc():
dvc_stage_name = get_dvc_stage_name()
dvc_params = None
if dvc_stage_name is not None:
dvc_params = dvc.api.params_show(stages=dvc_stage_name)
logger.info(f"DVC Stage: '{dvc_stage_name}'")
parser = argparse.ArgumentParser()
parser.add_argument("--hoist_params_path", nargs=1, required=False)
args = parser.parse_args()
if args.hoist_params_path is not None:
hoist_params_path = args.hoist_params_path[0]
logger.info(f"Hosting parameters to top level from: {hoist_params_path}")
dvc_params = hoist_params(dvc_params, hoist_params_path)
logger.info(
f"DVC Params being used by this script: {json.dumps(dvc_params, indent=4)}"
)
else:
logger.info("Not being run by DVC. (No 'DVC_STAGE_NAME' env variable set) ")
return dvc_params
# This will be set to the
dvc_params_dict = get_dvc_params_if_in_dvc()
# This is generated by dvc by reading the hydra config at runtime
trainer_task:
train:
do_train: true
do_eval: false
eval:
do_train: false
do_eval: true
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment