Skip to content

Commit

Permalink
US stock code supports Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupr committed Jan 26, 2021
1 parent df55653 commit 1a1c459
Show file tree
Hide file tree
Showing 11 changed files with 201 additions and 97 deletions.
5 changes: 3 additions & 2 deletions qlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.


__version__ = "0.6.1.dev"
__version__ = "0.6.1.99.dev"


import os
Expand All @@ -15,7 +15,7 @@
import subprocess
from pathlib import Path

from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path
from .utils import can_use_cache, init_instance_by_config, check_qlib_data
from .workflow.utils import experiment_exit_handler

# init qlib
Expand Down Expand Up @@ -88,6 +88,7 @@ def init(default_conf="client", **kwargs):
R.register(qr)
# clean up experiment when python program ends
experiment_exit_handler()
check_qlib_data(C)


def _mount_nfs_uri(C):
Expand Down
36 changes: 11 additions & 25 deletions qlib/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
import traceback
import numpy as np
import pandas as pd
from pathlib import Path
from multiprocessing import Pool

from .cache import H
from ..config import C
from .ops import *
from ..log import get_module_logger
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
from .base import Feature
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
Expand Down Expand Up @@ -215,23 +214,6 @@ def get_inst_type(cls, inst):
return cls.LIST
raise ValueError(f"Unknown instrument type {inst}")

