Skip to content

Commit

Permalink
Merge branch 'keras-team:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
doncarlos999 authored Jan 9, 2025
2 parents 4007769 + 97c1c00 commit 34954b6
Show file tree
Hide file tree
Showing 23 changed files with 502 additions and 121 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Keras 3: Deep Learning for Humans

Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, and PyTorch.
Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, PyTorch, and OpenVINO (for inference-only).
Effortlessly build and train models for computer vision, natural language processing, audio processing,
timeseries forecasting, recommender systems, etc.

Expand Down Expand Up @@ -73,7 +73,7 @@ python pip_build.py --install
## Configuring your backend

You can export the environment variable `KERAS_BACKEND` or you can edit your local config file at `~/.keras/keras.json`
to configure your backend. Available backend options are: `"tensorflow"`, `"jax"`, `"torch"`. Example:
to configure your backend. Available backend options are: `"tensorflow"`, `"jax"`, `"torch"`, `"openvino"`. Example:

```
export KERAS_BACKEND="jax"
Expand All @@ -91,6 +91,10 @@ import keras
**Note:** The backend must be configured before importing `keras`, and the backend cannot be changed after
the package has been imported.

**Note:** The OpenVINO backend is an inference-only backend, meaning it is designed only for running model
predictions using `model.predict()` method.
To use `openvino` backend, install the required dependencies from the `requirements-openvino.txt` file.

## Backwards compatibility

Keras 3 is intended to work as a drop-in replacement for `tf.keras` (when using the TensorFlow backend). Just take your
Expand Down
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
12 changes: 8 additions & 4 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def distribute_tensor(tensor, layout):
return global_value


def distribute_data_input(per_process_batch, layout):
def distribute_data_input(per_process_batch, layout, batch_dim_name):
"""Distribute the input data with the corresponding layout.
Note that the inputs here is a local worker batch. Within the local worker,
Expand All @@ -117,9 +117,13 @@ def distribute_data_input(per_process_batch, layout):
if not isinstance(layout, jax.sharding.Sharding):
layout = _to_jax_layout(layout)

mesh_shape = list(layout.mesh.shape.values())
num_model_replicas_total = mesh_shape[0] # batch dimension of the mesh
mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1
num_model_replicas_total = layout.mesh.shape[batch_dim_name]

mesh_model_dim_size = 1
for name, dim_size in layout.mesh.shape.items():
if not name == batch_dim_name:
mesh_model_dim_size *= dim_size

num_model_replicas_per_process = num_model_replicas_total / num_processes()
per_process_batch_size = per_process_batch.shape[0]

Expand Down
4 changes: 3 additions & 1 deletion keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def test_distribute_data_input(self):
mesh, jax.sharding.PartitionSpec("batch", None)
)

result = backend_dlib.distribute_data_input(per_process_batch, layout)
result = backend_dlib.distribute_data_input(
per_process_batch, layout, "batch"
)

# Check the shape of the global batch array
self.assertEqual(
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/jax/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def stateful_fn(*args, **kwargs):
self._tf_trackable.non_trainable_variables,
non_trainable_variables,
):
var.assign(new_value)
var.assign(tf.cast(new_value, var.dtype))
return output

stateful_fn.__signature__ = inspect.Signature(
Expand Down
8 changes: 6 additions & 2 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import itertools
from functools import partial

import jax
import numpy as np
Expand Down Expand Up @@ -988,15 +989,18 @@ def _get_jax_state(

def _distribute_data(data, layouts=None):
distribution = distribution_lib.distribution()

if distribution is not None:
if layouts is None:
layouts = tree.map_structure(
lambda d: distribution.get_data_layout(d.shape),
data,
)
return tree.map_structure(
jax_distribution_lib.distribute_data_input, data, layouts
jax_dist_data_input = partial(
jax_distribution_lib.distribute_data_input,
batch_dim_name=distribution.batch_dim_name,
)
return tree.map_structure(jax_dist_data_input, data, layouts)

return tree.map_structure(jax.device_put, data)

Expand Down
20 changes: 12 additions & 8 deletions keras/src/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,9 @@ class Distribution:
device_mesh: A `DeviceMesh` instance.
"""

def __init__(self, device_mesh):
def __init__(self, device_mesh, batch_dim_name=None):
self._device_mesh = device_mesh
self._batch_dim_name = batch_dim_name

def get_data_layout(self, data_shape):
"""Retrieve the `TensorLayout` for the input data.
Expand Down Expand Up @@ -341,6 +342,10 @@ def scope(self):
def device_mesh(self):
return self._device_mesh

@property
def batch_dim_name(self):
return self._batch_dim_name

def distribute_dataset(self, dataset):
"""Create a distributed dataset instance from the original user dataset.
Expand Down Expand Up @@ -395,7 +400,6 @@ def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True):
else:
self._initialize_mesh_from_list_devices()

self._batch_dim_name = self.device_mesh.axis_names[0]
# Those following attributes might get convert to public methods.
self._num_process = distribution_lib.num_processes()
self._process_id = distribution_lib.process_id()
Expand All @@ -408,7 +412,7 @@ def _initialize_with_device_mesh(self, device_mesh):
"Expect `mesh` to be an instance of `DeviceMesh`. "
f"Received: mesh={device_mesh} (of type {type(device_mesh)})"
)
super().__init__(device_mesh)
super().__init__(device_mesh, device_mesh.axis_names[0])
if self.device_mesh.devices.ndim != 1:
warnings.warn(
"Expect the input mesh to be 1D, but received "
Expand All @@ -424,7 +428,7 @@ def _initialize_mesh_from_devices(self, devices):
axis_names=[DEFAULT_BATCH_DIM_NAME],
devices=devices,
)
super().__init__(device_mesh)
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)

def _initialize_mesh_from_list_devices(self):
devices = np.array(list_devices())
Expand All @@ -433,11 +437,11 @@ def _initialize_mesh_from_list_devices(self):
axis_names=[DEFAULT_BATCH_DIM_NAME],
devices=devices,
)
super().__init__(device_mesh)
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)

def get_data_layout(self, data_shape):
data_shard_spec = [None] * len(data_shape)
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
Expand Down Expand Up @@ -590,7 +594,7 @@ def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs):

def get_data_layout(self, data_shape):
data_shard_spec = [None] * len(data_shape)
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
Expand Down Expand Up @@ -631,7 +635,7 @@ def distribute_dataset(self, dataset):
# Note that this might be smaller than one if model replicas are sharded
# across multiple processes.
mesh_batch_dim_index = self.device_mesh.axis_names.index(
self._batch_dim_name
self.batch_dim_name
)
num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index]
if num_model_replicas == 1:
Expand Down
6 changes: 3 additions & 3 deletions keras/src/distribution/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_create_with_device_mesh(self):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["data"])
self.assertEqual(distribution._batch_dim_name, "data")
self.assertEqual(distribution.batch_dim_name, "data")

self.assertFalse(distribution._is_multi_process)
self.assertEqual(distribution._process_id, 0)
Expand All @@ -197,7 +197,7 @@ def test_create_with_devices(self):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["batch"])
self.assertEqual(distribution._batch_dim_name, "batch")
self.assertEqual(distribution.batch_dim_name, "batch")

@mock.patch.object(
distribution_lib,
Expand All @@ -211,7 +211,7 @@ def test_create_with_list_devices(self, mock_list_devices):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["batch"])
self.assertEqual(distribution._batch_dim_name, "batch")
self.assertEqual(distribution.batch_dim_name, "batch")

def test_get_data_layout(self):
distribution = distribution_lib.DataParallel(
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
10 changes: 4 additions & 6 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,10 @@ class Layer(BackendLayer, Operation, KerasSaveable):
trainable: Boolean, whether the layer's variables should be trainable.
name: String name of the layer.
dtype: The dtype of the layer's computations and weights. Can also be a
`keras.DTypePolicy`,
which allows the computation and
weight dtype to differ. Defaults to `None`. `None` means to use
`keras.config.dtype_policy()`,
which is a `float32` policy unless set to different value
(via `keras.config.set_dtype_policy()`).
`keras.DTypePolicy`, which allows the computation and weight dtype
to differ. Defaults to `None`. `None` means to use
`keras.config.dtype_policy()`, which is a `float32` policy unless
set to different value (via `keras.config.set_dtype_policy()`).
Attributes:
name: The name of the layer (string).
Expand Down
Loading

0 comments on commit 34954b6

Please sign in to comment.