From 5c3e778e12476cd724a8399171c5f7d3683b581c Mon Sep 17 00:00:00 2001 From: AKuzina Date: Mon, 22 Nov 2021 17:19:24 +0100 Subject: [PATCH] add char-level PennTreeBank dataset --- __init__.py | 0 config.py | 10 ++ dataset.py | 64 ++++++++---- datasets/__init__.py | 1 + datasets/penn_tree_bank_char.py | 142 ++++++++++++++++++++++++++ model.py | 30 +++++- models/__init__.py | 4 +- models/ckcnn.py | 75 +++++++++++++- models/tcn.py | 35 +++++++ requirements.txt | 1 + tester.py | 68 ++++++++++++- trainer.py | 170 +++++++++++++++++++++++++++++++- 12 files changed, 574 insertions(+), 26 deletions(-) create mode 100644 __init__.py create mode 100644 datasets/penn_tree_bank_char.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config.py b/config.py index c69b84a..95f0f68 100644 --- a/config.py +++ b/config.py @@ -64,6 +64,8 @@ def get_config(): # This parameter is automatically derived from the other parameters of the run. It specifies # the path where the network parameters will be saved / loaded from. report_auc=False, + report_ppl=False, + report_bpc=False, max_epochs_no_improvement=100, # -------------------------- # Parameters of TCNs / BFCNNs @@ -81,6 +83,12 @@ def get_config(): # kernels. e.g., 32. pool=False, # **Not used in our experiments -> Worse performance.** # If True, it adds a max pool layer after each Residual Block. + emb_dropout=0., + # Embeding dropout for PennTreeBank dataset + emb_size=0, + # Size of the embedding + tied_weights=True, + # If true - use the same weight matrix for encoder and decoder # -------------------------- # Parameters of SIREN kernelnet_omega_0=0.0, @@ -109,6 +117,8 @@ def get_config(): drop_rate=0, # Specifies the rate at which data will be droped from the original dataset. Used for experiments # With missing data. e.g., 30, 50, 70. + # 6. PennTreeBank: valid sequence length. Seq_len = effective history + valid sequence length (only second part used for traning) + valid_seq_len=0, ) default_config = ml_collections.ConfigDict(default_config) return default_config diff --git a/dataset.py b/dataset.py index cc6ca5d..788ebc7 100644 --- a/dataset.py +++ b/dataset.py @@ -7,6 +7,7 @@ SpeechCommands, CharTrajectories, PhysioNet, + PennTreeBankChar, ) import ml_collections @@ -30,8 +31,10 @@ def dataset_constructor( "SpeechCommands": SpeechCommands, "CharTrajectories": CharTrajectories, "PhysioNet": PhysioNet, + 'PennTreeBankChar': PennTreeBankChar, }[config.dataset] - + if config.dataset == 'PennTreeBankChar': + eval_batch_size = 10 training_set = dataset( partition="train", seq_length=config.seq_length, @@ -39,6 +42,8 @@ def dataset_constructor( mfcc=config.mfcc, sr=config.sr_train, dropped_rate=config.drop_rate, + valid_seq_len=config.valid_seq_len, + batch_size=config.batch_size, ) test_set = dataset( partition="test", @@ -49,8 +54,10 @@ def dataset_constructor( if config.sr_test == 0 else config.sr_test, # Test set can be sample differently. dropped_rate=config.drop_rate, + valid_seq_len=config.valid_seq_len, + batch_size=eval_batch_size, ) - if config.dataset in ["SpeechCommands", "CharTrajectories", "PhysioNet"]: + if config.dataset in ["SpeechCommands", "CharTrajectories", "PhysioNet", "PennTreeBankChar"]: validation_set = dataset( partition="val", seq_length=config.seq_length, @@ -58,6 +65,8 @@ def dataset_constructor( mfcc=config.mfcc, sr=config.sr_train, dropped_rate=config.drop_rate, + valid_seq_len=config.valid_seq_len, + batch_size=eval_batch_size, ) else: validation_set = None @@ -74,29 +83,48 @@ def get_dataset( :return: Tuple ( dict(train_loader, val_loader) , test_loader) """ training_set, validation_set, test_set = dataset_constructor(config) + if config.dataset in ["PennTreeBankChar"]: + with config.unlocked(): + config.vocab_size = len(training_set.dictionary) + training_loader = torch.utils.data.DataLoader( + training_set, + batch_sampler=training_set.sampler, + num_workers=num_workers, + ) + test_loader = torch.utils.data.DataLoader( + test_set, + batch_sampler=test_set.sampler, + num_workers=num_workers, + ) - training_loader = torch.utils.data.DataLoader( - training_set, - batch_size=config.batch_size, - shuffle=True, - num_workers=num_workers, - ) - test_loader = torch.utils.data.DataLoader( - test_set, - batch_size=config.batch_size, - shuffle=False, - num_workers=num_workers, - ) - - if validation_set is not None: val_loader = torch.utils.data.DataLoader( validation_set, + batch_sampler=validation_set.sampler, + num_workers=num_workers, + ) + else: + training_loader = torch.utils.data.DataLoader( + training_set, + batch_size=config.batch_size, + shuffle=True, + num_workers=num_workers, + ) + test_loader = torch.utils.data.DataLoader( + test_set, batch_size=config.batch_size, shuffle=False, num_workers=num_workers, ) - else: - val_loader = test_loader + + if validation_set is not None: + val_loader = torch.utils.data.DataLoader( + validation_set, + batch_size=config.batch_size, + shuffle=False, + num_workers=num_workers, + ) + else: + val_loader = test_loader dataloaders = {"train": training_loader, "validation": val_loader} diff --git a/datasets/__init__.py b/datasets/__init__.py index dd7ee7b..6fbbc9a 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -11,3 +11,4 @@ ) from .physionet import PhysioNet from .char_trajectories import CharTrajectories +from .penn_tree_bank_char import PennTreeBankChar diff --git a/datasets/penn_tree_bank_char.py b/datasets/penn_tree_bank_char.py new file mode 100644 index 0000000..174d7bf --- /dev/null +++ b/datasets/penn_tree_bank_char.py @@ -0,0 +1,142 @@ +""" +Adapted from https://github.com/locuslab/TCN/blob/master/TCN/ +""" +import pickle +from collections import Counter +import os +import numpy as np +import torch +import pathlib +from .utils import load_data, save_data +import observations + + +class PennTreeBankChar(torch.utils.data.Dataset): + def __init__( + self, + partition: int, + seq_length: int, + valid_seq_len: int, + batch_size: int, + **kwargs, + ): + self.seq_len = seq_length + self.valid_seq_len = valid_seq_len + self.batch_size = batch_size + self.root = pathlib.Path("./data") + self.base_loc = self.root / "penn" + data_loc = self.base_loc / "preprocessed_data_char" + + if os.path.exists(data_loc): + self.dictionary = pickle.load(open(str(data_loc / 'dictionary_char'), 'rb')) + else: + train, valid, test = self._process_data() + + if not os.path.exists(data_loc): + os.mkdir(data_loc) + pickle.dump(self.dictionary, open(str(data_loc / 'dictionary_char'), 'wb')) + save_data( + data_loc, + train=train, + valid=valid, + test=test, + ) + + self.X, self.y = self.load_data(data_loc, partition) + if partition == 'train': + self.sampler = SequentialBatchSampler(self) + else: + self.sampler = SequentialBatchSampler(self, shuffle=False) + super(PennTreeBankChar, self).__init__() + + def __getitem__(self, ind): + b = ind // len(self.X[0]) + i = ind - b * len(self.X[0]) + return self.X[b][i], self.y[b][i] + + def __len__(self): + return len(self.X[0]) * len(self.X) + + def create_seq(self, data, batch_size): + nbatch = data.size(0) // batch_size + data = data.narrow(0, 0, nbatch * batch_size).view(batch_size, -1) ## crop tail + x = [] + y = [] + L = data.shape[1] + for i in range(0, L-1, self.valid_seq_len): + if i + self.seq_len - self.valid_seq_len >= L - 1: + continue + end = min(i + self.seq_len, L - 1) + x.append(data[:, i: end].contiguous()) + y.append(data[:, i+1: end+1].contiguous()) + return x, y + + def _process_data(self): + self.dictionary = Dictionary() + train, test, valid = getattr(observations, 'ptb')(self.base_loc) + for c in train + ' ' + test + '' + valid: + self.dictionary.add_word(c) + self.dictionary.prep_dict() + + + train = self._char_to_tensor(train) + valid = self._char_to_tensor(valid) + test = self._char_to_tensor(test) + return train, valid, test + + def _char_to_tensor(self, string): + tensor = torch.zeros(len(string)).long() + for i in range(len(string)): + tensor[i] = self.dictionary.char2idx[string[i]] + return tensor + + def load_data(self, data_loc, partition): + tensors = load_data(data_loc) + if partition == "train": + data = tensors["train"] + elif partition == "val": + data = tensors["valid"] + elif partition == "test": + data = tensors["test"] + else: + raise NotImplementedError("the set {} is not implemented.".format(set)) + X, y = self.create_seq(data, self.batch_size) + return X, y + + +class Dictionary(object): + def __init__(self): + self.char2idx = {} + self.idx2char = [] + self.counter = Counter() + + def add_word(self, word): + self.counter[word] += 1 + + def prep_dict(self): + for char in self.counter: + if char not in self.char2idx: + self.idx2char.append(char) + self.char2idx[char] = len(self.idx2char) - 1 + + def __len__(self): + return len(self.idx2char) + + +class SequentialBatchSampler(torch.utils.data.Sampler): + def __init__(self, data_source, shuffle=True): + super(SequentialBatchSampler, self).__init__(data_source) + self.X = data_source.X + if shuffle: + self.sampler = torch.utils.data.SubsetRandomSampler(np.arange(len(self.X))) + else: + self.sampler = np.arange(len(self.X)) + self.batch_size = self.X[0].shape[0] + + def __iter__(self): + for idx in self.sampler: + batch = [idx * self.batch_size + j for j in range(self.batch_size)] + yield batch + + def __len__(self): + return len(self.X) diff --git a/model.py b/model.py index d57b8c8..74fa00e 100644 --- a/model.py +++ b/model.py @@ -24,6 +24,8 @@ def get_model(config): in_channels = 1 elif config.dataset in ["PhysioNet"]: in_channels = 75 + elif config.dataset in ["PennTreeBankChar"]: + in_channels = config.emb_size else: raise NotImplementedError("Dataset {} not found.".format(config.dataset)) @@ -80,6 +82,15 @@ def get_model(config): kernel_size=config.cnn_kernel_size, dropout=config.dropout, ), + "PennTreeBankChar_TCN": lambda: models.PTB_TCN( + input_size=config.emb_size, + output_size=config.vocab_size, + num_channels=[config.no_hidden] * (config.no_blocks-1) + [config.emb_size], + kernel_size=config.cnn_kernel_size, + dropout=config.dropout, + emb_dropout=config.emb_dropout, + tied_weights=config.tied_weights, + ), "AddProblem_CKCNN": lambda: models.AddProblem_CKCNN( in_channels=in_channels, hidden_channels=config.no_hidden, @@ -168,6 +179,23 @@ def get_model(config): weight_dropout=config.weight_dropout, pool=config.pool, ), + "PennTreeBankChar_CKCNN": lambda: models.seqText_CKCNN( + in_channels=in_channels, + out_channels=config.vocab_size, + hidden_channels=config.no_hidden, + num_blocks=config.no_blocks, + kernelnet_hidden_channels=config.kernelnet_no_hidden, + kernelnet_activation_function=config.kernelnet_activation_function, + kernelnet_norm_type=config.kernelnet_norm_type, + dim_linear=1, + bias=True, + omega_0=config.kernelnet_omega_0, + dropout=config.dropout, + weight_dropout=config.weight_dropout, + pool=config.pool, + emb_dropout=config.emb_dropout, + tied_weights=config.tied_weights + ), "CharTrajectories_CKCNN": lambda: models.seqImg_CKCNN( in_channels=in_channels, out_channels=20, @@ -187,7 +215,7 @@ def get_model(config): # print number parameters print("Number of parameters:", ckconv.utils.num_params(model)) - # wandb.run.summary["no_params"] = ckconv.utils.num_params(model) + wandb.run.summary["no_params"] = ckconv.utils.num_params(model) # Check if multi-GPU available and if so, use the available GPU's print("GPU's available:", torch.cuda.device_count()) diff --git a/models/__init__.py b/models/__init__.py index 1218246..62b5d5c 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,3 +1,3 @@ -from .tcn import AddProblem_TCN, CopyMemory_TCN, MNIST_TCN -from .ckcnn import CopyMemory_CKCNN, AddProblem_CKCNN, seqImg_CKCNN +from .tcn import AddProblem_TCN, CopyMemory_TCN, MNIST_TCN, PTB_TCN +from .ckcnn import CopyMemory_CKCNN, AddProblem_CKCNN, seqImg_CKCNN, seqText_CKCNN from .bfcnn import seqImg_BFCNN diff --git a/models/ckcnn.py b/models/ckcnn.py index 9406c36..d68e832 100644 --- a/models/ckcnn.py +++ b/models/ckcnn.py @@ -17,16 +17,20 @@ def __init__( dropout: float, weight_dropout: float, pool: bool, # Always False in our experiments. + out_channels=None, ): super(CKCNN, self).__init__() blocks = [] for i in range(num_blocks): block_in_channels = in_channels if i == 0 else hidden_channels + block_out_channels = hidden_channels + if i == num_blocks-1 and out_channels is not None: + block_out_channels = out_channels blocks.append( ckconv.nn.CKBlock( block_in_channels, - hidden_channels, + block_out_channels, kernelnet_hidden_channels, kernelnet_activation_function, kernelnet_norm_type, @@ -183,3 +187,72 @@ def forward(self, x): if out.shape[-1] == 1: out = out.squeeze(-1) return out + + +class seqText_CKCNN(CKCNN): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + num_blocks: int, + kernelnet_hidden_channels: int, + kernelnet_activation_function: str, + kernelnet_norm_type: str, + dim_linear: int, + bias: bool, + omega_0: bool, + dropout: float, + weight_dropout: float, + pool: bool, + # vocab_size: int, + emb_dropout: float, + tied_weights=True, + ): + super().__init__( + in_channels, + hidden_channels, + num_blocks, + kernelnet_hidden_channels, + kernelnet_activation_function, + kernelnet_norm_type, + dim_linear, + bias, + omega_0, + dropout, + weight_dropout, + pool, + out_channels=in_channels + ) + self.encoder = torch.nn.Embedding(out_channels, in_channels) + self.drop = torch.nn.Dropout(emb_dropout) + + self.finallyr = torch.nn.Linear( + in_features=in_channels, out_features=out_channels + ) + self.train_len = None + self.init_weights() + if tied_weights: + self.finallyr.weight = self.encoder.weight + print("Weight tied") + + def init_weights(self): + self.encoder.weight.data.normal_(0, 0.01) + self.finallyr.weight.data.normal_(0, 0.01) + self.finallyr.bias.data.fill_(value=0.0) + + def forward(self, x, return_emb=False): + emb = self.prepare_input(x) + y1 = self.backbone(emb)[:, :, self.train_len-x.shape[-1]:] # MB x hid_size x seq_len + out = self.finallyr(y1.transpose(1, 2)) # MB x seq_len x voc_size + if return_emb: + return out, y1.transpose(1, 2) + return out + + def prepare_input(self, x): + if self.train_len is None: + self.train_len = x.shape[-1] + emb = self.drop(self.encoder(x)).transpose(1, 2) # MB x emb_size x seq_len + if emb.shape[-1] < self.train_len: + emb = torch.nn.functional.pad(emb, (self.train_len - emb.shape[-1], 0)) + return emb diff --git a/models/tcn.py b/models/tcn.py index 6d63f55..f8355de 100644 --- a/models/tcn.py +++ b/models/tcn.py @@ -194,3 +194,38 @@ def init_weights(self): def forward(self, x): y1 = self.tcn(x) return self.linear(y1[:, :, -1]) + + +class PTB_TCN(nn.Module): + def __init__(self, input_size, + output_size, + num_channels, + kernel_size, + dropout, + emb_dropout=0.1, + tied_weights=True): + super(PTB_TCN, self).__init__() + self.encoder = nn.Embedding(output_size, input_size) + self.tcn = TemporalConvNet(input_size, num_channels, kernel_size, dropout=dropout) + + self.linear = nn.Linear(num_channels[-1], output_size) + if tied_weights: + if num_channels[-1] != input_size: + raise ValueError('When using the tied flag, nhid must be equal to emsize') + self.linear.weight = self.encoder.weight + print("Weight tied") + self.drop = nn.Dropout(emb_dropout) + self.init_weights() + + def init_weights(self): + self.encoder.weight.data.normal_(0, 0.01) + self.linear.weight.data.normal_(0, 0.01) + self.linear.bias.data.fill_(0) + + def forward(self, x, return_emb=False): + emb = self.drop(self.encoder(x)).transpose(1, 2) # MB x emb_size x seq_len + y1 = self.tcn(emb) # MB x n_ch x seq_len + out = self.linear(y1.transpose(1, 2)) # MB x seq_len x voc_size + if return_emb: + return out, y1.transpose(1, 2) + return out diff --git a/requirements.txt b/requirements.txt index ee96eac..779eeb1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -115,3 +115,4 @@ wcwidth==0.2.5 webencodings==0.5.1 wheel==0.36.2 zipp==3.4.0 +observations==0.1.4 diff --git a/tester.py b/tester.py index 3b81c98..de0368d 100644 --- a/tester.py +++ b/tester.py @@ -3,7 +3,7 @@ import numpy as np import wandb - +import math import sklearn # project @@ -20,6 +20,7 @@ def test(model, test_loader, config): "SpeechCommands": _test_classif, "CharTrajectories": _test_classif, "PhysioNet": _test_classif, + "PennTreeBankChar": _test_language_modeling, }[config.dataset] test_acc = test_function(model, test_loader, config) @@ -43,6 +44,11 @@ def _test_classif(model, test_loader, config): pred_y_cpus = [] auc = 0 + if config.report_ppl: + criterion = torch.nn.CrossEntropyLoss() + running_ppl = 0. + ppl_N = 0 + with torch.no_grad(): # Iterate through data for inputs, labels in test_loader: @@ -61,6 +67,8 @@ def _test_classif(model, test_loader, config): preds = (outputs > 0.0).int() else: _, preds = torch.max(outputs, 1) + if len(labels.shape) > 1: + labels = labels.reshape(-1) total += labels.size(0) correct += (preds == labels).sum().item() @@ -70,6 +78,11 @@ def _test_classif(model, test_loader, config): true_y_cpus.append(labels.detach().cpu()) pred_y_cpus.append(outputs.detach().cpu()) + if config.report_ppl: + loss = criterion(outputs, labels) + running_ppl += (inputs.size(1) - config.seq_length + config.valid_seq_len) * loss.item() + ppl_N += inputs.size(1) - config.seq_length + config.valid_seq_len + # Print results test_acc = correct / total print( @@ -85,4 +98,57 @@ def _test_classif(model, test_loader, config): auc = sklearn.metrics.roc_auc_score(true_y_cpus, pred_y_cpus) print(f"AUC: {auc}") + if config.report_ppl: + ppl = math.exp(running_ppl / ppl_N) + print(f"PPL: {ppl}") + return test_acc, ppl + return test_acc, auc + + +def _test_language_modeling(model, test_loader, config): + # send model to device + device = config.device + model.eval() + model.to(device) + eff_history = config.seq_length - config.valid_seq_len + + # Summarize results + criterion = torch.nn.CrossEntropyLoss() + total = 0 + running_loss = 0 + + if config.report_ppl or config.report_bpc: + running_ppl = 0. + ppl_N = 0 + + with torch.no_grad(): + # Iterate through data + for inputs, labels in test_loader: + inputs = inputs.to(device) + labels = labels.to(device)[:, eff_history:].contiguous().view(-1) + outputs = model(inputs) + outputs = outputs[:, eff_history:].contiguous().view(-1, config.vocab_size) + loss = criterion(outputs, labels) + running_loss += loss.item() * labels.shape[0] + total += labels.shape[0] + + if config.report_ppl or config.report_bpc: + n = inputs.shape[1] - eff_history + running_ppl += n * loss.item() + ppl_N += n + + # Print results + test_loss = running_loss / total + print(f"\tTest loss: {test_loss:.2f}") + ppl =0. + if config.report_ppl: + ppl = math.exp(running_ppl / ppl_N) + print(f"\tTest PPL: {ppl:.2f}") + + bpc = 0. + if config.report_bpc: + bpc = (running_ppl / ppl_N) / math.log(2) + print(f"\tTest BPC: {bpc:.2f}") + + return test_loss, ppl, bpc diff --git a/trainer.py b/trainer.py index 23dd511..16084df 100644 --- a/trainer.py +++ b/trainer.py @@ -5,6 +5,7 @@ import copy import os import datetime +import math import numpy as np # logging @@ -28,6 +29,7 @@ def train(model, dataloaders, config, test_loader): "SpeechCommands": torch.nn.CrossEntropyLoss(), "CharTrajectories": torch.nn.CrossEntropyLoss(), "PhysioNet": torch.nn.BCEWithLogitsLoss(), + "PennTreeBankChar": torch.nn.CrossEntropyLoss(), }[config.dataset] train_function = { @@ -38,6 +40,8 @@ def train(model, dataloaders, config, test_loader): "SpeechCommands": _train_classif, "CharTrajectories": _train_classif, "PhysioNet": _train_classif, + "PennTreeBank": _train_language_modeling, + "PennTreeBankChar": _train_language_modeling, }[config.dataset] # Define optimizer and scheduler @@ -125,7 +129,7 @@ def get_scheduler(optimizer, config): def _train_classif( - model, criterion, optimizer, dataloader, lr_scheduler, config, test_loader + model, criterion, optimizer, dataloader, lr_scheduler, config, test_loader ): weight_regularizer = ckconv.nn.LnLoss(weight_loss=config.weight_decay, norm_type=2) # Training parameters @@ -289,8 +293,8 @@ def _train_classif( # Update scheduler if ( - isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) - and phase == "validation" + isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + and phase == "validation" ): lr_scheduler.step(epoch_acc) @@ -312,3 +316,163 @@ def _train_classif( model.load_state_dict(best_model_wts) # Return model and histories return model + + +def _train_language_modeling( + model, criterion, optimizer, dataloader, lr_scheduler, config, test_loader +): + print(f'Vocabulary size: {config.vocab_size} \n') + weight_regularizer = ckconv.nn.LnLoss(weight_loss=config.weight_decay, norm_type=2) + # Training parameters + epochs = config.epochs + device = config.device + eff_history = config.seq_length - config.valid_seq_len + + # Save best performing weights + best_model_wts = copy.deepcopy(model.state_dict()) + best_loss = 999 + + # Counter for epochs without improvement + epochs_no_improvement = 0 + max_epochs_no_improvement = config.max_epochs_no_improvement + + # compute num params + # total_params = sum(p.numel() for p in model.parameters()) + # if config.tied_weights: + # total_params -= sum(p.numel() for p in model.module.encoder.parameters()) + # wandb.run.summary["num_param"] = total_params + # iterate over epochs + for epoch in range(epochs): + print("Epoch {}/{}".format(epoch + 1, epochs)) + print("-" * 30) + # Print current learning rate + for param_group in optimizer.param_groups: + print("Learning Rate: {}".format(param_group["lr"])) + print("-" * 30) + # log learning_rate of the epoch + wandb.log({"lr": optimizer.param_groups[0]["lr"]}, step=epoch + 1) + + # Each epoch consist of training and validation + for phase in ["train", "validation"]: + if phase == "train": + model.train() + else: + model.eval() + + # Accumulate loss + running_loss = 0 + total = 0 + running_ppl = 0. + ppl_N = 0 + + # iterate over data + for inputs, labels in dataloader[phase]: + inputs = inputs.to(device) + labels = labels.to(device)[:, eff_history:].contiguous().view(-1) + + optimizer.zero_grad() + train = phase == "train" + with torch.set_grad_enabled(train): + # FwrdPhase: + outputs, emb = model(inputs, return_emb=True) + outputs = outputs[:, eff_history:].contiguous().view(-1, config.vocab_size) + loss = criterion(outputs, labels) + + if config.report_ppl or config.report_bpc: + n = inputs.shape[1] - eff_history + running_ppl += n * loss.item() + ppl_N += n + + # statistics + running_loss += loss.item() * labels.shape[0] + total += labels.shape[0] + # Regularization: + if config.weight_decay != 0.0: + loss = loss + weight_regularizer(model) + + # BwrdPhase: + if phase == "train": + loss.backward() + if config.clip > 0: + torch.nn.utils.clip_grad_norm_( + model.parameters(), config.clip + ) + optimizer.step() + + # statistics of the epoch + epoch_loss = running_loss / total + print("{} Loss: {:.2f}".format(phase, epoch_loss)) + + # log statistics of the epoch + wandb.log( + {"loss" + "_" + phase: epoch_loss}, + step=epoch + 1, + ) + + if config.report_ppl: + epoch_ppl = math.exp(running_ppl / ppl_N) + print(f"PPL: {epoch_ppl:.2f}") + wandb.log( + {f"ppl_{phase}": epoch_ppl}, + step=epoch + 1, + ) + if config.report_bpc: + epoch_bpc = (running_ppl / ppl_N) / math.log(2) + print(f"BPC: {epoch_bpc:.2f}") + wandb.log( + {f"bpc_{phase}": epoch_bpc}, + step=epoch + 1, + ) + # If better validation accuracy, replace best weights and compute the test performance + if phase == "validation" and epoch_loss < best_loss: + best_loss = epoch_loss + best_model_wts = copy.deepcopy(model.state_dict()) + + # Log best results so far and the weights of the model. + wandb.run.summary["best_val_loss"] = best_loss + + # Clean CUDA Memory + del inputs, outputs, labels + torch.cuda.empty_cache() + # Perform test and log results + test_loss, test_ppl, test_bpc = test(model, test_loader, config) + + if config.report_ppl: + wandb.run.summary["best_val_ppl"] = epoch_ppl + wandb.run.summary["best_test_ppl"] = test_ppl + wandb.log({"test_ppl": test_ppl}, step=epoch + 1) + if config.report_bpc: + wandb.run.summary["best_val_bpc"] = epoch_bpc + wandb.run.summary["best_test_bpc"] = test_bpc + wandb.log({"test_bpc": test_bpc}, step=epoch + 1) + + # Reset counter of epochs without progress + epochs_no_improvement = 0 + + elif phase == "validation" and epoch_loss >= best_loss: + # Otherwise, increase counter + epochs_no_improvement += 1 + + # Update scheduler + if ( + isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + and phase == "validation" + ): + lr_scheduler.step(-epoch_loss) + + # Update scheduler + if isinstance(lr_scheduler, torch.optim.lr_scheduler.MultiStepLR): + lr_scheduler.step() + print() + + # Check how many epochs without improvement have passed, and, if required, stop training. + if epochs_no_improvement == max_epochs_no_improvement: + print( + f"Stopping training due to {epochs_no_improvement} epochs of no improvement in validation accuracy." + ) + break + + # Load best model weights + model.load_state_dict(best_model_wts) + # Return model and histories + return model