Skip to content

Instantly share code, notes, and snippets.

@TobeTek
Created June 23, 2024 20:13
Show Gist options
  • Save TobeTek/74b8eb75900c261466ed30eaeb7b5070 to your computer and use it in GitHub Desktop.
Save TobeTek/74b8eb75900c261466ed30eaeb7b5070 to your computer and use it in GitHub Desktop.
A django management command to create migrations automatically for all models with Postgres' SearchVectorField
import os
import string
from collections import defaultdict
from django.core.management.base import BaseCommand, CommandError
from django.db import migrations
from django.db.migrations.writer import MigrationWriter
from django.db.models import Model
MIGRATION_FILE_NAME = "searchvectortrigger"
class Command(BaseCommand):
"""
Create migrations to create triggers for search vector fields.
Should be invoked after `./manage.py makemigrations`
"""
help = "Creates new migration(s) to create triggers for search vectors in models."
include_header = True
@property
def log_output(self):
return self.stdout
def log(self, msg):
self.log_output.write(msg)
def handle(self, *app_labels, **options):
from articles.models.articles import Article
from articles.models.categories import Category
from community.models import Post
from tools_and_settings.models import AppTool
SEARCH_VECTOR_FIELDS = {
Article: [
{
"vector_column": "english_fts_vector",
"trigger_columns": ["escaped_content_html", "topic", "title"],
},
],
Category: [
{
"vector_column": "english_fts_vector",
"trigger_columns": ["escaped_content_html", "topic", "title"],
},
],
Post: [
{
"vector_column": "english_fts_vector",
"trigger_columns": ["escaped_content_html", "topic", "title"],
},
],
AppTool: [
{
"vector_column": "english_fts_vector",
"trigger_columns": ["escaped_content_html", "topic", "title"],
},
{
"vector_column": "description_fts_vector",
"trigger_columns": ["description"],
},
],
}
search_trigger_migrations = defaultdict(list)
for model, search_vectors in SEARCH_VECTOR_FIELDS.items():
operations = []
for search_vector in search_vectors:
operations.append(
generate_search_vector_sql(
model=model,
vector_column=search_vector["vector_column"],
trigger_columns=search_vector["trigger_columns"],
)
)
search_trigger_migrations[model._meta.app_label].append(
[
model._meta.model_name,
operations,
]
)
self.write_migration_files(search_trigger_migrations)
def write_migration_files(self, changes):
"""
Take a changes dict and write them out as migration files.
"""
for app_label, model_migrations in changes.items():
for [model_name, operations] in model_migrations:
subclass = type(
"Migration",
(migrations.Migration,),
{
"dependencies": [],
"operations": operations,
},
)
migration = subclass(
name=f"0001_{MIGRATION_FILE_NAME}_{model_name}",
app_label=app_label,
)
writer = MigrationWriter(migration, self.include_header)
# Add dependency migrations if they exist
if dependency := self.get_most_recent_migration(writer.basedir):
dependency_migration_no, _ = dependency.split("_", 1)
new_migration_no = int(dependency_migration_no) + 1
subclass = type(
"Migration",
(migrations.Migration,),
{
"dependencies": [(app_label, dependency)],
"operations": operations,
},
)
migration = subclass(
name=f"{new_migration_no:0>4}_{MIGRATION_FILE_NAME}_{model_name}",
app_label=app_label,
)
writer = MigrationWriter(migration, self.include_header)
migrations_directory = os.path.dirname(writer.path)
if self.has_search_vector_migration(writer.basedir, model_name):
continue
if not os.path.exists(migrations_directory):
os.makedirs(migrations_directory, exist_ok=True)
init_path = os.path.join(migrations_directory, "__init__.py")
if not os.path.isfile(init_path):
open(init_path, "w").close()
migration_string = writer.as_string()
with open(writer.path, "w", encoding="utf-8") as fh:
fh.write(migration_string)
def has_search_vector_migration(self, app_migrations_folder: str, model_name: str):
for filename in sorted(os.listdir(app_migrations_folder)):
if MIGRATION_FILE_NAME in filename and model_name in filename:
return True
return False
def get_most_recent_migration(self, app_migrations_folder: str):
migration_files = [
filename
for filename in os.listdir(app_migrations_folder)
if "__init__" not in filename
]
migration_files = sorted(migration_files, key=lambda a: str(a))
if migration_files:
latest_migration, _ = migration_files[-1].rsplit(".", 1)
return latest_migration
def generate_search_vector_sql(
model: type[Model], vector_column: str, trigger_columns: list[str]
):
CREATE_TRIGGER_SQL = """ALTER TABLE {model_table} DROP COLUMN IF EXISTS {vector_column};
ALTER TABLE {model_table} ADD COLUMN {vector_column} tsvector GENERATED ALWAYS AS ({setweight_stmts}) STORED;"""
REVERSE_CREATE_TRIGGER_SQL = (
"""ALTER TABLE {model_table} DROP COLUMN {vector_column};"""
)
if len(trigger_columns) > len(string.ascii_uppercase):
CommandError("Maximum number of trigger columns exceeded for search vector")
db_table = model._meta.db_table
setweight_stmts = " || ".join(
[
f"setweight(to_tsvector('english', coalesce('{db_table}.{column}', '')), '{string.ascii_uppercase[indx]}')"
for indx, column in enumerate(trigger_columns)
]
)
return migrations.RunSQL(
sql=CREATE_TRIGGER_SQL.format(
model_table=db_table,
vector_column=vector_column,
setweight_stmts=setweight_stmts,
).replace("\n", " "),
reverse_sql=REVERSE_CREATE_TRIGGER_SQL.format(
model_table=db_table, vector_column=vector_column
).replace("\n", " "),
)
@TobeTek
Copy link
Author

TobeTek commented Aug 5, 2024

This isn't necessary with the new GeneratedField that came out in Django 5.0

This is how I would go about this now:

from django.contrib.postgres.search import (
    SearchVectorField,
    SearchVector,
)
from django.db import models
from django.db.models import GeneratedField

class FullTextSearchVectorFields(models.Model):
   """
  Mixin model that assumes the database/model has `topic` and `title` fields as well
   """
    escaped_content_html = models.TextField(
        verbose_name=_("escaped content html"),
        default="",
        blank=True,
        help_text=_(
            "The content used to build the full text search index. If original content is html, strip the HTML tags"
        ),
    )
    english_fts_vector = GeneratedField(
        verbose_name=_("english full text search vector"),
        expression=SearchVector("title", config="english")
        + SearchVector("topic", config="english")
        + SearchVector("escaped_content_html", config="english"),
        output_field=SearchVectorField(),
        db_persist=True,
        null=True,
        editable=False,
        help_text=_("This field generated by the db"),
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment