Skip to content

Commit

Permalink
Merge pull request #6 from M-Puig/martin
Browse files Browse the repository at this point in the history
aaaaaa
  • Loading branch information
M-Puig authored Dec 11, 2021
2 parents 894bb01 + bfb7cc2 commit 98cb67e
Show file tree
Hide file tree
Showing 3 changed files with 1,212 additions and 112 deletions.
349 changes: 349 additions & 0 deletions notebooks/gen_TO.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,349 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 25,
"id": "0dbf34a7-a58e-49a2-aa4b-8bca5f58c00e",
"metadata": {},
"outputs": [],
"source": [
"from __future__ import absolute_import\n",
"from __future__ import division\n",
"from __future__ import print_function\n",
"from __future__ import unicode_literals\n",
"\n",
"import torch\n",
"from torch.jit import script, trace\n",
"import torch.nn as nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"import csv\n",
"import random\n",
"import re\n",
"import os\n",
"import unicodedata\n",
"import codecs\n",
"from io import open\n",
"import itertools\n",
"import math\n",
"import pickle\n",
"import statistics\n",
"import sys\n",
"from functools import partial\n",
"\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"import tqdm\n",
"import nltk\n",
"#from google.colab import files"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "88ed704e-8f0c-4750-8c2c-a085911ea7e9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"Using device: cuda\n"
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"print(torch.cuda.is_available())\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
"print(\"Using device:\", device)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e20b3775-ea62-4292-97b9-c7f2e3d015a9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"573\n",
" author \\\n",
"0 WILLIAM SHAKESPEARE \n",
"1 DUCHESS OF NEWCASTLE MARGARET CAVENDISH \n",
"2 THOMAS BASTARD \n",
"3 EDMUND SPENSER \n",
"4 RICHARD BARNFIELD \n",
"\n",
" content \\\n",
"0 Let the bird of loudest lay\\r\\nOn the sole Ara... \n",
"1 Sir Charles into my chamber coming in,\\r\\nWhen... \n",
"2 Our vice runs beyond all that old men saw,\\r\\n... \n",
"3 Lo I the man, whose Muse whilome did maske,\\r\\... \n",
"4 Long have I longd to see my love againe,\\r\\nSt... \n",
"\n",
" poem name age type \n",
"0 The Phoenix and the Turtle Renaissance Mythology & Folklore \n",
"1 An Epilogue to the Above Renaissance Mythology & Folklore \n",
"2 Book 7, Epigram 42 Renaissance Mythology & Folklore \n",
"3 from The Faerie Queene: Book I, Canto I Renaissance Mythology & Folklore \n",
"4 Sonnet 16 Renaissance Mythology & Folklore \n"
]
}
],
"source": [
"data_file = '../data/poem/with_epoque.csv'\n",
"dataset = pd.read_csv(data_file)\n",
"print(len(dataset))\n",
"print(dataset.head())"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f8073de9-4b6d-4117-8704-60a2266e4532",
"metadata": {},
"outputs": [],
"source": [
"def make_data_training(df, bos_token = '<bos> ', eos_token = ' <bos>'):\n",
" inputs = []\n",
" context = []\n",
" targets = []\n",
" for i,rows in df.iterrows(): \n",
" for line in rows['content'].split('\\r\\n'):\n",
" if len(line.strip()) > 0:\n",
" inputs += [bos_token + line]\n",
" targets += [line + eos_token]\n",
" context.append(' '.join([str(rows['poem name']), rows['age'], rows['type']]))\n",
" \n",
" return pd.DataFrame(list(zip(inputs, context, targets)),columns =['text', 'context','target'])\n",
"\n",
"\n",
"#Defining torch dataset class for poems\n",
"class PoemDataset(Dataset):\n",
" def __init__(self, df):\n",
" self.df = df\n",
"\n",
" def __len__(self):\n",
" return len(self.df)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.df.iloc[idx]"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "5a8aecb5-c164-4b77-8849-4c686c2abb90",
"metadata": {},
"outputs": [],
"source": [
"df = make_data_training(dataset)\n",
"\n",
"num_lines = len(df)\n",
"\n",
"idxs = list(range(num_lines))\n",
"\n",
"test_idx = idxs[:int(0.1*num_lines)]\n",
"val_idx = idxs[int(0.1*num_lines):int(0.2*num_lines)]\n",
"train_idx = idxs[int(0.2*num_lines):]\n",
"\n",
"train_df = df.iloc[train_idx].reset_index(drop=True)\n",
"val_df = df.iloc[val_idx].reset_index(drop=True)\n",
"test_df = df.iloc[test_idx].reset_index(drop=True)\n",
"\n",
"train_data = train_df[['context', 'text', 'target']]\n",
"val_data = val_df[['context', 'text', 'target']]\n",
"test_data = test_df[['context', 'text', 'target']]\n",
"\n",
"train_dataset = PoemDataset(train_data)\n",
"val_dataset = PoemDataset(val_data)\n",
"test_dataset = PoemDataset(test_data)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "375d1ba3-256d-474f-9911-d78290c8ae0f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias']\n",
"- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"bert_model_name = 'distilbert-base-uncased' \n",
"\n",
"from transformers import DistilBertTokenizer, DistilBertModel\n",
"from transformers import get_linear_schedule_with_warmup\n",
"from tokenizers.processors import BertProcessing\n",
"\n",
"bert_model = DistilBertModel.from_pretrained(bert_model_name)\n",
"tokenizer = DistilBertTokenizer.from_pretrained(bert_model_name)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "f50e2c02-0f09-4131-88d6-09a244c50e30",
"metadata": {},
"outputs": [],
"source": [
"def transformer_collate_fn(batch, tokenizer):\n",
" bert_vocab = tokenizer.get_vocab()\n",
" bert_pad_token = bert_vocab['[PAD]']\n",
" bert_unk_token = bert_vocab['[UNK]']\n",
" bert_cls_token = bert_vocab['[CLS]']\n",
"\n",
" sentences, targets, masks = [], [], []\n",
" for data in batch:\n",
"\n",
" tokenizer_output = tokenizer([data['text']])\n",
" tokenized_sent = tokenizer_output['input_ids'][0]\n",
" \n",
" tokenizer_target = tokenizer([data['target']])\n",
" tokenized_sent_target = tokenizer_target['input_ids'][0]\n",
" \n",
" mask = tokenizer_output['attention_mask'][0]\n",
" sentences.append(torch.tensor(tokenized_sent))\n",
" targets.append(torch.tensor(tokenized_sent_target))\n",
" masks.append(torch.tensor(mask))\n",
" sentences = pad_sequence(sentences, batch_first=True, padding_value=bert_pad_token)\n",
" targets = pad_sequence(targets, batch_first=True, padding_value=bert_pad_token)\n",
" masks = pad_sequence(masks, batch_first=True, padding_value=0.0)\n",
" return sentences, targets, masks"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b523e4ae-ecb5-4e1a-baa9-6f3fbe04444b",
"metadata": {},
"outputs": [],
"source": [
"class EratoModel(nn.Module):\n",
" def __init__(self,\n",
" poly_encoder: nn.Module,\n",
" bert_encoder: nn.Module,\n",
" decoder: nn.Module,\n",
" enc_hid_dim=768, #default embedding size\n",
" outputs=2,\n",
" dropout=0.1):\n",
" super().__init__()\n",
" \n",
" self.poly_encoder = poly_encoder\n",
" self.bert_encoder = bert_encoder\n",
" self.decoder = decoder\n",
"\n",
"\n",
" def forward(self,\n",
" src,\n",
" mask):\n",
" bert_output = self.bert_encoder(src, mask)\n",
"\n",
" ### YOUR CODE HERE ###\n",
" hidden_state = bert_output[0] # (bs, seq_len, dim)\n",
"\n",
" return None"
]
},
{
"cell_type": "markdown",
"id": "de5dfb53-f33e-49af-afd9-58b6c838d3bd",
"metadata": {},
"source": [
"# Model Training"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "1d8f03a3-ff6d-4e8b-9305-b3f96af88054",
"metadata": {},
"outputs": [],
"source": [
"#define hyperparameters\n",
"BATCH_SIZE = 10\n",
"LR = 1e-5\n",
"WEIGHT_DECAY = 0\n",
"N_EPOCHS = 3\n",
"CLIP = 1.0\n",
"\n",
"#create pytorch dataloaders from train_dataset, val_dataset, and test_datset\n",
"train_dataloader = DataLoader(train_dataset,batch_size=BATCH_SIZE,collate_fn=partial(transformer_collate_fn, tokenizer=tokenizer), shuffle = True)\n",
"val_dataloader = DataLoader(val_dataset,batch_size=BATCH_SIZE,collate_fn=partial(transformer_collate_fn, tokenizer=tokenizer))\n",
"test_dataloader = DataLoader(test_dataset,batch_size=BATCH_SIZE,collate_fn=partial(transformer_collate_fn, tokenizer=tokenizer))"
]
},
{
"cell_type": "code",
"execution_count": 91,
"id": "ca9efe4e-4834-4882-90c3-6783f055653e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([10, 18, 768])\n"
]
}
],
"source": [
"for batch in train_dataloader:\n",
" sentences, targets, masks = batch[0], batch[1], batch[2]\n",
" s = tokenizer.decode(sentences[0,:], skip_special_tokens=False, clean_up_tokenization_spaces=False)\n",
" t = tokenizer.decode(targets[0,:], skip_special_tokens=False, clean_up_tokenization_spaces=False)\n",
" \n",
" bert_output = bert_model(sentences, masks)\n",
" \n",
" print(bert_output[0].shape)\n",
" \n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11bafc31-688f-4e9c-a517-e0f23728f44a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 98cb67e

Please sign in to comment.