Skip to content

Instantly share code, notes, and snippets.

@altescy
Last active April 17, 2024 06:12
Show Gist options
  • Save altescy/0f01eb3d3c02161cc6fedfdabdee56d5 to your computer and use it in GitHub Desktop.
Save altescy/0f01eb3d3c02161cc6fedfdabdee56d5 to your computer and use it in GitHub Desktop.
import inspect
import io
import sys
import traceback
import typing
from collections import defaultdict
from functools import wraps
from typing import (Any, Callable, Dict, List, Literal, Optional, TypeVar,
Union, cast)
if sys.version_info >= (3, 10):
from types import UnionType
else:
class UnionType:
...
_T = TypeVar("_T")
_T_Registrable = TypeVar("_T_Registrable", bound="Registrable")
_T_FromConfig = TypeVar("_T_FromConfig", bound="FromConfig")
class ConfigurationError(Exception):
pass
def _indent(s: str, level: int = 1) -> str:
tabs = "\t" * level
return tabs + s.replace("\n", f"\n{tabs}")
def _remove_optional(annotation: Any) -> Any:
origin = typing.get_origin(annotation)
args = typing.get_args(annotation)
if origin == Union and len(args) == 2 and args[1] is type(None):
return cast(type, args[0])
return annotation
def _construct_with_kwargs(
cls: type[_T], kwargs: dict[str, Any], key: tuple[str, ...]
) -> _T:
signature = inspect.signature(cls.__init__)
constructed_kwargs: dict[str, Any] = {}
unknows = set(kwargs.keys()) - set(signature.parameters.keys())
if issubclass(cls, Registrable):
unknows -= {"type"}
if unknows:
raise ConfigurationError(
f"[{'.'.join(key)}] Invalid keys in configuration: {unknows}"
)
for i, param in enumerate(signature.parameters.values()):
if i == 0:
continue # Assume the first parameter is `self`
subkey = key + (param.name,)
sub_param_name = ".".join(subkey)
subconfig = kwargs.get(param.name)
if param.default != param.empty and param.name not in kwargs:
continue
if param.default == param.empty and subconfig is None:
raise ConfigurationError(
f"Missing required parameter: {sub_param_name}",
)
constructed_kwargs[param.name] = _construct_from_config(
subconfig,
param.annotation,
subkey,
)
return cls(**constructed_kwargs)
def _construct_from_config(
config: Any,
annotation: Any,
key: tuple[str, ...],
) -> Any:
param_name = ".".join(key)
annotation = _remove_optional(annotation)
if annotation is None or annotation == Any:
return config
annotation_origin = typing.get_origin(annotation)
annotation_args = typing.get_args(annotation)
cls: type | None = None
if isinstance(annotation, type):
cls = annotation
elif isinstance(annotation_origin, type):
cls = annotation_origin
if annotation_origin in (list, List):
value_annotation = annotation_args[0] if annotation_args else None
return [
_construct_from_config(
value,
value_annotation,
key + (str(i),),
)
for i, value in enumerate(config)
]
if annotation_origin in (dict, Dict):
value_annotation = annotation_args[1] if annotation_args else None
return {
v: _construct_from_config(
v,
value_annotation,
key + (str(key),),
)
for k, v in config.items()
}
if annotation_origin == Literal:
if config not in annotation_args:
raise ConfigurationError(
f"[{param_name}] {config} is not a valid literal value."
)
return config
if annotation_origin in (Union, UnionType):
if not annotation_args:
return config
exceptions: list[tuple[Any, Exception, str]] = []
for subannotation in annotation_args:
try:
return _construct_from_config(config, subannotation, key)
except (
ValueError,
TypeError,
ConfigurationError,
AttributeError,
) as e:
with io.StringIO() as fp:
traceback.print_exc(file=fp)
tb = fp.getvalue()
exceptions.append((subannotation, e, tb))
continue
error_messages = [
f"[{param_name}] Trying to construct {annotation}"
f"with type {cls}:\n{e}\n{tb}"
for cls, e, tb in exceptions
]
raise ConfigurationError(
"\n\n"
+ "\n".join(_indent(msg) for msg in error_messages)
+ f"\n[{param_name}] Failed to construct object with"
+ f"type {annotation}"
)
if cls is None:
return config
if isinstance(config, (bool, int, float, str, type(None))):
if type(config) is not cls:
raise ConfigurationError(
f"[{param_name}] {config} is not an instance of {cls}"
)
return config
if not isinstance(config, dict):
if cls is not None:
raise ConfigurationError(
f"[{param_name}] Expected a dictionary, got {config} instead."
)
return config
if issubclass(cls, Registrable):
name = config.get("type")
if name is None:
raise ConfigurationError(
f"[{param_name}] Missing 'type' key in configuration."
)
cls = cls.by_name(name)
return _construct_with_kwargs(cls, config, key)
class Registrable:
_registry: dict[type["Registrable"], dict[str, Any]] = defaultdict(dict)
@classmethod
def register(
cls: type[_T_Registrable],
name: str,
) -> Callable[[type[_T_Registrable]], type[_T_Registrable]]:
@wraps(cls)
def decorator(
subclass: type[_T_Registrable],
) -> type[_T_Registrable]:
cls._registry[cls][name] = subclass
return subclass
return decorator
@classmethod
def by_name(
cls,
name: str,
) -> type[_T_FromConfig]:
return cls._registry[cls][name]
class FromConfig:
@classmethod
def from_config(
cls: type[_T_FromConfig],
config: dict[str, Any],
) -> _T_FromConfig:
return _construct_with_kwargs(cls, config, (cls.__name__,))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment