From a83b603d56ee7906eed15802ef7c82a1f21187cc Mon Sep 17 00:00:00 2001 From: heaodong Date: Mon, 31 Jul 2023 21:24:49 +0800 Subject: [PATCH] add files to ImageCaption --- ImageCaptioning/create_input_files.py | 13 ++ ImageCaptioning/datasets.py | 57 ++++++ ImageCaptioning/eval.py | 161 ++++++++++++++++ ImageCaptioning/models.py | 212 +++++++++++++++++++++ ImageCaptioning/utils.py | 263 ++++++++++++++++++++++++++ 5 files changed, 706 insertions(+) create mode 100755 ImageCaptioning/create_input_files.py create mode 100755 ImageCaptioning/datasets.py create mode 100644 ImageCaptioning/eval.py create mode 100755 ImageCaptioning/models.py create mode 100755 ImageCaptioning/utils.py diff --git a/ImageCaptioning/create_input_files.py b/ImageCaptioning/create_input_files.py new file mode 100755 index 0000000..0c73069 --- /dev/null +++ b/ImageCaptioning/create_input_files.py @@ -0,0 +1,13 @@ +"""Create files for dataset""" + +from utils import create_input_files + +if __name__ == '__main__': + # Create input files (along with word map) + create_input_files(dataset='coco', + karpathy_json_path='Deep-Tutorials-for-MindSpore/dataset_coco/dataset_coco.json', + image_folder='Deep-Tutorials-for-MindSpore/dataset_coco/', + captions_per_image=5, + min_word_freq=5, + output_folder='Deep-Tutorials-for-MindSpore/dataset_coco/', + max_len=50) diff --git a/ImageCaptioning/datasets.py b/ImageCaptioning/datasets.py new file mode 100755 index 0000000..b63a38f --- /dev/null +++ b/ImageCaptioning/datasets.py @@ -0,0 +1,57 @@ +# pylint: disable=C0103 +# pylint: disable=E0401 + +"""CaptionDataset""" +import json +import os +import h5py +import mindspore +from mindspore import Tensor + +class CaptionDataset: + """ + A MindSpore Dataset class to be used in a MindSpore DataLoader to create batches. + """ + + def __init__(self, data_folder, data_name, split): + """ + :param data_folder: folder where data files are stored + :param data_name: base name of processed datasets + :param split: split, one of 'TRAIN', 'VAL', or 'TEST' + :param transform: image transform pipeline + """ + self.split = split + assert self.split in {'TRAIN', 'VAL', 'TEST'} + + # Open hdf5 file where images are stored + self.h = h5py.File(os.path.join(data_folder, self.split + '_IMAGES_' + data_name + '.hdf5'), 'r') + self.imgs = self.h['images'] + + # Captions per image + self.cpi = self.h.attrs['captions_per_image'] + + # Load encoded captions (completely into memory) + with open(os.path.join(data_folder, self.split + '_CAPTIONS_' + data_name + '.json'), 'r', encoding='utf-8') as j: + self.captions = json.load(j) + + # Load caption lengths (completely into memory) + with open(os.path.join(data_folder, self.split + '_CAPLENS_' + data_name + '.json'), 'r', encoding='utf-8') as j: + self.caplens = json.load(j) + + # Total number of datapoints + self.dataset_size = len(self.captions) + + def __getitem__(self, i): + img = Tensor(self.imgs[i // self.cpi] / 255., dtype=mindspore.float32) + + caption = Tensor(self.captions[i], dtype=mindspore.int64) + caplen = Tensor(self.caplens[i], dtype=mindspore.int64) + if self.split == 'TRAIN': + return img, caption, caplen + # For validation of testing, also return all 'captions_per_image' captions to find BLEU-4 score + all_captions = Tensor( + self.captions[((i // self.cpi) * self.cpi):(((i // self.cpi) * self.cpi) + self.cpi)], dtype=mindspore.int64) + return img, caption, caplen, all_captions + + def __len__(self): + return self.dataset_size diff --git a/ImageCaptioning/eval.py b/ImageCaptioning/eval.py new file mode 100644 index 0000000..4cdef8a --- /dev/null +++ b/ImageCaptioning/eval.py @@ -0,0 +1,161 @@ +# pylint: disable=C0103 +# pylint: disable=E0401 + +"""evaluation""" +import json +import mindspore +from tqdm import tqdm +from mindspore import ops +from mindspore.dataset import vision +from mindspore import load_checkpoint, load_param_into_net +from mindspore.dataset import GeneratorDataset +from models import Encoder, DecoderWithAttention +from nltk.translate.bleu_score import corpus_bleu +from datasets import CaptionDataset + +data_folder = 'Deep-Tutorials-for-MindSpore/dataset_coco' +data_name = 'coco_5_cap_per_img_5_min_word_freq' +checkpoint = 'decoder_coco_5_cap_per_img_5_min_word_freq_1.ckpt' +word_map_file = 'Deep-Tutorials-for-MindSpore/dataset_coco/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json' + +emb_dim = 512 # dimension of word embeddings +attention_dim = 512 # dimension of attention linear layers +decoder_dim = 512 # dimension of decoder RNN +dropout = 0.5 + +with open(word_map_file, 'r', encoding='utf-8') as j: + word_map = json.load(j) +rev_word_map = {v: k for k, v in word_map.items()} +vocab_size = len(word_map) + +encoder = Encoder() +decoder = DecoderWithAttention(attention_dim=attention_dim, + embed_dim=emb_dim, + decoder_dim=decoder_dim, + vocab_size=vocab_size, + dropout=dropout) +params_dict = load_checkpoint(checkpoint) +load_param_into_net(decoder, params_dict) +encoder.set_train(False) +decoder.set_train(False) + +normalize = vision.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + is_hwc=False) + +def evaluate(beam_size): + """evaluate""" + test_dataset = GeneratorDataset(CaptionDataset(data_folder, data_name, 'TEST'), ['img', 'caption', 'caplen', 'allcaps']) + test_dataset = test_dataset.map(operations=[normalize], input_columns='img') + test_dataset = test_dataset.batch(1) + + references = [] + hypotheses = [] + + with tqdm(total = len(test_dataset)) as progress: + progress.set_description("EVALUATING AT BEAM SIZE " + str(beam_size)) + for i, (image, _, _, allcaps) in enumerate(test_dataset.create_tuple_iterator()): + + k = beam_size + encoder_out = encoder(image) + encoder_dim = encoder_out.shape[3] + encoder_out = encoder_out.view(1, -1, encoder_dim) + num_pixels = encoder_out.shape[1] + encoder_out = encoder_out.broadcast_to((k, num_pixels, encoder_dim)) + + k_prev_words = mindspore.Tensor([[word_map['']]] * k, dtype=mindspore.int32) + seqs = k_prev_words + top_k_scores = ops.zeros((k, 1)) + + complete_seqs = [] + complete_seqs_scores = [] + + step = 1 + h, c = decoder.init_hidden_state(encoder_out) + + while True: + + embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) + + awe, _ = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) + + gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim) + awe = gate * awe + + h, c = decoder.decode_step(ops.cat([embeddings, awe], axis=1), (h, c)) # (s, decoder_dim) + + scores = decoder.fc(h) # (s, vocab_size) + scores = ops.log_softmax(scores, axis=1) + + # Add + scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) + + # For the first step, all k points will have the same scores (since same k previous words, h, c) + if step == 1: + top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) + else: + # Unroll and find top scores, and their unrolled indices + top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) + + # Convert unrolled indices to actual indices of scores + prev_word_inds = top_k_words / vocab_size # (s) + next_word_inds = top_k_words % vocab_size # (s) + + # Add new words to sequences + seqs = ops.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], axis=1) # (s, step+1) + + # Which sequences are incomplete (didn't reach )? + incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if + next_word != word_map['']] + complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) + + # Set aside complete sequences + if len(complete_inds) > 0: + complete_seqs.extend(seqs[complete_inds].asnumpy().tolist()) + complete_seqs_scores.extend(top_k_scores[complete_inds]) + k -= len(complete_inds) # reduce beam length accordingly + + # Proceed with incomplete sequences + if k == 0: + break + seqs = seqs[incomplete_inds] + h = h[prev_word_inds[incomplete_inds]] + c = c[prev_word_inds[incomplete_inds]] + encoder_out = encoder_out[prev_word_inds[incomplete_inds]] + top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) + k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) + + # Break if things have been going on too long + if step > 50: + break + step += 1 + + i = complete_seqs_scores.index(max(complete_seqs_scores)) + seq = complete_seqs[i] + + # Referencess + img_caps = allcaps[0].asnumpy().tolist() + img_captions = list( + map(lambda c: [w for w in c if w not in {word_map[''], word_map[''], word_map['']}], + img_caps)) # remove and pads + references.append(img_captions) + + # Hypotheses + hypotheses.append([w for w in seq if w not in {word_map[''], word_map[''], word_map['']}]) + + assert len(references) == len(hypotheses) + progress.update(1) + + # Calculate BLEU-4 scores + bleu4 = corpus_bleu(references, hypotheses) + + return bleu4 + +def main(): + """main""" + beam_size = 3 + bleu4 = evaluate(beam_size) + print(f"\nBLEU-4 score @ beam size of {beam_size} is {bleu4:.f}.") + +if __name__ == '__main__': + main() diff --git a/ImageCaptioning/models.py b/ImageCaptioning/models.py new file mode 100755 index 0000000..b635638 --- /dev/null +++ b/ImageCaptioning/models.py @@ -0,0 +1,212 @@ +# pylint: disable=C0103 +# pylint: disable=E0401 + +"""models""" +from mindcv import resnet101 +from mindspore import nn, ops, Parameter +from mindspore.common.initializer import initializer, Uniform + +class Encoder(nn.Cell): + """ + Encoder. + """ + def __init__(self, encoded_image_size=14): + super().__init__() + self.enc_image_size = encoded_image_size + + resnet = resnet101(pretrained=True) + + modules = list(resnet.cells())[:-2] + self.resnet = nn.SequentialCell(*modules) + + self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) + + self.fine_tune() + + def construct(self, images): + """ + Forward propagation. + + :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size) + :return: encoded images + """ + out = self.resnet(images) # (batch_size, 2048, image_size/32, image_size/32) + out = self.adaptive_pool(out) # (batch_size, 2048, encoded_image_size, encoded_image_size) + out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 2048) + return out + + def fine_tune(self, fine_tune=True): + """ + Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder. + + :param fine_tune: Allow? + """ + for p in self.resnet.get_parameters(): + p.requires_grad = False + # If fine-tuning, only fine-tune convolutional blocks 2 through 4 + for c in list(self.resnet.cells())[5:]: + for p in c.get_parameters(): + p.requires_grad = fine_tune + +class Attention(nn.Cell): + """ + Attention Network. + """ + + def __init__(self, encoder_dim, decoder_dim, attention_dim): + """ + :param encoder_dim: feature size of encoded images + :param decoder_dim: size of decoder's RNN + :param attention_dim: size of the attention network + """ + super().__init__() + self.encoder_att = nn.Dense(encoder_dim, attention_dim) + self.decoder_att = nn.Dense(decoder_dim, attention_dim) + self.full_att = nn.Dense(attention_dim, 1) + self.relu = nn.ReLU() + self.softmax = nn.Softmax(axis=1) # softmax layer to calculate weights + + def construct(self, encoder_out, decoder_hidden): + """ + Forward propagation. + + :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) + :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) + :return: attention weighted encoding, weights + """ + att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim) + att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) + att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels) + alpha = self.softmax(att) # (batch_size, num_pixels) + attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(axis=1) # (batch_size, encoder_dim) + + return attention_weighted_encoding, alpha + +class DecoderWithAttention(nn.Cell): + """ + Decoder. + """ + + def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5): + """ + :param attention_dim: size of attention network + :param embed_dim: embedding size + :param decoder_dim: size of decoder's RNN + :param vocab_size: size of vocabulary + :param encoder_dim: feature size of encoded images + :param dropout: dropout + """ + super().__init__() + + self.encoder_dim = encoder_dim + self.attention_dim = attention_dim + self.embed_dim = embed_dim + self.decoder_dim = decoder_dim + self.vocab_size = vocab_size + self.dropout = dropout + + self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network + + self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer + self.dropout = nn.Dropout(p=self.dropout) + self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, has_bias=True) # decoding LSTMCell + self.init_h = nn.Dense(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell + self.init_c = nn.Dense(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell + self.f_beta = nn.Dense(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate + self.sigmoid = nn.Sigmoid() + self.fc = nn.Dense(decoder_dim, vocab_size) # linear layer to find scores over vocabulary + self.init_weights() # initialize some layers with the uniform distribution + + def init_weights(self): + """ + initialize embedding layer and fc layer + """ + self.embedding.embedding_table.set_data(initializer(Uniform(0.1), + self.embedding.embedding_table.shape, + self.embedding.embedding_table.dtype)) + self.fc.bias.set_data(initializer('zero', self.fc.bias.shape, self.fc.bias.dtype)) + self.fc.weight.set_data(initializer(Uniform(0.1), + self.fc.weight.shape, + self.fc.weight.dtype)) + + def load_pretrained_embeddings(self, embeddings): + """ + Loads embedding layer with pre-trained embeddings. + + :param embeddings: pre-trained embeddings + """ + self.embedding.embedding_table = Parameter(embeddings) + + def fine_tune_embeddings(self, fine_tune=True): + """ + Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings). + + :param fine_tune: Allow? + """ + for p in self.embedding.get_parameters(): + p.requires_grad = fine_tune + + def init_hidden_state(self, encoder_out): + """ + Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. + + :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) + :return: hidden state, cell state + """ + mean_encoder_out = encoder_out.mean(axis=1) + h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) + c = self.init_c(mean_encoder_out) + return h, c + + def construct(self, encoder_out, encoded_captions, caption_lengths): + """ + Forward propagation. + + :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) + :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) + :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) + :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices + """ + + batch_size = encoder_out.shape[0] + encoder_dim = encoder_out.shape[-1] + vocab_size = self.vocab_size + + # Flatten image + encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) + num_pixels = encoder_out.shape[1] + + # Sort input data by decreasing lengths; why? apparent below + caption_lengths, sort_ind = caption_lengths.sort(axis=0, descending=True) + encoder_out = encoder_out[sort_ind] + encoded_captions = encoded_captions[sort_ind] + + embeddings = self.embedding(encoded_captions) + + h, c = self.init_hidden_state(encoder_out) + + # We won't decode at the position, since we've finished generating as soon as we generate + # So, decoding lengths are actual lengths - 1 + decode_lengths = (caption_lengths -1).asnumpy().tolist() + + # Create tensors to hold word predicion scores and alphas + predictions = ops.zeros((batch_size, max(decode_lengths), vocab_size)) + alphas = ops.zeros((batch_size, max(decode_lengths), num_pixels)) + + # At each time-step, decode by + # attention-weighing the encoder's output based on the decoder's previous hidden state output + # then generate a new word in the decoder with the previous word and the attention weighted encoding + for t in range(max(decode_lengths)): + batch_size_t = sum(l for l in decode_lengths if l > t) + attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], + h[:batch_size_t]) + gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) + attention_weighted_encoding = gate * attention_weighted_encoding + h, c = self.decode_step( + ops.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], axis=1), + (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim) + preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size) + predictions[:batch_size_t, t, :] = preds + alphas[:batch_size_t, t, :] = alpha + + return predictions, encoded_captions, decode_lengths, alphas, sort_ind diff --git a/ImageCaptioning/utils.py b/ImageCaptioning/utils.py new file mode 100755 index 0000000..e885df8 --- /dev/null +++ b/ImageCaptioning/utils.py @@ -0,0 +1,263 @@ +# pylint: disable=C0103 +# pylint: disable=E0401 + +"""utils""" +import os +import json +from collections import Counter +from random import seed, choice, sample +import mindspore +import h5py +import numpy as np +from mindspore import ops, save_checkpoint +from tqdm import tqdm +from imageio import imread +from PIL import Image + +def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_image, min_word_freq, output_folder, + max_len=100): + """ + Creates input files for training, validation, and test data. + + :param dataset: name of dataset, one of 'coco', 'flickr8k', 'flickr30k' + :param karpathy_json_path: path of Karpathy JSON file with splits and captions + :param image_folder: folder with downloaded images + :param captions_per_image: number of captions to sample per image + :param min_word_freq: words occuring less frequently than this threshold are binned as s + :param output_folder: folder to save files + :param max_len: don't sample captions longer than this length + """ + + assert dataset in {'coco', 'flickr8k', 'flickr30k'} + + with open(karpathy_json_path, 'r', encoding='utf-8') as j: + data = json.load(j) + + train_image_paths = [] + train_image_captions = [] + val_image_paths = [] + val_image_captions = [] + test_image_paths = [] + test_image_captions = [] + word_freq = Counter() + + for img in data['images']: + captions = [] + for c in img['sentences']: + # Update word frequency + word_freq.update(c['tokens']) + if len(c['tokens']) <= max_len: + captions.append(c['tokens']) + + if len(captions) == 0: + continue + + path = os.path.join(image_folder, img['filepath'], img['filename']) if dataset == 'coco' else os.path.join( + image_folder, img['filename']) + + if img['split'] in {'train', 'restval'}: + train_image_paths.append(path) + train_image_captions.append(captions) + elif img['split'] in {'val'}: + val_image_paths.append(path) + val_image_captions.append(captions) + elif img['split'] in {'test'}: + test_image_paths.append(path) + test_image_captions.append(captions) + + assert len(train_image_paths) == len(train_image_captions) + assert len(val_image_paths) == len(val_image_captions) + assert len(test_image_paths) == len(test_image_captions) + + words = [w[0] for w in word_freq.items() if w[1] > min_word_freq] + word_map = {k: v + 1 for v, k in enumerate(words)} + word_map[''] = len(word_map) + 1 + word_map[''] = len(word_map) + 1 + word_map[''] = len(word_map) + 1 + word_map[''] = 0 + + # Create a base/root name for all output files + base_filename = dataset + '_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq' + + # Save word map to a JSON + with open(os.path.join(output_folder, 'WORDMAP_' + base_filename + '.json'), 'w', encoding='utf-8') as j: + json.dump(word_map, j) + + # Sample captions for each image, save images to HDF5 file, and captions and their lengths to JSON files + seed(123) + for impaths, imcaps, split in [(train_image_paths, train_image_captions, 'TRAIN'), + (val_image_paths, val_image_captions, 'VAL'), + (test_image_paths, test_image_captions, 'TEST')]: + + with h5py.File(os.path.join(output_folder, split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as h: + # Make a note of the number of captions we are sampling per image + h.attrs['captions_per_image'] = captions_per_image + + # Create dataset inside HDF5 file to store images + images = h.create_dataset('images', (len(impaths), 3, 256, 256), dtype='uint8') + + print(f"\nReading {split} images and captions, storing to file...\n") + + enc_captions = [] + caplens = [] + + for i, path in enumerate(tqdm(impaths)): + + # Sample captions + if len(imcaps[i]) < captions_per_image: + captions = imcaps[i] + [choice(imcaps[i]) for _ in range(captions_per_image - len(imcaps[i]))] + else: + captions = sample(imcaps[i], k=captions_per_image) + + # Sanity check + assert len(captions) == captions_per_image + + # Read images + img = imread(impaths[i]) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img = np.concatenate([img, img, img], axis=2) + img = np.array(Image.fromarray(img).resize((256, 256))) + img = img.transpose(2, 0, 1) + assert img.shape == (3, 256, 256) + assert np.max(img) <= 255 + + # Save image to HDF5 file + images[i] = img + + for j, c in enumerate(captions): + # Encode captions + enc_c = [word_map['']] + [word_map.get(word, word_map['']) for word in c] + [ + word_map['']] + [word_map['']] * (max_len - len(c)) + + # Find caption lengths + c_len = len(c) + 2 + + enc_captions.append(enc_c) + caplens.append(c_len) + + # Sanity check + assert images.shape[0] * captions_per_image == len(enc_captions) == len(caplens) + + # Save encoded captions and their lengths to JSON files + with open(os.path.join(output_folder, split + '_CAPTIONS_' + base_filename + '.json'), 'w', encoding='utf-8') as j: + json.dump(enc_captions, j) + + with open(os.path.join(output_folder, split + '_CAPLENS_' + base_filename + '.json'), 'w', encoding='utf-8') as j: + json.dump(caplens, j) + +def init_embedding(embeddings): + """ + Fills embedding tensor with values from the uniform distribution. + + :param embeddings: embedding tensor + """ + bias = np.sqrt(3.0 / embeddings.size(1)) + ops.uniform(embeddings, -bias, bias) + +def load_embeddings(emb_file, word_map): + """ + Creates an embedding tensor for the specified word map, for loading into the model. + + :param emb_file: file containing embeddings (stored in GloVe format) + :param word_map: word map + :return: embeddings in the same order as the words in the word map, dimension of embeddings + """ + + # Find embedding dimension + with open(emb_file, 'r', encoding='utf-8') as f: + emb_dim = len(f.readline().split(' ')) - 1 + + vocab = set(word_map.keys()) + embeddings = mindspore.Tensor((len(vocab), emb_dim), dtype=mindspore.float32) + init_embedding(embeddings) + + print("\nLoading embeddings...") + for line in open(emb_file, 'r', encoding='utf-8'): + line = line.split(' ') + + emb_word = line[0] + embedding = list(map(float, filter(lambda n: n and not n.isspace(), line[1:0]))) + + if emb_word not in vocab: + continue + + embeddings[word_map[emb_word]] = mindspore.Tensor(embedding, dtype=mindspore.float32) + + return embeddings, emb_dim + +def clip_gradient(grads, grad_clip): + """clip gradient""" + grads = ops.clip_by_value(grads, -grad_clip, grad_clip) + return grads + +class AverageMeter: + ''' + AverageMeter + ''' + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def reset(self): + '''reset''' + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + '''update''' + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def adjust_learning_rate(optimizer, shrink_factor): + """ + Shrinks learning rate by a specified factor. + + :param optimizer: optimizer whose learning rate must be shrunk. + :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. + """ + + print("\nDECAYING learning rate.") + new_lr = optimizer.learning_rate.value() * shrink_factor + optimizer.learning_rate.set_data(new_lr) + print("The new learning rate is {new_lr}") + +def accuracy(scores, targets, k): + """ + Computes top-k accuracy, from predicted and true labels. + + :param scores: scores from the model + :param targets: true labels + :param k: k in top-k accuracy + :return: top-k accuracy + """ + + batch_size = targets.shape[0] + _, ind = scores.topk(k, 1, True, True) + correct = ind.equal(targets.view(-1, 1).expand_as(ind)) + correct_total = correct.view(-1).float().sum() # 0D tensor + return correct_total * (100.0 / batch_size) + +def save_model(data_name, epoch, decoder, is_best=False): + """ + save model + """ + + # encoder_filename = f'encoder_{data_name}_{epoch}.ckpt' + decoder_filename = f'decoder_{data_name}_{epoch}.ckpt' + # save_checkpoint(encoder, encoder_filename) + save_checkpoint(decoder, decoder_filename) + if is_best: + # save_checkpoint(encoder, 'BEST_' + encoder_filename) + save_checkpoint(decoder, 'BEST_' + decoder_filename) + +def pack_padded(input, lengths): + packed_tensor = input[:, :max(lengths)] + packed_tensor = packed_tensor.flatten(start_dim=0, end_dim=1) + return packed_tensor \ No newline at end of file