Skip to content

Commit

Permalink
Support TF 2.12, drop Python 3.7 (adriangb#302)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored Jun 13, 2023
1 parent 32fc6af commit 427aab2
Show file tree
Hide file tree
Showing 28 changed files with 129 additions and 172 deletions.
18 changes: 11 additions & 7 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ jobs:
- uses: pre-commit/[email protected]

TestStable:
name: Ubuntu / Python ${{ matrix.python-version }} / TensorFlow Stable / Scikit-Learn Stable
name: Ubuntu / Python ${{ matrix.python-version }} / TensorFlow ${{ matrix.tensorflow-version }} / Scikit-Learn Stable
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
tensorflow-version: ["2.11", "2.12"]
fail-fast: false

steps:
Expand All @@ -52,7 +53,8 @@ jobs:
- name: Install Dependencies
run: |
poetry install -E tensorflow
poetry install
poetry run python -m pip install -U "tensorflow~=${{ matrix.tensorflow-version }}"
- name: Test with pytest
run: |
Expand Down Expand Up @@ -97,6 +99,7 @@ jobs:
- name: Test with pytest
if: always()
continue-on-error: true
run: |
poetry run python -m pip freeze
poetry run python -m pytest -v --cov=scikeras --cov-report xml --color=yes
Expand Down Expand Up @@ -145,8 +148,9 @@ jobs:
runs-on: ${{ matrix.os }}-latest
strategy:
matrix:
os: [MacOS, Windows] # test all OSs (except Ubuntu, which is already running other tests)
python-version: ["3.7", "3.10"] # test only the two extremes of supported Python versions
os: [MacOS, Windows] # test all OSs except Ubuntu, which is already running other tests
python-version: ["3.8", "3.11"] # test only the two extremes of supported Python versions
tensorflow-version: ["2.11", "2.12"] # test only the two extremes of supported TF versions
fail-fast: false

steps:
Expand All @@ -166,12 +170,12 @@ jobs:
- name: Install Dependencies
# TF is sorta dropping support for Windows
# At the very least they are no longer suppporting GPU
# See https://github.com/tensorflow/tensorflow/releases/tag/v2.11.0
# See https://github.com/tensorflow/tensorflow/releases/tag/v2.12.0
# and
# https://www.tensorflow.org/install/pip#windows-native
run: |
poetry install
poetry run pip install -U tensorflow-cpu
poetry run python -m pip install -U "tensorflow~=${{ matrix.tensorflow-version }}"
- name: Test with pytest
run: |
Expand Down
10 changes: 4 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 22.3.0
rev: "23.3.0"
hooks:
- id: black

- repo: https://github.com/timothycrosley/isort
rev: 5.10.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.272
hooks:
- id: isort
additional_dependencies: [toml]
exclude: ^.*/?setup\.py$
- id: ruff
11 changes: 4 additions & 7 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
import os
import sys


sys.path.insert(0, "../../scikeras")

# -- Project information -----------------------------------------------------

from scikeras import __version__
from scikeras import __version__ # noqa: 402


project = "SciKeras"
Expand Down Expand Up @@ -109,11 +108,9 @@ def setup(app):

# Functionality to build github source URI, taken from sklearn.

import inspect
import subprocess

from operator import attrgetter

import inspect # noqa: 402
import subprocess # noqa: 402
from operator import attrgetter # noqa: E402

REVISION_CMD = "git rev-parse --short HEAD"

Expand Down
38 changes: 21 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ license = "MIT"
name = "scikeras"
readme = "README.md"
repository = "https://github.com/adriangb/scikeras"
version = "0.10.0"
version = "0.11.0"

[tool.poetry.dependencies]
importlib-metadata = {version = ">=3", python = "<3.8"}
python = ">=3.7.0,<3.11.0"
python = ">=3.8.0,<3.12.0"
scikit-learn = ">=1.0.0"
packaging = ">=0.21"
tensorflow = {version = ">=2.11.0", optional = true}
Expand All @@ -43,6 +43,11 @@ grpcio = { version = "<1.50.0", markers = "python_version >= '3.10' and sys_plat
tensorflow = ["tensorflow"]
tensorflow-cpu = ["tensorflow-cpu"]

[tool.poetry.dependencies.tensorflow-io-gcs-filesystem]
# see https://github.com/tensorflow/tensorflow/issues/60202
version = ">=0.23.1,<0.32"
markers = "sys_platform == 'win32'"

[tool.poetry.dev-dependencies]
tensorflow = ">=2.11.0"
coverage = {extras = ["toml"], version = ">=6.4.2"}
Expand All @@ -58,25 +63,24 @@ pytest = ">=7.1.2"
pytest-cov = ">=3.0.0"
sphinx = ">=5.0.2"

[tool.isort]
atomic = true
filter_files = true
include_trailing_comma = true
known_first_party = "scikeras"
known_third_party = [
"tensorflow",
"sklearn",
[tool.ruff]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"C", # flake8-comprehensions
"B", # flake8-bugbear
]
ignore = [
"E501", # line too long, handled by black
"C901", # too complex
"B905", # strict argument to zip
]
line_length = 88
lines_after_imports = 2
lines_between_types = 1
multi_line_output = 3
skip_glob = ["*/setup.py"]
use_parentheses = true

[tool.black]
line-length = 88
target-version = ['py36', 'py38']
target-version = ['py38']

[tool.coverage.run]
source = ["scikeras/"]
Expand Down
12 changes: 5 additions & 7 deletions scikeras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,19 @@
MIN_TF_VERSION = "2.7.0"
TF_VERSION_ERR = f"SciKeras requires TensorFlow >= {MIN_TF_VERSION}."

from packaging import version

from packaging import version # noqa: E402

try:
from tensorflow import __version__ as tf_version
except ImportError: # pragma: no cover
raise ImportError("TensorFlow is not installed. " + TF_VERSION_ERR)
raise ImportError("TensorFlow is not installed. " + TF_VERSION_ERR) from None
else:
if version.parse(tf_version) < version.parse(MIN_TF_VERSION): # pragma: no cover
raise ImportError(TF_VERSION_ERR)

import tensorflow.keras as _keras
raise ImportError(TF_VERSION_ERR) from None

from scikeras import _saving_utils
import tensorflow.keras as _keras # noqa: E402

from scikeras import _saving_utils # noqa: E402

_keras.Model.__reduce__ = _saving_utils.pack_keras_model
_keras.Model.__deepcopy__ = _saving_utils.deepcopy_model
Expand Down
4 changes: 1 addition & 3 deletions scikeras/_saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import shutil
import tarfile
import tempfile

from contextlib import contextmanager
from io import BytesIO
from typing import Any, Callable, ContextManager, Dict, Hashable, Iterator, List, Tuple
from typing import Any, Callable, Dict, Hashable, Iterator, List, Tuple
from uuid import uuid4

import numpy as np
import tensorflow.keras as keras

from tensorflow import io as tf_io
from tensorflow.keras.models import load_model

Expand Down
12 changes: 5 additions & 7 deletions scikeras/_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import inspect

from types import FunctionType
from typing import Any, Callable, Dict, Iterable, Mapping, Sequence, Type, Union

from tensorflow.keras import losses as losses_mod
from tensorflow.keras import metrics as metrics_mod
from tensorflow.keras import optimizers as optimizers_mod


DIGITS = frozenset(str(i) for i in range(10))


Expand Down Expand Up @@ -37,7 +35,7 @@ def route_params(
Dict[str, Any]
Filtered parameters, with any routing prefixes removed.
"""
res = dict()
res = {}
routed = {k: v for k, v in params.items() if "__" in k}
non_routed = {k: params[k] for k in (params.keys() - routed.keys())}
for key, val in non_routed.items():
Expand Down Expand Up @@ -92,7 +90,7 @@ def unflatten_params(items, params, base_params=None):
if inspect.isclass(items):
item = items
new_base_params = {p: v for p, v in params.items() if "__" not in p}
base_params = base_params or dict()
base_params = base_params or {}
args_and_kwargs = {**base_params, **new_base_params}
for p, v in args_and_kwargs.items():
args_and_kwargs[p] = unflatten_params(
Expand All @@ -110,7 +108,7 @@ def unflatten_params(items, params, base_params=None):
return item(*args, **kwargs)
if isinstance(items, (list, tuple)):
iter_type_ = type(items)
res = list()
res = []
new_base_params = {p: v for p, v in params.items() if "__" not in p}
for idx, item in enumerate(items):
item_params = route_params(
Expand All @@ -126,7 +124,7 @@ def unflatten_params(items, params, base_params=None):
)
return iter_type_(res)
if isinstance(items, (dict,)):
res = dict()
res = {}
new_base_params = {p: v for p, v in params.items() if "__" not in p}
for key, item in items.items():
item_params = route_params(
Expand All @@ -144,7 +142,7 @@ def unflatten_params(items, params, base_params=None):
# non-compilable item, check if it has any routed parameters
item = items
new_base_params = {p: v for p, v in params.items() if "__" not in p}
base_params = base_params or dict()
base_params = base_params or {}
kwargs = {**base_params, **new_base_params}
if kwargs:
raise TypeError(
Expand Down
8 changes: 6 additions & 2 deletions scikeras/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def loss_name(loss: Union[str, Loss, Callable]) -> str:
fn_or_cls = keras_loss_get(loss)
if isinstance(fn_or_cls, Loss):
return _camel2snake(fn_or_cls.__class__.__name__)
return fn_or_cls.__name__
if hasattr(fn_or_cls, "__name__"):
return fn_or_cls.__name__
return fn_or_cls


def metric_name(metric: Union[str, Metric, Callable]) -> str:
Expand Down Expand Up @@ -109,4 +111,6 @@ def metric_name(metric: Union[str, Metric, Callable]) -> str:
fn_or_cls = keras_metric_get(metric)
if isinstance(fn_or_cls, Metric):
return _camel2snake(fn_or_cls.__class__.__name__)
return fn_or_cls.__name__
if hasattr(fn_or_cls, "__name__"):
return fn_or_cls.__name__
return fn_or_cls
3 changes: 0 additions & 3 deletions scikeras/utils/random_state.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import os
import random

from contextlib import contextmanager
from typing import Generator

import numpy as np
import tensorflow as tf

from tensorflow.python.eager import context
from tensorflow.python.framework import config, ops


DIGITS = frozenset(str(i) for i in range(10))


Expand Down
1 change: 0 additions & 1 deletion scikeras/utils/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import tensorflow as tf

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.exceptions import NotFittedError
from sklearn.pipeline import make_pipeline
Expand Down
11 changes: 4 additions & 7 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import functools
import inspect
import warnings

from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Mapping, Set, Tuple, Type, Union

import numpy as np
import tensorflow as tf

from scipy.sparse import isspmatrix, lil_matrix
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.exceptions import NotFittedError
Expand Down Expand Up @@ -219,7 +217,6 @@ def __init__(
epochs: int = 1,
**kwargs,
):

# Parse hardcoded params
self.model = model
self.build_fn = build_fn
Expand Down Expand Up @@ -300,7 +297,8 @@ def _check_model_param(self):
model = build_fn
warnings.warn(
"``build_fn`` will be renamed to ``model`` in a future release,"
" at which point use of ``build_fn`` will raise an Error instead."
" at which point use of ``build_fn`` will raise an Error instead.",
stacklevel=4,
)
if model is None:
# no model, use this class' _keras_build_fn
Expand Down Expand Up @@ -515,7 +513,7 @@ def _fit_keras_model(
except AttributeError:
raise ValueError(
f"`{bs_kwarg}=-1` requires that `X` implement `shape`"
)
) from None
fit_args = {k: v for k, v in fit_args.items() if not k.startswith("callbacks")}
fit_args["callbacks"] = self._fit_callbacks

Expand Down Expand Up @@ -828,7 +826,6 @@ def initialize(destination: str):
def _initialize(
self, X: np.ndarray, y: Union[np.ndarray, None] = None
) -> Tuple[np.ndarray, np.ndarray]:

# Handle random state
if isinstance(self.random_state, np.random.RandomState):
# Keras needs an integer
Expand Down Expand Up @@ -1016,7 +1013,7 @@ def _predict_raw(self, X, **kwargs):
except AttributeError:
raise ValueError(
"`batch_size=-1` requires that `X` implement `shape`"
)
) from None

# predict with Keras model
y_pred = self.model_.predict(x=X, **pred_args)
Expand Down
1 change: 0 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
import tensorflow as tf


# Force data conversion warnings to be come errors
pytestmark = pytest.mark.filterwarnings(
"error::sklearn.exceptions.DataConversionWarning"
Expand Down
2 changes: 0 additions & 2 deletions tests/mlp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model

from scikeras.wrappers import KerasRegressor


def dynamic_classifier(
hidden_layer_sizes=(10,),
Expand Down
Loading

0 comments on commit 427aab2

Please sign in to comment.