Skip to content

Commit

Permalink
Refactor test inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanislas0 committed Nov 29, 2022
1 parent 38f1623 commit 3281f0b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 2 additions & 0 deletions codegeex/torch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def get_token_stream(
topp: float = 1.0,
topk: int = 0.0,
greedy: bool = False,
recompute: bool = False,
):
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length)

Expand Down Expand Up @@ -197,6 +198,7 @@ def get_token_stream(
topp=topp,
topk=topk,
greedy=greedy,
recompute=recompute,
)

for tokens, lengths in batch_token_iterator:
Expand Down
9 changes: 8 additions & 1 deletion tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from codegeex.megatron.initialize import initialize_megatron
from codegeex.megatron.model import CodeGeeXModel
from codegeex.megatron.code_generation_utils import get_token_stream
from codegeex.quantization import quantize

torch.set_printoptions(precision=8)

Expand Down Expand Up @@ -80,7 +81,7 @@ def add_code_generation_args(parser):
group.add_argument(
"--ws-encoding-length",
type=int,
default=80,
default=10,
help="Length of whitespace encoding",
)
group.add_argument(
Expand Down Expand Up @@ -123,6 +124,10 @@ def add_code_generation_args(parser):
default=None,
help='Identify the type of programming language to generate',
)
group.add_argument(
"--quantize",
action="store_true",
)

return parser

Expand Down Expand Up @@ -151,6 +156,8 @@ def main():
model.eval()
if args.fp16 and args.ln_fp16:
model.half()
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="megatron")
model.cuda()

with open(args.prompt_file, "r") as f:
Expand Down

0 comments on commit 3281f0b

Please sign in to comment.