Skip to content

Commit

Permalink
exposing net_transformer_fun before add grad (pytorch#11003)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#11003

Need a interface to re-write the graph after the net is built and after adding gradient ops.

Reviewed By: aazzolini, harouwu

Differential Revision: D9557827

fbshipit-source-id: 2e082f0321c0776e488a29e18047d950948e7c37
  • Loading branch information
wat3rBro authored and facebook-github-bot committed Aug 29, 2018
1 parent bed9d41 commit dbce1c8
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions caffe2/python/data_parallel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def Parallelize(
param_update_builder_fun=None,
optimizer_builder_fun=None,
post_sync_builder_fun=None,
pre_grad_net_transformer_fun=None,
net_transformer_fun=None,
devices=None,
rendezvous=None,
Expand Down Expand Up @@ -91,6 +92,11 @@ def Parallelize(
Signature:
net_transformer_fun(
model, num_devices, device_prefix, device_type)
pre_grad_net_transformer_fun:
Optional function to transform the network similar to
net_transformer_fun, but happens before gradient ops
been add.
Signature: pre_grad_net_transformer_fun(model)
post_sync_builder_fun:
Function applied after initial parameter sync has been
completed, such as keeping multi-precision parameters
Expand Down Expand Up @@ -234,6 +240,9 @@ def Parallelize(
model_helper_obj._computed_param_names =\
list(viewkeys(computed_params_grouped))

if pre_grad_net_transformer_fun:
pre_grad_net_transformer_fun(model_helper_obj)

if has_parameter_updates:
log.info("Adding gradient operators")
_AddGradientOperators(devices, model_helper_obj, losses_by_gpu)
Expand Down

0 comments on commit dbce1c8

Please sign in to comment.