Skip to content

Commit

Permalink
Reuse cudagraph input and output tensor memory
Browse files Browse the repository at this point in the history
Signed-off-by: Robin Zhang <[email protected]>
  • Loading branch information
buptzyb committed Nov 25, 2024
1 parent ae393e8 commit dabfa19
Showing 1 changed file with 109 additions and 19 deletions.
128 changes: 109 additions & 19 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def _make_graphed_callables(
allow_unused_input: bool = False,
fp8_weight_caching: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
reuse_graph_inputs=False,
reuse_graph_outputs=False,
_order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
) -> SingleOrTuple[Callable]:
Expand Down Expand Up @@ -89,6 +91,17 @@ def _make_graphed_callables(
callables = (callables,)
sample_args = (sample_args,)
sample_kwargs = (sample_kwargs,)
if reuse_graph_inputs:
len_args = len(sample_args[0])
for arg in sample_args:
assert len_args == len(arg), f"Arguments must have same length and shape for reusing."
sample_args = list(sample_args)
len_kwargs = len(sample_kwargs[0])
for kwarg in sample_kwargs:
assert len_kwargs == len(
kwarg
), f"Keyword arguments must have same length and shape for reusing."
sample_kwargs = list(sample_kwargs)

# Check sizes of args
if _order is None:
Expand Down Expand Up @@ -280,22 +293,81 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
for c_id in _order:
# Following variables are for input/output reusing to save memory.
fwd_order_recorder = {}
fwd_order_accu = 0
per_callable_fwd_idx_recorder = []
static_grad_outputs = None
static_grad_inputs = []
static_grad_inputs_exists = False
for idx, c_id in enumerate(_order):
if c_id > 0:
if reuse_graph_inputs or reuse_graph_outputs:
# Record the fwd order pattern for input data reusing.
if c_id in fwd_order_recorder:
fwd_order_recorder[c_id].append(fwd_order_accu)
else:
fwd_order_recorder[c_id] = [fwd_order_accu]
fwd_order_accu += 1
if idx > 1 and _order[idx - 1] < 0:
# It can use the tensor buffer of a previous one.
reuse_fwd_idx = fwd_order_recorder[abs(_order[idx - 1])].pop(0)

# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
m_chunk = c_id - 1
for l_no in range(num_layers):
func = callables[m_chunk * num_layers + l_no]
per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + (
fwd_idx[m_chunk] * num_layers + l_no
)
if reuse_graph_inputs or reuse_graph_outputs:
per_callable_fwd_idx_recorder.append(per_callable_fwd_idx)
if idx > 1 and _order[idx - 1] < 0:
# It can use the tensor buffer of a previous one.
reuse_per_callable_fwd_idx = per_callable_fwd_idx_recorder[
reuse_fwd_idx * num_layers + l_no
]
if reuse_graph_inputs:
sample_args[per_callable_fwd_idx] = sample_args[
reuse_per_callable_fwd_idx
]
sample_kwargs[per_callable_fwd_idx] = sample_kwargs[
reuse_per_callable_fwd_idx
]
flatten_sample_args[per_callable_fwd_idx] = flatten_sample_args[
reuse_per_callable_fwd_idx
]
per_callable_static_input_surfaces[per_callable_fwd_idx] = (
per_callable_static_input_surfaces[reuse_per_callable_fwd_idx][
: len(flatten_sample_args[per_callable_fwd_idx])
]
+ per_callable_static_input_surfaces[per_callable_fwd_idx][
len(flatten_sample_args[per_callable_fwd_idx]) :
]
)
if reuse_graph_outputs:
static_outputs = per_callable_static_outputs[
reuse_per_callable_fwd_idx
]
detached_static_outputs = tuple(
so.detach() for so in static_outputs
)
args = sample_args[per_callable_fwd_idx]
kwargs = sample_kwargs[per_callable_fwd_idx]
fwd_graph = fwd_graphs[per_callable_fwd_idx]
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args, **kwargs)
flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
flatten_outputs, spec = _tree_flatten(outputs)
if reuse_graph_outputs and idx > 1 and _order[idx - 1] < 0:
for i, static_output in enumerate(detached_static_outputs):
static_output.copy_(flatten_outputs[i])
per_callable_static_outputs[per_callable_fwd_idx] = (
detached_static_outputs
)
else:
per_callable_static_outputs[per_callable_fwd_idx] = tuple(
flatten_outputs
)
per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec
graph_callables[per_callable_fwd_idx] = func
fwd_idx[m_chunk] += 1
Expand All @@ -310,9 +382,10 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx]
# For now, assumes all static_outputs require grad
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if not reuse_graph_inputs or static_grad_outputs is None:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
Expand All @@ -321,21 +394,29 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
only_inputs=True,
allow_unused=allow_unused_input,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
static_grad_inputs.append(None) # type: ignore[arg-type]
static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
if not reuse_graph_outputs:
static_grad_inputs = []
grad_idx = 0
for input_idx, arg in enumerate(static_input_surface):
if arg.requires_grad:
if reuse_graph_outputs and static_grad_inputs_exists:
if static_grad_inputs[input_idx] is not None:
static_grad_inputs[input_idx].copy_(grad_inputs[grad_idx])
else:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
elif not reuse_graph_outputs or not static_grad_inputs_exists:
static_grad_inputs.append(None) # type: ignore[arg-type]
if reuse_graph_outputs:
static_grad_inputs_exists = True

per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs
per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs
per_callable_static_grad_inputs[per_callable_bwd_idx] = tuple(
static_grad_inputs
)
bwd_idx[m_chunk] += 1
else:
# Capture forward graphs
Expand Down Expand Up @@ -599,6 +680,8 @@ def make_graphed_callables(
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
reuse_graph_inputs: bool = False,
reuse_graph_outputs: bool = False,
fp8_enabled: bool = False,
fp8_calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
Expand Down Expand Up @@ -629,6 +712,11 @@ def make_graphed_callables(
and outputs are disconnected in compute graph.
sample_kwargs: (tuple of) dict, optional
Keyword arguments to callable(s)
reuse_graph_inputs: bool, default = `False`
Whether or not to reuse input data buffer between graphs to save memory usage.
reuse_graph_outputs: bool, default = `False`
Whether or not to reuse output data buffer between graphs to save memory
usage. Reusing output data buffer will inevitably cause extra DtoD data copy.
pool: (tuple of) int, default = `None`, optional
An instance returned from function `torch.cuda.graph_pool_handle` that hints
this graph may share memory with the indicated pool.
Expand Down Expand Up @@ -714,6 +802,8 @@ def forward_func(*args, **kwargs):
allow_unused_input=allow_unused_input,
fp8_weight_caching=fp8_weight_caching,
sample_kwargs=sample_kwargs,
reuse_graph_inputs=reuse_graph_inputs,
reuse_graph_outputs=reuse_graph_outputs,
_order=_order,
pool=pool,
)
Expand Down

0 comments on commit dabfa19

Please sign in to comment.