-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathutils.py
107 lines (99 loc) · 3.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from src.mosidata import MOSIData
from src.iemodata import IEMOData
from src.simsdata import SIMSData
from torch.utils.data import DataLoader
class opt:
A_type = "comparE"
V_type = "denseface"
L_type = "bert_large"
norm_method = "trn"
corpus_name = "IEMOCAP"
in_mem = False
cvNo = 1
def get_data(args, split="train", full_data=False):
if args.dataset == "iemocap":
if split == "train":
data = IEMOData(
opt,
args.data_path,
set_name="trn",
drop_rate=args.drop_rate,
full_data=full_data,
)
elif split == "valid":
data = IEMOData(
opt,
args.data_path,
set_name="val",
drop_rate=args.drop_rate,
full_data=full_data,
)
elif split == "test":
data = IEMOData(
opt,
args.data_path,
set_name="tst",
drop_rate=args.drop_rate,
full_data=full_data,
)
elif args.dataset == "mosi" or args.dataset == "mosei":
data = MOSIData(
args.data_path, split, drop_rate=args.drop_rate, full_data=full_data
)
elif args.dataset == "sims":
data = SIMSData(
args.data_path, split, drop_rate=args.drop_rate, full_data=full_data
)
return data
def get_loader(args):
dataloaders = {}
n_nums = []
if args.dataset == "iemocap":
for split in ["train", "valid", "test"]:
dataset = get_data(args, split)
dataloaders[split] = DataLoader(
dataset,
batch_size=args.batch_size,
drop_last=False,
collate_fn=dataset.collate_fn,
)
orig_dims = dataset.get_dim()
n_nums.append(len(dataset))
seq_len = dataset.get_seq_len()
else:
for split in ["train", "valid", "test"]:
dataset = get_data(args, split)
dataloaders[split] = DataLoader(dataset, batch_size=args.batch_size)
orig_dims = dataset.get_dim()
n_nums.append(len(dataset))
seq_len = dataset.get_seq_len()
return dataloaders, orig_dims, n_nums, seq_len
def transfer_model(new_model, pretrained):
model = torch.load(pretrained)
pretrain_dict = model.state_dict()
new_dict = new_model.state_dict()
state_dict = {}
for k, v in pretrain_dict.items():
if k in new_dict.keys() and k not in [
"proj_l.weight",
"proj_a.weight",
"proj_v.weight",
"out_layer.weight",
"out_layer.bias",
]:
state_dict[k] = v
else:
print("Missing key(s) in state_dict :{}".format(k))
new_dict.update(state_dict)
new_model.load_state_dict(new_dict)
for name, param in new_model.named_parameters():
if name in pretrain_dict.keys() and name not in [
"proj_l.weight",
"proj_a.weight",
"proj_v.weight",
"out_layer.weight",
"out_layer.bias",
]:
param.requires_grad = False
return new_model