-
-
Save aodag/b9208764b181938f491bf11054f47ea6 to your computer and use it in GitHub Desktop.
dataclasses と sqlalchemy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
from typing import Type, TypeVar, Generic | |
from sqlalchemy import ( | |
create_engine, | |
MetaData, | |
Integer, | |
Unicode, | |
DateTime, | |
Column, | |
Table, | |
) | |
from sqlalchemy.orm import mapper, sessionmaker, scoped_session | |
import dataclasses | |
T = TypeVar("T") | |
col_types = {int: Integer, str: Unicode, datetime: DateTime} | |
def table_name(cls: Type): | |
return cls.__name__.lower() | |
def table(meta: MetaData, cls: Type): | |
return Table(table_name(cls), meta, *cols(cls)) | |
def cols(cls): | |
return [ | |
column(f.name, f.type, **f.metadata.get("sqlalchemy", {})) | |
for f in dataclasses.fields(cls) | |
] | |
def column(name: str, t, **kwargs): | |
col_type = col_types[t] | |
return Column(name, col_type, **kwargs) | |
def register(meta, cls): | |
t = table(meta, cls) | |
for c in t.columns: | |
print(c) | |
return mapper(cls, t) | |
def f(primary_key=False, **kw): | |
metadata = kw.get("metadata", {}) | |
metadata["sqlalchemy"] = {"primary_key": primary_key} | |
return dataclasses.field(metadata=metadata, **kw) | |
class Query(Generic[T]): | |
def __init__(self, t: Type[T], session): | |
self.t = t | |
self.s = session | |
def get(self, id: int) -> T: | |
return self.s.query(Person).one() | |
class DB: | |
def __init__(self, engine): | |
self.engine = engine | |
self.metadata = MetaData() | |
Session = scoped_session(sessionmaker(bind=engine)) | |
self.session = Session | |
def create_all(self): | |
self.metadata.create_all(bind=engine) | |
def register(self, cls): | |
register(self.metadata, cls) | |
def query(self, cls: Type[T]) -> Query[T]: | |
return Query(cls, self.session) | |
engine = create_engine("sqlite:///") | |
engine.echo = True | |
db = DB(engine) | |
@dataclasses.dataclass | |
class Person: | |
id: int = f(primary_key=True) | |
first: str = f() | |
last: str = f() | |
age: int = f() | |
db.register(Person) | |
db.create_all() | |
p = Person(1, "a", "b", 30) | |
db.session.add(p) | |
db.session.flush() | |
pp = db.query(Person).get(1) | |
pp.first = "a" | |
assert pp == p |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment