Skip to content

Instantly share code, notes, and snippets.

@trojblue
Last active June 12, 2024 23:44
Show Gist options
  • Save trojblue/a54d59bd7e76d2e8d31798a6bb4fc645 to your computer and use it in GitHub Desktop.
Save trojblue/a54d59bd7e76d2e8d31798a6bb4fc645 to your computer and use it in GitHub Desktop.

using functools.partial to pass in real arguments into kedro:

from functools import partial, update_wrapper
from kedro.pipeline import Pipeline, node

from .nodes import process_todo, DemoMerger


def create_wrapped_partial(func, *args, **kwargs):
    """
    Create a partial function and update its wrapper to preserve function metadata.
    """
    partial_func = partial(func, *args, **kwargs)
    update_wrapper(partial_func, func)
    return partial_func


def create_single_pipeline(mode: str, part_num: int):
    """
    Creates a single pipeline for a specified mode and part number, properly using partial
    for fixed arguments and Kedro's parameter handling for dynamic configuration.
    """
    data_sources = {
        "todo.parquet": process_todo,  # Example processing function names
        # "clip_metrics": process_clip_metrics,
        # "mld_tags": process_mld_tags,
        # "pixiv_meta": process_pixiv_meta,
        # "dbr_meta": process_dbr_meta,
    }

    processing_nodes = [
        node(
            func=create_wrapped_partial(data_sources[src], part_num=part_num),
            inputs={"parts_upload_dir": f"params:{mode}.parts_upload_dir"},
            outputs=f"processed_{mode}_{src}_{part_num}",
            name=f"process_{src}_{part_num}"
        ) for src in data_sources
    ]

    merger_instance = DemoMerger()

    # Explicitly define the inputs to the merger node
    merger_node = node(
        func=merger_instance.merge,
        inputs={"todo_df": f"processed_{mode}_todo.parquet_{part_num}"},
        outputs=f"{mode}_merged_{part_num}",
        name=f"{mode}_merger_{part_num}"
    )

    return Pipeline(processing_nodes + [merger_node])


def create_pipeline(mode: str, part_num: int):
    """
    Wrapper function to instantiate a pipeline with specific mode and part number.
    """
    return create_single_pipeline(mode, part_num)


# from dataproc4.utils.data_utils import get_todo_ids_list

def create_many_pipelines(mode: str):
    """
    Creates multiple pipelines for a specified mode and number of parts.
    """
    num_parts = 10
    return Pipeline([create_single_pipeline(mode, part_num) for part_num in range(num_parts)])
    

nodes.py:

import unibox as ub
import pandas as pd
from typing import Dict

def process_todo(parts_upload_dir: str, part_num: int) -> pd.DataFrame:
    """
    Example processing function that takes in a DataFrame and returns a DataFrame.
    """
    todo_uri = f"{parts_upload_dir.rstrip('/')}/{part_num}.todo.parquet"
    df = ub.loads(todo_uri, debug_print=False)
    return df

   

class DemoMerger:
    """
    Example class that merges. different merger takes different args in merge()
    """
    def merge(self, todo_df:pd.DataFrame) -> pd.DataFrame:
        """
        Sample that Merges DataFrames from different sources.

          actually only returns the first DataFrame for now
        """
        return todo_df
    

def test_base_merger():
    test_data = {
        'a': [1, 2],
        'b': [3, 4],
    }
    merger = DemoMerger()
    result = merger.merge(
        todo_df=pd.DataFrame(test_data),
    )
    assert result.equals(pd.DataFrame({'a': [1, 2], 'b': [3, 4]}))
    print('test_base_merger passed')

if __name__ == '__main__':
    test_base_merger()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment