Skip to content

Commit

Permalink
Fix type errors with current mypy and NumPy.
Browse files Browse the repository at this point in the history
Enable type stubs for jaxlib.

Fix a nondeterminism problem in jax2tf tests.
  • Loading branch information
hawkinsp committed Jun 24, 2021
1 parent b7e9a0b commit d658108
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 28 deletions.
2 changes: 1 addition & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
return concatenate_p.bind(*operands, dimension=dimension)

Precision = xla_client.PrecisionConfig.Precision
Precision.__str__ = lambda precision: precision.name
Precision.__str__ = lambda precision: precision.name # type: ignore
PrecisionType = Any
PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str],
Tuple[PrecisionType, PrecisionType]]
Expand Down
40 changes: 21 additions & 19 deletions jax/_src/scipy/optimize/_lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
from typing import Callable, NamedTuple, Optional, Union
from typing import Any, Callable, NamedTuple, Optional, Union
from functools import partial

import jax
Expand All @@ -23,6 +23,8 @@
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)


Array = Any

class LBFGSResults(NamedTuple):
"""Results from L-BFGS optimization
Expand All @@ -47,32 +49,32 @@ class LBFGSResults(NamedTuple):
5 = line search failed
ls_status: integer describing the end status of the last line search
"""
converged: Union[bool, jnp.ndarray]
failed: Union[bool, jnp.ndarray]
k: Union[int, jnp.ndarray]
nfev: Union[int, jnp.ndarray]
ngev: Union[int, jnp.ndarray]
x_k: jnp.ndarray
f_k: jnp.ndarray
g_k: jnp.ndarray
s_history: jnp.ndarray
y_history: jnp.ndarray
rho_history: jnp.ndarray
gamma: Union[float, jnp.ndarray]
status: Union[int, jnp.ndarray]
ls_status: Union[int, jnp.ndarray]
converged: Union[bool, Array]
failed: Union[bool, Array]
k: Union[int, Array]
nfev: Union[int, Array]
ngev: Union[int, Array]
x_k: Array
f_k: Array
g_k: Array
s_history: Array
y_history: Array
rho_history: Array
gamma: Union[float, Array]
status: Union[int, Array]
ls_status: Union[int, Array]


def _minimize_lbfgs(
fun: Callable,
x0: jnp.ndarray,
maxiter: Optional[int] = None,
x0: Array,
maxiter: Optional[float] = None,
norm=jnp.inf,
maxcor: int = 10,
ftol: float = 2.220446049250313e-09,
gtol: float = 1e-05,
maxfun: Optional[int] = None,
maxgrad: Optional[int] = None,
maxfun: Optional[float] = None,
maxgrad: Optional[float] = None,
maxls: int = 20,
):
"""
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,9 +877,9 @@ def _outside_call_translation_rule(
[array_sharding_proto] * len(non_empty_flat_results_aval) +
[token_sharding_proto])

shape = tuple(shape.with_major_to_minor_layout_if_absent()
for x in non_empty_flat_results_aval
for shape in xla.aval_to_xla_shapes(x))
shape = [shape.with_major_to_minor_layout_if_absent()
for x in non_empty_flat_results_aval
for shape in xla.aval_to_xla_shapes(x)]

build_infeed = functools.partial(xops.InfeedWithToken,
after_outfeed_itoken,
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,7 @@ def gen_conv(lhs, rhs, preferred_element_type: Optional[DType]):
# Follow the lowering for complex convolutions from
# lax._conv_general_dilated_translation. We can use the same conversion on all
# platforms because on XLA:TPU the compiler does the same as a rewrite.
preferred_float_et: Optional[Any]
if np.issubdtype(_in_avals[0].dtype, np.complexfloating):
if preferred_element_type is not None:
# Convert complex dtype to types used for real and imaginary parts
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def _process_dim(i: int, dim_spec: Union[str, int]):
raise ValueError(msg)
return dim_size
# We have a dimension polynomial for a known dimension.
dim_var = dim_poly.to_var()
dim_var = dim_poly.to_var() # type: ignore
if dim_var is not None:
shape_var_map[dim_spec].add(dim_size) # type: ignore
return dim_poly
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/primitive_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _make_integer_pow_harness(name, *, shape=(20, 30), dtype=np.int32, y=3):
y=y)


for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
for dtype in [d for d in jtu.dtypes.all if d not in jtu.dtypes.boolean]:
# Validate dtypes and y values for some special cases.
for y in range(-3, 5):
if np.issubdtype(dtype, np.integer) and y < 0:
Expand Down
2 changes: 1 addition & 1 deletion jax/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _check_jaxlib_version():
try:
from jaxlib import tpu_client as tpu_driver_client # pytype: disable=import-error
except:
tpu_driver_client = None
tpu_driver_client = None # type: ignore


cuda_path: Optional[str]
Expand Down
4 changes: 2 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ disable_error_code = attr-defined

[mypy-absl.*]
ignore_missing_imports = True
[mypy-jaxlib.*]
ignore_missing_imports = True
[mypy-numpy.*]
ignore_missing_imports = True
[mypy-opt_einsum.*]
Expand All @@ -18,3 +16,5 @@ ignore_errors = True
ignore_errors = True
[mypy-jax.experimental.jax2tf.tests.primitive_harness]
ignore_errors = True
[mypy-libtpu.*]
ignore_missing_imports = True

0 comments on commit d658108

Please sign in to comment.