Created
January 11, 2022 03:47
-
-
Save ThirVondukr/48a7a19871497ff7d7bccb282895d711 to your computer and use it in GitHub Desktop.
Querying relationship count using SQLAlchemy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import Table, Integer, Column, ForeignKey, String, create_engine, select, func | |
from sqlalchemy.orm import declarative_base, relationship, sessionmaker, column_property, undefer | |
Base = declarative_base() | |
engine = create_engine("sqlite://", echo=True) | |
Session = sessionmaker(bind=engine, future=True) | |
article_tags = Table( | |
"article__tags", | |
Base.metadata, | |
Column( | |
"tag_id", | |
Integer, | |
ForeignKey("tags.id"), | |
index=True, | |
primary_key=True, | |
), | |
Column( | |
"article_id", | |
Integer, | |
ForeignKey("articles.id"), | |
primary_key=True, | |
), | |
) | |
class Tag(Base): | |
__tablename__ = "tags" | |
id = Column(Integer, primary_key=True) | |
name = Column(String(255), nullable=False, unique=True) | |
articles = relationship("Article", secondary=article_tags, back_populates="tags") | |
articles_count = column_property( | |
select(func.count(article_tags.c.tag_id)).filter(article_tags.c.tag_id == id).scalar_subquery(), | |
deferred=True, | |
) | |
class Article(Base): | |
__tablename__ = "articles" | |
id = Column(Integer, primary_key=True) | |
title = Column(String(length=255), nullable=False) | |
tags = relationship("Tag", secondary=article_tags, back_populates="articles") | |
def main(): | |
Base.metadata.create_all(bind=engine) | |
tag_py = Tag(name="Python") | |
tag_db = Tag(name="Databases") | |
with Session.begin() as session: | |
session.add(Article(title="Article 1", tags=[tag_py])) | |
session.add(Article(title="Article 2", tags=[tag_py])) | |
session.add(Article(title="Article 3", tags=[tag_py])) | |
session.add(Article(title="Article 4", tags=[tag_py, tag_db])) | |
session.add(Article(title="Article 5", tags=[tag_py, tag_db])) | |
# Using column_property | |
print(select(Tag).options(undefer(Tag.articles_count))) | |
# Making subquery manually (Exactly same the sql as when using column property) | |
# But articles_count would not be present on Tag model | |
subquery = select(func.count(article_tags.c.tag_id)).filter(Tag.id == article_tags.c.tag_id).scalar_subquery() | |
query = select(subquery, Tag) | |
print(query) | |
# Using GroupBy | |
query = select(article_tags.c.tag_id, func.count(article_tags.c.article_id)).group_by(article_tags.c.tag_id) | |
for tag_id, articles_count in session.execute(query).all(): | |
print(tag_id, articles_count) | |
# Prints | |
# 1 5 | |
# 2 2 | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment