Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save CUDA Graph memory by reusing input and output tensors #1234

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 109 additions & 19 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def _make_graphed_callables(
allow_unused_input: bool = False,
fp8_weight_caching: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
reuse_graph_inputs=False,
reuse_graph_outputs=False,
_order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
) -> SingleOrTuple[Callable]:
Expand Down Expand Up @@ -89,6 +91,17 @@ def _make_graphed_callables(
callables = (callables,)
sample_args = (sample_args,)
sample_kwargs = (sample_kwargs,)
if reuse_graph_inputs:
len_args = len(sample_args[0])
for arg in sample_args:
assert len_args == len(arg), f"Arguments must have same length and shape for reusing."
sample_args = list(sample_args)
len_kwargs = len(sample_kwargs[0])
for kwarg in sample_kwargs:
assert len_kwargs == len(
kwarg
), f"Keyword arguments must have same length and shape for reusing."
sample_kwargs = list(sample_kwargs)

# Check sizes of args
if _order is None:
Expand Down Expand Up @@ -280,22 +293,81 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
for c_id in _order:
# Following variables are for input/output reusing to save memory.
fwd_order_recorder = {}
fwd_order_accu = 0
per_callable_fwd_idx_recorder = []
static_grad_outputs = None
static_grad_inputs = []
static_grad_inputs_exists = False
for idx, c_id in enumerate(_order):
if c_id > 0:
if reuse_graph_inputs or reuse_graph_outputs:
# Record the fwd order pattern for input data reusing.
if c_id in fwd_order_recorder:
fwd_order_recorder[c_id].append(fwd_order_accu)
else:
fwd_order_recorder[c_id] = [fwd_order_accu]
fwd_order_accu += 1
if idx > 1 and _order[idx - 1] < 0:
# It can use the tensor buffer of a previous one.
reuse_fwd_idx = fwd_order_recorder[abs(_order[idx - 1])].pop(0)

# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
m_chunk = c_id - 1
for l_no in range(num_layers):
func = callables[m_chunk * num_layers + l_no]
per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + (
fwd_idx[m_chunk] * num_layers + l_no
)
if reuse_graph_inputs or reuse_graph_outputs:
per_callable_fwd_idx_recorder.append(per_callable_fwd_idx)
if idx > 1 and _order[idx - 1] < 0:
# It can use the tensor buffer of a previous one.
reuse_per_callable_fwd_idx = per_callable_fwd_idx_recorder[
reuse_fwd_idx * num_layers + l_no
]
if reuse_graph_inputs:
sample_args[per_callable_fwd_idx] = sample_args[
reuse_per_callable_fwd_idx
]
sample_kwargs[per_callable_fwd_idx] = sample_kwargs[
reuse_per_callable_fwd_idx
]
flatten_sample_args[per_callable_fwd_idx] = flatten_sample_args[
reuse_per_callable_fwd_idx
]
per_callable_static_input_surfaces[per_callable_fwd_idx] = (
per_callable_static_input_surfaces[reuse_per_callable_fwd_idx][
: len(flatten_sample_args[per_callable_fwd_idx])
]
+ per_callable_static_input_surfaces[per_callable_fwd_idx][
len(flatten_sample_args[per_callable_fwd_idx]) :
]
)
if reuse_graph_outputs:
static_outputs = per_callable_static_outputs[
reuse_per_callable_fwd_idx
]
detached_static_outputs = tuple(
so.detach() for so in static_outputs
)
args = sample_args[per_callable_fwd_idx]
kwargs = sample_kwargs[per_callable_fwd_idx]
fwd_graph = fwd_graphs[per_callable_fwd_idx]
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args, **kwargs)
flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
flatten_outputs, spec = _tree_flatten(outputs)
if reuse_graph_outputs and idx > 1 and _order[idx - 1] < 0:
for i, static_output in enumerate(detached_static_outputs):
static_output.copy_(flatten_outputs[i])
per_callable_static_outputs[per_callable_fwd_idx] = (
detached_static_outputs
)
else:
per_callable_static_outputs[per_callable_fwd_idx] = tuple(
flatten_outputs
)
per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec
graph_callables[per_callable_fwd_idx] = func
fwd_idx[m_chunk] += 1
Expand All @@ -310,9 +382,10 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx]
# For now, assumes all static_outputs require grad
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if not reuse_graph_inputs or static_grad_outputs is None:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
Expand All @@ -321,21 +394,29 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
only_inputs=True,
allow_unused=allow_unused_input,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
static_grad_inputs.append(None) # type: ignore[arg-type]
static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
if not reuse_graph_outputs:
static_grad_inputs = []
grad_idx = 0
for input_idx, arg in enumerate(static_input_surface):
if arg.requires_grad:
if reuse_graph_outputs and static_grad_inputs_exists:
if static_grad_inputs[input_idx] is not None:
static_grad_inputs[input_idx].copy_(grad_inputs[grad_idx])
else:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
elif not reuse_graph_outputs or not static_grad_inputs_exists:
static_grad_inputs.append(None) # type: ignore[arg-type]
if reuse_graph_outputs:
static_grad_inputs_exists = True

per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs
per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs
per_callable_static_grad_inputs[per_callable_bwd_idx] = tuple(
static_grad_inputs
)
bwd_idx[m_chunk] += 1
else:
# Capture forward graphs
Expand Down Expand Up @@ -599,6 +680,8 @@ def make_graphed_callables(
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
reuse_graph_inputs: bool = False,
reuse_graph_outputs: bool = False,
fp8_enabled: bool = False,
fp8_calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
Expand Down Expand Up @@ -629,6 +712,11 @@ def make_graphed_callables(
and outputs are disconnected in compute graph.
sample_kwargs: (tuple of) dict, optional
Keyword arguments to callable(s)
reuse_graph_inputs: bool, default = `False`
Whether or not to reuse input data buffer between graphs to save memory usage.
reuse_graph_outputs: bool, default = `False`
Whether or not to reuse output data buffer between graphs to save memory
usage. Reusing output data buffer will inevitably cause extra DtoD data copy.
pool: (tuple of) int, default = `None`, optional
An instance returned from function `torch.cuda.graph_pool_handle` that hints
this graph may share memory with the indicated pool.
Expand Down Expand Up @@ -714,6 +802,8 @@ def forward_func(*args, **kwargs):
allow_unused_input=allow_unused_input,
fp8_weight_caching=fp8_weight_caching,
sample_kwargs=sample_kwargs,
reuse_graph_inputs=reuse_graph_inputs,
reuse_graph_outputs=reuse_graph_outputs,
_order=_order,
pool=pool,
)
Expand Down