Skip to content

Commit

Permalink
Estimate costs for a given run suite (stanford-crfm#1480)
Browse files Browse the repository at this point in the history
  • Loading branch information
teetone authored Apr 17, 2023
1 parent 7f59fee commit 769f6e0
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 11 deletions.
83 changes: 83 additions & 0 deletions scripts/estimate_cost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import argparse
import json
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict

"""
Given a run suite directory, outputs metrics needed to estimate the cost of running.
Usage:
python3 scripts/estimate_cost.py benchmark_output/runs/<Name of the run suite>
"""


@dataclass
class ModelCost:
total_num_prompt_tokens: int = 0

total_max_num_completion_tokens: int = 0

@property
def total_tokens(self) -> int:
return self.total_num_prompt_tokens + self.total_max_num_completion_tokens

def add_prompt_tokens(self, num_tokens: int):
self.total_num_prompt_tokens += num_tokens

def add_num_completion_tokens(self, num_tokens: int):
self.total_max_num_completion_tokens += num_tokens


class CostCalculator:
def __init__(self, run_suite_path: str):
self._run_suite_path: str = run_suite_path

def aggregate(self) -> Dict[str, ModelCost]:
"""Sums up the estimated number of tokens."""
models_to_costs: Dict[str, ModelCost] = defaultdict(ModelCost)

for run_dir in os.listdir(self._run_suite_path):
run_path: str = os.path.join(self._run_suite_path, run_dir)

if not os.path.isdir(run_path):
continue

run_spec_path: str = os.path.join(run_path, "run_spec.json")
if not os.path.isfile(run_spec_path):
continue

# Extract the model name
with open(run_spec_path) as f:
run_spec = json.load(f)
model: str = run_spec["adapter_spec"]["model"]

metrics_path: str = os.path.join(run_path, "stats.json")
with open(metrics_path) as f:
metrics = json.load(f)

for metric in metrics:
cost: ModelCost = models_to_costs[model]
metric_name: str = metric["name"]["name"]
if metric_name == "num_prompt_tokens":
cost.add_prompt_tokens(metric["sum"])
elif metric_name == "max_num_completion_tokens":
cost.add_num_completion_tokens(metric["sum"])

return models_to_costs


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("run_suite_path", type=str, help="Path to runs folder")
args = parser.parse_args()

calculator = CostCalculator(args.run_suite_path)
model_costs: Dict[str, ModelCost] = calculator.aggregate()
for model_name, model_cost in model_costs.items():
print(
f"{model_name}: Total prompt tokens={model_cost.total_num_prompt_tokens} + "
f"Total max completion tokens={model_cost.total_max_num_completion_tokens} = "
f"{model_cost.total_tokens}"
)
2 changes: 1 addition & 1 deletion src/helm/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from .metrics import ranking_metrics # noqa
from .metrics import summarization_metrics # noqa
from .metrics import toxicity_metrics # noqa
from .metrics import tokens_metric # noqa
from .metrics import dry_run_metrics # noqa
from .metrics import machine_translation_metrics # noqa
from .metrics import summarization_critique_metrics # noqa

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.metrics.statistic import Stat, merge_stat
from helm.benchmark.window_services.window_service import WindowService
from helm.benchmark.window_services.window_service_factory import WindowServiceFactory
from .metric import Metric, MetricResult, PerInstanceStats
from .metric_name import MetricName
from .metric_service import MetricService
Expand Down Expand Up @@ -33,22 +35,24 @@ def process(self, request_state: RequestState) -> List[Stat]:

# Maximum number of tokens in the completions
# This is an overestimate of the actual number of output tokens since sequences can early terminate
# TODO: rename this to max_num_completion_tokens
stats.append(Stat(MetricName("max_num_output_tokens")).add(request.num_completions * request.max_tokens))
stats.append(Stat(MetricName("max_num_completion_tokens")).add(request.num_completions * request.max_tokens))

# Get number of tokens in the prompt
tokenizer: WindowService = WindowServiceFactory.get_window_service(request.model, self.metric_service)
num_prompt_tokens: int = tokenizer.get_num_tokens(request.prompt)
stats.append(Stat(MetricName("num_prompt_tokens")).add(num_prompt_tokens))

return stats


class TokensMetric(Metric):
"""
Estimates the total number of tokens that will be used based on the requests.
"""
class DryRunMetric(Metric):
"""Metrics for dry run."""

def __init__(self):
self.token_cost_estimator = AutoTokenCostEstimator()

def __repr__(self):
return "TokensMetric()"
return "DryRunMetric"

def evaluate(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/helm/benchmark/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from .adaptation.adapter_spec import AdapterSpec
from .data_preprocessor import DataPreprocessor
from .executor import ExecutionSpec, Executor
from .metrics.dry_run_metrics import DryRunMetric
from .metrics.metric_name import MetricName
from .metrics.metric_service import MetricService
from .metrics.metric import Metric, MetricSpec, MetricResult, PerInstanceStats, create_metric, Stat
from .metrics.tokens_metric import TokensMetric
from .window_services.tokenizer_service import TokenizerService


Expand Down Expand Up @@ -166,8 +166,8 @@ def run_one(self, run_spec: RunSpec):
# When performing a dry run, only estimate the number of tokens instead
# of calculating the metrics.
metrics: List[Metric] = (
[] if self.dry_run else [create_metric(metric_spec) for metric_spec in run_spec.metric_specs]
) + [TokensMetric()]
[DryRunMetric()] if self.dry_run else [create_metric(metric_spec) for metric_spec in run_spec.metric_specs]
)
stats: List[Stat] = []
per_instance_stats: List[PerInstanceStats] = []
with htrack_block(f"{len(metrics)} metrics"):
Expand Down

0 comments on commit 769f6e0

Please sign in to comment.