Skip to content

Commit

Permalink
Merge pull request microsoft#49 from microsoft/gderosa/gpt2_export
Browse files Browse the repository at this point in the history
pr(archai): Adds support for GPT2 export and optimization with ONNX.
  • Loading branch information
gugarosa authored Dec 17, 2021
2 parents ff365d7 + 15d2305 commit c023c03
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 66 deletions.
24 changes: 22 additions & 2 deletions archai/nlp/nvidia_transformer_xl/onnx/export_torch_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,21 @@ def parse_args():
choices=['mem_transformer', 'hf_gpt2', 'hf_transfo_xl'],
help='Type of model to be exported.')

parser.add_argument('--opset_version',
type=int,
default=11,
help='Version of ONNX operators.')

parser.add_argument('--opt_level',
type=int,
default=0,
help='Level of the ORT optimization.')

parser.add_argument('--num_heads',
type=int,
default=8,
help='Number of attention heads (for fusion).')

parser.add_argument('--optimization',
action='store_true',
help='Applies optimization to the exported model.')
Expand All @@ -52,19 +62,29 @@ def parse_args():
torch_model_path = args.torch_model_path
onnx_model_path = args.onnx_model_path
model_type = args.model_type
opset_version = args.opset_version
opt_level = args.opt_level
num_heads = args.num_heads
optimization = args.optimization
quantization = args.quantization

# Loads the PyTorch model
model, model_config = load_from_pt(model_type, torch_model_path)

# Exports to ONNX
export_onnx_from_pt(model, model_config, model_type, onnx_model_path, share_weights=False)
export_onnx_from_pt(model,
model_config,
model_type,
onnx_model_path,
share_weights=True,
opset_version=opset_version)

# Whether optimization should be applied
if optimization:
ort_model_path = optimize_onnx(model_type, onnx_model_path, opt_level=opt_level)
ort_model_path = optimize_onnx(model_type,
onnx_model_path,
num_heads=num_heads,
opt_level=opt_level)

# Caveat to enable quantization after optimization
onnx_model_path = ort_model_path
Expand Down
70 changes: 42 additions & 28 deletions archai/nlp/nvidia_transformer_xl/onnx/onnx_utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,57 @@ def mockups(self) -> Dict[str, Any]:
}

@property
def inputs(self) -> None:
def inputs(self) -> OrderedDict:
"""Defines the inputs and their shapes to be used when exporting to ONNX.
"""

raise NotImplementedError
# Shape of past states
# [past_key_values, batch_size, n_head, past_seq_len, d_head]
pasts = [(f'past_{i}', {1: 'batch_size', 3: 'past_seq_len'}) for i in range(self.config['n_layer'])]
return OrderedDict([('input_ids', {0: 'batch_size', 1: 'seq_len'})] + pasts)

@property
def outputs(self) -> None:
def outputs(self) -> OrderedDict:
"""Defines the outputs and their shapes to be used when exporting to ONNX.
"""

raise NotImplementedError
# Shape of present states (past states when outputting)
# [2, batch_size, n_head, total_seq_len, d_head]
# Note total_seq_len is current seq_len + past_seq_len
presents = [(f'present_{i}', {1: 'batch_size', 3: 'total_seq_len'}) for i in range(self.config['n_layer'])]
return OrderedDict([('probs', {0: 'batch_size'})] + presents)


class HfGPT2OnnxConfig(OnnxConfig):
"""Provides an ONNX-export configuration for HfGPT2.
"""

def __init__(self, model_config: str) -> None:
"""Initializes the configuration.
Args:
model_config: Model configuration.
"""

super().__init__(model_config)

self.config['past_key_values'] = 2
self.config['model_type'] = 'gpt2'

@property
def mockups(self) -> Dict[str, Any]:
"""Defines the mockups (inputs) to be used when exporting to ONNX.
"""

