Skip to content

Commit

Permalink
✨feat:更新脚本调用方式
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyuxin committed Jun 15, 2023
1 parent f39e637 commit 7009154
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 40 deletions.
2 changes: 1 addition & 1 deletion mteb-zh/generate_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ReportMainScore:
task_mapping[TaskType.all] = {k: v for _, mapping in task_mapping for k, v in mapping.items()}


def generate_report_csv(results_dir: Path, task_type: TaskType = TaskType.classification):
def generate_report_csv(results_dir: Path, task_type: TaskType = TaskType.Classification):
scores = defaultdict(list)
output_file: Path = Path(f'mteb-zh-{task_type.value}.csv')

Expand Down
8 changes: 4 additions & 4 deletions mteb-zh/mteb_zh/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@


class TaskType(str, Enum):
classification = 'classification'
reranking = 'reranking'
retrieval = 'retrieval'
all = 'all'
Classification = 'Classification'
Reranking = 'Reranking'
Retrieval = 'Retrieval'
All = 'All'


class TNews(AbsTaskClassification):
Expand Down
56 changes: 21 additions & 35 deletions mteb-zh/run_mteb_zh.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pathlib import Path
from typing import Annotated

import typer
from mteb import MTEB
from mteb import MTEB, AbsTask

from mteb_zh.models import load_model, ModelType
from mteb_zh.tasks import (
Expand All @@ -17,47 +18,32 @@
)


default_tasks: list[AbsTask] = [
TYQSentiment(),
TNews(),
JDIphone(),
StockComSentiment(),
GubaEastmony(),
IFlyTek(),
T2RReranking(2),
T2RRetrieval(10000),
]


def main(
model_type: ModelType,
model_type: Annotated[ModelType, typer.Option()],
model_name: str | None = None,
task_type: TaskType = TaskType.classification,
task_type: TaskType = TaskType.Classification,
output_folder: Path = Path('results'),
):
output_folder = Path(output_folder)
model = load_model(model_type, model_name)

match task_type:
case TaskType.classification:
tasks = [
TYQSentiment(),
TNews(),
JDIphone(),
StockComSentiment(),
GubaEastmony(),
IFlyTek(),
]
case TaskType.reranking:
tasks = [
T2RReranking(2),
]
case TaskType.retrieval:
tasks = [
T2RRetrieval(10000),
]
case TaskType.all:
tasks = [
TYQSentiment(),
TNews(),
JDIphone(),
StockComSentiment(),
GubaEastmony(),
IFlyTek(),
T2RReranking(2),
T2RRetrieval(100000),
]
case _:
raise ValueError(f'Unknown task type: {task_type}')
print(tasks)
if task_type is TaskType.All:
tasks = default_tasks
else:
tasks = [task for task in default_tasks if task.description['type'] == task_type.value]

evaluation = MTEB(tasks=tasks)
evaluation.run(model, output_folder=str(output_folder / model_type.value))

Expand Down

0 comments on commit 7009154

Please sign in to comment.