Skip to content

Commit

Permalink
Apply ruff flake8-comprehensions (huggingface#21694)
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylion007 authored Feb 22, 2023
1 parent df06fb1 commit 5e8c8eb
Show file tree
Hide file tree
Showing 230 changed files with 971 additions and 955 deletions.
14 changes: 6 additions & 8 deletions examples/flax/image-captioning/run_image_captioning_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,14 +892,12 @@ def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)

Expand Down
14 changes: 6 additions & 8 deletions examples/flax/language-modeling/run_bart_dlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,14 +756,12 @@ def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)

Expand Down
14 changes: 6 additions & 8 deletions examples/flax/language-modeling/run_clm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,14 +648,12 @@ def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)

Expand Down
14 changes: 6 additions & 8 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,14 +679,12 @@ def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)

Expand Down
14 changes: 6 additions & 8 deletions examples/flax/language-modeling/run_t5_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,14 +791,12 @@ def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)

Expand Down
16 changes: 7 additions & 9 deletions examples/flax/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,12 @@ def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)

Expand Down Expand Up @@ -642,7 +640,7 @@ def prepare_train_features(examples):

return tokenized_examples

processed_raw_datasets = dict()
processed_raw_datasets = {}
if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
Expand Down
14 changes: 6 additions & 8 deletions examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,14 +742,12 @@ def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)

Expand Down
18 changes: 8 additions & 10 deletions examples/flax/text-classification/run_flax_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,12 @@ def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)

Expand Down Expand Up @@ -449,7 +447,7 @@ def main():
):
# Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
if sorted(label_name_to_id.keys()) == sorted(label_list):
logger.info(
f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
"Using it!"
Expand All @@ -458,7 +456,7 @@ def main():
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
"\nIgnoring the model labels as a result.",
)
elif data_args.task_name is None:
Expand Down
14 changes: 6 additions & 8 deletions examples/flax/token-classification/run_flax_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,14 +290,12 @@ def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)

Expand Down
2 changes: 1 addition & 1 deletion examples/legacy/pytorch-lightning/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def main():

# Optionally, predict on dev set and write to output_dir
if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))
model = model.load_from_checkpoint(checkpoints[-1])
return trainer.test(model)

Expand Down
2 changes: 1 addition & 1 deletion examples/legacy/pytorch-lightning/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,6 @@ def add_model_specific_args(parser, root_dir):
# pl use this default format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L322
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))
model = model.load_from_checkpoint(checkpoints[-1])
trainer.test(model)
6 changes: 3 additions & 3 deletions examples/legacy/question-answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,10 +810,10 @@ def main():
logger.info("Loading checkpoints saved during training for evaluation")
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
checkpoints = [
os.path.dirname(c)
for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
]

else:
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
Expand All @@ -830,7 +830,7 @@ def main():
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step)

result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result)

logger.info("Results: {}".format(results))
Expand Down
2 changes: 1 addition & 1 deletion examples/legacy/run_openai_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def tokenize_and_encode(obj):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
elif isinstance(obj, int):
return obj
return list(tokenize_and_encode(o) for o in obj)
return [tokenize_and_encode(o) for o in obj]

logger.info("Encoding dataset...")
train_dataset = load_rocstories_dataset(args.train_dataset)
Expand Down
6 changes: 3 additions & 3 deletions examples/legacy/run_swag.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,9 @@ def main():
checkpoints = [args.model_name_or_path]

if args.eval_all_checkpoints:
checkpoints = list(
checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
]

logger.info("Evaluate the following checkpoints: %s", checkpoints)

Expand All @@ -712,7 +712,7 @@ def main():
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step)

result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result)

logger.info("Results: {}".format(results))
Expand Down
4 changes: 2 additions & 2 deletions examples/legacy/seq2seq/run_distributed_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def eval_data_dir(
if num_return_sequences > 1:
preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq
for i, pred in enumerate(preds):
results.append(dict(pred=pred, id=ids[i].item()))
results.append({"pred": pred, "id": ids[i].item()})
save_json(results, save_path)
return results, sampler.num_replicas

Expand Down Expand Up @@ -232,7 +232,7 @@ def combine_partial_results(partial_results) -> List:
records = []
for partial_result in partial_results:
records.extend(partial_result)
records = list(sorted(records, key=lambda x: x["id"]))
records = sorted(records, key=lambda x: x["id"])
preds = [x["pred"] for x in records]
return preds

Expand Down
2 changes: 1 addition & 1 deletion examples/legacy/seq2seq/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def generate_summaries_or_translations(
fout.close()
runtime = int(time.time() - start_time) # seconds
n_obs = len(examples)
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
return {"n_obs": n_obs, "runtime": runtime, "seconds_per_sample": round(runtime / n_obs, 4)}


def datetime_now():
Expand Down
2 changes: 1 addition & 1 deletion examples/legacy/seq2seq/run_eval_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def parse_search_arg(search):
groups = search.split()
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
entry_names = list(entries.keys())
sets = [list(f"--{k} {v}" for v in vs.split(":")) for k, vs in entries.items()]
sets = [[f"--{k} {v}" for v in vs.split(":")] for k, vs in entries.items()]
matrix = [list(x) for x in itertools.product(*sets)]
return matrix, entry_names

Expand Down
2 changes: 1 addition & 1 deletion examples/legacy/seq2seq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def pickle_save(obj, path):


def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
return list(itertools.chain.from_iterable(summary_ids))


def save_git_info(folder_path: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def train_transforms(batch):
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
)
output_batch["input_values"].append(wav)
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
output_batch["labels"] = list(batch[data_args.label_column_name])

return output_batch

Expand All @@ -303,14 +303,14 @@ def val_transforms(batch):
for audio in batch[data_args.audio_column_name]:
wav = audio["array"]
output_batch["input_values"].append(wav)
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
output_batch["labels"] = list(batch[data_args.label_column_name])

return output_batch

# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = raw_datasets["train"].features[data_args.label_column_name].names
label2id, id2label = dict(), dict()
label2id, id2label = {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
Expand Down
6 changes: 3 additions & 3 deletions examples/pytorch/benchmarking/plot_csv_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def can_convert_to_float(string):
class Plot:
def __init__(self, args):
self.args = args
self.result_dict = defaultdict(lambda: dict(bsz=[], seq_len=[], result={}))
self.result_dict = defaultdict(lambda: {"bsz": [], "seq_len": [], "result": {}})

with open(self.args.csv_file, newline="") as csv_file:
reader = csv.DictReader(csv_file)
Expand Down Expand Up @@ -116,8 +116,8 @@ def plot(self):
axis.set_major_formatter(ScalarFormatter())

for model_name_idx, model_name in enumerate(self.result_dict.keys()):
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
batch_sizes = sorted(set(self.result_dict[model_name]["bsz"]))
sequence_lengths = sorted(set(self.result_dict[model_name]["seq_len"]))
results = self.result_dict[model_name]["result"]

(x_axis_array, inner_loop_array) = (
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/contrastive-image-text/run_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _freeze_params(module):
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(examples):
captions = [caption for caption in examples[caption_column]]
captions = list(examples[caption_column])
text_inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True)
examples["input_ids"] = text_inputs.input_ids
examples["attention_mask"] = text_inputs.attention_mask
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def main():
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
label2id, id2label = dict(), dict()
label2id, id2label = {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
Expand Down
Loading

0 comments on commit 5e8c8eb

Please sign in to comment.