forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Experiment new layout map API for keras models.
Please see the layout_map docstring for how the new API works. The new approach give user just a scope, which user should use to create their DTensor model. There is some assumptions for the code happens in the context: 1. All the weights will be convert to a lazyinitVariable. They will be replaced to DVariable at different time by different type of model. 2. Any model created within the scope will have the layout_map attached to them, and the map will be used to create DVariable for the model. If there is no model created in scope, then the LazyInitVariable will never be converted. 3. For subclass model, since the weights are created when the first time __call__ is invoked, we inject the __call__ to first init the variable with lazyinitVariable, and then map to DVariable. In this case, the layout_map_scope actually does nothing when user create the subclass model, since the weights are not created yet. The scope only allow the model fetch the layout_map, which can be inject to model.__init__ as well. But for API simplification purpose, we consolidate into just one API. 4. For functional/sequential model, since their weights are created eagerly, the DVariable creation happens at the init_graph_network. The scope approach is mostly used for this case, since the variable creation happens before functional.__init__. It will be too late if we inject the layout_map at __init__. 5. The DVariable creation has some special logic for disabling lazy_variable_scope, which was causing issue for functional model. The variable initializer usually uses tf.random.Generator under the hood. It will create the stateVar when init, and will be convert to a LazyInitVariable if the init happens in the scope. We would like to disable the scope for that case, since the init should happen with the tf.function on a dtensor device scope. The stateVar will be created as DVariable. PiperOrigin-RevId: 432248522
- Loading branch information
1 parent
ee8117b
commit 2e664be
Showing
8 changed files
with
140 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.