diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py index 326508e68..3f8c22193 100644 --- a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py @@ -4,7 +4,6 @@ from modelscope.msdatasets.ms_dataset import MsDataset from modelscope.utils.constant import DownloadMode -from typing import Optional class ASRDataset(MsDataset): @@ -37,19 +36,18 @@ def load(cls, train_set='train', dev_set='validation', download_mode: Optional[DownloadMode] = None): - if download_mode is not None: - ds_dict = MsDataset.load( - dataset_name=dataset_name, namespace=namespace, download_mode=download_mode) - return ds_dict - else: - if os.path.exists(dataset_name): + if os.path.exists(dataset_name): + if download_mode != DownloadMode.FORCE_REDOWNLOAD: data_dir = dataset_name ds_dict = {} ds_dict['train'] = cls.load_core(data_dir, train_set) ds_dict['validation'] = cls.load_core(data_dir, dev_set) ds_dict['raw_data_dir'] = data_dir - return ds_dict else: ds_dict = MsDataset.load( - dataset_name=dataset_name, namespace=namespace) - return ds_dict + dataset_name=dataset_name, namespace=namespace, download_mode=download_mode) + else: + ds_dict = MsDataset.load( + dataset_name=dataset_name, namespace=namespace, download_mode=download_mode) + return ds_dict +