Last active
April 17, 2024 06:12
-
-
Save altescy/0f01eb3d3c02161cc6fedfdabdee56d5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
import io | |
import sys | |
import traceback | |
import typing | |
from collections import defaultdict | |
from functools import wraps | |
from typing import (Any, Callable, Dict, List, Literal, Optional, TypeVar, | |
Union, cast) | |
if sys.version_info >= (3, 10): | |
from types import UnionType | |
else: | |
class UnionType: | |
... | |
_T = TypeVar("_T") | |
_T_Registrable = TypeVar("_T_Registrable", bound="Registrable") | |
_T_FromConfig = TypeVar("_T_FromConfig", bound="FromConfig") | |
class ConfigurationError(Exception): | |
pass | |
def _indent(s: str, level: int = 1) -> str: | |
tabs = "\t" * level | |
return tabs + s.replace("\n", f"\n{tabs}") | |
def _remove_optional(annotation: Any) -> Any: | |
origin = typing.get_origin(annotation) | |
args = typing.get_args(annotation) | |
if origin == Union and len(args) == 2 and args[1] is type(None): | |
return cast(type, args[0]) | |
return annotation | |
def _construct_with_kwargs( | |
cls: type[_T], kwargs: dict[str, Any], key: tuple[str, ...] | |
) -> _T: | |
signature = inspect.signature(cls.__init__) | |
constructed_kwargs: dict[str, Any] = {} | |
unknows = set(kwargs.keys()) - set(signature.parameters.keys()) | |
if issubclass(cls, Registrable): | |
unknows -= {"type"} | |
if unknows: | |
raise ConfigurationError( | |
f"[{'.'.join(key)}] Invalid keys in configuration: {unknows}" | |
) | |
for i, param in enumerate(signature.parameters.values()): | |
if i == 0: | |
continue # Assume the first parameter is `self` | |
subkey = key + (param.name,) | |
sub_param_name = ".".join(subkey) | |
subconfig = kwargs.get(param.name) | |
if param.default != param.empty and param.name not in kwargs: | |
continue | |
if param.default == param.empty and subconfig is None: | |
raise ConfigurationError( | |
f"Missing required parameter: {sub_param_name}", | |
) | |
constructed_kwargs[param.name] = _construct_from_config( | |
subconfig, | |
param.annotation, | |
subkey, | |
) | |
return cls(**constructed_kwargs) | |
def _construct_from_config( | |
config: Any, | |
annotation: Any, | |
key: tuple[str, ...], | |
) -> Any: | |
param_name = ".".join(key) | |
annotation = _remove_optional(annotation) | |
if annotation is None or annotation == Any: | |
return config | |
annotation_origin = typing.get_origin(annotation) | |
annotation_args = typing.get_args(annotation) | |
cls: type | None = None | |
if isinstance(annotation, type): | |
cls = annotation | |
elif isinstance(annotation_origin, type): | |
cls = annotation_origin | |
if annotation_origin in (list, List): | |
value_annotation = annotation_args[0] if annotation_args else None | |
return [ | |
_construct_from_config( | |
value, | |
value_annotation, | |
key + (str(i),), | |
) | |
for i, value in enumerate(config) | |
] | |
if annotation_origin in (dict, Dict): | |
value_annotation = annotation_args[1] if annotation_args else None | |
return { | |
v: _construct_from_config( | |
v, | |
value_annotation, | |
key + (str(key),), | |
) | |
for k, v in config.items() | |
} | |
if annotation_origin == Literal: | |
if config not in annotation_args: | |
raise ConfigurationError( | |
f"[{param_name}] {config} is not a valid literal value." | |
) | |
return config | |
if annotation_origin in (Union, UnionType): | |
if not annotation_args: | |
return config | |
exceptions: list[tuple[Any, Exception, str]] = [] | |
for subannotation in annotation_args: | |
try: | |
return _construct_from_config(config, subannotation, key) | |
except ( | |
ValueError, | |
TypeError, | |
ConfigurationError, | |
AttributeError, | |
) as e: | |
with io.StringIO() as fp: | |
traceback.print_exc(file=fp) | |
tb = fp.getvalue() | |
exceptions.append((subannotation, e, tb)) | |
continue | |
error_messages = [ | |
f"[{param_name}] Trying to construct {annotation}" | |
f"with type {cls}:\n{e}\n{tb}" | |
for cls, e, tb in exceptions | |
] | |
raise ConfigurationError( | |
"\n\n" | |
+ "\n".join(_indent(msg) for msg in error_messages) | |
+ f"\n[{param_name}] Failed to construct object with" | |
+ f"type {annotation}" | |
) | |
if cls is None: | |
return config | |
if isinstance(config, (bool, int, float, str, type(None))): | |
if type(config) is not cls: | |
raise ConfigurationError( | |
f"[{param_name}] {config} is not an instance of {cls}" | |
) | |
return config | |
if not isinstance(config, dict): | |
if cls is not None: | |
raise ConfigurationError( | |
f"[{param_name}] Expected a dictionary, got {config} instead." | |
) | |
return config | |
if issubclass(cls, Registrable): | |
name = config.get("type") | |
if name is None: | |
raise ConfigurationError( | |
f"[{param_name}] Missing 'type' key in configuration." | |
) | |
cls = cls.by_name(name) | |
return _construct_with_kwargs(cls, config, key) | |
class Registrable: | |
_registry: dict[type["Registrable"], dict[str, Any]] = defaultdict(dict) | |
@classmethod | |
def register( | |
cls: type[_T_Registrable], | |
name: str, | |
) -> Callable[[type[_T_Registrable]], type[_T_Registrable]]: | |
@wraps(cls) | |
def decorator( | |
subclass: type[_T_Registrable], | |
) -> type[_T_Registrable]: | |
cls._registry[cls][name] = subclass | |
return subclass | |
return decorator | |
@classmethod | |
def by_name( | |
cls, | |
name: str, | |
) -> type[_T_FromConfig]: | |
return cls._registry[cls][name] | |
class FromConfig: | |
@classmethod | |
def from_config( | |
cls: type[_T_FromConfig], | |
config: dict[str, Any], | |
) -> _T_FromConfig: | |
return _construct_with_kwargs(cls, config, (cls.__name__,)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment