Skip to content

Commit

Permalink
[BugFix] Fix wrong implementation of first_value/last_value/lead/lag …
Browse files Browse the repository at this point in the history
…with ignore nulls (StarRocks#18614)

Signed-off-by: liuyehcf <[email protected]>
  • Loading branch information
liuyehcf authored Mar 6, 2023
1 parent 4397070 commit 2839abe
Showing 1 changed file with 84 additions and 21 deletions.
105 changes: 84 additions & 21 deletions be/src/exprs/agg/window.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,16 +331,18 @@ class FirstValueWindowFunction final : public ValueWindowFunction<LT, FirstValue
return;
}

this->data(state).has_value = true;

size_t value_index =
!ignoreNulls ? frame_start : ColumnHelper::find_nonnull(columns[0], frame_start, frame_end);
if (value_index == frame_end || columns[0]->is_null(value_index)) {
this->data(state).is_null = true;
if (!ignoreNulls) {
this->data(state).has_value = true;
}
} else {
const Column* data_column = ColumnHelper::get_data_column(columns[0]);
const InputColumnType* column = down_cast<const InputColumnType*>(data_column);
this->data(state).is_null = false;
this->data(state).has_value = true;
AggDataTypeTraits<LT>::assign_value(this->data(state).value,
AggDataTypeTraits<LT>::get_row_ref(*column, value_index));
}
Expand All @@ -358,7 +360,8 @@ template <LogicalType LT, bool ignoreNulls, typename = guard::Guard>
struct LastValueState {
using T = AggDataValueType<LT>;
T value;
bool is_null = ignoreNulls;
bool is_null = false;
bool has_value = false;
};

template <LogicalType LT, bool ignoreNulls, typename T = RunTimeCppType<LT>>
Expand All @@ -367,7 +370,8 @@ class LastValueWindowFunction final : public ValueWindowFunction<LT, LastValueSt

void reset(FunctionContext* ctx, const Columns& args, AggDataPtr __restrict state) const override {
this->data(state).value = {};
this->data(state).is_null = ignoreNulls;
this->data(state).is_null = false;
this->data(state).has_value = false;
}

