Skip to content

Commit

Permalink
[query] Use valid globals reference in MWZJ and TABK (hail-is#14246)
Browse files Browse the repository at this point in the history
CHANGELOG: Fix a bug, introduced in 0.2.114, in which
`Table.multi_way_zip_join` and `Table.aggregate_by_key` could throw
"NoSuchElementException: Ref with name `__iruid_...`" when one or more
of the tables had a number of partitions substantially different from
the desired number of output partitions.

Fixes hail-is#14245.

In both MultiWayZipJoin and TableAggregateByKey, we repartition the
child but neglect to use the new globals `Ref` from the repartitioned
child. As long as `repartitionNoShuffle` does not create a new
TableStage with new globals, this is fine, but that is not, in general,
true. It seems that recently, in lowered backends, when the repartition
cost is deemed "high" we generate a fresh TableStage with a fresh
globals ref.
  • Loading branch information
danking authored Feb 2, 2024
1 parent d261554 commit d4679eb
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 31 deletions.
17 changes: 17 additions & 0 deletions hail/python/test/hail/table/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,23 @@ def test_multi_way_zip_join_key_downcast2(self):
ht = hl.Table.multi_way_zip_join(vcfs, 'data', 'new_globals')
assert exp_count == ht._force_count()

def test_multi_way_zip_join_highly_unbalanced_partitions__issue_14245(self):
def import_vcf(file: str, partitions: int):
return (
hl.import_vcf(file, force_bgz=True, reference_genome='GRCh38', min_partitions=partitions)
.rows()
.select()
)

hl.Table.multi_way_zip_join(
[
import_vcf(resource('gvcfs/HG00096.g.vcf.gz'), 100),
import_vcf(resource('gvcfs/HG00268.g.vcf.gz'), 1),
],
'data',
'new_globals',
).write(new_temp_file(extension='ht'))

def test_index_maintains_count(self):
t1 = hl.Table.parallelize(
[{'a': 'foo', 'b': 1}, {'a': 'bar', 'b': 2}, {'a': 'bar', 'b': 2}],
Expand Down
65 changes: 34 additions & 31 deletions hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1197,39 +1197,39 @@ object LowerTableIR {

case TableAggregateByKey(child, expr) =>
val loweredChild = lower(child)

loweredChild.repartitionNoShuffle(
val repartitioned = loweredChild.repartitionNoShuffle(
ctx,
loweredChild.partitioner.coarsen(child.typ.key.length).strictify(),
)
.mapPartition(Some(child.typ.key)) { partition =>
Let(
FastSeq("global" -> loweredChild.globals),
mapIR(StreamGroupByKey(partition, child.typ.key, missingEqual = true)) { groupRef =>
StreamAgg(
groupRef,
"row",
bindIRs(
ArrayRef(
ApplyAggOp(
FastSeq(I32(1)),
FastSeq(SelectFields(Ref("row", child.typ.rowType), child.typ.key)),
AggSignature(Take(), FastSeq(TInt32), FastSeq(child.typ.keyType)),
),
I32(0),
), // FIXME: would prefer a First() agg op
expr,
) { case Seq(key, value) =>
MakeStruct(child.typ.key.map(k =>
(k, GetField(key, k))
) ++ expr.typ.asInstanceOf[TStruct].fieldNames.map { f =>
(f, GetField(value, f))
})
},
)
},
)
}

repartitioned.mapPartition(Some(child.typ.key)) { partition =>
Let(
FastSeq("global" -> repartitioned.globals),
mapIR(StreamGroupByKey(partition, child.typ.key, missingEqual = true)) { groupRef =>
StreamAgg(
groupRef,
"row",
bindIRs(
ArrayRef(
ApplyAggOp(
FastSeq(I32(1)),
FastSeq(SelectFields(Ref("row", child.typ.rowType), child.typ.key)),
AggSignature(Take(), FastSeq(TInt32), FastSeq(child.typ.keyType)),
),
I32(0),
), // FIXME: would prefer a First() agg op
expr,
) { case Seq(key, value) =>
MakeStruct(child.typ.key.map(k =>
(k, GetField(key, k))
) ++ expr.typ.asInstanceOf[TStruct].fieldNames.map { f =>
(f, GetField(value, f))
})
},
)
},
)
}

case TableDistinct(child) =>
val loweredChild = lower(child)
Expand Down Expand Up @@ -2155,7 +2155,10 @@ object LowerTableIR {
)
val repartitioned = lowered.map(_.repartitionNoShuffle(ctx, newPartitioner))
val newGlobals = MakeStruct(FastSeq(
globalName -> MakeArray(lowered.map(_.globals), TArray(lowered.head.globalType))
globalName -> MakeArray(
repartitioned.map(_.globals),
TArray(repartitioned.head.globalType),
)
))
val globalsRef = Ref(genUID(), newGlobals.typ)

Expand Down

0 comments on commit d4679eb

Please sign in to comment.