Skip to content

Commit

Permalink
revert changes in examples
Browse files Browse the repository at this point in the history
  • Loading branch information
aakashapoorv committed May 21, 2024
1 parent e2c027f commit 95d36c2
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 14 deletions.
8 changes: 1 addition & 7 deletions example_chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

import os

from typing import List, Optional

import fire
Expand Down Expand Up @@ -30,10 +28,6 @@ def main(
`max_gen_len` is optional because finetuned models are able to stop generations naturally.
"""
assert 1 <= max_seq_len <= 8192, f"max_seq_len must be between 1 and 8192, got {max_seq_len}."
assert os.path.isdir(ckpt_dir), f"Checkpoint directory '{ckpt_dir}' does not exist."
assert os.path.isfile(tokenizer_path), f"Tokenizer file '{tokenizer_path}' does not exist."

generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
Expand Down Expand Up @@ -87,4 +81,4 @@ def main(


if __name__ == "__main__":
fire.Fire(main)
fire.Fire(main)
8 changes: 1 addition & 7 deletions example_text_completion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

import os

from typing import List

import fire
Expand All @@ -26,10 +24,6 @@ def main(
The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192.
`max_gen_len` is needed because pre-trained models usually do not stop completions naturally.
"""
assert 1 <= max_seq_len <= 8192, f"max_seq_len must be between 1 and 8192, got {max_seq_len}."
assert os.path.isdir(ckpt_dir), f"Checkpoint directory '{ckpt_dir}' does not exist."
assert os.path.isfile(tokenizer_path), f"Tokenizer file '{tokenizer_path}' does not exist."

generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
Expand Down Expand Up @@ -67,4 +61,4 @@ def main(


if __name__ == "__main__":
fire.Fire(main)
fire.Fire(main)

0 comments on commit 95d36c2

Please sign in to comment.