Created
September 17, 2020 22:32
-
-
Save omry/4f970ef7041732ba923d03f9fff33757 to your computer and use it in GitHub Desktop.
Recursive instantiation usage prototype
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
from dataclasses import dataclass, field | |
from typing import Any, List | |
from omegaconf import MISSING, II, OmegaConf | |
import hydra | |
from hydra.core.config_store import ConfigStore | |
from hydra.utils import instantiate | |
# library code | |
class Optimizer: | |
def __init__(self, lr: float) -> None: | |
self.lr = lr | |
# demoing a case of two different optimizer implementations | |
class Adam(Optimizer): | |
def __init__(self, lr: float, beta: float) -> None: | |
self.lr = lr | |
self.beta = float | |
class SGD(Optimizer): | |
def __init__(self, lr: float) -> None: | |
self.lr = lr | |
class Dataset: | |
def __init__(self, path: str, batch_size: int) -> None: | |
self.path = path | |
self.batch_size = batch_size | |
class Trainer: | |
def __init__( | |
self, | |
optimizer: Optimizer, | |
dataset: Dataset, | |
batch_size: int, | |
) -> None: | |
# currently those are config objects, not the real objects. | |
# recursive instantiation will fix it | |
print("provided optimizer :", optimizer) | |
print("provided dataset :", dataset) | |
self.optimizer = optimizer | |
self.dataset = dataset | |
self.batch_size = batch_size | |
# config hierarchy (will be possible to code-gen this in the future) | |
# Since recursive instantiating is not really supported now I can't commit. | |
@dataclass | |
class OptimizerConf: | |
_target_: str = "my_app.Optimizer" | |
lr: float = MISSING | |
@dataclass | |
class AdamConf(OptimizerConf): | |
_target_: str = "my_app.Adam" | |
lr: float = MISSING | |
beta: float = MISSING | |
@dataclass | |
class SGDConf(OptimizerConf): | |
_target_: str = "my_app.SGD" | |
lr: float = MISSING | |
@dataclass | |
class DatasetConf: | |
_target_: str = "my_app.Dataset" | |
path: str = MISSING | |
batch_size: int = MISSING | |
@dataclass | |
class TrainerConf: | |
_target_: str = "my_app.Trainer" | |
batch_size: int = MISSING | |
# not populated, we will choose the right optimizer with config composition | |
optimizer: OptimizerConf = MISSING | |
# if there is only one option we can just inline it here. | |
dataset: DatasetConf = DatasetConf() | |
@dataclass | |
class Config: | |
trainer: TrainerConf = TrainerConf() | |
defaults: List[Any] = field( | |
default_factory=lambda: [ | |
# by default, compose adam | |
{"optimizer": "adam"}, | |
# populate the rest from user_config, as an example this will be a yaml | |
"user_config", | |
] | |
) | |
cs = ConfigStore.instance() | |
cs.store(name="config", node=Config) | |
cs.store(group="optimizer", name="adam", node=AdamConf, package="trainer.optimizer") | |
cs.store(group="optimizer", name="sgd", node=SGDConf, package="trainer.optimizer") | |
@hydra.main(config_name="config") | |
def my_app(cfg: Config) -> None: | |
print(OmegaConf.to_yaml(cfg)) | |
# once recursive instantiation will be supported, optimizer and dataset would be the actual objects | |
# currently they are the config node | |
trainer = instantiate(cfg.trainer) | |
if __name__ == "__main__": | |
my_app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# @package _global_ | |
trainer: | |
batch_size: 32 | |
# in a more interesting scenario we will also compose the optimizer from multiple yaml files | |
optimizer: | |
lr: 0.1 | |
beta: 0.9 | |
dataset: | |
path: /foo/bar | |
batch_size: ${trainer.batch_size} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example output: