diff --git a/mteb-zh/generate_report.py b/mteb-zh/generate_report.py index cf0d5c7..2d36d96 100644 --- a/mteb-zh/generate_report.py +++ b/mteb-zh/generate_report.py @@ -34,7 +34,7 @@ class ReportMainScore: task_mapping[TaskType.all] = {k: v for _, mapping in task_mapping for k, v in mapping.items()} -def generate_report_csv(results_dir: Path, task_type: TaskType = TaskType.classification): +def generate_report_csv(results_dir: Path, task_type: TaskType = TaskType.Classification): scores = defaultdict(list) output_file: Path = Path(f'mteb-zh-{task_type.value}.csv') diff --git a/mteb-zh/mteb_zh/tasks.py b/mteb-zh/mteb_zh/tasks.py index 9a0f3f1..d2d26bd 100644 --- a/mteb-zh/mteb_zh/tasks.py +++ b/mteb-zh/mteb_zh/tasks.py @@ -16,10 +16,10 @@ class TaskType(str, Enum): - classification = 'classification' - reranking = 'reranking' - retrieval = 'retrieval' - all = 'all' + Classification = 'Classification' + Reranking = 'Reranking' + Retrieval = 'Retrieval' + All = 'All' class TNews(AbsTaskClassification): diff --git a/mteb-zh/run_mteb_zh.py b/mteb-zh/run_mteb_zh.py index 2d93715..090dbf2 100644 --- a/mteb-zh/run_mteb_zh.py +++ b/mteb-zh/run_mteb_zh.py @@ -1,7 +1,8 @@ from pathlib import Path +from typing import Annotated import typer -from mteb import MTEB +from mteb import MTEB, AbsTask from mteb_zh.models import load_model, ModelType from mteb_zh.tasks import ( @@ -17,47 +18,32 @@ ) +default_tasks: list[AbsTask] = [ + TYQSentiment(), + TNews(), + JDIphone(), + StockComSentiment(), + GubaEastmony(), + IFlyTek(), + T2RReranking(2), + T2RRetrieval(10000), +] + + def main( - model_type: ModelType, + model_type: Annotated[ModelType, typer.Option()], model_name: str | None = None, - task_type: TaskType = TaskType.classification, + task_type: TaskType = TaskType.Classification, output_folder: Path = Path('results'), ): output_folder = Path(output_folder) model = load_model(model_type, model_name) - match task_type: - case TaskType.classification: - tasks = [ - TYQSentiment(), - TNews(), - JDIphone(), - StockComSentiment(), - GubaEastmony(), - IFlyTek(), - ] - case TaskType.reranking: - tasks = [ - T2RReranking(2), - ] - case TaskType.retrieval: - tasks = [ - T2RRetrieval(10000), - ] - case TaskType.all: - tasks = [ - TYQSentiment(), - TNews(), - JDIphone(), - StockComSentiment(), - GubaEastmony(), - IFlyTek(), - T2RReranking(2), - T2RRetrieval(100000), - ] - case _: - raise ValueError(f'Unknown task type: {task_type}') - print(tasks) + if task_type is TaskType.All: + tasks = default_tasks + else: + tasks = [task for task in default_tasks if task.description['type'] == task_type.value] + evaluation = MTEB(tasks=tasks) evaluation.run(model, output_folder=str(output_folder / model_type.value))