Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use token ids in vllm #153

Merged
merged 30 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
348ca2f
use token ids
AlexPiche Dec 21, 2024
ce463f7
skip empty tokens
AlexPiche Dec 21, 2024
7bc2e83
fix attention_mask
AlexPiche Dec 21, 2024
b591885
fix try except
AlexPiche Dec 21, 2024
582a0e7
use complete
AlexPiche Dec 22, 2024
51107e9
revert to chat complete
AlexPiche Dec 22, 2024
06876e2
rm rm_leading_white_space
AlexPiche Dec 22, 2024
210a933
no max batch tokens
AlexPiche Dec 22, 2024
03325c8
fix typo
AlexPiche Dec 22, 2024
cda2e18
do not encode empty token
AlexPiche Dec 22, 2024
0653aeb
128_max_num_batched_tokens
AlexPiche Dec 22, 2024
2ec847f
128 max tokens
AlexPiche Dec 22, 2024
d58158f
max num seqs
AlexPiche Dec 22, 2024
7963a03
--max-num-seqs: 64
AlexPiche Dec 22, 2024
f4a781d
max num seqs 128
AlexPiche Dec 23, 2024
dbc0c0b
256 max num tokens
AlexPiche Dec 23, 2024
f74016d
easier log prob
AlexPiche Dec 23, 2024
91a1886
64 max num seqs
AlexPiche Dec 23, 2024
aa44c4c
no chuncked prefill
AlexPiche Dec 23, 2024
769dc64
baseline
AlexPiche Dec 23, 2024
51cf53c
max num seqs 64
AlexPiche Dec 23, 2024
3acd833
clean up
AlexPiche Dec 23, 2024
97757b1
max num batched tokens 256
AlexPiche Dec 24, 2024
3808056
clean vllm_args
AlexPiche Dec 27, 2024
7b1290b
add assert and logger info
AlexPiche Dec 27, 2024
7b021f5
Merge remote-tracking branch 'origin/main' into vllm_token_ids
AlexPiche Dec 27, 2024
ed901c2
check if entry input ids is empty
AlexPiche Dec 27, 2024
1a1ae2c
fix test
AlexPiche Dec 27, 2024
2dabb13
simplify finetune preprocessing code
AlexPiche Jan 6, 2025
c3e39db
Merge remote-tracking branch 'origin/main' into vllm_token_ids
AlexPiche Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 27 additions & 59 deletions examples/rl_gsm8k/orchestrate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tqdm import tqdm

import wandb

wandb.require("core")
from .cot_math_agent import (
CoTMathAgent,
Expand All @@ -43,6 +44,7 @@
from tapeagents.batch import batch_main_loop
from tapeagents.core import LLMOutputParsingFailureAction, StepMetadata, TrainingText
from tapeagents.finetune.logging_ import flatten_dict_config, init_wandb
from tapeagents.finetune.data import MASKED_TOKEN_ID
from tapeagents.llms import TrainableLLM
from tapeagents.observe import LLMCall, SQLiteWriterThread, retrieve_all_llm_calls
from tapeagents.orchestrator import main_loop
Expand All @@ -52,8 +54,13 @@

def annotate_traces_with_ref_logprobs(agent: CoTMathAgent, trace: TrainingText, strict: bool) -> TrainingText | None:
try:
ref_logprobs = agent.llm.get_logprobs(trace.prompt_text, trace.output_text) # type: ignore
prompt_token_ids, completion_token_ids = (
trace.input_ids[: -len(trace.logprobs)],
trace.input_ids[-len(trace.logprobs) :],
)
ref_logprobs = agent.llm.get_logprobs(prompt_token_ids, completion_token_ids) # type: ignore
trace.ref_logprobs = [c["logprob"] for c in ref_logprobs["content"]]
assert len(trace.ref_logprobs) == len(trace.logprobs), f"{len(trace.ref_logprobs)} != {len(trace.logprobs)}"
return trace
except Exception as e:
logger.error(f"Failed to get ref logprobs: {e}")
Expand Down Expand Up @@ -111,7 +118,6 @@ def extract_tape_training_samples(
- Dictionary with statistics (reward, steps, success, no_errors)
"""
discarded = []
compute_logprobs = []
tape_prompt_tokens = 0
tape_output_tokens = 0
match cfg.dataset_name:
Expand Down Expand Up @@ -155,49 +161,21 @@ def extract_tape_training_samples(
llm_call = step.metadata.other["llm_call"]
trace = agent.llm.make_training_text(llm_call.prompt, llm_call.output)

hf_tokens = get_tokens_from_hf_tokenizer(agent.llm.tokenizer, llm_call.prompt, llm_call.output)

logprobs = [c["logprob"] for c in llm_call.logprobs]
vllm_tokens = [c["token"] for c in llm_call.logprobs]

# Huggingface tokenizer for Gemma2B adds an extra newline at the end of the chat template.
# Try to detect this and fix.
if len(vllm_tokens) == len(hf_tokens) - 1 and vllm_tokens == hf_tokens[:-1] and hf_tokens[-1] == "\n":
# The last token is a newline, add it to the vLLM tokens
vllm_tokens.append("\n")
logprobs.append(-20.0)

# Note: tokens produced during generation are not always the same as the tokens produced on the full sequence
if vllm_tokens != hf_tokens:
# the online vLLM tokenizer does not agree with the HF tokenizer
try:
new_logprobs_dict = agent.llm.get_logprobs(trace.prompt_text, trace.output_text) # type: ignore
new_logprobs = [c["logprob"] for c in new_logprobs_dict["content"]]
new_vllm_tokens = [c["token"] for c in new_logprobs_dict["content"]]
assert len(new_vllm_tokens) == len(hf_tokens), "Token mismatch"
logprobs = new_logprobs
compute_logprobs.append(1)
except Exception as e:
logger.error(f"Failed to get logprobs: {e}")
discarded.append(1)
continue
else:
compute_logprobs.append(0)
input_ids = [lp["token_id"] for lp in llm_call.logprobs]
labels = [lp["token_id"] for lp in llm_call.logprobs if lp["generated"]]
# MASKED_TOKEN_ID is -100 and is the default "ignore_index" in nn.CrossEntropyLoss,
# see https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
labels = [MASKED_TOKEN_ID] * (len(input_ids) - len(labels)) + labels

trace.input_ids = input_ids
trace.labels = labels

trace.reward = reward
trace.logprobs = logprobs
trace.logprobs = [lp["logprob"] for lp in llm_call.logprobs if lp["generated"]]
trace.group_id = new_tape.metadata.parent_id
tape_prompt_tokens += llm_call.prompt_length_tokens
tape_output_tokens += llm_call.output_length_tokens
if (
len(trace.logprobs) == llm_call.output_length_tokens
and (llm_call.prompt_length_tokens + llm_call.output_length_tokens) < cfg.finetune.seq_length
):
training_samples.append(trace)
discarded.append(0)
else:
logger.debug(f"Discarding trace: {trace.prompt_text} {trace.output_text}")
discarded.append(1)
training_samples.append(trace)

tape_stats = {
"reward": reward,
Expand All @@ -207,8 +185,6 @@ def extract_tape_training_samples(
"discarded": np.mean(discarded) if discarded else 0,
"prompt_tokens": tape_prompt_tokens,
"output_tokens": tape_output_tokens,

"compute_logprobs": np.mean(compute_logprobs) if compute_logprobs else 0,
}
return new_tape, training_samples, tape_stats

Expand Down Expand Up @@ -246,7 +222,6 @@ def generate_training_data(
no_errors_stats = defaultdict(list)
success_stats = defaultdict(list)
discarded_stats = defaultdict(list)
compute_logprobs_stats = defaultdict(list)
training_samples: List[TrainingText] = []

logger.info(f"Starting {cfg.dataset_name} {split_name} main loop")
Expand Down Expand Up @@ -282,7 +257,6 @@ def generate_and_extract_tape_training_samples(
success_stats[new_tape.metadata.parent_id].append(tape_stats["success"])
no_errors_stats[new_tape.metadata.parent_id].append(tape_stats["no_error"])
discarded_stats[new_tape.metadata.parent_id].append(tape_stats["discarded"])
compute_logprobs_stats[new_tape.metadata.parent_id].append(tape_stats["compute_logprobs"])
prompt_tokens += tape_stats["prompt_tokens"]
output_tokens += tape_stats["output_tokens"]

Expand All @@ -303,7 +277,6 @@ def generate_and_extract_tape_training_samples(
f"execution_time/{split_name}_make_data": end_make_data - start_make_data,
f"execution_time/{split_name}_tapes_made_per_second": len(new_tapes) / (end_make_data - start_make_data),
f"{split_name}_discarded": np.mean([np.mean(v) for v in discarded_stats.values()]),
f"{split_name}_compute_logprobs": np.mean([np.mean(v) for v in compute_logprobs_stats.values()]),
f"{split_name}_prompt_tokens": prompt_tokens,
f"{split_name}_output_tokens": output_tokens,
},
Expand Down Expand Up @@ -351,10 +324,6 @@ def main(cfg: DictConfig):
conf_dir = exp_path / "conf"
os.makedirs(conf_dir, exist_ok=True)
finetune_path = exp_path / "finetune"
remove_leading_white_space = True if "deepseek" in cfg.model_path else False
if remove_leading_white_space:
# vLLM sometimes generate a leading white space https://github.com/vllm-project/vllm/issues/3935
logger.info("Removing leading white space from the model. This is necessary for DeepSeek models")

while state["iteration"] < cfg.max_iterations:
logger.info(f"Starting iteration {state['iteration']}")
Expand Down Expand Up @@ -386,8 +355,7 @@ def main(cfg: DictConfig):
parameters=cfg.llm.parameters,
use_cache=False,
collect_logprobs=True,
remove_leading_white_space=remove_leading_white_space,
observe_llm_calls=False
observe_llm_calls=False,
)

test_llm = TrainableLLM(
Expand All @@ -396,8 +364,7 @@ def main(cfg: DictConfig):
tokenizer_name=str(assistant_model_path),
parameters=cfg.test_llm.parameters,
use_cache=False,
remove_leading_white_space=remove_leading_white_space,
observe_llm_calls=False
observe_llm_calls=False,
)

train_agent = CoTMathAgent.create(llm=llm)
Expand All @@ -416,9 +383,10 @@ def main(cfg: DictConfig):
llm_stats = agent.llm.get_stats()
make_data_took = stats[f"execution_time/{split_name}_make_data"]
more_llm_stats = {
"make_data_output_tokens/s": llm_stats["total_output_tokens"] / make_data_took,
"make_data_prompt_tokens/s": llm_stats["total_prompt_tokens"] / make_data_took,
"make_data_tokens/s": (llm_stats["total_output_tokens"] + llm_stats["total_prompt_tokens"]) / make_data_took,
"make_data_output_tokens/s": llm_stats["total_prompt_tokens"] / make_data_took,
"make_data_prompt_tokens/s": llm_stats["total_output_tokens"] / make_data_took,
"make_data_tokens/s": (llm_stats["total_output_tokens"] + llm_stats["total_prompt_tokens"])
/ make_data_took,
}
for k, v in llm_stats.items():
if "/s" in k:
Expand Down Expand Up @@ -530,10 +498,10 @@ def main(cfg: DictConfig):

start_finetune = time.time()
launch_training(
str(conf_dir),
str(state["iteration"]),
str(conf_dir),
str(state["iteration"]),
cfg.accelerate_cfg_path,
use_deepspeed=cfg.use_deepspeed # defaults to False
use_deepspeed=cfg.use_deepspeed, # defaults to False
)
time_finetune = time.time() - start_finetune
time_iteration = time.time() - start_iteration
Expand Down
2 changes: 2 additions & 0 deletions tapeagents/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class TrainingText(BaseModel):
reward: float = 0.0
logprobs: List[float] = Field(default_factory=list)
ref_logprobs: List[float] = Field(default_factory=list)
input_ids: List[int] = Field(default_factory=list)
labels: List[int] = Field(default_factory=list)
group_id: str | None = None

@property
Expand Down
42 changes: 25 additions & 17 deletions tapeagents/finetune/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,24 +114,32 @@ def preprocess_fn(
seq_length: int,
is_rl: bool = False,
) -> BatchEncoding:
encoding = tokenizer(
entry["text"],
return_offsets_mapping=True,
max_length=seq_length,
truncation=True,
)
if "predicted_spans" in entry:
predicted_spans = entry["predicted_spans"]
if "input_ids" in entry and entry["input_ids"]:
AlexPiche marked this conversation as resolved.
Show resolved Hide resolved
# build the `encoding` object from the given tokenization
encoding = BatchEncoding()
encoding["input_ids"] = entry["input_ids"]
encoding["labels"] = entry["labels"]
encoding["attention_mask"] = [1] * len(entry["input_ids"])
else:
text_length = len(entry["text"])
predicted_chars = entry.get("n_predicted", text_length)
predicted_spans = [(text_length - predicted_chars, text_length)]
validate_spans(entry["text"], predicted_spans)
encoding["labels"], _ = mask_labels(
encoding["input_ids"], # type: ignore
encoding["offset_mapping"], # type: ignore
predicted_spans,
)
# tokenize text to build the `encoding` object
encoding = tokenizer(
entry["text"],
return_offsets_mapping=True,
max_length=seq_length,
truncation=True,
)
if "predicted_spans" in entry:
predicted_spans = entry["predicted_spans"]
else:
text_length = len(entry["text"])
predicted_chars = entry.get("n_predicted", text_length)
predicted_spans = [(text_length - predicted_chars, text_length)]
validate_spans(entry["text"], predicted_spans)
encoding["labels"], _ = mask_labels(
encoding["input_ids"], # type: ignore
encoding["offset_mapping"], # type: ignore
predicted_spans,
)
if is_rl:
encoding = prepare_rl_fields(
encoding,
Expand Down
Loading