Skip to content

Commit

Permalink
Merge pull request THUDM#69 from THUDM/develop
Browse files Browse the repository at this point in the history
Merge develop branch
  • Loading branch information
Stanislas0 authored Feb 20, 2023
2 parents ba876cb + 84365c7 commit d91ed51
Show file tree
Hide file tree
Showing 18 changed files with 1,179 additions and 475 deletions.
70 changes: 70 additions & 0 deletions codegeex/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import copy

from typing import *
from codegeex.megatron.model import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.torch.inference import get_token_stream


def get_model(
backend: str = "megatron",
quantized: bool = False,
):
pass


def generate(
model: CodeGeeXModel,
tokenizer: CodeGeeXTokenizer,
prompt: str,
out_seq_length: int,
seq_length: int = 2048,
top_k: int = 0,
top_p: float = 1.0,
temperature: float = 1.0,
micro_batch_size: int = 1,
backend: str = "megatron",
greedy: bool = False,
verbose: bool = False,
):
tokens = tokenizer.encode_code(prompt)
n_token_prompt = len(tokens)

if verbose:
print(f"Current prompt:\n{prompt}")
print("N_token_prompt:", n_token_prompt)

generated_codes = []
if backend == "megatron":
token_stream = get_token_stream(
model,
tokenizer,
seq_length,
out_seq_length,
[copy.deepcopy(tokens) for _ in range(micro_batch_size)],
micro_batch_size=micro_batch_size,
topk=top_k,
topp=top_p,
temperature=temperature,
greedy=greedy,
)
is_finished = [False for _ in range(micro_batch_size)]
for i, generated in enumerate(token_stream):
generated_tokens = generated[0]
for j in range(micro_batch_size):
if is_finished[j]:
continue

if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(generated_tokens[j]) >= out_seq_length:
is_finished[j] = True
generated_tokens_ = generated_tokens[j].cpu().numpy().tolist()
generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:])
generated_code = "".join(generated_code)
generated_codes.append(generated_code)
if verbose:
print(f"\nGenerated code {i}:\n{generated_code}")

if all(is_finished):
break

return generated_codes
2 changes: 1 addition & 1 deletion codegeex/benchmark/humaneval-x/generate_humaneval_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def add_code_generation_args(parser):
nargs="*",
type=int,
default=None,
help='Identify the type of programming language to generate',
help='Specify bad ids that will not be used',
)
group.add_argument(
"--quantize",
Expand Down
34 changes: 31 additions & 3 deletions codegeex/data/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
import os
import gzip
import json

from typing import *


LANGUAGE_TAG = {
"c" : "// language: C",
"c++" : "// language: C++",
"cpp" : "// language: C++",
"c" : "// language: C",
"c#" : "// language: C#",
"csharp" : "// language: C#",
"css" : "/* language: CSS */",
"cuda" : "// language: Cuda",
"dart" : "// language: Dart",
"lua" : "// language: Lua",
"objectivec" : "// language: Objective-C",
"objective-c" : "// language: Objective-C",
"objective-c++": "// language: Objective-C++",
"python" : "# language: Python",
"perl" : "# language: Perl",
"prolog" : f"% language: Prolog",
"swift" : "// language: swift",
"lisp" : "; language: Lisp",
"java" : "// language: Java",
"scala" : "// language: Scala",
"tex" : f"% language: TeX",
"vue" : "<!--language: Vue-->",
"markdown" : "<!--language: Markdown-->",
"html" : "<!--language: HTML-->",
"php" : "// language: PHP",
"js" : "// language: JavaScript",
Expand All @@ -26,13 +35,32 @@
"go" : "// language: Go",
"shell" : "# language: Shell",
"rust" : "// language: Rust",
"css" : "/* language: CSS */",
"sql" : "-- language: SQL",
"kotlin" : "// language: Kotlin",
"vb" : "' language: Visual Basic",
"ruby" : "# language: Ruby",
"pascal" : "// language: Pascal",
"r" : "# language: R",
"fortran" : "!language: Fortran",
"lean" : "-- language: Lean",
"matlab" : f"% language: Matlab",
"delphi" : "{language: Delphi}",
"scheme" : "; language: Scheme",
"basic" : "' language: Basic",
"assembly" : "; language: Assembly",
"groovy" : "// language: Groovy",
"abap" : "* language: Abap",
"gdscript" : "# language: GDScript",
"haskell" : "-- language: Haskell",
"julia" : "# language: Julia",
"elixir" : "# language: Elixir",
"excel" : "' language: Excel",
"clojure" : "; language: Clojure",
"actionscript" : "// language: ActionScript",
"solidity" : "// language: Solidity",
"powershell" : "# language: PowerShell",
"erlang" : f"% language: Erlang",
"cobol" : "// language: Cobol",
}


Expand Down
2 changes: 1 addition & 1 deletion codegeex/data/process_pretrain_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def process_sample(

try:
if language is not None and language in LANGUAGE_TAG.keys():
code = LANGUAGE_TAG[language] + sample["code"]
code = LANGUAGE_TAG[language] + "\n" + sample["code"]
else:
code = sample["code"]
except Exception as e:
Expand Down
6 changes: 6 additions & 0 deletions codegeex/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def process_sample_strict(self, sample: PromptSample) -> List[Dict[str, List[int
"""
Instead of processing lazily, we turn the iterable into a list.
"""
if sample is None:
return None

return list(self.process_sample(sample))

def process_sample_(self, sample) -> List[Dict[str, List[int]]]:
Expand Down Expand Up @@ -141,6 +144,9 @@ def process_sample_strict(self, sample: LabelSample) -> List[Dict[str, List[int]
"""
Instead of processing lazily, we turn the iterable into a list.
"""
if sample is None:
return None

return list(self.process_sample(sample))

def process_sample_(self, sample) -> List[Dict[str, List[int]]]:
Expand Down
26 changes: 26 additions & 0 deletions codegeex/megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ def _add_network_size_args(parser):
help="Disable BERT binary head.",
dest="bert_binary_head",
)
group.add_argument(
"--compress",
action="store_true",
)

return parser

Expand Down Expand Up @@ -560,6 +564,24 @@ def _add_regularization_args(parser):
group.add_argument(
"--sgd-momentum", type=float, default=0.9, help="Momentum factor for sgd"
)
group.add_argument(
"--shrink-logit-embedding-gradient",
action="store_true",
)
group.add_argument(
"--shrink-embedding-gradient-alpha",
type=float,
default=1.0,
help='Shrink embedding gradient for alpha',
)
group.add_argument(
"--shrink-embedding-gradient-steps",
nargs='*',
default=None,
help='--shrink-embedding-gradient-steps <x1> <x2>'
'Shrink embedding gradient alpha for x1 steps,'
'then warm it up to 1.0 with x2 steps',
)

return parser

Expand Down Expand Up @@ -751,6 +773,10 @@ def _add_initialization_args(parser):
def _add_inference_args(parser):
group = parser.add_argument_group(title="initialization")

group.add_argument(
'--evaluation',
action="store_true",
)
group.add_argument(
'--beam-warmup',
action="store_true",
Expand Down
4 changes: 2 additions & 2 deletions codegeex/megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False):
if release:
directory = ""
else:
directory = "iter_{:07d}".format(iteration)
directory = f"global_step{iteration}"
# Use both the tensor and pipeline MP rank.
if mpu.get_pipeline_model_parallel_world_size() == 1:
return os.path.join(
Expand Down Expand Up @@ -174,7 +174,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Saving is a collective communication
checkpoint_name = get_checkpoint_name(args.save, iteration)
# Trim off the filename and mp_rank_* directory.
for _ in range(3):
for _ in range(2):
checkpoint_name = os.path.dirname(checkpoint_name)
model[0].save_checkpoint(checkpoint_name, client_state=state_dict)

Expand Down
Loading

0 comments on commit d91ed51

Please sign in to comment.