return {
'input_ids': torch.randint(0, self.config['n_token'], (BATCH_SIZE, SEQ_LEN)),
'past_key_values': tuple([torch.zeros(self.config['past_key_values'], BATCH_SIZE, self.config['n_head'], SEQ_LEN, self.config['d_head']) for _ in range(self.config['n_layer'])])
}


class MemTransformerLMOnnxConfig(OnnxConfig):
Expand Down Expand Up @@ -88,26 +125,3 @@ def mockups(self) -> Dict[str, Any]:
'input_ids': torch.randint(0, self.config['n_token'], (BATCH_SIZE, SEQ_LEN)),
'past_key_values': tuple([torch.zeros(self.config['past_key_values'], BATCH_SIZE, self.config['n_head'], SEQ_LEN, self.config['d_head']) for _ in range(self.config['n_layer'])])
}

@property
def inputs(self) -> OrderedDict:
"""Defines the inputs and their shapes to be used when exporting to ONNX.
"""

# Shape of past states
# [past_key_values, batch_size, n_head, past_seq_len, d_head]
pasts = [(f'past_{i}', {0: str(self.config['past_key_values']), 1: 'batch_size', 2: str(self.config['n_head']), 3: 'past_seq_len', 4: str(self.config['d_head'])}) for i in range(self.config['n_layer'])]
return OrderedDict([('input_ids', {0: 'batch_size', 1: 'seq_len'})] + pasts)

@property
def outputs(self) -> OrderedDict:
"""Defines the outputs and their shapes to be used when exporting to ONNX.
"""

# Shape of present states (past states when outputting)
# [past_key_values, batch_size, n_head, total_seq_len, d_head]
# Note total_seq_len is current seq_len + past_seq_len
presents = [(f'present_{i}', {0: str(self.config['past_key_values']), 1: 'batch_size', 2: str(self.config['n_head']), 3: 'total_seq_len', 4: str(self.config['d_head'])}) for i in range(self.config['n_layer'])]
return OrderedDict([('probs', {0: 'batch_size', 1: str(self.config['n_token'])})] + presents)
25 changes: 19 additions & 6 deletions archai/nlp/nvidia_transformer_xl/onnx/onnx_utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
from onnx import helper, load_model, numpy_helper, save

from archai.nlp.nvidia_transformer_xl.onnx.onnx_utils.operators import tril_onnx, triu_onnx
from archai.nlp.nvidia_transformer_xl.onnx.onnx_utils.configs import MemTransformerLMOnnxConfig
from archai.nlp.nvidia_transformer_xl.onnx.onnx_utils.configs import HfGPT2OnnxConfig, MemTransformerLMOnnxConfig

# List of available ONNX configurations
AVAILABLE_ONNX_CONFIGS = {
'hf_gpt2': HfGPT2OnnxConfig,
'mem_transformer': MemTransformerLMOnnxConfig
}


