Skip to content

Commit

Permalink
Remove unused test_dataset (huggingface#34516)
Browse files Browse the repository at this point in the history
  • Loading branch information
thisisiron authored Nov 5, 2024
1 parent 663c851 commit 45b0c76
Showing 1 changed file with 0 additions and 35 deletions.
35 changes: 0 additions & 35 deletions examples/pytorch/contrastive-image-text/run_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,6 @@ class DataTrainingArguments:
default=None,
metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
)
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input testing data file (a jsonlines file)."},
)
max_seq_length: Optional[int] = field(
default=128,
metadata={
Expand Down Expand Up @@ -190,9 +186,6 @@ def __post_init__(self):
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."


dataset_name_mapping = {
Expand Down Expand Up @@ -315,9 +308,6 @@ def main():
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
dataset = load_dataset(
extension,
data_files=data_files,
Expand Down Expand Up @@ -387,8 +377,6 @@ def _freeze_params(module):
column_names = dataset["train"].column_names
elif training_args.do_eval:
column_names = dataset["validation"].column_names
elif training_args.do_predict:
column_names = dataset["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return
Expand Down Expand Up @@ -490,29 +478,6 @@ def filter_corrupt_images(examples):
# Transform images on the fly as doing it on the whole dataset takes too much time.
eval_dataset.set_transform(transform_images)

if training_args.do_predict:
if "test" not in dataset:
raise ValueError("--do_predict requires a test dataset")
test_dataset = dataset["test"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(test_dataset), data_args.max_eval_samples)
test_dataset = test_dataset.select(range(max_eval_samples))

test_dataset = test_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
)
test_dataset = test_dataset.map(
function=tokenize_captions,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[col for col in column_names if col != image_column],
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on test dataset",
)

# Transform images on the fly as doing it on the whole dataset takes too much time.
test_dataset.set_transform(transform_images)

# 8. Initialize our trainer
trainer = Trainer(
model=model,
Expand Down

0 comments on commit 45b0c76

Please sign in to comment.