void update_batch_single_state_with_frame(FunctionContext* ctx, AggDataPtr __restrict state, const Column** columns,
Expand All @@ -381,11 +385,17 @@ class LastValueWindowFunction final : public ValueWindowFunction<LT, LastValueSt
size_t value_index =
!ignoreNulls ? frame_end - 1 : ColumnHelper::last_nonnull(columns[0], frame_start, frame_end);
if (value_index == frame_end || columns[0]->is_null(value_index)) {
this->data(state).is_null = true;
if (ignoreNulls) {
this->data(state).is_null = (!this->data(state).has_value);
} else {
this->data(state).is_null = true;
this->data(state).has_value = true;
}
} else {
const Column* data_column = ColumnHelper::get_data_column(columns[0]);
const InputColumnType* column = down_cast<const InputColumnType*>(data_column);
this->data(state).is_null = false;
this->data(state).has_value = true;
AggDataTypeTraits<LT>::assign_value(this->data(state).value,
AggDataTypeTraits<LT>::get_row_ref(*column, value_index));
}
Expand All @@ -412,9 +422,23 @@ template <LogicalType LT, bool ignoreNulls, bool isLag, typename T = RunTimeCppT
class LeadLagWindowFunction final : public ValueWindowFunction<LT, LeadLagState<LT>, T> {
using InputColumnType = typename ValueWindowFunction<LT, FirstValueState<LT>, T>::InputColumnType;

mutable int64_t offset = 0;

void reset(FunctionContext* ctx, const Columns& args, AggDataPtr __restrict state) const override {
this->data(state).value = {};
this->data(state).is_null = false;

// get offset
const Column* arg1 = args[1].get();
DCHECK(arg1->is_constant());
const auto* offset_column = down_cast<const ConstColumn*>(arg1);
if (offset_column->is_nullable()) {
offset = 0;
} else {
offset = ColumnHelper::get_const_value<LogicalType::TYPE_BIGINT>(arg1);
}

// get default value
const Column* arg2 = args[2].get();
DCHECK(arg2->is_constant());
const auto* default_column = down_cast<const ConstColumn*>(arg2);
Expand All @@ -440,24 +464,53 @@ class LeadLagWindowFunction final : public ValueWindowFunction<LT, LeadLagState<
return;
}

if (!columns[0]->is_null(frame_end - 1)) {
this->data(state).is_null = false;
const Column* data_column = ColumnHelper::get_data_column(columns[0]);
const InputColumnType* column = down_cast<const InputColumnType*>(data_column);
AggDataTypeTraits<LT>::assign_value(this->data(state).value,
AggDataTypeTraits<LT>::get_row_ref(*column, frame_end - 1));
} else {
if (!ignoreNulls) {
this->data(state).is_null = true;
return;
// for lead/lag, [peer_group_start, peer_group_end] equals to [partition_start, partition_end]
// when lead/lag called, the whole partitoin's data has already been here, so we can just check all the way to the begining or the end
if (ignoreNulls) {
// lead(v1 ignore nulls, <offset>) has window `ROWS BETWEEN UNBOUNDED PRECEDING AND <offset> FOLLOWING`
// frame_start = partition_start
// frame_end = current_row + <offset> + 1
// current_row = frame_end - 1 - <offset>
//
// lag(v1 ignore nulls, <offset>) has window `ROWS BETWEEN UNBOUNDED PRECEDING AND <offset> PRECEDING`
// frame_start = partition_start
// frame_end = current_row - <offset> + 1
// current_row = frame_end - 1 + <offset>
int64_t current_row = frame_end - 1 + (isLag ? offset : -offset);
if (current_row < peer_group_start) {
current_row = peer_group_start;
} else if (current_row >= peer_group_end) {
current_row = peer_group_end - 1;
}

int64_t cnt = offset;
size_t value_index = current_row;
if (isLag) {
// Look backward, find <offset>-th non-null value
while (value_index > peer_group_start && cnt > 0) {
int64_t next_index = ColumnHelper::last_nonnull(columns[0], peer_group_start, value_index);
if (next_index == value_index) {
break;
}
value_index = next_index;
DCHECK_GE(value_index, peer_group_start);
cnt--;
}
} else {
// Look forward, find <offset>-th non-null value
while (value_index < peer_group_end && cnt > 0) {
int64_t next_index = ColumnHelper::find_nonnull(columns[0], value_index + 1, peer_group_end);
if (next_index == peer_group_end) {
break;
}
value_index = next_index;
DCHECK_LE(value_index, peer_group_end);
cnt--;
}
}
// for lead/lag, [peer_group_start, peer_group_end] equals to [partition_start, partition_end]
// when lead/lag called, the whole partitoin's data has already been here, so we can just check all the way to the begining or the end
size_t value_index = isLag ? ColumnHelper::last_nonnull(columns[0], peer_group_start, frame_end - 1)
: ColumnHelper::find_nonnull(columns[0], frame_end, peer_group_end);
DCHECK_LE(value_index, peer_group_end);
DCHECK_GE(value_index, peer_group_start);
if (value_index == peer_group_end || columns[0]->is_null(value_index)) {
DCHECK_LE(value_index, peer_group_end);
if (cnt > 0 || value_index == peer_group_end || columns[0]->is_null(value_index)) {
this->data(state).is_null = true;
} else {
const Column* data_column = ColumnHelper::get_data_column(columns[0]);
Expand All @@ -466,6 +519,16 @@ class LeadLagWindowFunction final : public ValueWindowFunction<LT, LeadLagState<
AggDataTypeTraits<LT>::assign_value(this->data(state).value,
AggDataTypeTraits<LT>::get_row_ref(*column, value_index));
}
} else {
if (!columns[0]->is_null(frame_end - 1)) {
this->data(state).is_null = false;
const Column* data_column = ColumnHelper::get_data_column(columns[0]);
const InputColumnType* column = down_cast<const InputColumnType*>(data_column);
AggDataTypeTraits<LT>::assign_value(this->data(state).value,
AggDataTypeTraits<LT>::get_row_ref(*column, frame_end - 1));
} else {
this->data(state).is_null = true;
}
}
}

Expand Down

0 comments on commit 2839abe

Please sign in to comment.