|
from sqlalchemy_continuum.plugins import Plugin |
|
from sqlalchemy_continuum import Operation |
|
from sqlalchemy.inspection import inspect |
|
|
|
|
|
class RelatedVersioningPlugin(Plugin): |
|
|
|
def __init__(self): |
|
self.class_registry = [] |
|
|
|
def get_model(self, table_name): |
|
""" |
|
Get declarative model for given table |
|
|
|
:param table: sqlalchemy.schema.Table |
|
:return: db.Model |
|
""" |
|
for klass in self.class_registry: |
|
if hasattr(klass, '__tablename__'): |
|
if klass.__tablename__ == table_name: |
|
return klass |
|
raise AttributeError('Unknown model for table %s' % table_name) |
|
|
|
def before_create_version_objects(self, uow, session): |
|
|
|
# Iterate all operations and find other objects, having a foreign key |
|
# reference to the target object, that should also be added to transaction |
|
for _, operation in uow.operations.items(): |
|
|
|
# Process each operation only once |
|
if operation.processed: |
|
continue |
|
|
|
if not self.class_registry: |
|
self.class_registry = operation.target._decl_class_registry.values() |
|
|
|
# Prevent infinite loop |
|
max_allowed_level = 10 |
|
|
|
# Work on all referenced objects and their types, starting from operation.target |
|
object_sets = [(max_allowed_level, [operation.target]),] |
|
|
|
for level, objects in object_sets: |
|
|
|
if level <= 0: |
|
raise RuntimeError('Maximum level reached while collecting referenced objects') |
|
|
|
model = type(objects[0]) |
|
table = objects[0].__table__ |
|
ignored_properties = model.__versioned__.get('exclude_for_related', []) |
|
|
|
# Get object property name with a given foreign keys |
|
property_name_by_fk = {} |
|
reflected_model = inspect(model) |
|
for column in table.columns: |
|
for fk in column.foreign_keys: |
|
property_name_by_fk[fk] = reflected_model.get_property_by_column(column).key |
|
|
|
# Get other tables referenced by foreign keys in the object |
|
known_tables = table.metadata.tables |
|
referenced_tables = set() |
|
for fk in table.foreign_keys: |
|
referenced_tables.update(( |
|
(foreign_table_name, fk.column, property_name_by_fk[fk]) |
|
for foreign_table_name, foreign_table in known_tables.items() |
|
if property_name_by_fk[fk] not in ignored_properties and fk.references(foreign_table) |
|
)) |
|
|
|
# Forward same versioning operation to all objects referenced with foreign key |
|
for foreign_table_name, foreign_column, local_property_name in referenced_tables: |
|
foreign_model = self.get_model(foreign_table_name) |
|
|
|
# Collect foreign column values |
|
values = [] |
|
for obj in objects: |
|
value = getattr(obj, local_property_name) |
|
if value is not None: |
|
values.append(value) |
|
|
|
if values: |
|
|
|
# Get objects from database |
|
referenced_objects = foreign_model.query.filter(foreign_column.in_(values)).all() |
|
|
|
# Include all referenced objects to current versioning transaction |
|
for referenced_object in referenced_objects: |
|
if referenced_object not in uow.operations: |
|
uow.operations.add(Operation(referenced_object, operation.type)) |
|
|
|
# Check references on referenced model |
|
if referenced_objects: |
|
object_sets.append((level-1, referenced_objects)) |