Created
October 25, 2020 17:58
-
-
Save DrJackilD/9bd21869abc65469a0a946f9ad6e5a45 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Just a custom function, which allow you to use ARRAY_AGG aggregation function in TortoiseORM | |
""" | |
from typing import TYPE_CHECKING, Any, Optional, Type, cast, Iterable, Tuple | |
from pypika import Table | |
from pypika.functions import Function as PypikaFunction | |
from tortoise.exceptions import ConfigurationError | |
from tortoise.fields.relational import BackwardFKRelation, ForeignKeyFieldInstance, RelationalField | |
from tortoise.functions import Function | |
if TYPE_CHECKING: # pragma: nocoverage | |
from tortoise.models import Model | |
from tortoise.fields.base import Field | |
class ArrayAggTerm(PypikaFunction): | |
def __init__(self, *fields, alias=None): | |
super().__init__("ARRAY_AGG", tuple(*fields), alias=alias) | |
class ArrayAgg(Function): | |
""" | |
Function/Aggregate base. | |
:param field: Field name | |
:param default_values: Extra parameters to the function. | |
.. attribute:: database_func | |
:annotation: pypika.terms.Function | |
The pypika function this represents. | |
.. attribute:: populate_field_object | |
:annotation: bool = False | |
Enable populate_field_object where we want to try and preserve the field type. | |
""" | |
database_func = ArrayAggTerm | |
def __init__(self, fields: Iterable[str], *default_values: Any) -> None: | |
self.fields = fields | |
self.field_object: "Optional[Field]" = None | |
self.default_values = default_values | |
def _get_function_field(self, fields: "Iterable[Field]", *default_values): | |
return self.database_func(fields, *default_values) | |
def _resolve_field(self, model: "Type[Model]", table: Table, field: str) -> Tuple[list, str]: | |
joins = [] | |
fields = field.split("__") | |
for iter_field in fields[:-1]: | |
if iter_field not in model._meta.fetch_fields: | |
raise ConfigurationError(f"{field} not resolvable") | |
related_field = cast(RelationalField, model._meta.fields_map[iter_field]) | |
joins.append((table, iter_field, related_field)) | |
model = related_field.related_model | |
related_table: Table = related_field.related_model._meta.basetable | |
if isinstance(related_field, ForeignKeyFieldInstance): | |
# Only FK's can be to same table, so we only auto-alias FK join tables | |
related_table = related_table.as_(f"{table.get_table_name()}__{iter_field}") | |
table = related_table | |
last_field = fields[-1] | |
if last_field in model._meta.fetch_fields: | |
related_field = cast(RelationalField, model._meta.fields_map[last_field]) | |
related_field_meta = related_field.related_model._meta | |
joins.append((table, last_field, related_field)) | |
related_table = related_field_meta.basetable | |
if isinstance(related_field, BackwardFKRelation): | |
if table == related_table: | |
related_table = related_table.as_(f"{table.get_table_name()}__{last_field}") | |
field = related_table[related_field_meta.db_pk_column] | |
else: | |
field_object = model._meta.fields_map[last_field] | |
if field_object.source_field: | |
field = table[field_object.source_field] | |
else: | |
field = table[last_field] | |
if self.populate_field_object: | |
self.field_object = model._meta.fields_map.get(last_field, None) | |
if self.field_object: # pragma: nobranch | |
func = self.field_object.get_for_dialect(model._meta.db.capabilities.dialect, "function_cast") | |
if func: | |
field = func(self.field_object, field) | |
return joins, field | |
def _resolve_fields_for_model(self, model: "Type[Model]", table: Table, fields: Iterable[str]) -> tuple: | |
joins = [] | |
model_fields = [] | |
for field in fields: | |
field_joins, resolved_field = self._resolve_field(model, table, field) | |
joins.extend(field_joins) | |
model_fields.append(resolved_field) | |
return joins, model_fields | |
def resolve(self, model: "Type[Model]", table: Table) -> dict: | |
""" | |
Used to resolve the Function statement for SQL generation. | |
:param model: Model the function is applied on to. | |
:param table: ``pypika.Table`` to keep track of the virtual SQL table | |
(to allow self referential joins) | |
:return: Dict with keys ``"joins"`` and ``"fields"`` | |
""" | |
joins, fields = self._resolve_fields_for_model(model, table, self.fields) | |
function = {"joins": joins, "field": self._get_function_field(fields, *self.default_values)} | |
return function |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment