forked from pythonstock/stock
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.py
194 lines (168 loc) · 7.24 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#!/usr/local/bin/python
# -*- coding: utf-8 -*-
# apk add py-mysqldb or
import platform
import datetime
import time
import sys
import os
import MySQLdb
from sqlalchemy import create_engine
from sqlalchemy.types import NVARCHAR
from sqlalchemy import inspect
import tushare as ts
import pandas as pd
import traceback
# 使用环境变量获得数据库。兼容开发模式可docker模式。
MYSQL_HOST = os.environ.get('MYSQL_HOST') if (os.environ.get('MYSQL_HOST') != None) else "mariadb"
MYSQL_USER = os.environ.get('MYSQL_USER') if (os.environ.get('MYSQL_USER') != None) else "root"
MYSQL_PWD = os.environ.get('MYSQL_PWD') if (os.environ.get('MYSQL_PWD') != None) else "mariadb"
MYSQL_DB = os.environ.get('MYSQL_DB') if (os.environ.get('MYSQL_DB') != None) else "stock_data"
print("MYSQL_HOST :", MYSQL_HOST, ",MYSQL_USER :", MYSQL_USER, ",MYSQL_DB :", MYSQL_DB)
MYSQL_CONN_URL = "mysql+mysqldb://" + MYSQL_USER + ":" + MYSQL_PWD + "@" + MYSQL_HOST + ":3306/" + MYSQL_DB + "?charset=utf8"
print("MYSQL_CONN_URL :", MYSQL_CONN_URL)
def engine():
engine = create_engine(
MYSQL_CONN_URL,
encoding='utf8', convert_unicode=True)
return engine
def engine_to_db(to_db):
MYSQL_CONN_URL_NEW = "mysql+mysqldb://" + MYSQL_USER + ":" + MYSQL_PWD + "@" + MYSQL_HOST + ":3306/" + to_db + "?charset=utf8"
engine = create_engine(
MYSQL_CONN_URL_NEW,
encoding='utf8', convert_unicode=True)
return engine
# 通过数据库链接 engine。
def conn():
try:
db = MySQLdb.connect(MYSQL_HOST, MYSQL_USER, MYSQL_PWD, MYSQL_DB, charset="utf8")
# db.autocommit = True
except Exception as e:
print("conn error :", e)
db.autocommit(on=True)
return db.cursor()
# 定义通用方法函数,插入数据库表,并创建数据库主键,保证重跑数据的时候索引唯一。
def insert_db(data, table_name, write_index, primary_keys):
# 插入默认的数据库。
insert_other_db(MYSQL_DB, data, table_name, write_index, primary_keys)
# 增加一个插入到其他数据库的方法。
def insert_other_db(to_db, data, table_name, write_index, primary_keys):
# 定义engine
engine_mysql = engine_to_db(to_db)
# 使用 http://docs.sqlalchemy.org/en/latest/core/reflection.html
# 使用检查检查数据库表是否有主键。
insp = inspect(engine_mysql)
col_name_list = data.columns.tolist()
# 如果有索引,把索引增加到varchar上面。
if write_index:
# 插入到第一个位置:
col_name_list.insert(0, data.index.name)
print(col_name_list)
data.to_sql(name=table_name, con=engine_mysql, schema=to_db, if_exists='append',
dtype={col_name: NVARCHAR(length=255) for col_name in col_name_list}, index=write_index)
# 判断是否存在主键
if insp.get_primary_keys(table_name) == []:
with engine_mysql.connect() as con:
# 执行数据库插入数据。
try:
con.execute('ALTER TABLE `%s` ADD PRIMARY KEY (%s);' % (table_name, primary_keys))
except Exception as e:
print("################## ADD PRIMARY KEY ERROR :", e)
# 插入数据。
def insert(sql, params=()):
with conn() as db:
print("insert sql:" + sql)
try:
db.execute(sql, params)
except Exception as e:
print("error :", e)
# 查询数据
def select(sql, params=()):
with conn() as db:
print("select sql:" + sql)
try:
db.execute(sql, params)
except Exception as e:
print("error :", e)
result = db.fetchall()
return result
# 计算数量
def select_count(sql, params=()):
with conn() as db:
print("select sql:" + sql)
try:
db.execute(sql, params)
except Exception as e:
print("error :", e)
result = db.fetchall()
# 只有一个数组中的第一个数据
if len(result) == 1:
return int(result[0][0])
else:
return 0
# 通用函数。获得日期参数。
def run_with_args(run_fun):
tmp_datetime_show = datetime.datetime.now() # 修改成默认是当日执行 + datetime.timedelta()
tmp_datetime_str = tmp_datetime_show.strftime("%Y-%m-%d %H:%M:%S.%f")
str_db = "MYSQL_HOST :" + MYSQL_HOST + ", MYSQL_USER :" + MYSQL_USER + ", MYSQL_DB :" + MYSQL_DB
print("\n######################### " + str_db + " ######################### ")
print("\n######################### begin run %s %s #########################" % (run_fun, tmp_datetime_str))
start = time.time()
# 要支持数据重跑机制,将日期传入。循环次数
if len(sys.argv) == 3:
# python xxx.py 2017-07-01 10
tmp_year, tmp_month, tmp_day = sys.argv[1].split("-")
loop = int(sys.argv[2])
tmp_datetime = datetime.datetime(int(tmp_year), int(tmp_month), int(tmp_day))
for i in range(0, loop):
# 循环插入多次数据,重复跑历史数据使用。
# time.sleep(5)
tmp_datetime_new = tmp_datetime + datetime.timedelta(days=i)
try:
run_fun(tmp_datetime_new)
except Exception as e:
print("error :", e)
traceback.print_exc()
elif len(sys.argv) == 2:
# python xxx.py 2017-07-01
tmp_year, tmp_month, tmp_day = sys.argv[1].split("-")
tmp_datetime = datetime.datetime(int(tmp_year), int(tmp_month), int(tmp_day))
try:
run_fun(tmp_datetime)
except Exception as e:
print("error :", e)
traceback.print_exc()
else:
# tmp_datetime = datetime.datetime.now() + datetime.timedelta(days=-1)
try:
run_fun(tmp_datetime_show) # 使用当前时间
except Exception as e:
print("error :", e)
traceback.print_exc()
print("######################### finish %s , use time: %s #########################" % (
tmp_datetime_str, time.time() - start))
# 设置基础目录,每次加载使用。
bash_stock_tmp = "/data/cache/hist_data_cache/%s/%s/"
if not os.path.exists(bash_stock_tmp):
os.makedirs(bash_stock_tmp) # 创建多个文件夹结构。
print("######################### init tmp dir #########################")
# 增加读取股票缓存方法。加快处理速度。
def get_hist_data_cache(code, date_start, date_end):
cache_dir = bash_stock_tmp % (date_end[0:7], date_end)
# 如果没有文件夹创建一个。月文件夹和日文件夹。方便删除。
# print("cache_dir:", cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
cache_file = cache_dir + "%s^%s.gzip.pickle" % (date_end, code)
# 如果缓存存在就直接返回缓存数据。压缩方式。
if os.path.isfile(cache_file):
print("######### read from cache #########", cache_file)
return pd.read_pickle(cache_file, compression="gzip")
else:
print("######### get data, write cache #########", code, date_start, date_end)
stock = ts.get_hist_data(code, start=date_start, end=date_end)
if stock is None:
return None
stock = stock.sort_index(0) # 将数据按照日期排序下。
stock.to_pickle(cache_file, compression="gzip")
return stock