forked from lucasjinreal/tensorflow_poems
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpoems.py
79 lines (70 loc) · 2.77 KB
/
poems.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# -*- coding: utf-8 -*-
# file: poems.py
# author: JinTian
# time: 08/03/2017 7:39 PM
# Copyright 2017 JinTian. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
import collections
import numpy as np
start_token = 'B'
end_token = 'E'
def process_poems(file_name):
# poems -> list of numbers
poems = []
with open(file_name, "r", encoding='utf-8', ) as f:
for line in f.readlines():
try:
title, content = line.strip().split(':')
content = content.replace(' ', '')
if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
start_token in content or end_token in content:
continue
if len(content) < 5 or len(content) > 79:
continue
content = start_token + content + end_token
poems.append(content)
except ValueError as e:
pass
# poems = sorted(poems, key=len)
all_words = [word for poem in poems for word in poem]
counter = collections.Counter(all_words)
words = sorted(counter.keys(), key=lambda x: counter[x], reverse=True)
words.append(' ')
L = len(words)
word_int_map = dict(zip(words, range(L)))
poems_vector = [list(map(lambda word: word_int_map.get(word, L), poem)) for poem in poems]
return poems_vector, word_int_map, words
def generate_batch(batch_size, poems_vec, word_to_int):
n_chunk = len(poems_vec) // batch_size
x_batches = []
y_batches = []
for i in range(n_chunk):
start_index = i * batch_size
end_index = start_index + batch_size
batches = poems_vec[start_index:end_index]
length = max(map(len, batches))
x_data = np.full((batch_size, length), word_to_int[' '], np.int32)
for row, batch in enumerate(batches):
x_data[row, :len(batch)] = batch
y_data = np.copy(x_data)
y_data[:, :-1] = x_data[:, 1:]
"""
x_data y_data
[6,2,4,6,9] [2,4,6,9,9]
[1,4,2,8,5] [4,2,8,5,5]
"""
x_batches.append(x_data)
y_batches.append(y_data)
return x_batches, y_batches