Skip to content

Commit

Permalink
Minor simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed May 14, 2023
1 parent c56cafb commit 8cc8ed3
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 42 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ Keras Core is a new multi-backend implementation of the Keras API, with support

Keras Core is intend to work as a drop-in replacement for `tf.keras` (when using the TensorFlow backend).

In addition, Keras models can consume datasets in any format, regardless of the backend you're using:
you can train your models with your existing tf.data.Dataset pipelines or Torch DataLoaders.

## Why use Keras Core?

- Write custom components (e.g. layers, models, metrics) that you can move across framework boundaries.
- Make your code future-proof by avoiding framework lock-in.
- As a JAX user: get access to a fully-featured modeling and training library.
- As a PyTorch user: get access to the real Keras, at last!
- As a PyTorch user: get access to power of Keras, at last!
- As a JAX user: get access to a fully-featured, battle-tested modeling and training library.

## Credits

Expand Down
31 changes: 3 additions & 28 deletions keras_core/layers/merging/dot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.merging.base_merge import Merge
from keras_core.utils.numerical_utils import normalize


def batch_dot(x, y, axes=None):
Expand Down Expand Up @@ -195,32 +196,6 @@ def batch_dot(x, y, axes=None):
return result


def l2_normalize(x, axis=None, epsilon=1e-9):
"""Normalizes along dimension `axis` using an L2 norm.
For a 1-D tensor with `axis = 0`, computes
`output = x / sqrt(max(sum(x**2), epsilon))`
For `x` with more dimensions, independently normalizes each
1-D slice along dimension `axis`.
Args:
x: Input tensor.
axis: Dimension along which to normalize. A scalar or a
vector of integers.
epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)`
as the divisor if `norm < sqrt(epsilon)`.
Returns:
A normalized tensor with the same shape as input.
"""

square_sum = ops.sum(ops.square(x), axis=axis, keepdims=True)
x_inv_norm = 1.0 / ops.sqrt(ops.maximum(square_sum, epsilon))
return ops.multiply(x, x_inv_norm)


@keras_core_export("keras_core.layers.Dot")
class Dot(Merge):
"""Computes element-wise dot product of two tensors.
Expand Down Expand Up @@ -340,8 +315,8 @@ def _merge_function(self, inputs):
axes.append(self.axes[i])

if self.normalize:
x1 = l2_normalize(x1, axis=axes[0])
x2 = l2_normalize(x2, axis=axes[1])
x1 = normalize(x1, axis=axes[0])
x2 = normalize(x2, axis=axes[1])
output = batch_dot(x1, x2, axes)
return output

Expand Down
13 changes: 5 additions & 8 deletions keras_core/layers/normalization/spectral_normalization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from keras_core import backend
from keras_core import initializers
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers import Wrapper
from keras_core.layers.input_spec import InputSpec
from keras_core.utils.numerical_utils import normalize


@keras_core_export("keras_core.layers.SpectralNormalization")
Expand Down Expand Up @@ -96,10 +98,10 @@ def normalize_weights(self):
# check for zeroes weights
if not all([w == 0.0 for w in weights]):
for _ in range(self.power_iterations):
vector_v = self._l2_normalize(
ops.matmul(vector_u, ops.transpose(weights))
vector_v = normalize(
ops.matmul(vector_u, ops.transpose(weights)), axis=None
)
vector_u = self._l2_normalize(ops.matmul(vector_v, weights))
vector_u = normalize(ops.matmul(vector_v, weights), axis=None)
# vector_u = tf.stop_gradient(vector_u)
# vector_v = tf.stop_gradient(vector_v)
sigma = ops.matmul(
Expand All @@ -113,11 +115,6 @@ def normalize_weights(self):
)
)

def _l2_normalize(self, x):
square_sum = ops.sum(ops.square(x), keepdims=True)
x_inv_norm = 1 / ops.sqrt(ops.maximum(square_sum, 1e-12))
return ops.multiply(x, x_inv_norm)

def get_config(self):
config = {"power_iterations": self.power_iterations}
base_config = super().get_config()
Expand Down
5 changes: 1 addition & 4 deletions pip_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
package = "keras_core"
build_directory = "tmp_build_dir"
dist_directory = "dist"
to_copy = [
"setup.py",
"README.md"
]
to_copy = ["setup.py", "README.md"]


def ignore_files(_, filenames):
Expand Down

0 comments on commit 8cc8ed3

Please sign in to comment.