Skip to content

Commit

Permalink
[autochunk] refactor chunk memory estimation (hpcaitech#2762)
Browse files Browse the repository at this point in the history
* refact memory code

* dont log free var memory

* add memory align

* update chunk target

* update setting for new memory

* finish test

* update tracer

* update typo

* update test
  • Loading branch information
oahzxl authored Mar 8, 2023
1 parent b51bfec commit 2ca9728
Show file tree
Hide file tree
Showing 12 changed files with 302 additions and 430 deletions.
42 changes: 24 additions & 18 deletions colossalai/autochunk/autochunk_codegen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, List, Tuple
from typing import Any, Callable, Dict, Iterable, List, Tuple

import torch

Expand Down Expand Up @@ -216,14 +216,13 @@ def _add_node_slice(
return body


def emit_code_with_chunk(
body: List[str],
nodes: Iterable[Node],
emit_node_func,
delete_unused_value_func,
search_chunk: SearchChunk,
chunk_infos: List,
):
def emit_code_with_chunk(body: List[str],
nodes: Iterable[Node],
emit_node_func: Callable,
delete_unused_value_func: Callable,
search_chunk: SearchChunk,
chunk_infos: List,
eval_mem: bool = False):
"""
Emit code with chunk according to chunk_infos.
Expand Down Expand Up @@ -260,6 +259,9 @@ def emit_code_with_chunk(
region_idx = 0
within_chunk_region = False

if eval_mem:
body.append("init_memory = torch.cuda.memory_allocated() / 1024**2\n")

while node_idx < len(node_list):
node = node_list[node_idx]

Expand Down Expand Up @@ -289,10 +291,18 @@ def emit_code_with_chunk(
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names)
if eval_mem:
body.append(
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, chunk_inputs_names)
if eval_mem:
body.append(
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))

# generate chunk region end
if node_idx in chunk_ends:
Expand All @@ -312,8 +322,10 @@ def __init__(self,
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False) -> None:
print_progress: bool = False,
eval_mem: bool = False) -> None:
super().__init__()
self.eval_mem = eval_mem
# find the chunk regions
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
self.chunk_infos = self.search_chunk.search_region()
Expand Down Expand Up @@ -511,14 +523,8 @@ def emit_node(node: Node, body):

# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk(
body,
nodes,
emit_node,
delete_unused_values,
self.search_chunk,
self.chunk_infos,
)
emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
self.eval_mem)

if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
Expand Down
Loading

0 comments on commit 2ca9728

Please sign in to comment.