def convert_instruments(self, instrument):
_instruments_map = getattr(self, "_instruments_map", None)
if _instruments_map is None:
_df_list = []
# FIXME: each process will read these files
for _path in Path(C.get_data_path()).joinpath("instruments").glob("*.txt"):
_df = pd.read_csv(_path, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
_df_list.append(_df.iloc[:, [0, -1]])
df = pd.concat(_df_list, sort=False)
df["inst"] = df["inst"].astype(str)
df = df.fillna(axis=1, method="ffill")
df = df.sort_values("inst").drop_duplicates(subset=["inst"], keep="first")
df["save_inst"] = df["save_inst"].astype(str)
_instruments_map = df.set_index("inst").iloc[:, 0].to_dict()
setattr(self, "_instruments_map", _instruments_map)
return _instruments_map.get(instrument, instrument)


class FeatureProvider(abc.ABC):
"""Feature provider class
Expand Down Expand Up @@ -590,12 +572,16 @@ def _load_instruments(self, market):
fname = self._uri_inst.format(market)
if not os.path.exists(fname):
raise ValueError("instruments not exists for market " + market)

_instruments = dict()
df = pd.read_csv(fname, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
df["start_datetime"] = pd.to_datetime(df["start_datetime"])
df["end_datetime"] = pd.to_datetime(df["end_datetime"])
df["inst"] = df["inst"].astype(str)
df["save_inst"] = df.loc[:, ["inst", "save_inst"]].fillna(axis=1, method="ffill")["save_inst"].astype(str)
df = pd.read_csv(
fname,
sep="\t",
usecols=[0, 1, 2],
names=["inst", "start_datetime", "end_datetime"],
dtype={"inst": str},
parse_dates=["start_datetime", "end_datetime"],
)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
Expand Down Expand Up @@ -652,7 +638,7 @@ def _uri_data(self):
def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
instrument = Inst.convert_instruments(instrument)
instrument = code_to_fname(instrument)
uri_data = self._uri_data.format(instrument.lower(), field, freq)
if not os.path.exists(uri_data):
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
Expand Down
6 changes: 5 additions & 1 deletion qlib/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def setUpClass(cls) -> None:
print(f"Qlib data is not found in {provider_uri}")

GetData().qlib_data(
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
name="qlib_data_simple",
region="cn",
interval="1d",
target_dir=provider_uri,
delete_old=False,
)
init(provider_uri=provider_uri, region=REG_CN)
91 changes: 80 additions & 11 deletions qlib/tests/data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import re
import qlib
import shutil
import zipfile
import requests
import datetime
from tqdm import tqdm
from pathlib import Path
from loguru import logger


class GetData:
DATASET_VERSION = "v1"
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip"

def __init__(self, delete_zip_file=False):
"""
Expand All @@ -20,13 +27,24 @@ def __init__(self, delete_zip_file=False):
"""
self.delete_zip_file = delete_zip_file

def _download_data(self, file_name: str, target_dir: [Path, str]):
def normalize_dataset_version(self, dataset_version: str = None):
if dataset_version is None:
dataset_version = self.DATASET_VERSION
return dataset_version

def merge_remote_url(self, file_name: str, dataset_version: str = None):
return f"{self.REMOTE_URL}/{self.normalize_dataset_version(dataset_version)}/{file_name}"

def _download_data(
self, file_name: str, target_dir: [Path, str], delete_old: bool = True, dataset_version: str = None
):
target_dir = Path(target_dir).expanduser()
target_dir.mkdir(exist_ok=True, parents=True)
# saved file name
_target_file_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + file_name
target_path = target_dir.joinpath(_target_file_name)

url = f"{self.REMOTE_URL}/{file_name}"
target_path = target_dir.joinpath(file_name)

url = self.merge_remote_url(file_name, dataset_version)
resp = requests.get(url, stream=True)
if resp.status_code != 200:
raise requests.exceptions.HTTPError()
Expand All @@ -42,19 +60,59 @@ def _download_data(self, file_name: str, target_dir: [Path, str]):
fp.write(chuck)
p_bar.update(chuck_size)

self._unzip(target_path, target_dir)
self._unzip(target_path, target_dir, delete_old)
if self.delete_zip_file:
target_path.unlike()

def check_dataset(self, file_name: str, dataset_version: str = None):
url = self.merge_remote_url(file_name, dataset_version)
resp = requests.get(url, stream=True)
status = True
if resp.status_code == 404:
status = False
return status

@staticmethod
def _unzip(file_path: Path, target_dir: Path):
def _unzip(file_path: Path, target_dir: Path, delete_old: bool = True):
if delete_old:
logger.warning(
f"will delete the old qlib data directory(features, instruments, calendars, features_cache, dataset_cache): {target_dir}"
)
GetData._delete_qlib_data(target_dir)
logger.info(f"{file_path} unzipping......")
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
for _file in tqdm(zp.namelist()):
zp.extract(_file, str(target_dir.resolve()))

@staticmethod
def _delete_qlib_data(file_dir: Path):
logger.info(f"delete {file_dir}")
rm_dirs = []
for _name in ["features", "calendars", "instruments", "features_cache", "dataset_cache"]:
_p = file_dir.joinpath(_name)
if _p.exists():
rm_dirs.append(str(_p.resolve()))
if rm_dirs:
flag = input(
f"Will be deleted: "
f"\n\t{rm_dirs}"
f"\nIf you do not need to delete {file_dir}, please change the <--target_dir>"
f"\nAre you sure you want to delete, yes(Y/y), no (N/n):"
)
if str(flag) not in ["Y", "y"]:
exit()
for _p in rm_dirs:
logger.warning(f"delete: {_p}")
shutil.rmtree(_p)

def qlib_data(
self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"
self,
name="qlib_data",
target_dir="~/.qlib/qlib_data/cn_data",
version=None,
interval="1d",
region="cn",
delete_old=True,
):
"""download cn qlib data from remote
Expand All @@ -65,20 +123,31 @@ def qlib_data(
name: str
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
version: str
data version, value from [v0, v1, ..., latest], by default latest
data version, value from [v1, ...], by default None(use script to specify version)
interval: str
data freq, value from [1d], by default 1d
region: str
data region, value from [cn, us], by default cn
delete_old: bool
delete an existing directory, by default True
Examples
---------
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
-------
"""
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
self._download_data(file_name.lower(), target_dir)
qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__))

def _get_file_name(v):
return self.QLIB_DATA_NAME.format(
dataset_name=name, region=region.lower(), interval=interval.lower(), qlib_version=v
)

file_name = _get_file_name(qlib_version)
if not self.check_dataset(file_name, version):
file_name = _get_file_name("latest")
self._download_data(file_name.lower(), target_dir, delete_old, dataset_version=version)

def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
"""download cn csv data from remote
Expand Down
53 changes: 50 additions & 3 deletions qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pathlib import Path
from typing import Union, Tuple

from .. import __version__ as qlib_version
from ..config import C
from ..log import get_module_logger

Expand Down Expand Up @@ -643,15 +644,28 @@ def exists_qlib_data(qlib_dir):
# check instruments
code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir()))
_instrument = instruments_dir.joinpath("all.txt")
df = pd.read_csv(_instrument, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
df = df.iloc[:, [0, -1]].fillna(axis=1, method="ffill")
miss_code = set(df.iloc[:, -1].apply(str.lower)) - set(code_names)
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
return False

return True


def check_qlib_data(qlib_config):
inst_dir = Path(qlib_config["provider_uri"]).joinpath("instruments")
for _p in inst_dir.glob("*.txt"):
try:
assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, (
f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:"
f"\n\tIf you are using the data provided by qlib: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset"
f"\n\tIf you are using your own data, please dump the data again: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format"
)
except AssertionError:
raise


def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
"""
make the df index sorted
Expand Down Expand Up @@ -742,3 +756,36 @@ def load_dataset(path_or_obj):
elif extension == ".csv":
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
raise ValueError(f"unsupported file type `{extension}`")


def code_to_fname(code: str):
"""stock code to file name
Parameters
----------
code: str
"""
# NOTE: In windows, the following name is I/O device, and the file with the corresponding name cannot be created
# reference: https://superuser.com/questions/86999/why-cant-i-name-a-folder-or-file-con-in-windows
replace_names = ["CON", "PRN", "AUX", "NUL"]
replace_names += [f"COM{i}" for i in range(10)]
replace_names += [f"LPT{i}" for i in range(10)]

prefix = "_qlib_"
if str(code).upper() in replace_names:
code = prefix + str(code)

return code


def fname_to_code(fname: str):
"""file name to stock code
Parameters
----------
fname: str
"""
prefix = "_qlib_"
if fname.startswith(prefix):
fname = fname.lstrip(prefix)
return fname
2 changes: 1 addition & 1 deletion scripts/data_collector/yahoo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pip install -r requirements.txt

### Download data and Normalize data
```bash
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d --normalize_dir ~/.qlib/stock_data/normalize
```

### Download Data
Expand Down
14 changes: 7 additions & 7 deletions scripts/data_collector/yahoo/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from loguru import logger
from yahooquery import Ticker
from dateutil.tz import tzlocal
from qlib.utils import code_to_fname

CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
Expand All @@ -40,7 +41,7 @@ def __init__(
end=None,
interval="1d",
max_workers=4,
max_collector_count=5,
max_collector_count=2,
delay=0,
check_data_length: bool = False,
limit_nums: int = None,
Expand All @@ -55,7 +56,7 @@ def __init__(
max_workers: int
workers, default 4
max_collector_count: int
default 5
default 2
delay: float
time.sleep(delay), default 0
interval: str
Expand Down Expand Up @@ -147,11 +148,10 @@ def save_stock(self, symbol, df: pd.DataFrame):
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
if stock_path.exists():
with stock_path.open("a") as fp:
df.to_csv(fp, index=False, header=False)
_temp_df = pd.read_csv(stock_path, nrows=0)
df.loc[:, _temp_df.columns].to_csv(stock_path, index=False, header=False, mode="a")
else:
with stock_path.open("w") as fp:
df.to_csv(fp, index=False)
df.to_csv(stock_path, index=False, mode="w")

def _save_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
Expand Down Expand Up @@ -350,7 +350,7 @@ def download_index_data(self):
pass

def normalize_symbol(self, symbol):
return symbol.upper()
return code_to_fname(symbol).upper()

@property
def _timezone(self):
Expand Down
Loading

0 comments on commit 1a1c459

Please sign in to comment.