Skip to content

Instantly share code, notes, and snippets.

@DrJackilD
Created October 25, 2020 17:58
Show Gist options
  • Save DrJackilD/9bd21869abc65469a0a946f9ad6e5a45 to your computer and use it in GitHub Desktop.
Save DrJackilD/9bd21869abc65469a0a946f9ad6e5a45 to your computer and use it in GitHub Desktop.
"""
Just a custom function, which allow you to use ARRAY_AGG aggregation function in TortoiseORM
"""
from typing import TYPE_CHECKING, Any, Optional, Type, cast, Iterable, Tuple
from pypika import Table
from pypika.functions import Function as PypikaFunction
from tortoise.exceptions import ConfigurationError
from tortoise.fields.relational import BackwardFKRelation, ForeignKeyFieldInstance, RelationalField
from tortoise.functions import Function
if TYPE_CHECKING: # pragma: nocoverage
from tortoise.models import Model
from tortoise.fields.base import Field
class ArrayAggTerm(PypikaFunction):
def __init__(self, *fields, alias=None):
super().__init__("ARRAY_AGG", tuple(*fields), alias=alias)
class ArrayAgg(Function):
"""
Function/Aggregate base.
:param field: Field name
:param default_values: Extra parameters to the function.
.. attribute:: database_func
:annotation: pypika.terms.Function
The pypika function this represents.
.. attribute:: populate_field_object
:annotation: bool = False
Enable populate_field_object where we want to try and preserve the field type.
"""
database_func = ArrayAggTerm
def __init__(self, fields: Iterable[str], *default_values: Any) -> None:
self.fields = fields
self.field_object: "Optional[Field]" = None
self.default_values = default_values
def _get_function_field(self, fields: "Iterable[Field]", *default_values):
return self.database_func(fields, *default_values)
def _resolve_field(self, model: "Type[Model]", table: Table, field: str) -> Tuple[list, str]:
joins = []
fields = field.split("__")
for iter_field in fields[:-1]:
if iter_field not in model._meta.fetch_fields:
raise ConfigurationError(f"{field} not resolvable")
related_field = cast(RelationalField, model._meta.fields_map[iter_field])
joins.append((table, iter_field, related_field))
model = related_field.related_model
related_table: Table = related_field.related_model._meta.basetable
if isinstance(related_field, ForeignKeyFieldInstance):
# Only FK's can be to same table, so we only auto-alias FK join tables
related_table = related_table.as_(f"{table.get_table_name()}__{iter_field}")
table = related_table
last_field = fields[-1]
if last_field in model._meta.fetch_fields:
related_field = cast(RelationalField, model._meta.fields_map[last_field])
related_field_meta = related_field.related_model._meta
joins.append((table, last_field, related_field))
related_table = related_field_meta.basetable
if isinstance(related_field, BackwardFKRelation):
if table == related_table:
related_table = related_table.as_(f"{table.get_table_name()}__{last_field}")
field = related_table[related_field_meta.db_pk_column]
else:
field_object = model._meta.fields_map[last_field]
if field_object.source_field:
field = table[field_object.source_field]
else:
field = table[last_field]
if self.populate_field_object:
self.field_object = model._meta.fields_map.get(last_field, None)
if self.field_object: # pragma: nobranch
func = self.field_object.get_for_dialect(model._meta.db.capabilities.dialect, "function_cast")
if func:
field = func(self.field_object, field)
return joins, field
def _resolve_fields_for_model(self, model: "Type[Model]", table: Table, fields: Iterable[str]) -> tuple:
joins = []
model_fields = []
for field in fields:
field_joins, resolved_field = self._resolve_field(model, table, field)
joins.extend(field_joins)
model_fields.append(resolved_field)
return joins, model_fields
def resolve(self, model: "Type[Model]", table: Table) -> dict:
"""
Used to resolve the Function statement for SQL generation.
:param model: Model the function is applied on to.
:param table: ``pypika.Table`` to keep track of the virtual SQL table
(to allow self referential joins)
:return: Dict with keys ``"joins"`` and ``"fields"``
"""
joins, fields = self._resolve_fields_for_model(model, table, self.fields)
function = {"joins": joins, "field": self._get_function_field(fields, *self.default_values)}
return function
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment