Skip to content

Instantly share code, notes, and snippets.

@lmmx
Last active August 16, 2024 11:04
Show Gist options
  • Save lmmx/546b1b6ffa8696976941ad0889762af2 to your computer and use it in GitHub Desktop.
Save lmmx/546b1b6ffa8696976941ad0889762af2 to your computer and use it in GitHub Desktop.
Controlled data access: Router specifies named Routes (paths that the data is accessed at), a model BeforeValidator on a Routee is populated with those paths, or a ValueError if the path was missing/partial, or a Route in Router was extra (not modelled in the Routee). The generic type argument lets Routee 'mirror' the interface of Router
from pathlib import Path
from pydantic import BaseModel, BeforeValidator, TypeAdapter, model_validator
from typing import Annotated, Union, Any, Generic, TypeVar
RoutePart = Union[int, str]
rp_ta = TypeAdapter(RoutePart)
def split_route(route: str) -> list[RoutePart]:
match route:
case str():
return list(map(rp_ta.validate_strings, route.split(".")))
case list():
return route
case _:
raise ValueError(f"Invalid route: {route}")
Route = Annotated[list[RoutePart], BeforeValidator(split_route)]
class Router(BaseModel, validate_default=True):
field_1: Route = "a.aa.aaa.0"
field_2: Route = "a.aa.aaa.1"
field_3: Route = "b.bb.2"
R = TypeVar("R", bound=Router)
def extract_subpath(path: Route, data: dict) -> Any:
"""Extract a subpath or else an error reporter fallback indicating where it failed."""
for part_idx, part in enumerate(path):
reporter = ValueError(f"Missing {part=} on {path}")
match part:
case str() as key:
data = data.get(key, reporter)
case int() as idx:
data = (
data[idx]
if isinstance(data, list) and 0 <= idx < len(data)
else reporter
)
if data is reporter:
break
return data
class Routed(BaseModel, Generic[R], extra="forbid"):
@model_validator(mode="before")
def supply(cls, data):
values = {}
(route_model,) = cls.__pydantic_generic_metadata__["args"]
router = route_model()
values = {
field: extract_subpath(path=route, data=data)
for field, route in router.model_dump().items()
}
for unrouted_field in set(cls.model_fields) - set(router.model_fields):
reporter = ValueError(f"No route for {unrouted_field}")
values.update({unrouted_field: reporter})
return values
class Routee(Routed[R]):
field_1: str
field_2: str
field_3: int
json_str = Path("data.json").read_text()
result = Routee[Router].model_validate_json(json_str)
print(result.model_dump())
@lmmx
Copy link
Author

lmmx commented Aug 16, 2024

data.json:

{
  "a": {
    "aa": {
      "aaa": [
        "value_1",
        "value_2"
      ]
    }
  },
  "b": {
    "bb": [
      3,
      4,
      5
    ]
  }
}

python hint_router.py:

{'field_1': 'value_1', 'field_2': 'value_2', 'field_3': 5}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment