Skip to content

Commit

Permalink
Convert Keras Core to Keras 3.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 22, 2023
1 parent 10d0349 commit b9be76a
Show file tree
Hide file tree
Showing 700 changed files with 5,823 additions and 5,823 deletions.
22 changes: 11 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[![](https://github.com/keras-team/keras-core/workflows/Tests/badge.svg?branch=main)](https://github.com/keras-team/keras-core/actions?query=workflow%3ATests+branch%3Amain)
[![](https://badge.fury.io/py/keras-core.svg)](https://badge.fury.io/py/keras-core)
[![](https://github.com/keras-team/keras/workflows/Tests/badge.svg?branch=main)](https://github.com/keras-team/keras/actions?query=workflow%3ATests+branch%3Amain)
[![](https://badge.fury.io/py/keras.svg)](https://badge.fury.io/py/keras)

# Keras Core: A new multi-backend Keras
# Keras 3: A new multi-backend Keras

Keras Core is a new multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Keras 3 is a new multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.

**WARNING:** At this time, this package is experimental.
It has rough edges and not everything might work as expected.
Expand All @@ -13,7 +13,7 @@ Once ready, this package will become Keras 3.0 and subsume `tf.keras`.

## Local installation

Keras Core is compatible with Linux and MacOS systems. To install a local development version:
Keras 3 is compatible with Linux and MacOS systems. To install a local development version:

1. Install dependencies:

Expand All @@ -28,7 +28,7 @@ python pip_build.py --install
```

You should also install your backend of choice: `tensorflow`, `jax`, or `torch`.
Note that `tensorflow` is required for using certain Keras Core features: certain preprocessing layers as
Note that `tensorflow` is required for using certain Keras 3 features: certain preprocessing layers as
well as `tf.data` pipelines.

## Configuring your backend
Expand All @@ -46,16 +46,16 @@ In Colab, you can do:
import os
os.environ["KERAS_BACKEND"] = "jax"

import keras_core as keras
import keras as keras
```

**Note:** The backend must be configured before importing `keras_core`, and the backend cannot be changed after
**Note:** The backend must be configured before importing `keras`, and the backend cannot be changed after
the package has been imported.

## Backwards compatibility

Keras Core is intended to work as a drop-in replacement for `tf.keras` (when using the TensorFlow backend). Just take your
existing `tf.keras` code, change the `keras` imports to `keras_core`, make sure that your calls to `model.save()` are using
Keras 3 is intended to work as a drop-in replacement for `tf.keras` (when using the TensorFlow backend). Just take your
existing `tf.keras` code, change the `keras` imports to `keras`, make sure that your calls to `model.save()` are using
the up-to-date `.keras` format, and you're done.

If your `tf.keras` model does not include custom components, you can start running it on top of JAX or PyTorch immediately.
Expand All @@ -66,7 +66,7 @@ to a backend-agnostic implementation in just a few minutes.
In addition, Keras models can consume datasets in any format, regardless of the backend you're using:
you can train your models with your existing `tf.data.Dataset` pipelines or PyTorch `DataLoaders`.

## Why use Keras Core?
## Why use Keras 3?

- Run your high-level Keras workflows on top of any framework -- benefiting at will from the advantages of each framework,
e.g. the scalability and performance of JAX or the production ecosystem options of TensorFlow.
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/layer_benchmark/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Benchmark the layer performance

This directory contains benchmarks to compare the performance of
`keras_core.layers.XXX` and `tf.keras.layers.XXX`. We compare the performance of
`keras.layers.XXX` and `tf.keras.layers.XXX`. We compare the performance of
both the forward pass and train step (forward & backward pass).

To run the benchmark, use the command below and change the flags according to
Expand Down
56 changes: 28 additions & 28 deletions benchmarks/layer_benchmark/base_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tensorflow as tf
from absl import flags

import keras_core
import keras

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -66,7 +66,7 @@ def on_predict_batch_end(self, batch, logs=None):
self.state["throughput"] = throughput


class KerasCoreBenchmarkMetricsCallback(keras_core.callbacks.Callback):
class KerasCoreBenchmarkMetricsCallback(keras.callbacks.Callback):
def __init__(self, start_batch=1, stop_batch=None):
self._callback = BenchmarkMetricsCallback(start_batch, stop_batch)

Expand Down Expand Up @@ -108,36 +108,36 @@ def __init__(
input_shape,
flat_call_inputs=True,
jit_compile=True,
keras_core_layer=None,
keras_layer=None,
tf_keras_layer=None,
):
self.layer_name = layer_name
_keras_core_layer_class = getattr(keras_core.layers, layer_name)
_keras_layer_class = getattr(keras.layers, layer_name)
_tf_keras_layer_class = getattr(tf.keras.layers, layer_name)

if keras_core_layer is None:
# Sometimes you want to initialize the keras_core layer and tf_keras
if keras_layer is None:
# Sometimes you want to initialize the keras layer and tf_keras
# layer in a different way. For example, `Bidirectional` layer,
# which takes in `keras_core.layers.Layer` and
# which takes in `keras.layers.Layer` and
# `tf.keras.layer.Layer` separately.
self._keras_core_layer = _keras_core_layer_class(**init_args)
self._keras_layer = _keras_layer_class(**init_args)
else:
self._keras_core_layer = keras_core_layer
self._keras_layer = keras_layer

if tf_keras_layer is None:
self._tf_keras_layer = _tf_keras_layer_class(**init_args)
else:
self._tf_keras_layer = tf_keras_layer

self.input_shape = input_shape
self._keras_core_model = self._build_keras_core_model(
self._keras_model = self._build_keras_model(
input_shape, flat_call_inputs
)
self._tf_keras_model = self._build_tf_keras_model(
input_shape, flat_call_inputs
)

self._keras_core_model.compile(
self._keras_model.compile(
loss="mse", optimizer="sgd", jit_compile=jit_compile
)
self._tf_keras_model.compile(
Expand All @@ -148,19 +148,19 @@ def __init__(
self.jit_compile = jit_compile
self.input_shape = input_shape

def _build_keras_core_model(self, input_shape, flat_call_inputs=True):
def _build_keras_model(self, input_shape, flat_call_inputs=True):
inputs = []
if not isinstance(input_shape[0], (tuple, list)):
input_shape = [input_shape]

for shape in input_shape:
inputs.append(keras_core.Input(shape=shape))
inputs.append(keras.Input(shape=shape))

if flat_call_inputs:
outputs = self._keras_core_layer(*inputs)
outputs = self._keras_layer(*inputs)
else:
outputs = self._keras_core_layer(inputs)
return keras_core.Model(inputs=inputs, outputs=outputs)
outputs = self._keras_layer(inputs)
return keras.Model(inputs=inputs, outputs=outputs)

def _build_tf_keras_model(self, input_shape, flat_call_inputs=True):
inputs = []
Expand Down Expand Up @@ -195,7 +195,7 @@ def benchmark_predict(self, num_samples, batch_size, data=None):
stop_batch=num_iterations
)

self._keras_core_model.predict(
self._keras_model.predict(
data,
batch_size=batch_size,
callbacks=[callback],
Expand All @@ -207,15 +207,15 @@ def benchmark_predict(self, num_samples, batch_size, data=None):
callbacks=[tf_keras_callback],
)

keras_core_throughput = (
keras_throughput = (
callback._callback.state["throughput"] * batch_size
)
tf_keras_throughput = (
tf_keras_callback._callback.state["throughput"] * batch_size
)
print(
f"Keras Core throughput of forward pass of {self.layer_name}: "
f"{keras_core_throughput:.2f} samples/sec."
f"Keras 3 throughput of forward pass of {self.layer_name}: "
f"{keras_throughput:.2f} samples/sec."
)
print(
f"TF Keras throughput of forward pass of {self.layer_name}: "
Expand All @@ -240,15 +240,15 @@ def benchmark_train(self, num_samples, batch_size, data=None, label=None):
if self.flat_call_inputs:
# Scale by a small factor to avoid zero gradients.
label = (
keras_core.backend.convert_to_numpy(
self._keras_core_layer(*data)
keras.backend.convert_to_numpy(
self._keras_layer(*data)
)
* 1.001
)
else:
label = (
keras_core.backend.convert_to_numpy(
self._keras_core_layer(data)
keras.backend.convert_to_numpy(
self._keras_layer(data)
)
* 1.001
)
Expand All @@ -259,7 +259,7 @@ def benchmark_train(self, num_samples, batch_size, data=None, label=None):
stop_batch=num_iterations
)

self._keras_core_model.fit(
self._keras_model.fit(
data,
label,
batch_size=batch_size,
Expand All @@ -272,15 +272,15 @@ def benchmark_train(self, num_samples, batch_size, data=None, label=None):
callbacks=[tf_keras_callback],
)

keras_core_throughput = (
keras_throughput = (
callback._callback.state["throughput"] * batch_size
)
tf_keras_throughput = (
tf_keras_callback._callback.state["throughput"] * batch_size
)
print(
f"Keras Core throughput of forward & backward pass of "
f"{self.layer_name}: {keras_core_throughput:.2f} samples/sec."
f"Keras 3 throughput of forward & backward pass of "
f"{self.layer_name}: {keras_throughput:.2f} samples/sec."
)
print(
f"TF Keras throughput of forward & backward pass of "
Expand Down
14 changes: 7 additions & 7 deletions benchmarks/layer_benchmark/rnn_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from absl import app
from absl import flags

import keras_core
import keras
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -194,16 +194,16 @@ def benchmark_bidirectional(
):
layer_name = "Bidirectional"
init_args = {}
keras_core_layer = keras_core.layers.Bidirectional(
keras_core.layers.LSTM(32)
keras_layer = keras.layers.Bidirectional(
keras.layers.LSTM(32)
)
tf_keras_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32))
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
keras_core_layer=keras_core_layer,
keras_layer=keras_layer,
tf_keras_layer=tf_keras_layer,
)

Expand All @@ -225,8 +225,8 @@ def benchmark_time_distributed(
):
layer_name = "TimeDistributed"
init_args = {}
keras_core_layer = keras_core.layers.TimeDistributed(
keras_core.layers.Conv2D(16, (3, 3))
keras_layer = keras.layers.TimeDistributed(
keras.layers.Conv2D(16, (3, 3))
)
tf_keras_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Conv2D(16, (3, 3))
Expand All @@ -236,7 +236,7 @@ def benchmark_time_distributed(
init_args,
input_shape=[10, 32, 32, 3],
jit_compile=jit_compile,
keras_core_layer=keras_core_layer,
keras_layer=keras_layer,
tf_keras_layer=tf_keras_layer,
)

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/model_benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import time

import keras_core
import keras


class BenchmarkMetricsCallback(keras_core.callbacks.Callback):
class BenchmarkMetricsCallback(keras.callbacks.Callback):
def __init__(self, start_batch=1, stop_batch=None):
self.start_batch = start_batch
self.stop_batch = stop_batch
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/model_benchmark/bert_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl import logging
from model_benchmark.benchmark_utils import BenchmarkMetricsCallback

import keras_core as keras
import keras as keras

flags.DEFINE_string("model_size", "small", "The size of model to benchmark.")
flags.DEFINE_string(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from absl import logging
from model_benchmark.benchmark_utils import BenchmarkMetricsCallback

import keras_core as keras
import keras as keras

flags.DEFINE_string("model", "EfficientNetV2B0", "The model to benchmark.")
flags.DEFINE_integer("epochs", 1, "The number of epochs.")
Expand Down
12 changes: 6 additions & 6 deletions benchmarks/torch_ctl_benchmark/conv_model_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import torch.nn as nn
import torch.optim as optim

import keras_core
import keras
from benchmarks.torch_ctl_benchmark.benchmark_utils import train_loop
from keras_core import layers
from keras import layers

num_classes = 2
input_shape = (3, 256, 256)
Expand Down Expand Up @@ -55,8 +55,8 @@ def forward(self, x):
return x


def run_keras_core_custom_training_loop():
keras_model = keras_core.Sequential(
def run_keras_custom_training_loop():
keras_model = keras.Sequential(
[
layers.Input(shape=input_shape),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
Expand All @@ -74,7 +74,7 @@ def run_keras_core_custom_training_loop():
num_epochs=num_epochs,
optimizer=optimizer,
loss_fn=loss_fn,
framework="keras_core",
framework="keras",
)


Expand All @@ -93,5 +93,5 @@ def run_torch_custom_training_loop():


if __name__ == "__main__":
run_keras_core_custom_training_loop()
run_keras_custom_training_loop()
run_torch_custom_training_loop()
12 changes: 6 additions & 6 deletions benchmarks/torch_ctl_benchmark/dense_model_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import torch.nn as nn
import torch.optim as optim

import keras_core
import keras
from benchmarks.torch_ctl_benchmark.benchmark_utils import train_loop
from keras_core import layers
from keras import layers

num_classes = 2
input_shape = (8192,)
Expand Down Expand Up @@ -55,8 +55,8 @@ def forward(self, x):
return x


def run_keras_core_custom_training_loop():
keras_model = keras_core.Sequential(
def run_keras_custom_training_loop():
keras_model = keras.Sequential(
[
layers.Input(shape=input_shape),
layers.Dense(64, activation="relu"),
Expand All @@ -73,7 +73,7 @@ def run_keras_core_custom_training_loop():
num_epochs=num_epochs,
optimizer=optimizer,
loss_fn=loss_fn,
framework="keras_core",
framework="keras",
)


Expand All @@ -92,5 +92,5 @@ def run_torch_custom_training_loop():


if __name__ == "__main__":
run_keras_core_custom_training_loop()
run_keras_custom_training_loop()
run_torch_custom_training_loop()
Loading

0 comments on commit b9be76a

Please sign in to comment.