Last active
April 27, 2024 16:56
-
-
Save daaniam/69b13544eb35207afbbbfb67e6e58e1d to your computer and use it in GitHub Desktop.
SQLALchemy - Polymorphic association
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from sqlalchemy import ForeignKey, UniqueConstraint, select | |
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine | |
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, selectin_polymorphic | |
aengine = create_async_engine("sqlite+aiosqlite:///database2.db") | |
asession = async_sessionmaker(aengine) | |
class Base(DeclarativeBase): | |
pass | |
class Tenant(Base): | |
__tablename__ = "tenant" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
# Tenant specific | |
name: Mapped[str] = mapped_column(unique=True) | |
# Relationships | |
accounts: Mapped[list["Account"]] = relationship(back_populates="tenant") | |
class Account(Base): | |
__tablename__ = "accounts" | |
# __mapper_args__ = {"polymorphic_identity": "account"} | |
__table_args__ = (UniqueConstraint("tenant_id", "username"),) | |
id: Mapped[int] = mapped_column(primary_key=True) | |
# Account specific | |
username: Mapped[str] = mapped_column() | |
# ForeignKeys | |
profile_id: Mapped[int] = mapped_column(ForeignKey("profile_associations.id")) | |
tenant_id: Mapped[int] = mapped_column(ForeignKey("tenant.id", ondelete="CASCADE")) | |
# Relationships | |
profile = relationship("Profile", lazy="selectin") | |
tenant = relationship("Tenant", back_populates="accounts", lazy="selectin") | |
class Profile(Base): | |
__tablename__ = "profile_associations" | |
__mapper_args__ = {"polymorphic_on": "profile_type"} | |
id: Mapped[int] = mapped_column(primary_key=True) | |
# Discriminator | |
profile_type: Mapped[str] = mapped_column() | |
class Staff(Profile): | |
__tablename__ = "staff" | |
__mapper_args__ = {"polymorphic_identity": "staff", "polymorphic_load": "selectin"} | |
# ForeignKeys | |
id: Mapped[int] = mapped_column(ForeignKey("profile_associations.id"), primary_key=True) | |
# Staff specific | |
staff_name: Mapped[str] = mapped_column() | |
class Teacher(Profile): | |
__tablename__ = "teachers" | |
__mapper_args__ = {"polymorphic_identity": "teacher", "polymorphic_load": "selectin"} | |
# ForeignKeys | |
id: Mapped[int] = mapped_column(ForeignKey("profile_associations.id"), primary_key=True) | |
# Teacher specific | |
teacher_name: Mapped[str] = mapped_column() | |
class Student(Profile): | |
__tablename__ = "students" | |
__mapper_args__ = {"polymorphic_identity": "student", "polymorphic_load": "selectin"} | |
# ForeignKeys | |
id: Mapped[int] = mapped_column(ForeignKey("profile_associations.id"), primary_key=True) | |
# Student specific | |
student_name: Mapped[str] = mapped_column() | |
async def create_tables(): | |
async with aengine.begin() as conn: | |
await conn.run_sync(Base.metadata.create_all) | |
async def drop_tables(): | |
async with aengine.begin() as conn: | |
await conn.run_sync(Base.metadata.drop_all) | |
async def main(): | |
await drop_tables() | |
await create_tables() | |
s = asession() | |
new_tenant = Tenant(name="tenant1") | |
# New profiles | |
new_teacher = Teacher(teacher_name="teachername1") | |
new_staff = Staff(staff_name="staff1") | |
# New accounts | |
new_account1 = Account(username="username1") | |
new_account2 = Account(username="username2") | |
# Add Accounts to Tenant | |
new_tenant.accounts.append(new_account1) | |
new_tenant.accounts.append(new_account2) | |
# Add Profiles to Accounts | |
new_account1.profile = new_teacher | |
new_account2.profile = new_staff | |
s.add_all([new_tenant, new_account1, new_teacher]) | |
await s.commit() | |
print("Teacher account:") | |
teacher_account = await s.scalar(select(Account).where(Account.username == "username1")) | |
print(teacher_account.__dict__) | |
print(teacher_account.profile.teacher_name) | |
print(teacher_account.profile.profile_type) | |
print("Staff account:") | |
staff_account = await s.scalar(select(Account).where(Account.username == "username2")) | |
print(staff_account.__dict__) | |
print(staff_account.profile.staff_name) | |
print(staff_account.profile.profile_type) | |
await s.close() | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment