Skip to content

Commit

Permalink
Seq num sort (databrickslabs#42)
Browse files Browse the repository at this point in the history
* asofjoin + associated unit tests refactoring

* Added "other cols" to init, to optionally select fewer columns. Addressed PR comments.

* add logic to sort by sequence number

* adding unit tests

* removing other_cols

* added comment for new constructor value

* adding sequence sort to all windows (as of join needs the sequence number from the right-hand DF)

Co-authored-by: MaxDBX <[email protected]>
  • Loading branch information
rportilla-databricks and MaxDBX authored Nov 2, 2020
1 parent 494e786 commit 4b986e0
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 16 deletions.
2 changes: 2 additions & 0 deletions datasource/refdata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from tempo.tsdf import TSDF

54 changes: 38 additions & 16 deletions tempo/tsdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@
from pyspark.sql.window import Window

class TSDF:
def __init__(self, df, ts_col="event_ts", partition_cols=None, other_select_cols=None):

def __init__(self, df, ts_col="event_ts", partition_cols=None, sequence_col = None):
"""
Constructor
:param df:
:param ts_col:
:param partitionCols:
:sequence_col every tsdf allows for a tie-breaker secondary sort key
"""
self.ts_col = self.__validated_column(df, ts_col)
self.partitionCols = [] if partition_cols is None else self.__validated_columns(df, partition_cols)
self.other_cols = ([col for col in df.columns if col not in self.partitionCols + [self.ts_col]]
if other_select_cols is None else other_select_cols)

self.df = df.select([self.ts_col] + self.partitionCols + self.other_cols)
self.df = df
self.sequence_col = '' if sequence_col is None else sequence_col
"""
Make sure DF is ordered by its respective ts_col and partition columns.
"""
Expand Down Expand Up @@ -59,7 +58,8 @@ def __addPrefixToColumns(self,col_list,prefix):
range(len(col_list)), self.df)

ts_col = '_'.join([prefix, self.ts_col])
return TSDF(df, ts_col, self.partitionCols)
seq_col = '_'.join([prefix, self.sequence_col]) if self.sequence_col else self.sequence_col
return TSDF(df, ts_col, self.partitionCols, sequence_col = seq_col)

def __addColumnsFromOtherDF(self, other_cols):
"""
Expand All @@ -77,14 +77,17 @@ def __combineTSDF(self, ts_df_right, combined_ts_col):

return TSDF(combined_df, combined_ts_col, self.partitionCols)

def __getLastRightRow(self, left_ts_col, right_cols):
def __getLastRightRow(self, left_ts_col, right_cols, sequence_col):
from functools import reduce
"""Get last right value of each right column (inc. right timestamp) for each self.ts_col value
self.ts_col, which is the combined time-stamp column of both left and right dataframe, is dropped at the end
since it is no longer used in subsequent methods.
"""
window_spec = Window.partitionBy(self.partitionCols).orderBy(self.ts_col)
ptntl_sort_keys = [self.ts_col, sequence_col]
sort_keys = [f.col(col_name) for col_name in ptntl_sort_keys if col_name != '']

window_spec = Window.partitionBy(self.partitionCols).orderBy(sort_keys)
df = reduce(lambda df, idx: df.withColumn(right_cols[idx], f.last(right_cols[idx], True).over(window_spec)),
range(len(right_cols)), self.df)

Expand Down Expand Up @@ -126,18 +129,33 @@ def asofJoin(self, right_tsdf, left_prefix=None, right_prefix="right", tsPartiti
time brackets, which can help alleviate skew.
NOTE: partition cols have to be the same for both Dataframes.
Parameters
:param right_tsdf - right-hand data frame containing columns to merge in
:param left_prefix - optional prefix for base data frame
:param right_prefix - optional prefix for right-hand data frame
:param tsPartitionVal - value to break up each partition into time brackets
:param fraction - overlap fraction
"""
# Check whether partition columns have same name in both dataframes
self.__checkPartitionCols(right_tsdf)

# prefix non-partition columns, to avoid duplicated columns.
left_tsdf = ((self.__addPrefixToColumns([self.ts_col] + self.other_cols, left_prefix))
left_df = self.df
right_df = right_tsdf.df

orig_left_col_diff = list(set(left_df.columns).difference(set(self.partitionCols)))
orig_right_col_diff = list(set(right_df.columns).difference(set(self.partitionCols)))

left_tsdf = ((self.__addPrefixToColumns([self.ts_col] + orig_left_col_diff, left_prefix))
if left_prefix is not None else self)
right_tsdf = right_tsdf.__addPrefixToColumns([right_tsdf.ts_col] + right_tsdf.other_cols, right_prefix)
right_tsdf = right_tsdf.__addPrefixToColumns([right_tsdf.ts_col] + orig_right_col_diff, right_prefix)

left_nonpartition_cols = list(set(left_tsdf.df.columns).difference(set(self.partitionCols)))
right_nonpartition_cols = list(set(right_tsdf.df.columns).difference(set(self.partitionCols)))

# For both dataframes get all non-partition columns (including ts_col)
left_columns = [left_tsdf.ts_col] + left_tsdf.other_cols
right_columns = [right_tsdf.ts_col] + right_tsdf.other_cols
left_columns = [left_tsdf.ts_col] + left_nonpartition_cols
right_columns = [right_tsdf.ts_col] + right_nonpartition_cols

# Union both dataframes, and create a combined TS column
combined_ts_col = "combined_ts"
Expand All @@ -148,10 +166,10 @@ def asofJoin(self, right_tsdf, left_prefix=None, right_prefix="right", tsPartiti

# perform asof join.
if tsPartitionVal is None:
asofDF = combined_df.__getLastRightRow(left_tsdf.ts_col, right_columns)
asofDF = combined_df.__getLastRightRow(left_tsdf.ts_col, right_columns, right_tsdf.sequence_col)
else:
tsPartitionDF = combined_df.__getTimePartitions(tsPartitionVal, fraction=fraction)
asofDF = tsPartitionDF.__getLastRightRow(left_tsdf.ts_col, right_columns)
asofDF = tsPartitionDF.__getLastRightRow(left_tsdf.ts_col, right_columns, right_tsdf.sequence_col)

# Get rid of overlapped data and the extra columns generated from timePartitions
df = asofDF.df.filter(f.col("is_original") == 1).drop("ts_partition","is_original")
Expand All @@ -162,7 +180,11 @@ def asofJoin(self, right_tsdf, left_prefix=None, right_prefix="right", tsPartiti


def __baseWindow(self):
w = Window().orderBy(f.col(self.ts_col).cast("long"))
# add all sort keys - time is first, unique sequence number breaks the tie
ptntl_sort_keys = [self.ts_col, self.sequence_col]
sort_keys = [f.col(col_name).cast("long") for col_name in ptntl_sort_keys if col_name != '']

w = Window().orderBy(sort_keys)
if self.partitionCols:
w = w.partitionBy([f.col(elem) for elem in self.partitionCols])
return w
Expand Down
52 changes: 52 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,58 @@ def test_asof_join(self):
# joined dataframe should equal the expected dataframe
self.assertDataFramesEqual(joined_df, dfExpected)

def test_sequence_number_sort(self):
"""Skew AS-OF Join with Partition Window Test"""
leftSchema = StructType([StructField("symbol", StringType()),
StructField("event_ts", StringType()),
StructField("trade_pr", FloatType()),
StructField("trade_id", IntegerType())])

rightSchema = StructType([StructField("symbol", StringType()),
StructField("event_ts", StringType()),
StructField("bid_pr", FloatType()),
StructField("ask_pr", FloatType()),
StructField("seq_nb", LongType())])

expectedSchema = StructType([StructField("symbol", StringType()),
StructField("event_ts", StringType()),
StructField("trade_pr", FloatType()),
StructField("trade_id", IntegerType()),
StructField("right_event_ts", StringType()),
StructField("right_bid_pr", FloatType()),
StructField("right_ask_pr", FloatType()),
StructField("right_seq_nb", LongType())])

left_data = [["S1", "2020-08-01 00:00:10", 349.21, 1],
["S1", "2020-08-01 00:01:12", 351.32, 2],
["S1", "2020-09-01 00:02:10", 361.1, 3],
["S1", "2020-09-01 00:19:12", 362.1, 4]]

right_data = [["S1", "2020-08-01 00:00:01", 345.11, 351.12, 1],
["S1", "2020-08-01 00:01:05", 348.10, 1000.13, 3],
["S1", "2020-08-01 00:01:05", 348.10, 100.13, 2],
["S1", "2020-09-01 00:02:01", 358.93, 365.12, 4],
["S1", "2020-09-01 00:15:01", 359.21, 365.31, 5]]

expected_data = [
["S1", "2020-08-01 00:00:10", 349.21, 1, "2020-08-01 00:00:01", 345.11, 351.12, 1],
["S1", "2020-08-01 00:01:12", 351.32, 2, "2020-08-01 00:01:05", 348.10, 1000.13, 3],
["S1", "2020-09-01 00:02:10", 361.1, 3, "2020-09-01 00:02:01", 358.93, 365.12, 4],
["S1", "2020-09-01 00:19:12", 362.1, 4, "2020-09-01 00:15:01", 359.21, 365.31, 5]]

# construct dataframes
dfLeft = self.buildTestDF(leftSchema, left_data)
dfRight = self.buildTestDF(rightSchema, right_data)
dfExpected = self.buildTestDF(expectedSchema, expected_data, ["right_event_ts", "event_ts"])

# perform the join
tsdf_left = TSDF(dfLeft, partition_cols=["symbol"])
tsdf_right = TSDF(dfRight, partition_cols=["symbol"], sequence_col="seq_nb")
joined_df = tsdf_left.asofJoin(tsdf_right, right_prefix='right').df

# joined dataframe should equal the expected dataframe
self.assertDataFramesEqual(joined_df, dfExpected)

def test_partitioned_asof_join(self):
"""AS-OF Join with a time-partition"""
leftSchema = StructType([StructField("symbol", StringType()),
Expand Down

0 comments on commit 4b986e0

Please sign in to comment.