Created
March 12, 2024 17:36
-
-
Save WolfEYc/f6ce1e947372742bf1108290447e5b0a to your computer and use it in GitHub Desktop.
Polars Oracle Helpers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import re | |
from contextlib import asynccontextmanager | |
from functools import reduce | |
from itertools import chain | |
from operator import itemgetter | |
from typing import Any, Coroutine, Iterable, NamedTuple, Optional | |
import oracledb | |
import polars as pl | |
from mirouteapi.connections.oracle import ENV, MRTE_ORACLE_PWD | |
from mirouteapi.logger import LOGGER | |
class ListReplacements(NamedTuple): | |
query: str | |
new_kwargs: dict[str, Any] | |
def oracle_arraytype(lst: list): | |
if len(lst) == 0 or isinstance(lst[0], str): | |
return ORACLE_TYPES["string"] | |
if isinstance(lst[0], int): | |
return ORACLE_TYPES["number"] | |
if isinstance(lst[0], bytes): | |
return ORACLE_TYPES["bytes"] | |
return ORACLE_TYPES["string"] | |
def python_list_to_oracle_array(lst: list): | |
arr_type = oracle_arraytype(lst) | |
return arr_type.newobject(lst) | |
def replace_query_list(query: str, list_key: str) -> str: | |
return query.replace(f":{list_key}", f"(SELECT * FROM TABLE(:{list_key}))") | |
def replace_query_lists(query: str, list_keys: Iterable[str]) -> str: | |
return reduce(replace_query_list, list_keys, query) | |
def replace_lists_and_query(query: str, kwargs: dict[str, Any]) -> ListReplacements: | |
list_kwargs = filter(lambda item: isinstance(item[1], list), kwargs.items()) | |
non_list_kwargs = filter(lambda item: not isinstance(item[1], list), kwargs.items()) | |
list_kwargs = list(list_kwargs) | |
list_keys: map[str] = map(itemgetter(0), list_kwargs) | |
query = replace_query_lists(query, list_keys) | |
list_kwargs = map( | |
lambda item: (item[0], python_list_to_oracle_array(item[1])), list_kwargs | |
) | |
new_kwargs = chain(non_list_kwargs, list_kwargs) | |
new_kwargs = dict(new_kwargs) | |
return ListReplacements(query, new_kwargs) | |
def replace_lists(kwargs: dict[str, Any]): | |
list_kwargs = filter(lambda item: isinstance(item[1], list), kwargs.items()) | |
list_kwargs = map( | |
lambda item: (item[0], python_list_to_oracle_array(item[1])), list_kwargs | |
) | |
kwargs.update(list_kwargs) | |
return kwargs | |
LIMIT_REGEX = re.compile(r"LIMIT\s+(\d+)", re.IGNORECASE) | |
def limit_replacement(match: re.Match): | |
return f"FETCH NEXT {match.group(1)} ROWS ONLY" | |
def replace_limit_sql(query: str) -> str: | |
return LIMIT_REGEX.sub(limit_replacement, query) | |
async def cursor_to_df(cursor: oracledb.AsyncCursor, to_lower: bool) -> pl.DataFrame: | |
columns = map(itemgetter(0), cursor.description) | |
data = await cursor.fetchall() | |
if to_lower: | |
columns = map(str.lower, columns) | |
columns = list(columns) | |
cursor.close() | |
return pl.DataFrame(data, schema=columns, orient="row") | |
async def fetch( | |
conn: oracledb.AsyncConnection, | |
query: str, | |
*, | |
schema_overrides: Optional[dict] = None, | |
to_lower: bool = True, | |
**kwargs, | |
) -> pl.DataFrame: | |
query, kwargs = replace_lists_and_query(query, kwargs) | |
query = replace_limit_sql(query) | |
with conn.cursor() as cursor: | |
LOGGER.debug(f"Querying Oracle with:\n{query}\nwith kwargs:\n{kwargs}") | |
await cursor.execute(query, **kwargs) | |
data = await cursor.fetchall() | |
columns = map(itemgetter(0), cursor.description) | |
if to_lower: | |
columns = map(str.lower, columns) | |
columns = list(columns) | |
return pl.DataFrame( | |
data, | |
schema=columns, | |
orient="row", | |
schema_overrides=schema_overrides, | |
infer_schema_length=len(data), | |
) | |
async def fetch_proc( | |
conn: oracledb.AsyncConnection, | |
proc: str, | |
*, | |
out_keys: dict[str, Any], | |
to_lower: bool = True, | |
**kwargs, | |
) -> dict[str, Any]: | |
kwargs = replace_lists(kwargs) | |
with conn.cursor() as cursor: | |
out_vals = map(cursor.var, out_keys.values()) | |
out_vals = list(out_vals) | |
out_dict = zip(out_keys.keys(), out_vals) | |
kwargs.update(out_dict) | |
LOGGER.debug(f"Calling Oracle stored proc:\n{proc}\nwith kwargs:\n{kwargs}") | |
await cursor.callproc(proc, keyword_parameters=kwargs) | |
out_results = map(oracledb.Var.getvalue, out_vals) | |
out_results = map( | |
lambda x: cursor_to_df(x, to_lower) | |
if isinstance(x, oracledb.AsyncCursor) | |
else x, | |
out_results, | |
) | |
out_results = zip(out_keys.keys(), out_results) | |
out_results = dict(out_results) | |
coros_dict = filter(lambda x: isinstance(x[1], Coroutine), out_results.items()) | |
coros_dict = dict(coros_dict) | |
results = await asyncio.gather(*coros_dict.values()) # type: ignore | |
coros_dict = zip(coros_dict.keys(), results) | |
out_results.update(coros_dict) | |
return out_results | |
def get_oracle_types(nice_names_to_oracle_type_names: dict[str, str]): | |
with oracledb.connect(**ENV, password=MRTE_ORACLE_PWD) as con: | |
oracle_types = map(con.gettype, nice_names_to_oracle_type_names.values()) | |
nice_names_to_oracle_types = zip( | |
nice_names_to_oracle_type_names.keys(), oracle_types | |
) | |
nice_names_to_oracle_types = dict(nice_names_to_oracle_types) | |
return nice_names_to_oracle_types | |
ORACLE_TYPES = get_oracle_types( | |
{ | |
"number": "SYS.ODCINUMBERLIST", | |
"string": "SYS.ODCIVARCHAR2LIST", | |
"bytes": "SYS.ODCIRAWLIST", | |
} | |
) | |
def gen_set_sql(col: str, include_nulls: bool = False): | |
return f"{col} = :{col}" if include_nulls else f"{col} = COALESCE(:{col}, {col})" | |
async def update_many( | |
conn: oracledb.AsyncConnection, | |
df: pl.DataFrame, | |
table: str, | |
pkey_cols: set[str], | |
include_nulls: bool = False, | |
): | |
df_cols = set(df.columns) | |
non_pkey_cols = df_cols.difference(pkey_cols) | |
set_sqls = map(lambda x: gen_set_sql(x, include_nulls), non_pkey_cols) | |
set_sql = ", ".join(set_sqls) | |
where_sqls = map(lambda x: gen_set_sql(x, True), pkey_cols) | |
where_sql = " AND ".join(where_sqls) | |
update_sql = f"""--sql | |
UPDATE {table} | |
SET {set_sql} | |
WHERE {where_sql} | |
""" | |
rows = df.to_dicts() | |
LOGGER.debug(f"Updating Oracle with:\n{update_sql}\nwith df:\n{df}") | |
await conn.executemany(update_sql, rows) | |
return conn | |
async def insert_many( | |
conn: oracledb.AsyncConnection, | |
df: pl.DataFrame, | |
table: str, | |
): | |
columns_sql = ", ".join(df.columns) | |
values_sqls = map(lambda col: f":{col}", df.columns) | |
values_sql = ", ".join(values_sqls) | |
insert_sql = f"""--sql | |
INSERT INTO {table} ({columns_sql}) | |
VALUES ({values_sql}) | |
""" | |
LOGGER.debug(f"Inserting into Oracle with:\n{insert_sql}\nwith df:\n{df}") | |
rows = df.to_dicts() | |
await conn.executemany(insert_sql, rows) | |
return conn | |
class PoolWrapper: | |
pool: oracledb.AsyncConnectionPool | |
async def init(self): | |
self.pool = oracledb.create_pool_async(password=MRTE_ORACLE_PWD, **ENV) | |
@asynccontextmanager | |
async def acquire(self): | |
async with self.pool.acquire() as conn: | |
yield conn | |
async def close(self): | |
await self.pool.close(force=True) | |
async def fetch( | |
self, | |
query: str, | |
*, | |
schema_overrides: Optional[dict] = None, | |
to_lower: bool = True, | |
**kwargs, | |
) -> pl.DataFrame: | |
async with self.acquire() as conn: | |
return await fetch( | |
conn, | |
query, | |
schema_overrides=schema_overrides, | |
to_lower=to_lower, | |
**kwargs, | |
) | |
async def fetch_proc( | |
self, | |
proc: str, | |
*, | |
out_keys: dict[str, Any], | |
to_lower: bool = True, | |
**kwargs, | |
) -> dict[str, Any]: | |
async with self.acquire() as conn: | |
return await fetch_proc( | |
conn, proc, out_keys=out_keys, to_lower=to_lower, **kwargs | |
) | |
@asynccontextmanager | |
async def update_many( | |
self, | |
df: pl.DataFrame, | |
table: str, | |
pkey_cols: set[str], | |
include_nulls: bool = False, | |
): | |
async with self.acquire() as conn: | |
await update_many(conn, df, table, pkey_cols, include_nulls) | |
yield ConnWrapper(conn) | |
if conn.transaction_in_progress: | |
await conn.commit() | |
async def update_many_autocommit( | |
self, | |
df: pl.DataFrame, | |
table: str, | |
pkey_cols: set[str], | |
include_nulls: bool = False, | |
): | |
async with self.acquire() as conn: | |
conn.autocommit = True | |
await update_many(conn, df, table, pkey_cols, include_nulls) | |
conn.autocommit = False | |
@asynccontextmanager | |
async def insert_many(self, df: pl.DataFrame, table: str): | |
async with self.acquire() as conn: | |
await insert_many(conn, df, table) | |
yield ConnWrapper(conn) | |
if conn.transaction_in_progress: | |
await conn.commit() | |
async def insert_many_autocommit(self, df: pl.DataFrame, table: str): | |
async with self.acquire() as conn: | |
conn.autocommit = True | |
await insert_many(conn, df, table) | |
conn.autocommit = False | |
class ConnWrapper: | |
conn: oracledb.AsyncConnection | |
def __init__(self, conn: oracledb.AsyncConnection): | |
self.conn = conn | |
async def fetch( | |
self, | |
query: str, | |
*, | |
schema_overrides: Optional[dict] = None, | |
to_lower: bool = True, | |
**kwargs, | |
) -> pl.DataFrame: | |
return await fetch( | |
self.conn, | |
query, | |
schema_overrides=schema_overrides, | |
to_lower=to_lower, | |
**kwargs, | |
) | |
async def fetch_proc( | |
self, | |
proc: str, | |
*, | |
out_keys: dict[str, Any], | |
to_lower: bool = True, | |
**kwargs, | |
) -> dict[str, Any]: | |
return await fetch_proc( | |
self.conn, | |
proc, | |
to_lower=to_lower, | |
out_keys=out_keys, | |
**kwargs, | |
) | |
async def update_many( | |
self, | |
df: pl.DataFrame, | |
table: str, | |
pkey_cols: set[str], | |
include_nulls: bool = False, | |
): | |
await update_many(self.conn, df, table, pkey_cols, include_nulls) | |
return self | |
async def insert_many(self, df: pl.DataFrame, table: str): | |
await insert_many(self.conn, df, table) | |
return self | |
async def commit(self): | |
await self.conn.commit() | |
return self | |
async def rollback(self): | |
await self.conn.rollback() | |
return self | |
ORACLE = PoolWrapper() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment