-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathexport_abstractive_summary_model
executable file
·80 lines (59 loc) · 2.3 KB
/
export_abstractive_summary_model
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!.venv/bin/python3
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import urllib.request
import torch
import os
import sys
if os.path.exists("data/summarizer/abstractive"):
print("data/summarizer/abstractive already exists. Exiting...")
sys.exit()
os.makedirs("data/summarizer/abstractive")
MODEL = "lidiya/bart-large-xsum-samsum"
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL, torchscript=True)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenizer.save_pretrained("data/summarizer/abstractive")
for p in model.parameters():
p.requires_grad_(False)
def run_decoder_wp(input_ids, encoder_hidden_states, past_key_values):
(logits, past) = model.model.decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
)
return (model.lm_head(logits), past)
def run_decoder(input_ids, encoder_hidden_states):
(logits, past) = model.model.decoder(
input_ids=input_ids, encoder_hidden_states=encoder_hidden_states
)
out = model.lm_head(logits)
return (out, past)
def run_encoder(input_ids):
res = model.model.encoder(input_ids=input_ids)[0]
return res
embed_size_per_head = model.config.d_model // model.config.decoder_attention_heads
input_ids = torch.tensor([[19] * 1], dtype=torch.long)
keys = torch.ones(
model.config.decoder_layers,
1,
model.config.decoder_attention_heads,
1,
embed_size_per_head,
)
past_key_values = tuple((key, key) for key in keys)
encoder_outputs = torch.rand(1, 1024, 1024, dtype=torch.float32)
encoder_hidden_states = torch.rand(1, 1024, 1024, dtype=torch.float32)
# trace models
traced_decoder_wp = torch.jit.trace(
run_decoder_wp, (input_ids, encoder_hidden_states, past_key_values)
)
input_ids = torch.tensor([[19] * 1024], dtype=torch.long)
traced_decoder = torch.jit.trace(run_decoder, (input_ids, encoder_hidden_states))
traced_encoder = torch.jit.trace(run_encoder, input_ids)
traced_decoder_wp.save("data/summarizer/abstractive/traced_decoder_wp.pt")
traced_decoder.save("data/summarizer/abstractive/traced_decoder.pt")
traced_encoder.save("data/summarizer/abstractive/traced_encoder.pt")
urllib.request.urlretrieve(
"http://s3.trystract.com/public/truncated_word2vec.bin.gz",
"data/summarizer/word2vec.bin.gz",
)