-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
231 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### character-level LSTM in Pytorch" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 该笔记记录的模型如下所示\n", | ||
"<img src=\"assets/charseq.jpeg\" width=\"500\">" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 导入相关的包" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np \n", | ||
"import torch \n", | ||
"from torch import nn\n", | ||
"import torch.nn.functional as F" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 读取数据" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"with open('data/anna.txt', 'r') as f:\n", | ||
" text = f.read()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"'Chapter 1\\n\\n\\nHappy families are all alike; every unhappy family is unhappy in its own\\nway.\\n\\nEverythin'" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"text[:100]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 对其进行标记(tokenization)\n", | ||
"即是为其数据建造一个词典,让每个字符均有对应的标签" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"chars = tuple(set(text))\n", | ||
"int2char = dict(enumerate(chars))\n", | ||
"char2int = {ch:li for li,ch in int2char.items()}\n", | ||
"\n", | ||
"encoded = np.array([char2int[ch] for ch in text])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[14 45 48 54 76 72 74 46 75 22 22 22 82 48 54 54 18 46 80 48 12 49 27 49\n", | ||
" 72 62 46 48 74 72 46 48 27 27 46 48 27 49 28 72 4 46 72 2 72 74 18 46\n", | ||
" 70 39 45 48 54 54 18 46 80 48 12 49 27 18 46 49 62 46 70 39 45 48 54 54\n", | ||
" 18 46 49 39 46 49 76 62 46 9 51 39 22 51 48 18 0 22 22 58 2 72 74 18\n", | ||
" 76 45 49 39]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(encoded[:100])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 将其转换为one-hot编码\n", | ||
"构造一个函数,使其具有使得原本的文本转化为one-hot功能" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def one_hot_encoded(arr,n_labels):\n", | ||
" \n", | ||
" one_hot = np.zeros((np.multiply(*arr.shape),n_labels),dtype = np.float32)\n", | ||
" \n", | ||
" one_hot [np.arange(one_hot.shape[0]),arr.flatten()] = 1\n", | ||
" \n", | ||
" ont_hot = one_hot.reshape((*arr.shape,n_labels))\n", | ||
" \n", | ||
" return one_hot" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", | ||
" [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", | ||
" [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", | ||
" [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", | ||
" [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"test_seq = np.array([[2,4,5,7,2]])\n", | ||
"one_hot = one_hot_encoded(test_seq , 10)\n", | ||
"print(one_hot)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 建造一个mini-batch\n", | ||
"此处对应的是英文文本的batch,且训练数据是一段英文文本,故其batch的组成是其不同的英文序列\n", | ||
"且为了保留batch的完整,需要丢弃掉最后一部分空余的数据" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_batches(arr,batch_size,seq_length):\n", | ||
" \n", | ||
" batch_size_total = batch_size * seq_length\n", | ||
" \n", | ||
" n_batches = len(arr) //batch_size_total\n", | ||
" \n", | ||
" arr = arr[:n_batches * batch_size_total]\n", | ||
" \n", | ||
" arr = arr.reshape((batch_size , -1))\n", | ||
" \n", | ||
" for i in range(0,arr.shape[1],seq_length):\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"np.multiply?" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |