From 67222e25ad1e0eda7184aea77f0e0feec60652a9 Mon Sep 17 00:00:00 2001 From: Yian Zhang <42804517+YianZhang@users.noreply.github.com> Date: Sat, 4 Feb 2023 15:06:43 +0800 Subject: [PATCH] Machine Translation, WMT14, and BLEU (#1329) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implemented the first machine translation scenario WMT14 and the adaptor/metric/etc. it relies on. Fixes #1303 ### Usage `venv/bin/helm-run -r wmt_14:model=openai/davinci,language_pair=fr-en,max_train_instances=1 -m {MAX_EVAL_INSTANCES} -t {NUM_TRAIN_TRIALS} --suite 1` ### Results | Translation Direction | Ours (mean) | Ours (max) | GPT-3 Paper | | ------------- | ------------- | ------------- | ------------- | | fr->en | 34.0 | 37.1 | 33.7 | | en->fr | 25.5 | 31.1 | 28.3 | - The metric is sacrebleu. - The setting is 1-shot in-context learning. We use the same prompt used in the GPT-3 paper. - Our results are the mean/max over 5 training trials on 100 test examples; it is unclear how many trials OpenAI does or how many test examples they include. They likely use the whole test set, which include ~3000 examples. - OpenAI uses beam search for decoding while we use greedy decoding. --------- Co-authored-by: 张逸安 Yian --- src/helm/benchmark/__init__.py | 2 + .../metrics/machine_translation_metrics.py | 36 +++++++ src/helm/benchmark/run_specs.py | 59 ++++++++++++ .../benchmark/scenarios/wmt_14_scenario.py | 96 +++++++++++++++++++ 4 files changed, 193 insertions(+) create mode 100644 src/helm/benchmark/metrics/machine_translation_metrics.py create mode 100644 src/helm/benchmark/scenarios/wmt_14_scenario.py diff --git a/src/helm/benchmark/__init__.py b/src/helm/benchmark/__init__.py index 5e694b0d9f..e1cff2bb01 100644 --- a/src/helm/benchmark/__init__.py +++ b/src/helm/benchmark/__init__.py @@ -43,6 +43,7 @@ from .scenarios import entity_data_imputation_scenario # noqa from .scenarios import big_bench_scenario # noqa from .scenarios import pubmed_qa_scenario # noqa +from .scenarios import wmt_14_scenario # noqa # Metrics from .metrics import basic_metrics # noqa @@ -56,6 +57,7 @@ from .metrics import summarization_metrics # noqa from .metrics import toxicity_metrics # noqa from .metrics import tokens_metric # noqa +from .metrics import machine_translation_metrics # noqa # Perturbations for data augmentation from .augmentations.extra_space_perturbation import ExtraSpacePerturbation # noqa diff --git a/src/helm/benchmark/metrics/machine_translation_metrics.py b/src/helm/benchmark/metrics/machine_translation_metrics.py new file mode 100644 index 0000000000..82e8867703 --- /dev/null +++ b/src/helm/benchmark/metrics/machine_translation_metrics.py @@ -0,0 +1,36 @@ +from typing import List +from sacrebleu import BLEU + +from helm.benchmark.adaptation.request_state import RequestState +from .metric import Metric +from .metric_name import MetricName +from .statistic import Stat + + +class MachineTranslationMetric(Metric): + """ + Compute the BLEU score for Machine Translation scenarios. The implementation is based on sacrebleu. + """ + + def evaluate_instances(self, request_states: List[RequestState]) -> List[Stat]: + """ + Compute the corpus-level metric based on all reqeust_states. + """ + + bleu = BLEU() + + refs: List[List[str]] = [[]] + sys: List = [] + for request_state in request_states: + # Assume there is one referece per instance. TODO: Support multiple references after adding more scenarios. + num_references: int = len(request_state.instance.references) + if num_references != 1: + raise ValueError(f"This instance has {num_references} references, but we currently only support one.") + # Usually there is only one completion for each instance. + assert request_state.result is not None + if len(request_state.result.completions) != 1: + raise ValueError("Each request result should have only exactly one completion.") + sys.append(request_state.result.completions[0].text) + refs[0].append(request_state.instance.references[0].output.text) + bleu_score = bleu.corpus_score(sys, refs).score + return [Stat(MetricName("bleu")).add(bleu_score)] diff --git a/src/helm/benchmark/run_specs.py b/src/helm/benchmark/run_specs.py index 4b2cfb449c..ade4378942 100644 --- a/src/helm/benchmark/run_specs.py +++ b/src/helm/benchmark/run_specs.py @@ -314,6 +314,27 @@ def get_summarization_adapter_spec(num_sents: int, **kwargs) -> AdapterSpec: ) +def get_machine_translation_adapter_spec( + source_language, target_language, max_train_instances, **kwargs +) -> AdapterSpec: + """ + Used for machine translation. + """ + return AdapterSpec( + method=ADAPT_GENERATION, + instructions=f"Translate {source_language} to {target_language}:", + input_prefix="", + input_suffix=" = ", + output_prefix="", + output_suffix="\n", + max_train_instances=max_train_instances, + num_outputs=1, + stop_sequences=["\n\n"], + temperature=0.0, + **kwargs, + ) + + ############################################################ # Examples of scenario and adapter specs @@ -489,6 +510,12 @@ def get_code_metric_specs(dataset: str, timeout: float) -> List[MetricSpec]: return [MetricSpec(class_name="helm.benchmark.code_metrics.APPSMetric", args=args)] +def get_machine_translation_metric_specs() -> List[MetricSpec]: + return [ + MetricSpec(class_name="helm.benchmark.machine_translation_metrics.MachineTranslationMetric", args={}) + ] + get_basic_metric_specs([]) + + ############################################################ # Run specs @@ -1634,6 +1661,37 @@ def get_lex_glue_spec(subset: str) -> RunSpec: ) +def get_wmt_14_spec(language_pair: str, max_train_instances: int = 1) -> RunSpec: + FULL_LANGUAGE_NAMES = { + "cs": "Czech", + "de": "German", + "fr": "French", + "hi": "Hindi", + "ru": "Russian", + "en": "English", + } + source_language, target_language = language_pair.split("-") + + scenario_spec = ScenarioSpec( + class_name="helm.benchmark.scenarios.wmt_14_scenario.WMT14Scenario", + args={"source_language": source_language, "target_language": target_language}, + ) + + adapter_spec = get_machine_translation_adapter_spec( + source_language=FULL_LANGUAGE_NAMES[source_language], + target_language=FULL_LANGUAGE_NAMES[target_language], + max_train_instances=max_train_instances, + ) + + return RunSpec( + name=f"wmt_14:language_pair={language_pair}", + scenario_spec=scenario_spec, + adapter_spec=adapter_spec, + metric_specs=get_machine_translation_metric_specs(), + groups=["wmt_14"], + ) + + ############################################################ CANONICAL_RUN_SPEC_FUNCS: Dict[str, Callable[..., RunSpec]] = { @@ -1683,6 +1741,7 @@ def get_lex_glue_spec(subset: str) -> RunSpec: "pubmed_qa": get_pubmed_qa_spec, "lextreme": get_lextreme_spec, "lex_glue": get_lex_glue_spec, + "wmt_14": get_wmt_14_spec, } diff --git a/src/helm/benchmark/scenarios/wmt_14_scenario.py b/src/helm/benchmark/scenarios/wmt_14_scenario.py new file mode 100644 index 0000000000..6d6e38d7b6 --- /dev/null +++ b/src/helm/benchmark/scenarios/wmt_14_scenario.py @@ -0,0 +1,96 @@ +from typing import List, Any +from datasets import load_dataset +from helm.common.hierarchical_logger import hlog +from .scenario import Scenario, Instance, Reference, TRAIN_SPLIT, VALID_SPLIT, TEST_SPLIT, CORRECT_TAG, Input, Output + + +MAX_TRAIN_INSTANCES = 20_000 # This is arbitrary, but 20,000 training examples should be enough. + + +class WMT14Scenario(Scenario): + """ + The 2014 Workshop on Statistical Machine Translation: + https://aclanthology.org/W14-3302.pdf + + The scenario consists of 5 subsets, each of which is a parallel corpus between English and another language. The + non-English languages include Czech, German, French, Hindi, and Russian. + + For each language pair, the validation and test set each includes around 3,000 examples, while the training set is + usually much larger. We therefore randomly downsample the training set to speedup data processing. + + Task prompt structure: + + Translate {source_language} to {target_language}: + {Hypothesis} = {Reference} + + Example from WMT14 Fr-En: + + Hypothesis: Assemblée générale + Reference: General Assembly + """ + + name = "WMT_14" + description = "Scenario for the 2014 Workshop on Statistical Machine Translation" + tags = ["machine_translation"] + + def __init__(self, source_language, target_language): + super().__init__() + valid_languages = set(["cs", "de", "fr", "hi", "ru", "en"]) + self.source_language = source_language + self.target_language = target_language + if self.source_language not in valid_languages or self.target_language not in valid_languages: + raise ValueError("WMT14 only includes the following languages: cs, de, fr, hi, ru, en.") + if self.source_language == self.target_language: + raise ValueError("The source language and the target language should be different.") + if self.source_language != "en" and self.target_language != "en": + raise ValueError("One of the languages should be English.") + + def _deduplicate(self, dataset: List): + """ + Remove instances in the dataset with the same label. + """ + + deduplicated_dataset = [] + seen_labels = set() + for example in dataset: + if example[self.target_language] not in seen_labels: + seen_labels.add(example[self.target_language]) + deduplicated_dataset.append(example) + return deduplicated_dataset + + def get_instances(self) -> List[Instance]: + hlog("Loading the HuggingFace dataset. The first time could take several minutes.") + subset_name = f"{self.source_language if self.source_language!='en' else self.target_language}-en" + hf_dataset: Any = load_dataset("wmt14", subset_name) + splits = {"train": TRAIN_SPLIT, "validation": VALID_SPLIT, "test": TEST_SPLIT} + + instances: List[Instance] = [] + hlog("Generating instances") + # Some training sets are too large, so we will only take a random subset of it. + hf_dataset["train"] = hf_dataset["train"].shuffle(seed=42)[:MAX_TRAIN_INSTANCES] + hf_dataset["train"]["translation"] = self._deduplicate(hf_dataset["train"]["translation"]) + for example in hf_dataset["train"]["translation"]: + source_sentence: str = example[self.source_language] + target_sentence: str = example[self.target_language] + instances.append( + Instance( + input=Input(text=source_sentence), + references=[Reference(Output(text=target_sentence), tags=[CORRECT_TAG])], + split="train", + ) + ) + + # No special handling needed for validation or test. + for split_name in ["validation", "test"]: + split = splits[split_name] + for example in hf_dataset[split_name]: + source_sentence = example["translation"][self.source_language] + target_sentence = example["translation"][self.target_language] + instances.append( + Instance( + input=Input(text=source_sentence), + references=[Reference(Output(text=target_sentence), tags=[CORRECT_TAG])], + split=split, + ) + ) + return instances