Skip to content

Commit

Permalink
Batch padding for DDP (h2oai#611)
Browse files Browse the repository at this point in the history
* implementation

* fix

* noqa

* f

* c

* c

* c
  • Loading branch information
psinger authored Feb 8, 2024
1 parent a7050b3 commit 19d0ce2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ Defines the padding quantile H2O LLM Studio uses to select the maximum token len
- Lowering the quantile can significantly increase training runtime and reduce memory usage in unevenly distributed sequence lengths but can hurt performance
- The setting depends on the batch size and should be adjusted accordingly
- No padding is done in inference, and the selected **Max Length** is guaranteed
- Setting to 0 disables padding
- Setting to 0 disables padding
- In case of distributed training, the quantile will be calculated across all GPUs
46 changes: 20 additions & 26 deletions llm_studio/src/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler
from llm_studio.src.utils.exceptions import LLMDataException
from llm_studio.src.utils.gpu_utils import sync_across_processes
from llm_studio.src.utils.utils import PatchedAttribute, set_seed

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -656,33 +657,26 @@ def batch_padding(
return batch
elif training and cfg.tokenizer.padding_quantile < 1.0:
if padding_side == "left":
idx = int(
torch.floor(
torch.quantile(
torch.stack(
[
torch.where(batch[mask_key][i] == 1)[0].min()
for i in range(batch[mask_key].size(0))
]
).float(),
1 - cfg.tokenizer.padding_quantile,
)
)
)
lengths = torch.stack(
[
torch.where(batch[mask_key][i] == 1)[0].min()
for i in range(batch[mask_key].size(0))
]
).float()
quantile = 1 - cfg.tokenizer.padding_quantile
else:
idx = int(
torch.ceil(
torch.quantile(
torch.stack(
[
torch.where(batch[mask_key][i] == 1)[0].max()
for i in range(batch[mask_key].size(0))
]
).float(),
cfg.tokenizer.padding_quantile,
)
)
)
lengths = torch.stack(
[
torch.where(batch[mask_key][i] == 1)[0].max()
for i in range(batch[mask_key].size(0))
]
).float()
quantile = cfg.tokenizer.padding_quantile
if cfg.environment._distributed:
lengths = sync_across_processes(
lengths, cfg.environment._world_size
) # type: ignore
idx = int(torch.floor(torch.quantile(lengths, quantile)))
else:
if padding_side == "left":
idx = int(torch.where(batch[mask_key] == 1)[1].min())
Expand Down

0 comments on commit 19d0ce2

Please sign in to comment.