Skip to content

Commit

Permalink
fix stream bugs (modelscope#1794)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Aug 22, 2024
1 parent ab8476c commit 3d5a07c
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 16 deletions.
4 changes: 2 additions & 2 deletions swift/llm/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
is_ddp_plus_mp, is_dist, is_master, plot_images, seed_everything, show_layers)
from .sft import _get_train_val_dataset
from .tuner import prepare_model
from .utils import (TEMPLATE_MAPPING, RLHFArguments, Template, get_dataset, get_model_tokenizer, get_template,
get_time_info, set_generation_config)
from .utils import (TEMPLATE_MAPPING, RLHFArguments, Template, get_model_tokenizer, get_template, get_time_info,
set_generation_config)

logger = get_logger()

Expand Down
19 changes: 8 additions & 11 deletions swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import datasets.fingerprint
import json
import numpy as np
import pandas as pd
from datasets import Dataset as HfDataset
from datasets import IterableDataset as HfIterableDataset
from datasets import concatenate_datasets, interleave_datasets
Expand All @@ -20,7 +19,7 @@
from tqdm.auto import tqdm
from transformers.utils import strtobool

from swift.utils import get_logger, get_seed, is_dist, is_local_master, read_from_jsonl, transform_jsonl_to_df
from swift.utils import get_logger, get_seed, is_dist, is_local_master
from swift.utils.torch_utils import _find_local_mac
from .media import MediaCache, MediaTag
from .preprocess import (AlpacaPreprocessor, ClsPreprocessor, ComposePreprocessor, ConversationsPreprocessor,
Expand Down Expand Up @@ -316,9 +315,6 @@ def load_ms_dataset(dataset_id: str,
use_hf: bool = False,
streaming: bool = False,
revision: Optional[str] = None) -> Optional[DATASET_TYPE]:
if not use_hf:
from modelscope import MsDataset

if subset_split_list is None or len(subset_split_list) == 0:
return None
dataset_list = []
Expand All @@ -338,6 +334,7 @@ def load_ms_dataset(dataset_id: str,
except Exception:
raise
else:
from modelscope import MsDataset
if is_dist() and not is_local_master():
force_redownload = False
else:
Expand Down Expand Up @@ -429,8 +426,8 @@ def _post_preprocess(
streaming_buffer_size = kwargs.get('streaming_buffer_size', 16384)
if streaming_val_size > 0:
train_dataset = train_dataset.shuffle(seed=get_seed(random_state), buffer_size=streaming_buffer_size)
val_dataset = dataset.take(int(streaming_val_size))
train_dataset = dataset.skip(int(streaming_val_size))
val_dataset = train_dataset.take(int(streaming_val_size))
train_dataset = train_dataset.skip(int(streaming_val_size))

res = []
for dataset in [train_dataset, val_dataset]:
Expand Down Expand Up @@ -2556,10 +2553,10 @@ def generate_example(dataset):
model_n, model_a = model_name[0], model_author[0]
else:
model_n, model_a = model_name[1], model_author[1]
yield {
'query': d['query'].replace('{{NAME}}', model_n).replace('{{AUTHOR}}', model_a),
'response': d['response'].replace('{{NAME}}', model_n).replace('{{AUTHOR}}', model_a)
}
yield {
'query': d['query'].replace('{{NAME}}', model_n).replace('{{AUTHOR}}', model_a),
'response': d['response'].replace('{{NAME}}', model_n).replace('{{AUTHOR}}', model_a)
}

dataset = HfIterableDataset.from_generator(generate_example, gen_kwargs={'dataset': dataset})
else:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from .media import MediaTag
from .template import History

PreprocessFunc = Callable[[HfDataset], HfDataset]
dataset_enable_cache = strtobool(os.environ.get('DATASET_ENABLE_CACHE', 'False'))

DATASET_TYPE = Union[HfDataset, HfIterableDataset]
PreprocessFunc = Callable[[DATASET_TYPE], DATASET_TYPE]

logger = get_logger()

Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def add_default_tags(self, example: Dict[str, Any]) -> None:
example[media_key] = [m for m in example[media_key] if m]
num_media = len(example[media_key])
num_new_tags = num_media - num_media_tags
assert num_new_tags >= 0, (f'Number of media: {num_media}, number of media_tags: {num_media_tags}')
assert num_new_tags >= 0, f'Number of media: {num_media}, number of media_tags: {num_media_tags}'
if history:
history[0][0] = media_tag * num_new_tags + history[0][0]
else:
Expand Down
1 change: 0 additions & 1 deletion swift/trainers/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def train(self, *args, **kwargs) -> torch.Tensor:
@staticmethod
def stat_dataset(llm_dataset, is_encoder_decoder: bool = False) -> Any:
_token_len = []
from datasets import Dataset as HfDataset
from swift.utils.np_utils import stat_array
if isinstance(llm_dataset, HfDataset):
prompt_input_ids = llm_dataset['prompt_input_ids']
Expand Down

0 comments on commit 3d5a07c

Please sign in to comment.