Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
Huan He committed Feb 6, 2023
1 parent 9a9ddc4 commit 7e619c8
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 341 deletions.
116 changes: 21 additions & 95 deletions algorithms/RAINCOAT.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,23 @@ def forward(self, x):
ef = self.nn2(self.con1(ef).squeeze())
et = self.cnn(x)
f = torch.concat([ef,et],-1)
return F.normalize(f)
return F.normalize(f), out_ft

class tf_decoder(nn.Module):
def __init__(self, configs):
super(tf_decoder, self).__init__()
self.nn = nn.LayerNorm([3, 128],eps=1e-04)
self.nn2 = nn.LayerNorm([3, 128],eps=1e-04)
self.input_channels, self.sequence_len = configs.input_channels, configs.sequence_len
self.nn = nn.LayerNorm([self.input_channels, self.sequence_len],eps=1e-04)
self.fc1 = nn.Linear(64, 3*128)
self.convT = torch.nn.ConvTranspose1d(64, 128,3, stride=1)
self.convT = torch.nn.ConvTranspose1d(configs.final_out_channels, self.sequence_len, self.input_channels, stride=1)
self.modes = configs.fourier_modes
self.conv_block1 = nn.Sequential(
nn.ConvTranspose1d(configs.final_out_channels, configs.mid_channels, kernel_size=3,
stride=1),
nn.BatchNorm1d(configs.mid_channels),
nn.ReLU(),
# nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
# nn.Dropout(configs.dropout)
nn.Dropout(configs.dropout)
)
self.conv_block2 = nn.Sequential(
nn.ConvTranspose1d(configs.mid_channels, configs.sequence_len , \
Expand All @@ -146,16 +146,17 @@ def __init__(self, configs):
nn.ReLU(),
# nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
)
self.lin = nn.Linear(configs.final_out_channels, self.input_channels * self.sequence_len)

def forward(self, f, out_ft):
# freq, time = f.chunk(2,dim=1)
x_low = self.nn(torch.fft.irfft(out_ft, n=128))
et = f[:,self.modes:]
x_high = self.conv_block1(et.unsqueeze(2))
x_high = self.conv_block2(x_high).permute(0,2,1)
# x_high = self.conv_block1(et.unsqueeze(2))
# x_high = self.conv_block2(x_high).permute(0,2,1)
# x_high = self.nn2(F.gelu((self.fc1(time).reshape(-1, 3, 128))))
# print(x_low.shape, time.shape)
# x_high = self.nn2(F.relu(self.convT(time.unsqueeze(2))).permute(0,2,1))
x_high = self.nn(F.relu(self.convT(et.unsqueeze(2))).permute(0,2,1))
# x_high = self.nn(F.relu(self.lin(et).reshape(-1, self.input_channels, self.sequence_len)))
return x_low + x_high

class RAINCOAT(Algorithm):
Expand Down Expand Up @@ -187,12 +188,12 @@ def __init__(self, configs, hparams, device):
def update(self, src_x, src_y, trg_x):

self.optimizer.zero_grad()
src_feat = self.feature_extractor(src_x)
trg_feat = self.feature_extractor(trg_x)
# src_recon = self.decoder(src_feat, self.feature_extractor.recons)
# trg_recon = self.decoder(trg_feat, out)
# recons = 1e-4*(self.recons(src_recon, src_x) + self.recons(trg_recon, src_x))
# recons.backward(retain_graph=True)
src_feat, out_s = self.feature_extractor(src_x)
trg_feat, out_t = self.feature_extractor(trg_x)
src_recon = self.decoder(src_feat, out_s)
trg_recon = self.decoder(trg_feat, out_t)
recons = 1e-4*(self.recons(src_recon, src_x)+self.recons(trg_recon, trg_x))
recons.backward(retain_graph=True)
dr, _, _ = self.sink(src_feat, trg_feat)
sink_loss = 1 *dr
sink_loss.backward(retain_graph=True)
Expand All @@ -206,88 +207,13 @@ def update(self, src_x, src_y, trg_x):

def correct(self,src_x, src_y, trg_x):
self.coptimizer.zero_grad()
src_feat, out = self.encoder(src_x)
trg_feat, out = self.encoder(trg_x)
src_recon = self.decoder(src_feat, out)
trg_recon = self.decoder(trg_feat, out)
recons = self.recons(trg_recon, trg_x)+self.recons(src_recon, src_x)
src_feat, out_s = self.feature_extractor(src_x)
trg_feat, out_t = self.feature_extractor(trg_x)
src_recon = self.decoder(src_feat, out_s)
trg_recon = self.decoder(trg_feat, out_t)
recons = self.recons(trg_recon, trg_x) + 0.1*self.recons(src_recon, src_x)
recons.backward()
self.coptimizer.step()
return {'recon': recons.item()}


# class TFAC(Algorithm):
# """
# TFAC: Time Frequency Domain Adaptation with Correct
# """
# def __init__(self, configs, hparams, device):
# super(TFAC, self).__init__(configs)
# self.encoder = tf_encoder(configs).to(device)
# self.decoder = tf_decoder(configs).to(device)
# # self.classifier = ResClassifier_MME(configs).to(device)
# self.classifier = classifier(configs).to(device)
# # self.classifier.weights_init()

# self.optimizer = torch.optim.Adam(
# list(self.encoder.parameters()) + \
# # list(self.decoder.parameters())\
# list(self.classifier.parameters()),
# lr=hparams["learning_rate"],
# weight_decay=hparams["weight_decay"]
# )
# self.coptimizer = torch.optim.Adam(
# list(self.encoder.parameters())+list(self.decoder.parameters()),
# lr=1*hparams["learning_rate"],
# weight_decay=hparams["weight_decay"]
# )

# self.hparams = hparams

# # self.loss_func = losses.TripletMarginLoss()
# self.recons = nn.L1Loss(reduction='sum').to(device)
# # self.recons = nn.MSELoss(reduction='sum').to(device)
# self.pi = torch.acos(torch.zeros(1)).item() * 2
# self.loss_func = losses.ContrastiveLoss(pos_margin=0.5)
# # self.loss_func = OrthogonalProjectionLoss(gamma=0.5)
# self.sink = SinkhornDistance(eps=5e-3, max_iter=500)

# def update(self, src_x, src_y, trg_x):

# self.optimizer.zero_grad()
# # self.classifier.weight_norm()
# src_feat, out = self.encoder(src_x)
# trg_feat, out = self.encoder(trg_x)
# # src_recon = self.decoder(src_feat, out)
# # trg_recon = self.decoder(trg_feat, out)
# # recons = self.recons(src_recon, src_x)+self.recons(trg_recon, src_x)
# # recons.backward(retain_graph=True)
# dr, _, _ = self.sink(src_feat, trg_feat)
# sink_loss = 1 *dr
# sink_loss.backward(retain_graph=True)
# #
# # loss= 3*sink_loss + recons
# lossinner = 1 * self.loss_func(src_feat, src_y)
# # lossinner.backward(retain_graph=True)
# # lossinner = 2 * self.op_loss (src_feat, src_y)
# lossinner.backward(retain_graph=True)
# src_pred = self.classifier(src_feat)

# loss_cls = 1 *self.cross_entropy(src_pred, src_y)
# loss_cls.backward(retain_graph=True)
# # loss = 10 * sink_loss + loss_cls + 1*lossinner
# # loss.backward()
# self.optimizer.step()
# return {'Src_cls_loss': loss_cls.item(),'Sink': sink_loss.item(), 'inner': lossinner.item()}

# def correct(self,src_x, src_y, trg_x):
# self.coptimizer.zero_grad()
# src_feat, out = self.encoder(src_x)
# trg_feat, out = self.encoder(trg_x)
# src_recon = self.decoder(src_feat, out)
# trg_recon = self.decoder(trg_feat, out)
# recons = self.recons(src_recon, src_x)+self.recons(trg_recon, src_x)
# recons.backward()
# self.coptimizer.step()
# return {'recon': recons.item()}


112 changes: 112 additions & 0 deletions algorithms/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,119 @@ def update(self, src_x, src_y, trg_x, trg_index, step, epoch, len_dataloader):

return {'total_loss': loss.item(), 'src_loss': src_loss.item(), 'loss_nc': loss_nc.item(), 'loss_ent': loss_nc.item()}


class OVANet(Algorithm):
"""
OVANet https://arxiv.org/pdf/2104.03344v3.pdf
Based on PyTorch implementation: https://github.com/VisionLearningGroup/OVANet
"""
def __init__(self, backbone_fe, configs, hparams, device):
super().__init__(configs)

self.device = device
self.hparams = hparams
self.criterion = nn.CrossEntropyLoss()

self.feature_extractor = backbone_fe(configs) # G
self.classifier1 = classifier(configs) # C1

configs2 = configs
configs2.num_classes = configs.num_classes * 2

self.classifier2 = classifier(configs2) # C2

self.feature_extractor.to(device)
self.classifier1.to(device)
self.classifier2.to(device)

self.opt_g = SGD(self.feature_extractor.parameters(), momentum=self.hparams['sgd_momentum'],
lr = self.hparams['learning_rate'], weight_decay=0.0005, nesterov=True)
self.opt_c = SGD(list(self.classifier1.parameters()) + list(self.classifier2.parameters()), lr=1.0,
momentum=self.hparams['sgd_momentum'], weight_decay=0.0005,
nesterov=True)

param_lr_g = []
for param_group in self.opt_g.param_groups:
param_lr_g.append(param_group["lr"])
param_lr_c = []
for param_group in self.opt_c.param_groups:
param_lr_c.append(param_group["lr"])

self.param_lr_g = param_lr_g
self.param_lr_c = param_lr_c


@staticmethod
def _inv_lr_scheduler(param_lr, optimizer, iter_num, gamma=10,
power=0.75, init_lr=0.001,weight_decay=0.0005,
max_iter=10000):
#10000
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
#max_iter = 10000
gamma = 10.0
lr = init_lr * (1 + gamma * min(1.0, iter_num / max_iter)) ** (-power)
i=0
for param_group in optimizer.param_groups:
param_group['lr'] = lr * param_lr[i]
i+=1
return lr

def update(self, src_x, src_y, trg_x, step, epoch, len_train_source, len_train_target):

# Applying classifier network => replacing G, C2 in paper
self.feature_extractor.train()
self.classifier1.train()
self.classifier2.train()

self._inv_lr_scheduler(self.param_lr_g, self.opt_g, step,
init_lr=self.hparams['learning_rate'],
max_iter=self.hparams['min_step'])
self._inv_lr_scheduler(self.param_lr_c, self.opt_c, step,
init_lr=self.hparams['learning_rate'],
max_iter=self.hparams['min_step'])

self.opt_g.zero_grad()
self.opt_c.zero_grad()

# self.classifier2.weight_norm()

## Source loss calculation
out_s = self.classifier1(self.feature_extractor(src_x))
out_open = self.classifier2(self.feature_extractor(src_x))

## source classification loss
loss_s = self.criterion(out_s, src_y)
## open set loss for source
out_open = out_open.view(out_s.size(0), 2, -1)
open_loss_pos, open_loss_neg = ova_loss(out_open, src_y)
## b x 2 x C
loss_open = 0.5 * (open_loss_pos + open_loss_neg)
## open set loss for target
all = loss_s + loss_open

# OEM - Open Entropy Minimization
no_adapt = False
if not no_adapt: # TODO: Figure out if this needs to be altered
out_open_t = self.classifier2(self.feature_extractor(trg_x))
out_open_t = out_open_t.view(trg_x.size(0), 2, -1)

ent_open = open_entropy(out_open_t)
all += self.hparams['multi'] * ent_open

all.backward()

self.opt_g.step()
self.opt_c.step()
self.opt_g.zero_grad()
self.opt_c.zero_grad()

return {'src_loss': loss_s.item(),
'open_loss': loss_open.item(),
'open_src_pos_loss': open_loss_pos.item(),
'open_src_neg_loss': open_loss_neg.item(),
'open_trg_loss': ent_open.item()
}

class AdaMatch(Algorithm):
"""
AdaMatch https://arxiv.org/abs/2106.04732
Expand Down
2 changes: 1 addition & 1 deletion configs/data_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self):
self.dropout = 0.5
self.num_classes = 6
self.fourier_modes = 64
self.out_dim = 128
self.out_dim = 192
# CNN and RESNET features

self.mid_channels = 64
Expand Down
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

# ======== Experiments Name ================
parser.add_argument('--save_dir', default='experiments_logs', type=str, help='Directory containing all experiments')
parser.add_argument('--experiment_description', default='WISDM-RAINCOAT', type=str, help='Name of your experiment (EEG, HAR, HHAR_SA, WISDM')
parser.add_argument('--run_description', default='WISDM-RAINCOAT', type=str, help='name of your runs')
parser.add_argument('--experiment_description', default='HAR-RAINCOAT', type=str, help='Name of your experiment (EEG, HAR, HHAR_SA, WISDM')
parser.add_argument('--run_description', default='HAR-RAINCOAT', type=str, help='name of your runs')

# ========= Select the DA methods ============
parser.add_argument('--da_method', default='RAINCOAT', type=str, help='DANN, Deep_Coral, RAINCOAT, MMDA, VADA, DIRT, CDAN, AdaMatch, HoMM, CoDATS')

# ========= Select the DATASET ==============
parser.add_argument('--data_path', default=r'./data', type=str, help='Path containing dataset')
parser.add_argument('--dataset', default='WISDM', type=str, help='Dataset of choice: (WISDM - EEG - HAR - HHAR_SA, Boiler)')
parser.add_argument('--dataset', default='HAR', type=str, help='Dataset of choice: (WISDM - EEG - HAR - HHAR_SA, Boiler)')

# ========= Select the BACKBONE ==============
parser.add_argument('--backbone', default='CNN', type=str, help='Backbone of choice: (CNN - RESNET18 - TCN)')
Expand Down
25 changes: 24 additions & 1 deletion models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,30 @@ def forward(self, x):
b = b.sum(dim=1)
return -1.0 * b.mean(dim=0)


def ova_loss(out_open, label):
assert len(out_open.size()) == 3
assert out_open.size(1) == 2

out_open = F.softmax(out_open, 1)

label_p = torch.zeros((out_open.size(0),
out_open.size(2))).long().cuda()
label_range = torch.arange(0, out_open.size(0)).long() # - 1
label_p[label_range, label] = 1
label_n = 1 - label_p
open_loss_pos = torch.mean(torch.sum(-torch.log(out_open[:, 1, :]
+ 1e-8) * label_p, 1))
open_loss_neg = torch.mean(torch.max(-torch.log(out_open[:, 0, :] +
1e-8) * label_n, 1)[0])
return open_loss_pos, open_loss_neg

def open_entropy(out_open):
assert len(out_open.size()) == 3
assert out_open.size(1) == 2
out_open = F.softmax(out_open, 1)
ent_open = torch.mean(torch.mean(torch.sum(-out_open * torch.log(out_open + 1e-8), 1), 1))
return ent_open

class VAT(nn.Module):
def __init__(self, model, device):
super(VAT, self).__init__()
Expand Down
Loading

0 comments on commit 7e619c8

Please sign in to comment.