Skip to content

Commit

Permalink
fix enable_cache (#2813)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Dec 31, 2024
1 parent 58be39e commit 054ae1a
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 47 deletions.
2 changes: 1 addition & 1 deletion docs/source/Instruction/ReleaseNote3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

1. 数据集模块重构。数据集加载速度提升2-20倍,encode速度提升2-4倍,支持streaming模式
- 移除了dataset_name机制,采用dataset_id、dataset_dir、dataset_path方式指定数据集
- 使用`--dataset_num_proc`支持多进程加速处理、使用`--load_from_cache_file true`支持使用数据前处理缓存
- 使用`--dataset_num_proc`支持多进程加速处理
- 使用`--streaming`支持流式加载hub端和本地数据集
- 支持`--packing`命令以获得更稳定的训练效率
- 指定`--dataset <dataset_dir>`支持本地加载开源数据集
Expand Down
2 changes: 1 addition & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
- data_seed: 数据集随机种子,默认为42
- 🔥dataset_num_proc: 数据集预处理的进程数,默认为1
- 🔥streaming: 流式读取并处理数据集,默认False
- load_from_cache_file: 数据集预处理使用cache,默认False
- enable_cache: 数据集预处理使用cache,默认False
- 注意: 如果改为True,在数据集有更改时可能无法生效,如果修改本参数发现训练不正常请考虑设置为False
- download_mode: 数据集下载模式,包含`reuse_dataset_if_exists``force_redownload`,默认为reuse_dataset_if_exists
- strict: 如果为True,则数据集只要某行有问题直接抛错,否则会丢弃出错行。默认False
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ The introduction to command line parameters will cover base arguments, atomic ar
- data_seed: Random seed for the dataset, default is 42.
- 🔥dataset_num_proc: Number of processes for dataset preprocessing, default is 1.
- 🔥streaming: Stream read and process the dataset, default is False.
- load_from_cache_file: Use cache for dataset preprocessing, default is False.
- enable_cache: Use cache for dataset preprocessing, default is False.
- Note: If set to True, it may not take effect if the dataset changes. If modifying this parameter leads to issues during training, consider setting it to False.
- download_mode: Dataset download mode, including `reuse_dataset_if_exists` and `force_redownload`, default is reuse_dataset_if_exists.
- strict: If True, the dataset will throw an error if any row has a problem; otherwise, it will discard the erroneous row. Default is False.
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/ReleaseNote3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

1. Dataset module refactoring. The dataset loading speed has improved by 2-20 times, and encoding speed has improved by 2-4 times, with support for streaming mode.
- Removed the dataset_name mechanism; now use dataset_id, dataset_dir, or dataset_path to specify the dataset.
- Use `--dataset_num_proc` to support multi-process acceleration and `--load_from_cache_file true` to support cache processing before using the data.
- Use `--dataset_num_proc` to support multi-process acceleration.
- Use `--streaming` to support streaming loading of hub and local datasets.
- Support `--packing` command for more stable training efficiency.
- Use `--dataset <dataset_dir>` to support local loading of open-source datasets.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@ CUDA_VISIBLE_DEVICES=0 \
MAX_PIXELS=1003520 \
swift sft \
--model Qwen/Qwen2-VL-7B-Instruct \
--dataset 'modelscope/coco_2014_caption#20000' \
--train_type lora \
--dataset 'swift/OK-VQA_train#1000' \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-4 \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--freeze_vit true \
--gradient_accumulation_steps 16 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 5
--save_total_limit 5 \
--logging_steps 5 \
--max_length 2048 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4
15 changes: 12 additions & 3 deletions examples/train/multimodal/dpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,24 @@ MAX_PIXELS=1003520 \
swift rlhf \
--rlhf_type dpo \
--model Qwen/Qwen2-VL-7B-Instruct \
--train_type lora \
--dataset swift/RLAIF-V-Dataset \
--train_type lora \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-4 \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--freeze_vit true \
--gradient_accumulation_steps $(expr 16 / $nproc_per_node) \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--save_total_limit 5 \
--deepspeed zero3 \
--logging_steps 5
--logging_steps 5 \
--max_length 2048 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4
13 changes: 11 additions & 2 deletions examples/train/multimodal/grounding.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@ CUDA_VISIBLE_DEVICES=0 \
MAX_PIXELS=1003520 \
swift sft \
--model Qwen/Qwen2-VL-7B-Instruct \
--train_type lora \
--dataset 'swift/refcoco:grounding#1000' \
--train_type lora \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-4 \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--freeze_vit true \
--gradient_accumulation_steps 16 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 5
--logging_steps 5 \
--max_length 2048 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4
9 changes: 6 additions & 3 deletions swift/llm/argument/base_args/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dataclasses import dataclass, field
from typing import List, Literal, Optional

from datasets import enable_caching

from swift.llm import DATASET_MAPPING, register_dataset_info
from swift.utils import get_logger

Expand All @@ -20,7 +22,7 @@ class DataArguments:
data_seed (Optional[int]): Seed for dataset shuffling. Default is None.
dataset_num_proc (int): Number of processes to use for data loading and preprocessing. Default is 1.
streaming (bool): Flag to enable streaming of datasets. Default is False.
load_from_cache_file (bool): Flag to load dataset from cache file. Default is False.
enable_cache (bool): Flag to load dataset from cache file. Default is False.
download_mode (Literal): Mode for downloading datasets. Default is 'reuse_dataset_if_exists'.
model_name (List[str]): List containing Chinese and English names of the model. Default is [None, None].
model_author (List[str]): List containing Chinese and English names of the model author.
Expand All @@ -38,7 +40,7 @@ class DataArguments:
dataset_num_proc: int = 1
streaming: bool = False

load_from_cache_file: bool = False
enable_cache: bool = False
download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists'
strict: bool = False
# Chinese name and English name
Expand All @@ -58,6 +60,8 @@ def _init_custom_dataset_info(self):
def __post_init__(self):
if self.data_seed is None:
self.data_seed = self.seed
if self.enable_cache:
enable_caching()
if len(self.val_dataset) > 0 or self.streaming:
self.split_dataset_ratio = 0.
if len(self.val_dataset) > 0:
Expand All @@ -74,7 +78,6 @@ def get_dataset_kwargs(self):
'streaming': self.streaming,
'use_hf': self.use_hf,
'hub_token': self.hub_token,
'load_from_cache_file': self.load_from_cache_file,
'download_mode': self.download_mode,
'strict': self.strict,
'model_name': self.model_name,
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import datasets.fingerprint
from datasets import disable_caching

from swift.utils.torch_utils import _find_local_mac
from . import dataset
Expand All @@ -27,3 +28,4 @@ def _update_fingerprint_mac(*args, **kwargs):
datasets.fingerprint.update_fingerprint = _update_fingerprint_mac
datasets.arrow_dataset.update_fingerprint = _update_fingerprint_mac
register_dataset_info()
disable_caching()
20 changes: 4 additions & 16 deletions swift/llm/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def _load_dataset_path(dataset_meta: DatasetMeta,
*,
num_proc: int = 1,
strict: bool = False,
load_from_cache_file: bool = False,
streaming: bool = False) -> HfDataset:
dataset_path = dataset_meta.dataset_path

Expand All @@ -176,8 +175,7 @@ def _load_dataset_path(dataset_meta: DatasetMeta,
kwargs['na_filter'] = False
dataset = hf_load_dataset(file_type, data_files=dataset_path, **kwargs)

dataset = dataset_meta.preprocess_func(
dataset, num_proc=num_proc, strict=strict, load_from_cache_file=load_from_cache_file)
dataset = dataset_meta.preprocess_func(dataset, num_proc=num_proc, strict=strict)
dataset = DatasetLoader._remove_useless_columns(dataset)
return dataset

Expand All @@ -191,7 +189,6 @@ def _load_repo_dataset(
use_hf: Optional[bool] = None,
hub_token: Optional[str] = None,
strict: bool = False,
load_from_cache_file: bool = False,
revision: Optional[str] = None,
download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists',
) -> HfDataset:
Expand Down Expand Up @@ -244,8 +241,7 @@ def _load_repo_dataset(
dataset = dataset._hf_ds
if streaming and isinstance(dataset, HfDataset):
dataset = dataset.to_iterable_dataset()
dataset = subset.preprocess_func(
dataset, num_proc=num_proc, strict=strict, load_from_cache_file=load_from_cache_file)
dataset = subset.preprocess_func(dataset, num_proc=num_proc, strict=strict)
dataset = DatasetLoader._remove_useless_columns(dataset)
datasets.append(dataset)
return DatasetLoader._concat_datasets(datasets, streaming)
Expand Down Expand Up @@ -278,7 +274,6 @@ def post_process(
split_dataset_ratio: float = 0.,
streaming: bool = False,
random_state: Optional[np.random.RandomState] = None,
load_from_cache_file: bool = False,
) -> Tuple[DATASET_TYPE, Optional[DATASET_TYPE]]:
"""Split into train/val datasets and perform dataset sampling."""
assert dataset_sample is None or dataset_sample > 0
Expand Down Expand Up @@ -318,8 +313,7 @@ def post_process(
train_sample = dataset_sample - val_sample
assert train_sample > 0
train_dataset, val_dataset = train_dataset.train_test_split(
test_size=val_sample, seed=get_seed(random_state),
load_from_cache_file=load_from_cache_file).values()
test_size=val_sample, seed=get_seed(random_state)).values()
train_dataset = sample_dataset(train_dataset, train_sample, random_state)
return train_dataset, val_dataset

Expand All @@ -342,7 +336,6 @@ def load(
use_hf: Optional[bool] = None,
hub_token: Optional[str] = None,
strict: bool = False,
load_from_cache_file: bool = False,
download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists',
) -> HfDataset:

Expand All @@ -351,7 +344,6 @@ def load(
dataset_meta=dataset_meta,
num_proc=num_proc,
strict=strict,
load_from_cache_file=load_from_cache_file,
streaming=streaming,
)
else:
Expand All @@ -373,7 +365,6 @@ def load(
hub_token=hub_token,
num_proc=num_proc,
strict=strict,
load_from_cache_file=load_from_cache_file,
revision=revision,
streaming=streaming,
download_mode=download_mode)
Expand Down Expand Up @@ -407,7 +398,6 @@ def load_dataset(
use_hf: Optional[bool] = None,
hub_token: Optional[str] = None,
strict: bool = False,
load_from_cache_file: bool = False,
download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists',
# self-cognition
model_name: Union[Tuple[str, str], List[str], None] = None, # zh, en
Expand All @@ -417,7 +407,6 @@ def load_dataset(
Args:
download_mode: Download mode, default is `reuse_dataset_if_exists`.
load_from_cache_file: Use cache file or not, Default False.
strict: Raise if any row is not correct.
hub_token: The token of the hub.
use_hf: Use hf dataset or ms dataset.
Expand All @@ -444,7 +433,6 @@ def load_dataset(
'num_proc': num_proc,
'use_hf': use_hf,
'strict': strict,
'load_from_cache_file': load_from_cache_file,
'download_mode': download_mode,
'streaming': streaming,
'hub_token': hub_token
Expand All @@ -461,7 +449,7 @@ def load_dataset(
split_dataset_ratio=split_dataset_ratio,
random_state=seed,
streaming=streaming,
load_from_cache_file=load_from_cache_file)
)
if train_dataset is not None:
train_datasets.append(train_dataset)
if val_dataset is not None:
Expand Down
6 changes: 2 additions & 4 deletions swift/llm/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def __call__(
*,
num_proc: int = 1,
strict: bool = False,
load_from_cache_file: bool = False,
batch_size: int = 1000,
) -> DATASET_TYPE:
from ..utils import sample_dataset
Expand All @@ -258,7 +257,7 @@ def __call__(
dataset = self._cast_pil_image(dataset)
map_kwargs = {}
if isinstance(dataset, HfDataset):
map_kwargs.update({'num_proc': num_proc, 'load_from_cache_file': load_from_cache_file})
map_kwargs.update({'num_proc': num_proc})
with self._patch_arrow_writer():
try:
dataset_mapped = dataset.map(
Expand Down Expand Up @@ -462,9 +461,8 @@ def __call__(
*,
num_proc: int = 1,
strict: bool = False,
load_from_cache_file: bool = False,
) -> DATASET_TYPE:
dataset = get_features_dataset(dataset)
dataset = dataset.rename_columns(self.columns_mapping)
preprocessor = self._get_preprocessor(dataset)
return preprocessor(dataset, num_proc=num_proc, load_from_cache_file=load_from_cache_file, strict=strict)
return preprocessor(dataset, num_proc=num_proc, strict=strict)
15 changes: 3 additions & 12 deletions swift/llm/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ def _prepare_callbacks(self):

def _stat_dataset(self, dataset: HfDataset):
args = self.args
dataset = GetLengthPreprocessor()(
dataset, num_proc=args.dataset_num_proc, load_from_cache_file=args.load_from_cache_file)
dataset = GetLengthPreprocessor()(dataset, num_proc=args.dataset_num_proc)
_, stat_str = stat_array(dataset['length'])
logger.info(f'Dataset Token Length: {stat_str}')
return stat_str
Expand All @@ -243,17 +242,9 @@ def _encode_dataset(self, train_dataset, val_dataset):
else:
preprocessor_cls = PackingPreprocessor if args.packing else EncodePreprocessor
preprocessor = preprocessor_cls(template=template)
train_dataset = preprocessor(
train_dataset,
num_proc=args.dataset_num_proc,
strict=args.strict,
load_from_cache_file=args.load_from_cache_file)
train_dataset = preprocessor(train_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
if val_dataset is not None and not args.predict_with_generate:
val_dataset = preprocessor(
val_dataset,
num_proc=args.dataset_num_proc,
strict=args.strict,
load_from_cache_file=args.load_from_cache_file)
val_dataset = preprocessor(val_dataset, num_proc=args.dataset_num_proc, strict=args.strict)

inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset))
template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {})
Expand Down

0 comments on commit 054ae1a

Please sign in to comment.