def weight_sharing(onnx_model_path: str) -> None:
def weight_sharing(onnx_model_path: str, model_type: str) -> None:
"""Shares weights between embedding and softmax layers.
Args:
Expand All @@ -40,12 +41,24 @@ def _find_weights_by_shape(weights, shape):
# Gathers weights and nodes from the loaded model
weights = {w.name:w for w in model.graph.initializer}
nodes = {n.name:n for n in model.graph.node}
n_emb_weight = len(list(filter(lambda x: 'word_emb.emb_layers' in x, weights.keys())))
n_cutoffs = n_emb_weight - 1

if model_type == 'hf_gpt2':
n_emb_weight = 1
n_cutoffs = 0
elif model_type == 'mem_transformer':
n_emb_weight = len(list(filter(lambda x: 'word_emb.emb_layers' in x, weights.keys())))
n_cutoffs = n_emb_weight - 1
else:
raise ValueError(f'Model {model_type} not supported for weight sharing.')


for i in range(n_emb_weight):
# Grabs the embedding weights pointer and removes from the graph
emb_weight_name = f'word_emb.emb_layers.{i}.weight'

if model_type == 'hf_gpt2':
emb_weight_name = 'transformer.wte.weight'

emb_weight = numpy_helper.to_array(weights[emb_weight_name])
model.graph.initializer.remove(weights[emb_weight_name])

Expand Down Expand Up @@ -81,7 +94,7 @@ def export_onnx_from_pt(model: torch.nn.Module,
onnx_model_path: str,
share_weights: Optional[bool] = True,
do_constant_folding: Optional[bool] = True,
opset_version: Optional[int] = 12) -> None:
opset_version: Optional[int] = 11) -> None:
"""Exports a PyTorch-based model to ONNX.
Args:
Expand Down Expand Up @@ -125,4 +138,4 @@ def export_onnx_from_pt(model: torch.nn.Module,

# Applies weight sharing
if share_weights:
weight_sharing(onnx_model_path)
weight_sharing(onnx_model_path, model_type)
40 changes: 30 additions & 10 deletions archai/nlp/nvidia_transformer_xl/onnx/onnx_utils/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,42 @@
import torch.nn.functional as F


def forward_with_probs(self,
input_ids: torch.LongTensor,
past_key_values: Optional[Tuple[torch.FloatTensor, ...]] = None
) -> torch.FloatTensor:
def forward_gpt2_onnx(self,
input_ids: torch.LongTensor,
past_key_values: Optional[Tuple[torch.FloatTensor, ...]] = None
) -> Tuple[torch.FloatTensor, ...]:
"""Overrides the HfGPT2 forward by returning probabilities and past key/values.
Args:
input_ids: Input tensor.
past_key_values: Past pre-computed key/values tensor.
Returns:
(Tuple[torch.FloatTensor, ...]): Output probabilities and past key/values.
"""

outputs = self.transformer(input_ids, past_key_values=past_key_values)

hidden_states = outputs[0]
preds = F.softmax(self.lm_head(hidden_states[:,-1,:]), dim=-1)
past_key_values = tuple([torch.stack(p) for p in outputs.past_key_values])

return preds, past_key_values


def forward_memformer_onnx(self,
input_ids: torch.LongTensor,
past_key_values: Optional[Tuple[torch.FloatTensor, ...]] = None
) -> Tuple[torch.FloatTensor, ...]:
"""Overrides the MemTransformerLM forward by returning probabilities.
Args:
input_ids: Input tensor.
past_key_values: Past pre-computed key/values tensor.
Returns:
(torch.FloatTensor): Output probabilities.
(Tuple[torch.FloatTensor, ...]): Output probabilities and past key/values.
"""

Expand All @@ -41,10 +65,6 @@ def forward_with_probs(self,
# Calculates the output predictions/probabilities
preds = self.crit(hidden_preds)

# As we are using batch_size x seq_len for input and outputs
# We need to change the view of the predictions if they exist
preds = preds.view(input_ids.size(1), -1) if preds is not None else None

# Reshapes past_key_values back to standard shape
past_key_values = tuple([p.permute([0, 2, 3, 1, 4]) for p in past_key_values])

Expand Down Expand Up @@ -79,7 +99,7 @@ def _compute_logit(hidden: torch.FloatTensor,
return logit


def crit_forward_with_probs(self, hidden: torch.FloatTensor) -> torch.FloatTensor:
def crit_forward_memformer_onnx(self, hidden: torch.FloatTensor) -> torch.FloatTensor:
"""Overrides the Projective Adaptive Softmax forward by returning probabilities.
Args:
Expand Down
19 changes: 14 additions & 5 deletions archai/nlp/nvidia_transformer_xl/onnx/onnx_utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

from onnxruntime import (GraphOptimizationLevel, InferenceSession,
SessionOptions)

from archai.nlp.nvidia_transformer_xl.models.archai_model import ArchaiModel
from archai.nlp.nvidia_transformer_xl.models.available_models import AVAILABLE_MODELS

from archai.nlp.nvidia_transformer_xl.onnx.onnx_utils.forward import (crit_forward_with_probs,
forward_with_probs)
from archai.nlp.nvidia_transformer_xl.onnx.onnx_utils.forward import (crit_forward_memformer_onnx, forward_gpt2_onnx,
forward_memformer_onnx)

# Constants available in onnxruntime
# that enables performance optimization
Expand Down Expand Up @@ -79,8 +79,17 @@ def load_from_pt(model_type: str, torch_model_path: str) -> Tuple[ArchaiModel, d

# Overrides forward functions if MemTransformerLM
if model_type == 'mem_transformer':
model.forward = types.MethodType(forward_with_probs, model)
model.crit.forward = types.MethodType(crit_forward_with_probs, model.crit)
model.forward = types.MethodType(forward_memformer_onnx, model)
model.crit.forward = types.MethodType(crit_forward_memformer_onnx, model.crit)

if model_type == 'hf_gpt2':
model = model.model
model.forward = types.MethodType(forward_gpt2_onnx, model)

if type(model_config['d_head']) is list:
model_config['d_head'] = model_config['d_head'][0]
if type(model_config['n_head']) is list:
model_config['n_head'] = model_config['n_head'][0]

# Puts to evaluation model to disable dropout
model.eval()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, model_type: str) -> None:
self.attention_mask_format = AttentionMaskFormat.AttentionMask

if model_type == 'hf_gpt2':
self.enable_embed_layer_norm = False
self.enable_skip_layer_norm = False

def use_raw_attention_mask(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def optimize(self,
self.fuse_shape()

# Removes useless Reshape nodes that are staling through the graph
self.utils.remove_useless_reshape_nodes(self)
self.utils.remove_useless_reshape_nodes()

# Post-processing step
self.clean_graph()
Expand Down
12 changes: 11 additions & 1 deletion archai/nlp/nvidia_transformer_xl/onnx/onnx_utils/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

def optimize_onnx(model_type: str,
onnx_model_path: str,
num_heads: Optional[int] = 8,
use_gpu: Optional[bool] = False,
opt_level: Optional[int] = 0,
only_ort: Optional[bool] = False,
Expand All @@ -32,6 +33,7 @@ def optimize_onnx(model_type: str,
Args:
model_type: Type of model to be optimized.
onnx_model_path: Path to the ONNX model to be optimized.
num_heads: Number of attention heads.
use_gpu: Whether to use GPU during optimization.
opt_level: Level of optimization.
only_ort: Whether to only apply ORT optimization.
Expand Down Expand Up @@ -70,7 +72,15 @@ def optimize_onnx(model_type: str,
# Loads the ORT-optimized model, optimizer and fusion options
ort_model = load_model(ort_model_path or onnx_model_path)
ort_model_path = create_file_name_identifier(Path(onnx_model_path), '_opt')
optimizer = AVAILABLE_ONNX_OPTS[model_type](ort_model)

# Puts the arguments for the optimizer
optimizer_args = (ort_model, )
if model_type == 'hf_gpt2':
# Adds `hidden_size` as zero
# just for retro-compatibility
optimizer_args += (num_heads, 0)

optimizer = AVAILABLE_ONNX_OPTS[model_type](*optimizer_args)
options = FusionOptions(model_type)

# Optimizes the model
Expand Down
11 changes: 5 additions & 6 deletions archai/nlp/nvidia_transformer_xl/onnx/validate_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,11 @@ def parse_args():
model_onnx = load_from_onnx(onnx_model_path)

# Checks the type of attention to define the `past_key_values`
# `k`, `v` and relative embeddings
if model_config['attn_type'] == 0:
n_past_values = 3
else:
# `k` and `v`
n_past_values = 2
n_past_values = 2
if model_type == 'mem_transformer':
if model_config['attn_type'] == 0:
# `k`, `v` and relative embeddings
n_past_values = 3

# Defines PyTorch inputs
torch.manual_seed(0)
Expand Down
Loading

0 comments on commit c023c03

Please sign in to comment.