Skip to content

Commit

Permalink
Change jax.core.DropVar to be a non-singleton.
Browse files Browse the repository at this point in the history
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Oct 18, 2021
1 parent e2ddad5 commit 03c21d1
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion spinoffs/oryx/oryx/core/interpreters/propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def write(self, var: VarOrLiteral, cell: Cell) -> Cell:
if isinstance(var, jax_core.Literal):
return cell
cur_cell = self.read(var)
if var is jax_core.dropvar:
if isinstance(var, jax_core.DropVar):
return cur_cell
self.env[var] = cur_cell.join(cell)
return self.env[var]
Expand Down

0 comments on commit 03c21d1

Please sign in to comment.