Skip to content

Commit

Permalink
update (modelscope#1689)
Browse files Browse the repository at this point in the history
  • Loading branch information
hjh0119 authored Aug 13, 2024
1 parent bb208d2 commit 5297d8e
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 10 deletions.
5 changes: 4 additions & 1 deletion docs/source/LLM/ORPO算法最佳实践.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ swift内置了处理方法将`answer_zh`作为`response`,将`answer_en`作为`re
# Experimental environment: A100
# DDP + MP
# Memory usage: 4*24G
nproc_per_node=2

CUDA_VISIBLE_DEVICES=0,1,2,3 \
NPROC_PER_NODE=2 \
NPROC_PER_NODE=$nproc_per_node \
MASTER_PORT=29500 \
swift rlhf \
--rlhf_type orpo \
--model_type llama3-8b-instruct \
Expand Down
5 changes: 4 additions & 1 deletion docs/source/Multi-Modal/人类偏好对齐训练文档.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,11 @@ swift rlhf \
--save_total_limit 2

# DDP + MP
nproc_per_node=2

CUDA_VISIBLE_DEVICES=0,1,2,3 \
NPROC_PER_NODE=2 \
NPROC_PER_NODE=$nproc_per_node \
MASTER_PORT=29500 \
swift rlhf \
--rlhf_type dpo \
--model_type llava1_6-mistral-7b-instruct \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,11 @@ swift rlhf \

# DDP + MP
# Memory usage: 4*24G
nproc_per_node=2

CUDA_VISIBLE_DEVICES=0,1,2,3 \
NPROC_PER_NODE=2 \
NPROC_PER_NODE=$nproc_per_node \
MASTER_PORT=29500 \
swift rlhf \
--rlhf_type dpo \
--model_type llama3-8b-instruct \
Expand Down
5 changes: 4 additions & 1 deletion docs/source_en/LLM/ORPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ Swift has built-in methods for processing this dataset, using `answer_zh` as `re
# Experimental environment: A100
# DDP + MP
# Memory usage: 4*24G
nproc_per_node=2

CUDA_VISIBLE_DEVICES=0,1,2,3 \
NPROC_PER_NODE=2 \
NPROC_PER_NODE=$nproc_per_node \
MASTER_PORT=29500 \
swift rlhf \
--rlhf_type orpo \
--model_type llama3-8b-instruct \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,11 @@ swift rlhf \
--save_total_limit 2

# DDP + MP
nproc_per_node=2

CUDA_VISIBLE_DEVICES=0,1,2,3 \
NPROC_PER_NODE=2 \
NPROC_PER_NODE=$nproc_per_node \
MASTER_PORT=29500 \
swift rlhf \
--rlhf_type dpo \
--model_type llava1_6-mistral-7b-instruct \
Expand Down
2 changes: 1 addition & 1 deletion swift/trainers/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .utils import build_tokenized_answer, patch_trl, sort_by_max_length

logger = get_logger()
patch_trl()


class CPOTrainer(PushToMsHubMixin, SwiftMixin, HFCPOTrainer):
Expand All @@ -21,6 +20,7 @@ def __init__(self, *args, template: Template, test_oom_error=False, **kwargs):
kwargs.pop('gamma', None)
self.streaming = kwargs.pop('streaming')
is_vision = kwargs.pop('is_vision')
patch_trl(is_vision)
self.keys = [] # keys appears in tokenize_row
self.column_names = list(next(iter(kwargs.get('train_dataset'))).keys())
self.need_filter: bool = False
Expand Down
2 changes: 1 addition & 1 deletion swift/trainers/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .utils import build_tokenized_answer, patch_trl, sort_by_max_length

logger = get_logger()
patch_trl()


class DPOTrainer(PushToMsHubMixin, SwiftMixin, HFDPOTrainer):
Expand All @@ -22,6 +21,7 @@ def __init__(self, *args, template: Template, sft_beta=0., test_oom_error=False,
self.sft_beta = sft_beta
self.streaming = kwargs.pop('streaming')
is_vision = kwargs.pop('is_vision')
patch_trl(is_vision)
self.keys = [] # keys appears in tokenize_row
self.column_names = list(next(iter(kwargs.get('train_dataset'))).keys())
self.need_filter: bool = False
Expand Down
2 changes: 1 addition & 1 deletion swift/trainers/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .utils import build_tokenized_answer, patch_trl, sort_by_max_length

logger = get_logger()
patch_trl()


class ORPOTrainer(PushToMsHubMixin, SwiftMixin, HFORPOTrainer):
Expand All @@ -20,6 +19,7 @@ def __init__(self, *args, template: Template, test_oom_error=False, **kwargs):
self.template = template
self.streaming = kwargs.pop('streaming')
is_vision = kwargs.pop('is_vision')
patch_trl(is_vision)
self.keys = []
self.column_names = list(next(iter(kwargs.get('train_dataset'))).keys())
self.need_filter: bool = False
Expand Down
21 changes: 19 additions & 2 deletions swift/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import heapq
import inspect
from functools import partial
from types import FunctionType, MethodType
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -120,7 +121,7 @@ def sort_by_max_length(dataset: HfDataset, num_dataset: int, is_encoder_decoder:
return dataset.select(idx)


def patch_trl():
def patch_trl(is_vision_model: bool = False):
from .callback import DefaultFlowCallbackNew, PrinterCallbackNew, ProgressCallbackNew
from transformers import trainer

Expand All @@ -129,7 +130,10 @@ def patch_trl():
trainer.PrinterCallback = PrinterCallbackNew

# fix encoder-decoder error
patch_datacollator()
if is_vision_model:
patch_datacollator()
patch_dataset_map()

patch_itds_map()


Expand Down Expand Up @@ -235,3 +239,16 @@ def new_map(self, *args, **kwargs):
IterableDataset.map = new_map
IterableDataset._old_map = old_map
# model.forward = MethodType(_patch_ids_map(map_func), IterableDataset)


def patch_dataset_map():
original_map = HfDataset.map
if not hasattr(HfDataset, '_old_map'):

def patched_map(self, function, **kwargs):
if 'writer_batch_size' not in kwargs:
kwargs['writer_batch_size'] = 10
return original_map(self, function, **kwargs)

HfDataset.map = patched_map
HfDataset._old_map = original_map

0 comments on commit 5297d8e

Please sign in to comment.