Skip to content

Commit

Permalink
import deepspeed only if use_deepspeed is True
Browse files Browse the repository at this point in the history
  • Loading branch information
manmay-nakhashi committed Jul 10, 2023
1 parent 31a2e15 commit 82724cc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tortoise/models/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepspeed
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
Expand Down Expand Up @@ -362,6 +361,7 @@ def post_init_gpt2_config(self, use_deepspeed=False):
self.mel_head
)
if use_deepspeed:
import deepspeed
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
mp_size=1,
replace_with_kernel_inject=True,
Expand Down

0 comments on commit 82724cc

Please sign in to comment.