diff --git a/spinoffs/oryx/oryx/core/interpreters/propagate.py b/spinoffs/oryx/oryx/core/interpreters/propagate.py index 4f93ebdc3e..23d2b3c827 100644 --- a/spinoffs/oryx/oryx/core/interpreters/propagate.py +++ b/spinoffs/oryx/oryx/core/interpreters/propagate.py @@ -164,7 +164,7 @@ def write(self, var: VarOrLiteral, cell: Cell) -> Cell: if isinstance(var, jax_core.Literal): return cell cur_cell = self.read(var) - if var is jax_core.dropvar: + if isinstance(var, jax_core.DropVar): return cur_cell self.env[var] = cur_cell.join(cell) return self.env[var]