Created
May 10, 2019 07:42
-
-
Save afraenkel/659344aa8b3668e6d4a1b3a3bc274c7b to your computer and use it in GitHub Desktop.
PySpark: partitioning groups of events that occur within 90 days of previous.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyspark | |
from pyspark import SparkContext | |
from pyspark.sql import SQLContext | |
from pyspark.sql import functions as F | |
from pyspark.sql import types as T | |
from pyspark.sql import Window | |
sc = SparkContext() | |
sqlc = SQLContext(sc) | |
# Test data | |
# E.g. IDs and server log-in days | |
L = ''' | |
a,2016-01-01 | |
a,2016-06-01 | |
a,2016-02-01 | |
a,2016-05-01 | |
b,2016-10-01 | |
b,2016-11-01 | |
a,2016-11-01 | |
a,2017-01-01 | |
b,2017-06-01 | |
a,2017-04-01 | |
b,2017-04-01 | |
'''.strip().split('\n') | |
L = list(map(lambda x:x.split(','), L)) | |
## DataFrame | |
df = sqlc.createDataFrame( | |
L, | |
schema=T.StructType( | |
[ | |
T.StructField('ID', T.StringType()), | |
T.StructField('DATE', T.StringType()) | |
]) | |
).withColumn('DATE', F.to_date('DATE').alias('DATE')) | |
# df.collect() | |
# [Row(ID='a', DATE=datetime.date(2016, 1, 1)), | |
# Row(ID='a', DATE=datetime.date(2016, 6, 1)), | |
# Row(ID='a', DATE=datetime.date(2016, 2, 1)), | |
# Row(ID='a', DATE=datetime.date(2016, 5, 1)), | |
# Row(ID='b', DATE=datetime.date(2016, 10, 1)), | |
# Row(ID='b', DATE=datetime.date(2016, 11, 1)), | |
# Row(ID='a', DATE=datetime.date(2016, 11, 1)), | |
# Row(ID='a', DATE=datetime.date(2017, 1, 1)), | |
# Row(ID='b', DATE=datetime.date(2017, 6, 1)), | |
# Row(ID='a', DATE=datetime.date(2017, 4, 1)), | |
# Row(ID='b', DATE=datetime.date(2017, 4, 1))] | |
### Partition data by ID (e.g. user) | |
## Attach a unique integer for each cluster of events in a partition. This partitions the partition. | |
## An event is in a cluster if it occurs within 90 days of the previous event. | |
wind = Window.partitionBy("ID").orderBy("DATE") | |
with_diffs = ( | |
df | |
.withColumn('time_between', | |
F.datediff(F.col('DATE'), F.lag(F.col('DATE'), 1).over(wind)) | |
) | |
.fillna(0, subset='time_between') | |
.withColumn('change', F.when(F.col('time_between') < 90, 0).otherwise(1)) | |
.withColumn('cumsum', F.sum(F.col('change')).over(wind)) | |
) | |
# keeps all the intermediate columns | |
# with_diffs.collect() | |
# [Row(ID='b', DATE=datetime.date(2016, 10, 1), time_between=0, change=0, cumsum=0), | |
# Row(ID='b', DATE=datetime.date(2016, 11, 1), time_between=31, change=0, cumsum=0), | |
# Row(ID='b', DATE=datetime.date(2017, 4, 1), time_between=151, change=1, cumsum=1), | |
# Row(ID='b', DATE=datetime.date(2017, 6, 1), time_between=61, change=0, cumsum=1), | |
# Row(ID='a', DATE=datetime.date(2016, 1, 1), time_between=0, change=0, cumsum=0), | |
# Row(ID='a', DATE=datetime.date(2016, 2, 1), time_between=31, change=0, cumsum=0), | |
# Row(ID='a', DATE=datetime.date(2016, 5, 1), time_between=90, change=1, cumsum=1), | |
# Row(ID='a', DATE=datetime.date(2016, 6, 1), time_between=31, change=0, cumsum=1), | |
# Row(ID='a', DATE=datetime.date(2016, 11, 1), time_between=153, change=1, cumsum=2), | |
# Row(ID='a', DATE=datetime.date(2017, 1, 1), time_between=61, change=0, cumsum=2), | |
# Row(ID='a', DATE=datetime.date(2017, 4, 1), time_between=90, change=1, cumsum=3)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment