-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathtext_to_onnx.py
26 lines (17 loc) · 876 Bytes
/
text_to_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from encoder.export_text_encoder import TextEncoder
# Export ImageEncoder of the CLIP to onnx model
if __name__ == '__main__':
import clip
device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()
text_encoder = TextEncoder(embed_dim=512, context_length=77, vocab_size=49408,
transformer_width=512, transformer_heads=8, transformer_layers=12)
missing_keys, unexpected_keys = text_encoder.load_state_dict(model.state_dict(), strict=False)
text_encoder.eval()
input_tensor = clip.tokenize("a diagram").to(device)
traced_model = torch.jit.trace(text_encoder, input_tensor)
onnx_filename = 'clip-text-encoder.onnx'
torch.onnx.export(text_encoder, input_tensor, onnx_filename)
# python -m onnxsim clip-text-encoder.onnx clip-text-encoder-optimized.onnx