-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from M-Puig/martin
aaaaaa
- Loading branch information
Showing
3 changed files
with
1,212 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,349 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 25, | ||
"id": "0dbf34a7-a58e-49a2-aa4b-8bca5f58c00e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from __future__ import absolute_import\n", | ||
"from __future__ import division\n", | ||
"from __future__ import print_function\n", | ||
"from __future__ import unicode_literals\n", | ||
"\n", | ||
"import torch\n", | ||
"from torch.jit import script, trace\n", | ||
"import torch.nn as nn\n", | ||
"from torch import optim\n", | ||
"import torch.nn.functional as F\n", | ||
"import numpy as np\n", | ||
"import csv\n", | ||
"import random\n", | ||
"import re\n", | ||
"import os\n", | ||
"import unicodedata\n", | ||
"import codecs\n", | ||
"from io import open\n", | ||
"import itertools\n", | ||
"import math\n", | ||
"import pickle\n", | ||
"import statistics\n", | ||
"import sys\n", | ||
"from functools import partial\n", | ||
"\n", | ||
"from torch.utils.data import Dataset, DataLoader\n", | ||
"from torch.nn.utils.rnn import pad_sequence\n", | ||
"import tqdm\n", | ||
"import nltk\n", | ||
"#from google.colab import files" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "88ed704e-8f0c-4750-8c2c-a085911ea7e9", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"True\n", | ||
"Using device: cuda\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"\n", | ||
"print(torch.cuda.is_available())\n", | ||
"if torch.cuda.is_available():\n", | ||
" device = torch.device(\"cuda\")\n", | ||
"else:\n", | ||
" device = torch.device(\"cpu\")\n", | ||
"print(\"Using device:\", device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "e20b3775-ea62-4292-97b9-c7f2e3d015a9", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"573\n", | ||
" author \\\n", | ||
"0 WILLIAM SHAKESPEARE \n", | ||
"1 DUCHESS OF NEWCASTLE MARGARET CAVENDISH \n", | ||
"2 THOMAS BASTARD \n", | ||
"3 EDMUND SPENSER \n", | ||
"4 RICHARD BARNFIELD \n", | ||
"\n", | ||
" content \\\n", | ||
"0 Let the bird of loudest lay\\r\\nOn the sole Ara... \n", | ||
"1 Sir Charles into my chamber coming in,\\r\\nWhen... \n", | ||
"2 Our vice runs beyond all that old men saw,\\r\\n... \n", | ||
"3 Lo I the man, whose Muse whilome did maske,\\r\\... \n", | ||
"4 Long have I longd to see my love againe,\\r\\nSt... \n", | ||
"\n", | ||
" poem name age type \n", | ||
"0 The Phoenix and the Turtle Renaissance Mythology & Folklore \n", | ||
"1 An Epilogue to the Above Renaissance Mythology & Folklore \n", | ||
"2 Book 7, Epigram 42 Renaissance Mythology & Folklore \n", | ||
"3 from The Faerie Queene: Book I, Canto I Renaissance Mythology & Folklore \n", | ||
"4 Sonnet 16 Renaissance Mythology & Folklore \n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"data_file = '../data/poem/with_epoque.csv'\n", | ||
"dataset = pd.read_csv(data_file)\n", | ||
"print(len(dataset))\n", | ||
"print(dataset.head())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"id": "f8073de9-4b6d-4117-8704-60a2266e4532", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def make_data_training(df, bos_token = '<bos> ', eos_token = ' <bos>'):\n", | ||
" inputs = []\n", | ||
" context = []\n", | ||
" targets = []\n", | ||
" for i,rows in df.iterrows(): \n", | ||
" for line in rows['content'].split('\\r\\n'):\n", | ||
" if len(line.strip()) > 0:\n", | ||
" inputs += [bos_token + line]\n", | ||
" targets += [line + eos_token]\n", | ||
" context.append(' '.join([str(rows['poem name']), rows['age'], rows['type']]))\n", | ||
" \n", | ||
" return pd.DataFrame(list(zip(inputs, context, targets)),columns =['text', 'context','target'])\n", | ||
"\n", | ||
"\n", | ||
"#Defining torch dataset class for poems\n", | ||
"class PoemDataset(Dataset):\n", | ||
" def __init__(self, df):\n", | ||
" self.df = df\n", | ||
"\n", | ||
" def __len__(self):\n", | ||
" return len(self.df)\n", | ||
"\n", | ||
" def __getitem__(self, idx):\n", | ||
" return self.df.iloc[idx]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 36, | ||
"id": "5a8aecb5-c164-4b77-8849-4c686c2abb90", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df = make_data_training(dataset)\n", | ||
"\n", | ||
"num_lines = len(df)\n", | ||
"\n", | ||
"idxs = list(range(num_lines))\n", | ||
"\n", | ||
"test_idx = idxs[:int(0.1*num_lines)]\n", | ||
"val_idx = idxs[int(0.1*num_lines):int(0.2*num_lines)]\n", | ||
"train_idx = idxs[int(0.2*num_lines):]\n", | ||
"\n", | ||
"train_df = df.iloc[train_idx].reset_index(drop=True)\n", | ||
"val_df = df.iloc[val_idx].reset_index(drop=True)\n", | ||
"test_df = df.iloc[test_idx].reset_index(drop=True)\n", | ||
"\n", | ||
"train_data = train_df[['context', 'text', 'target']]\n", | ||
"val_data = val_df[['context', 'text', 'target']]\n", | ||
"test_data = test_df[['context', 'text', 'target']]\n", | ||
"\n", | ||
"train_dataset = PoemDataset(train_data)\n", | ||
"val_dataset = PoemDataset(val_data)\n", | ||
"test_dataset = PoemDataset(test_data)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 37, | ||
"id": "375d1ba3-256d-474f-9911-d78290c8ae0f", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias']\n", | ||
"- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | ||
"- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"bert_model_name = 'distilbert-base-uncased' \n", | ||
"\n", | ||
"from transformers import DistilBertTokenizer, DistilBertModel\n", | ||
"from transformers import get_linear_schedule_with_warmup\n", | ||
"from tokenizers.processors import BertProcessing\n", | ||
"\n", | ||
"bert_model = DistilBertModel.from_pretrained(bert_model_name)\n", | ||
"tokenizer = DistilBertTokenizer.from_pretrained(bert_model_name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 38, | ||
"id": "f50e2c02-0f09-4131-88d6-09a244c50e30", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def transformer_collate_fn(batch, tokenizer):\n", | ||
" bert_vocab = tokenizer.get_vocab()\n", | ||
" bert_pad_token = bert_vocab['[PAD]']\n", | ||
" bert_unk_token = bert_vocab['[UNK]']\n", | ||
" bert_cls_token = bert_vocab['[CLS]']\n", | ||
"\n", | ||
" sentences, targets, masks = [], [], []\n", | ||
" for data in batch:\n", | ||
"\n", | ||
" tokenizer_output = tokenizer([data['text']])\n", | ||
" tokenized_sent = tokenizer_output['input_ids'][0]\n", | ||
" \n", | ||
" tokenizer_target = tokenizer([data['target']])\n", | ||
" tokenized_sent_target = tokenizer_target['input_ids'][0]\n", | ||
" \n", | ||
" mask = tokenizer_output['attention_mask'][0]\n", | ||
" sentences.append(torch.tensor(tokenized_sent))\n", | ||
" targets.append(torch.tensor(tokenized_sent_target))\n", | ||
" masks.append(torch.tensor(mask))\n", | ||
" sentences = pad_sequence(sentences, batch_first=True, padding_value=bert_pad_token)\n", | ||
" targets = pad_sequence(targets, batch_first=True, padding_value=bert_pad_token)\n", | ||
" masks = pad_sequence(masks, batch_first=True, padding_value=0.0)\n", | ||
" return sentences, targets, masks" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b523e4ae-ecb5-4e1a-baa9-6f3fbe04444b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class EratoModel(nn.Module):\n", | ||
" def __init__(self,\n", | ||
" poly_encoder: nn.Module,\n", | ||
" bert_encoder: nn.Module,\n", | ||
" decoder: nn.Module,\n", | ||
" enc_hid_dim=768, #default embedding size\n", | ||
" outputs=2,\n", | ||
" dropout=0.1):\n", | ||
" super().__init__()\n", | ||
" \n", | ||
" self.poly_encoder = poly_encoder\n", | ||
" self.bert_encoder = bert_encoder\n", | ||
" self.decoder = decoder\n", | ||
"\n", | ||
"\n", | ||
" def forward(self,\n", | ||
" src,\n", | ||
" mask):\n", | ||
" bert_output = self.bert_encoder(src, mask)\n", | ||
"\n", | ||
" ### YOUR CODE HERE ###\n", | ||
" hidden_state = bert_output[0] # (bs, seq_len, dim)\n", | ||
"\n", | ||
" return None" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "de5dfb53-f33e-49af-afd9-58b6c838d3bd", | ||
"metadata": {}, | ||
"source": [ | ||
"# Model Training" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 39, | ||
"id": "1d8f03a3-ff6d-4e8b-9305-b3f96af88054", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#define hyperparameters\n", | ||
"BATCH_SIZE = 10\n", | ||
"LR = 1e-5\n", | ||
"WEIGHT_DECAY = 0\n", | ||
"N_EPOCHS = 3\n", | ||
"CLIP = 1.0\n", | ||
"\n", | ||
"#create pytorch dataloaders from train_dataset, val_dataset, and test_datset\n", | ||
"train_dataloader = DataLoader(train_dataset,batch_size=BATCH_SIZE,collate_fn=partial(transformer_collate_fn, tokenizer=tokenizer), shuffle = True)\n", | ||
"val_dataloader = DataLoader(val_dataset,batch_size=BATCH_SIZE,collate_fn=partial(transformer_collate_fn, tokenizer=tokenizer))\n", | ||
"test_dataloader = DataLoader(test_dataset,batch_size=BATCH_SIZE,collate_fn=partial(transformer_collate_fn, tokenizer=tokenizer))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 91, | ||
"id": "ca9efe4e-4834-4882-90c3-6783f055653e", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"torch.Size([10, 18, 768])\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for batch in train_dataloader:\n", | ||
" sentences, targets, masks = batch[0], batch[1], batch[2]\n", | ||
" s = tokenizer.decode(sentences[0,:], skip_special_tokens=False, clean_up_tokenization_spaces=False)\n", | ||
" t = tokenizer.decode(targets[0,:], skip_special_tokens=False, clean_up_tokenization_spaces=False)\n", | ||
" \n", | ||
" bert_output = bert_model(sentences, masks)\n", | ||
" \n", | ||
" print(bert_output[0].shape)\n", | ||
" \n", | ||
" break" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "11bafc31-688f-4e9c-a517-e0f23728f44a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.