Created
January 30, 2019 16:06
-
-
Save timothyrenner/4379b1fd901b6027fc1132aa14e2ee50 to your computer and use it in GitHub Desktop.
Pyspark Partition Definition
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
# We'll need this handy list more than once. It enforces the | |
# column order required by the model. | |
FEATURES = ["feature1", "feature2", "feature3", ...] | |
def predict_partition(rows): | |
""" Calls a vectorized prediction by loading the partition into memory. | |
Parameters | |
---------- | |
rows : List[pyspark.sql.Row] | |
The rows of the partition. | |
Returns | |
------- | |
List[pyspark.sql.Row] | |
The predictions. | |
""" | |
# Load the input rows into a data frame. | |
# It's safer to rely on the Row object's key-based | |
# lookup. | |
rows_df = pd.DataFrame.from_records( | |
# For each row, convert to a dict mapping the column name to the features. | |
[row.asDict() for row in rows] | |
) | |
# It's possible the partition could be empty as an edge case. | |
if rows_df.empty: | |
return [] | |
rows_df.loc[:, 'prediction'] = model.predict(rows_df[FEATURES].values) | |
# Now we need to turn the predictions (a numpy array) into Rows again. | |
# This form assigns each column (plus "prediction") as a keyword argument. | |
make_row = lambda row: Row(**{col: row[1][col] for col in rows_df.columns}) | |
return map(make_row, rows_df.iterrows()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment