diff --git a/vnpy/trader/app.py b/vnpy/trader/app.py index d831329485..650155c0d9 100644 --- a/vnpy/trader/app.py +++ b/vnpy/trader/app.py @@ -1,8 +1,8 @@ -"""""" - from abc import ABC from pathlib import Path +from .engine import BaseEngine + class BaseApp(ABC): """ @@ -13,6 +13,6 @@ class BaseApp(ABC): app_module: str = "" # App module string used in import_module app_path: Path = "" # Absolute path of app folder display_name: str = "" # Name for display on the menu. - engine_class = None # App engine class + engine_class: BaseEngine = None # App engine class widget_name: str = "" # Class name of app widget icon_name: str = "" # Icon file name of app widget diff --git a/vnpy/trader/converter.py b/vnpy/trader/converter.py index 8a0c2947c4..e23d69a66b 100644 --- a/vnpy/trader/converter.py +++ b/vnpy/trader/converter.py @@ -1,4 +1,3 @@ -"""""" from copy import copy from typing import Dict, List @@ -16,7 +15,7 @@ class OffsetConverter: """""" - def __init__(self, main_engine: MainEngine): + def __init__(self, main_engine: MainEngine) -> None: """""" self.main_engine: MainEngine = main_engine self.holdings: Dict[str, "PositionHolding"] = {} @@ -26,7 +25,7 @@ def update_position(self, position: PositionData) -> None: if not self.is_convert_required(position.vt_symbol): return - holding = self.get_position_holding(position.vt_symbol) + holding: PositionHolding = self.get_position_holding(position.vt_symbol) holding.update_position(position) def update_trade(self, trade: TradeData) -> None: @@ -34,7 +33,7 @@ def update_trade(self, trade: TradeData) -> None: if not self.is_convert_required(trade.vt_symbol): return - holding = self.get_position_holding(trade.vt_symbol) + holding: PositionHolding = self.get_position_holding(trade.vt_symbol) holding.update_trade(trade) def update_order(self, order: OrderData) -> None: @@ -42,7 +41,7 @@ def update_order(self, order: OrderData) -> None: if not self.is_convert_required(order.vt_symbol): return - holding = self.get_position_holding(order.vt_symbol) + holding: PositionHolding = self.get_position_holding(order.vt_symbol) holding.update_order(order) def update_order_request(self, req: OrderRequest, vt_orderid: str) -> None: @@ -50,14 +49,14 @@ def update_order_request(self, req: OrderRequest, vt_orderid: str) -> None: if not self.is_convert_required(req.vt_symbol): return - holding = self.get_position_holding(req.vt_symbol) + holding: PositionHolding = self.get_position_holding(req.vt_symbol) holding.update_order_request(req, vt_orderid) def get_position_holding(self, vt_symbol: str) -> "PositionHolding": """""" - holding = self.holdings.get(vt_symbol, None) + holding: PositionHolding = self.holdings.get(vt_symbol, None) if not holding: - contract = self.main_engine.get_contract(vt_symbol) + contract: ContractData = self.main_engine.get_contract(vt_symbol) holding = PositionHolding(contract) self.holdings[vt_symbol] = holding return holding @@ -72,7 +71,7 @@ def convert_order_request( if not self.is_convert_required(req.vt_symbol): return [req] - holding = self.get_position_holding(req.vt_symbol) + holding: PositionHolding = self.get_position_holding(req.vt_symbol) if lock: return holding.convert_order_request_lock(req) @@ -87,7 +86,7 @@ def is_convert_required(self, vt_symbol: str) -> bool: """ Check if the contract needs offset convert. """ - contract = self.main_engine.get_contract(vt_symbol) + contract: ContractData = self.main_engine.get_contract(vt_symbol) # Only contracts with long-short position mode requires convert if not contract: @@ -101,7 +100,7 @@ def is_convert_required(self, vt_symbol: str) -> bool: class PositionHolding: """""" - def __init__(self, contract: ContractData): + def __init__(self, contract: ContractData) -> None: """""" self.vt_symbol: str = contract.vt_symbol self.exchange: Exchange = contract.exchange @@ -149,7 +148,7 @@ def update_order_request(self, req: OrderRequest, vt_orderid: str) -> None: """""" gateway_name, orderid = vt_orderid.split(".") - order = req.create_order_data(orderid, gateway_name) + order: OrderData = req.create_order_data(orderid, gateway_name) self.update_order(order) def update_trade(self, trade: TradeData) -> None: @@ -208,7 +207,7 @@ def calculate_frozen(self) -> None: if order.offset == Offset.OPEN: continue - frozen = order.volume - order.traded + frozen: float = order.volume - order.traded if order.direction == Direction.LONG: if order.offset == Offset.CLOSETODAY: @@ -255,28 +254,28 @@ def convert_order_request_shfe(self, req: OrderRequest) -> List[OrderRequest]: return [req] if req.direction == Direction.LONG: - pos_available = self.short_pos - self.short_pos_frozen - td_available = self.short_td - self.short_td_frozen + pos_available: float = self.short_pos - self.short_pos_frozen + td_available: float = self.short_td - self.short_td_frozen else: - pos_available = self.long_pos - self.long_pos_frozen - td_available = self.long_td - self.long_td_frozen + pos_available: float = self.long_pos - self.long_pos_frozen + td_available: float = self.long_td - self.long_td_frozen if req.volume > pos_available: return [] elif req.volume <= td_available: - req_td = copy(req) + req_td: OrderRequest = copy(req) req_td.offset = Offset.CLOSETODAY return [req_td] else: - req_list = [] + req_list: List[OrderRequest] = [] if td_available > 0: - req_td = copy(req) + req_td: OrderRequest = copy(req) req_td.offset = Offset.CLOSETODAY req_td.volume = td_available req_list.append(req_td) - req_yd = copy(req) + req_yd: OrderRequest = copy(req) req_yd.offset = Offset.CLOSEYESTERDAY req_yd.volume = req.volume - td_available req_list.append(req_yd) @@ -286,26 +285,26 @@ def convert_order_request_shfe(self, req: OrderRequest) -> List[OrderRequest]: def convert_order_request_lock(self, req: OrderRequest) -> List[OrderRequest]: """""" if req.direction == Direction.LONG: - td_volume = self.short_td - yd_available = self.short_yd - self.short_yd_frozen + td_volume: float = self.short_td + yd_available: float = self.short_yd - self.short_yd_frozen else: - td_volume = self.long_td - yd_available = self.long_yd - self.long_yd_frozen + td_volume: float = self.long_td + yd_available: float = self.long_yd - self.long_yd_frozen # If there is td_volume, we can only lock position if td_volume: - req_open = copy(req) + req_open: OrderRequest = copy(req) req_open.offset = Offset.OPEN return [req_open] # If no td_volume, we close opposite yd position first # then open new position else: - close_volume = min(req.volume, yd_available) - open_volume = max(0, req.volume - yd_available) - req_list = [] + close_volume: float = min(req.volume, yd_available) + open_volume: float = max(0, req.volume - yd_available) + req_list: List[OrderRequest] = [] if yd_available: - req_yd = copy(req) + req_yd: OrderRequest = copy(req) if self.exchange in [Exchange.SHFE, Exchange.INE]: req_yd.offset = Offset.CLOSEYESTERDAY else: @@ -314,7 +313,7 @@ def convert_order_request_lock(self, req: OrderRequest) -> List[OrderRequest]: req_list.append(req_yd) if open_volume: - req_open = copy(req) + req_open: OrderRequest = copy(req) req_open.offset = Offset.OPEN req_open.volume = open_volume req_list.append(req_open) @@ -324,41 +323,41 @@ def convert_order_request_lock(self, req: OrderRequest) -> List[OrderRequest]: def convert_order_request_net(self, req: OrderRequest) -> List[OrderRequest]: """""" if req.direction == Direction.LONG: - pos_available = self.short_pos - self.short_pos_frozen - td_available = self.short_td - self.short_td_frozen - yd_available = self.short_yd - self.short_yd_frozen + pos_available: float = self.short_pos - self.short_pos_frozen + td_available: float = self.short_td - self.short_td_frozen + yd_available: float = self.short_yd - self.short_yd_frozen else: - pos_available = self.long_pos - self.long_pos_frozen - td_available = self.long_td - self.long_td_frozen - yd_available = self.long_yd - self.long_yd_frozen + pos_available: float = self.long_pos - self.long_pos_frozen + td_available: float = self.long_td - self.long_td_frozen + yd_available: float = self.long_yd - self.long_yd_frozen # Split close order to close today/yesterday for SHFE/INE exchange if req.exchange in {Exchange.SHFE, Exchange.INE}: - reqs = [] - volume_left = req.volume + reqs: List[OrderRequest] = [] + volume_left: float = req.volume if td_available: - td_volume = min(td_available, volume_left) + td_volume: float = min(td_available, volume_left) volume_left -= td_volume - td_req = copy(req) + td_req: OrderRequest = copy(req) td_req.offset = Offset.CLOSETODAY td_req.volume = td_volume reqs.append(td_req) if volume_left and yd_available: - yd_volume = min(yd_available, volume_left) + yd_volume: float = min(yd_available, volume_left) volume_left -= yd_volume - yd_req = copy(req) + yd_req: OrderRequest = copy(req) yd_req.offset = Offset.CLOSEYESTERDAY yd_req.volume = yd_volume reqs.append(yd_req) if volume_left > 0: - open_volume = volume_left + open_volume: float = volume_left - open_req = copy(req) + open_req: OrderRequest = copy(req) open_req.offset = Offset.OPEN open_req.volume = open_volume reqs.append(open_req) @@ -366,22 +365,22 @@ def convert_order_request_net(self, req: OrderRequest) -> List[OrderRequest]: return reqs # Just use close for other exchanges else: - reqs = [] - volume_left = req.volume + reqs: List[OrderRequest] = [] + volume_left: float = req.volume if pos_available: - close_volume = min(pos_available, volume_left) + close_volume: float = min(pos_available, volume_left) volume_left -= pos_available - close_req = copy(req) + close_req: OrderRequest = copy(req) close_req.offset = Offset.CLOSE close_req.volume = close_volume reqs.append(close_req) if volume_left > 0: - open_volume = volume_left + open_volume: float = volume_left - open_req = copy(req) + open_req: OrderRequest = copy(req) open_req.offset = Offset.OPEN open_req.volume = open_volume reqs.append(open_req) diff --git a/vnpy/trader/database.py b/vnpy/trader/database.py index 3ec6fc0001..79e8c5cfbd 100644 --- a/vnpy/trader/database.py +++ b/vnpy/trader/database.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from datetime import datetime +from types import ModuleType from typing import List from pytz import timezone from dataclasses import dataclass @@ -17,7 +18,7 @@ def convert_tz(dt: datetime) -> datetime: """ Convert timezone of datetime object to DB_TZ. """ - dt = dt.astimezone(DB_TZ) + dt: datetime = dt.astimezone(DB_TZ) return dt.replace(tzinfo=None) @@ -128,10 +129,10 @@ def get_database() -> BaseDatabase: # Try to import database module try: - module = import_module(module_name) + module: ModuleType = import_module(module_name) except ModuleNotFoundError: print(f"找不到数据库驱动{module_name},使用默认的SQLite数据库") - module = import_module("vnpy_sqlite") + module: ModuleType = import_module("vnpy_sqlite") # Create database object from module database = module.Database() diff --git a/vnpy/trader/datafeed.py b/vnpy/trader/datafeed.py index 3dd376bd0c..7a8de5c270 100644 --- a/vnpy/trader/datafeed.py +++ b/vnpy/trader/datafeed.py @@ -1,4 +1,5 @@ from abc import ABC +from types import ModuleType from typing import Optional, List from importlib import import_module @@ -46,10 +47,10 @@ def get_datafeed() -> BaseDatafeed: # Try to import datafeed module try: - module = import_module(module_name) + module: ModuleType = import_module(module_name) except ModuleNotFoundError: print(f"找不到数据服务驱动{module_name},使用默认的RQData数据服务") - module = import_module("vnpy_rqdata") + module: ModuleType = import_module("vnpy_rqdata") # Create datafeed object from module datafeed = module.Datafeed() diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index aa33036bcc..4e58ba3511 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -1,11 +1,9 @@ -""" -""" - import logging from logging import Logger import smtplib import os from abc import ABC +from pathlib import Path from datetime import datetime from email.message import EmailMessage from queue import Empty, Queue @@ -51,7 +49,7 @@ class MainEngine: Acts as the core of the trading platform. """ - def __init__(self, event_engine: EventEngine = None): + def __init__(self, event_engine: EventEngine = None) -> None: """""" if event_engine: self.event_engine: EventEngine = event_engine @@ -71,7 +69,7 @@ def add_engine(self, engine_class: Any) -> "BaseEngine": """ Add function engine. """ - engine = engine_class(self, self.event_engine) + engine: BaseEngine = engine_class(self, self.event_engine) self.engines[engine.engine_name] = engine return engine @@ -81,9 +79,9 @@ def add_gateway(self, gateway_class: Type[BaseGateway], gateway_name: str = "") """ # Use default name if gateway_name not passed if not gateway_name: - gateway_name = gateway_class.default_name + gateway_name: str = gateway_class.default_name - gateway = gateway_class(self.event_engine, gateway_name) + gateway: BaseGateway = gateway_class(self.event_engine, gateway_name) self.gateways[gateway_name] = gateway # Add gateway supported exchanges into engine @@ -97,10 +95,10 @@ def add_app(self, app_class: Type[BaseApp]) -> "BaseEngine": """ Add app. """ - app = app_class() + app: BaseApp = app_class() self.apps[app.app_name] = app - engine = self.add_engine(app.engine_class) + engine: BaseEngine = self.add_engine(app.engine_class) return engine def init_engines(self) -> None: @@ -115,15 +113,15 @@ def write_log(self, msg: str, source: str = "") -> None: """ Put log event with specific message. """ - log = LogData(msg=msg, gateway_name=source) - event = Event(EVENT_LOG, log) + log: LogData = LogData(msg=msg, gateway_name=source) + event: Event = Event(EVENT_LOG, log) self.event_engine.put(event) def get_gateway(self, gateway_name: str) -> BaseGateway: """ Return gateway object by name. """ - gateway = self.gateways.get(gateway_name, None) + gateway: BaseGateway = self.gateways.get(gateway_name, None) if not gateway: self.write_log(f"找不到底层接口:{gateway_name}") return gateway @@ -132,7 +130,7 @@ def get_engine(self, engine_name: str) -> "BaseEngine": """ Return engine object by name. """ - engine = self.engines.get(engine_name, None) + engine: BaseEngine = self.engines.get(engine_name, None) if not engine: self.write_log(f"找不到引擎:{engine_name}") return engine @@ -141,7 +139,7 @@ def get_default_setting(self, gateway_name: str) -> Optional[Dict[str, Any]]: """ Get default setting dict of a specific gateway. """ - gateway = self.get_gateway(gateway_name) + gateway: BaseGateway = self.get_gateway(gateway_name) if gateway: return gateway.get_default_setting() return None @@ -168,7 +166,7 @@ def connect(self, setting: dict, gateway_name: str) -> None: """ Start connection of a specific gateway. """ - gateway = self.get_gateway(gateway_name) + gateway: BaseGateway = self.get_gateway(gateway_name) if gateway: gateway.connect(setting) @@ -176,7 +174,7 @@ def subscribe(self, req: SubscribeRequest, gateway_name: str) -> None: """ Subscribe tick data update of a specific gateway. """ - gateway = self.get_gateway(gateway_name) + gateway: BaseGateway = self.get_gateway(gateway_name) if gateway: gateway.subscribe(req) @@ -184,7 +182,7 @@ def send_order(self, req: OrderRequest, gateway_name: str) -> str: """ Send new order request to a specific gateway. """ - gateway = self.get_gateway(gateway_name) + gateway: BaseGateway = self.get_gateway(gateway_name) if gateway: return gateway.send_order(req) else: @@ -194,7 +192,7 @@ def cancel_order(self, req: CancelRequest, gateway_name: str) -> None: """ Send cancel order request to a specific gateway. """ - gateway = self.get_gateway(gateway_name) + gateway: BaseGateway = self.get_gateway(gateway_name) if gateway: gateway.cancel_order(req) @@ -202,7 +200,7 @@ def send_quote(self, req: QuoteRequest, gateway_name: str) -> str: """ Send new quote request to a specific gateway. """ - gateway = self.get_gateway(gateway_name) + gateway: BaseGateway = self.get_gateway(gateway_name) if gateway: return gateway.send_quote(req) else: @@ -212,7 +210,7 @@ def cancel_quote(self, req: CancelRequest, gateway_name: str) -> None: """ Send cancel quote request to a specific gateway. """ - gateway = self.get_gateway(gateway_name) + gateway: BaseGateway = self.get_gateway(gateway_name) if gateway: gateway.cancel_quote(req) @@ -220,7 +218,7 @@ def query_history(self, req: HistoryRequest, gateway_name: str) -> Optional[List """ Query bar history data from a specific gateway. """ - gateway = self.get_gateway(gateway_name) + gateway: BaseGateway = self.get_gateway(gateway_name) if gateway: return gateway.query_history(req) else: @@ -251,13 +249,13 @@ def __init__( main_engine: MainEngine, event_engine: EventEngine, engine_name: str, - ): + ) -> None: """""" - self.main_engine = main_engine - self.event_engine = event_engine - self.engine_name = engine_name + self.main_engine: MainEngine = main_engine + self.event_engine: EventEngine = event_engine + self.engine_name: str = engine_name - def close(self): + def close(self) -> None: """""" pass @@ -267,7 +265,7 @@ class LogEngine(BaseEngine): Processes log event and output with logging module. """ - def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + def __init__(self, main_engine: MainEngine, event_engine: EventEngine) -> None: """""" super(LogEngine, self).__init__(main_engine, event_engine, "log") @@ -279,7 +277,7 @@ def __init__(self, main_engine: MainEngine, event_engine: EventEngine): self.logger: Logger = logging.getLogger("veighna") self.logger.setLevel(self.level) - self.formatter = logging.Formatter( + self.formatter: logging.Formatter = logging.Formatter( "%(asctime)s %(levelname)s: %(message)s" ) @@ -297,14 +295,14 @@ def add_null_handler(self) -> None: """ Add null handler for logger. """ - null_handler = logging.NullHandler() + null_handler: logging.NullHandler = logging.NullHandler() self.logger.addHandler(null_handler) def add_console_handler(self) -> None: """ Add console output of log. """ - console_handler = logging.StreamHandler() + console_handler: logging.StreamHandler = logging.StreamHandler() console_handler.setLevel(self.level) console_handler.setFormatter(self.formatter) self.logger.addHandler(console_handler) @@ -313,12 +311,12 @@ def add_file_handler(self) -> None: """ Add file output of log. """ - today_date = datetime.now().strftime("%Y%m%d") - filename = f"vt_{today_date}.log" - log_path = get_folder_path("log") - file_path = log_path.joinpath(filename) + today_date: str = datetime.now().strftime("%Y%m%d") + filename: str = f"vt_{today_date}.log" + log_path: Path = get_folder_path("log") + file_path: Path = log_path.joinpath(filename) - file_handler = logging.FileHandler( + file_handler: logging.FileHandler = logging.FileHandler( file_path, mode="a", encoding="utf8" ) file_handler.setLevel(self.level) @@ -333,7 +331,7 @@ def process_log_event(self, event: Event) -> None: """ Process log event. """ - log = event.data + log: LogData = event.data self.logger.log(log.level, log.msg) @@ -342,7 +340,7 @@ class OmsEngine(BaseEngine): Provides order management system function. """ - def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + def __init__(self, main_engine: MainEngine, event_engine: EventEngine) -> None: """""" super(OmsEngine, self).__init__(main_engine, event_engine, "oms") @@ -532,7 +530,7 @@ def get_all_active_orders(self, vt_symbol: str = "") -> List[OrderData]: if not vt_symbol: return list(self.active_orders.values()) else: - active_orders = [ + active_orders: List[OrderData] = [ order for order in self.active_orders.values() if order.vt_symbol == vt_symbol @@ -542,13 +540,12 @@ def get_all_active_orders(self, vt_symbol: str = "") -> List[OrderData]: def get_all_active_quotes(self, vt_symbol: str = "") -> List[QuoteData]: """ Get all active quotes by vt_symbol. - If vt_symbol is empty, return all active qutoes. """ if not vt_symbol: return list(self.active_quotes.values()) else: - active_quotes = [ + active_quotes: List[QuoteData] = [ quote for quote in self.active_quotes.values() if quote.vt_symbol == vt_symbol @@ -561,7 +558,7 @@ class EmailEngine(BaseEngine): Provides email sending function. """ - def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + def __init__(self, main_engine: MainEngine, event_engine: EventEngine) -> None: """""" super(EmailEngine, self).__init__(main_engine, event_engine, "email") @@ -579,9 +576,9 @@ def send_email(self, subject: str, content: str, receiver: str = "") -> None: # Use default receiver if not specified. if not receiver: - receiver = SETTINGS["email.receiver"] + receiver: str = SETTINGS["email.receiver"] - msg = EmailMessage() + msg: EmailMessage = EmailMessage() msg["From"] = SETTINGS["email.sender"] msg["To"] = receiver msg["Subject"] = subject @@ -593,7 +590,7 @@ def run(self) -> None: """""" while self.active: try: - msg = self.queue.get(block=True, timeout=1) + msg: EmailMessage = self.queue.get(block=True, timeout=1) with smtplib.SMTP_SSL( SETTINGS["email.server"], SETTINGS["email.port"] diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index 64e1764cb3..d5dcabb21b 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -1,7 +1,3 @@ -""" - -""" - from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Callable from copy import copy @@ -84,7 +80,7 @@ class BaseGateway(ABC): # Exchanges supported in the gateway. exchanges: List[Exchange] = [] - def __init__(self, event_engine: EventEngine, gateway_name: str): + def __init__(self, event_engine: EventEngine, gateway_name: str) -> None: """""" self.event_engine: EventEngine = event_engine self.gateway_name: str = gateway_name @@ -93,7 +89,7 @@ def on_event(self, type: str, data: Any = None) -> None: """ General event push. """ - event = Event(type, data) + event: Event = Event(type, data) self.event_engine.put(event) def on_tick(self, tick: TickData) -> None: @@ -160,7 +156,7 @@ def write_log(self, msg: str) -> None: """ Write a log event from gateway. """ - log = LogData(msg=msg, gateway_name=self.gateway_name) + log: LogData = LogData(msg=msg, gateway_name=self.gateway_name) self.on_log(log) @abstractmethod @@ -283,7 +279,7 @@ class LocalOrderManager: Management tool to support use local order id for trading. """ - def __init__(self, gateway: BaseGateway, order_prefix: str = ""): + def __init__(self, gateway: BaseGateway, order_prefix: str = "") -> None: """""" self.gateway: BaseGateway = gateway @@ -306,7 +302,7 @@ def __init__(self, gateway: BaseGateway, order_prefix: str = ""): self.cancel_request_buf: Dict[str, CancelRequest] = {} # local_orderid: req # Hook cancel order function - self._cancel_order: Callable[CancelRequest] = gateway.cancel_order + self._cancel_order: Callable = gateway.cancel_order gateway.cancel_order = self.cancel_order def new_local_orderid(self) -> str: @@ -314,14 +310,14 @@ def new_local_orderid(self) -> str: Generate a new local orderid. """ self.order_count += 1 - local_orderid = self.order_prefix + str(self.order_count).rjust(8, "0") + local_orderid: str = self.order_prefix + str(self.order_count).rjust(8, "0") return local_orderid def get_local_orderid(self, sys_orderid: str) -> str: """ Get local orderid with sys orderid. """ - local_orderid = self.sys_local_orderid_map.get(sys_orderid, "") + local_orderid: str = self.sys_local_orderid_map.get(sys_orderid, "") if not local_orderid: local_orderid = self.new_local_orderid() @@ -333,7 +329,7 @@ def get_sys_orderid(self, local_orderid: str) -> str: """ Get sys orderid with local orderid. """ - sys_orderid = self.local_sys_orderid_map.get(local_orderid, "") + sys_orderid: str = self.local_sys_orderid_map.get(local_orderid, "") return sys_orderid def update_orderid_map(self, local_orderid: str, sys_orderid: str) -> None: @@ -353,7 +349,7 @@ def check_push_data(self, sys_orderid: str) -> None: if sys_orderid not in self.push_data_buf: return - data = self.push_data_buf.pop(sys_orderid) + data: dict = self.push_data_buf.pop(sys_orderid) if self.push_data_callback: self.push_data_callback(data) @@ -365,7 +361,7 @@ def add_push_data(self, sys_orderid: str, data: dict) -> None: def get_order_with_sys_orderid(self, sys_orderid: str) -> Optional[OrderData]: """""" - local_orderid = self.sys_local_orderid_map.get(sys_orderid, None) + local_orderid: str = self.sys_local_orderid_map.get(sys_orderid, None) if not local_orderid: return None else: @@ -373,7 +369,7 @@ def get_order_with_sys_orderid(self, sys_orderid: str) -> Optional[OrderData]: def get_order_with_local_orderid(self, local_orderid: str) -> OrderData: """""" - order = self.orders[local_orderid] + order: OrderData = self.orders[local_orderid] return copy(order) def on_order(self, order: OrderData) -> None: @@ -384,9 +380,8 @@ def on_order(self, order: OrderData) -> None: self.gateway.on_order(order) def cancel_order(self, req: CancelRequest) -> None: - """ - """ - sys_orderid = self.get_sys_orderid(req.orderid) + """""" + sys_orderid: str = self.get_sys_orderid(req.orderid) if not sys_orderid: self.cancel_request_buf[req.orderid] = req return @@ -394,10 +389,9 @@ def cancel_order(self, req: CancelRequest) -> None: self._cancel_order(req) def check_cancel_request(self, local_orderid: str) -> None: - """ - """ + """""" if local_orderid not in self.cancel_request_buf: return - req = self.cancel_request_buf.pop(local_orderid) + req: CancelRequest = self.cancel_request_buf.pop(local_orderid) self.gateway.cancel_order(req) diff --git a/vnpy/trader/object.py b/vnpy/trader/object.py index 228c7e2479..87bace382f 100644 --- a/vnpy/trader/object.py +++ b/vnpy/trader/object.py @@ -74,9 +74,9 @@ class TickData(BaseData): localtime: datetime = None - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" @dataclass @@ -98,9 +98,9 @@ class BarData(BaseData): low_price: float = 0 close_price: float = 0 - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" @dataclass @@ -124,10 +124,10 @@ class OrderData(BaseData): datetime: datetime = None reference: str = "" - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" - self.vt_orderid = f"{self.gateway_name}.{self.orderid}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" + self.vt_orderid: str = f"{self.gateway_name}.{self.orderid}" def is_active(self) -> bool: """ @@ -139,7 +139,7 @@ def create_cancel_request(self) -> "CancelRequest": """ Create cancel request object from order. """ - req = CancelRequest( + req: CancelRequest = CancelRequest( orderid=self.orderid, symbol=self.symbol, exchange=self.exchange ) return req @@ -163,11 +163,11 @@ class TradeData(BaseData): volume: float = 0 datetime: datetime = None - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" - self.vt_orderid = f"{self.gateway_name}.{self.orderid}" - self.vt_tradeid = f"{self.gateway_name}.{self.tradeid}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" + self.vt_orderid: str = f"{self.gateway_name}.{self.orderid}" + self.vt_tradeid: str = f"{self.gateway_name}.{self.tradeid}" @dataclass @@ -186,10 +186,10 @@ class PositionData(BaseData): pnl: float = 0 yd_volume: float = 0 - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" - self.vt_positionid = f"{self.vt_symbol}.{self.direction.value}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" + self.vt_positionid: str = f"{self.vt_symbol}.{self.direction.value}" @dataclass @@ -204,10 +204,10 @@ class AccountData(BaseData): balance: float = 0 frozen: float = 0 - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.available = self.balance - self.frozen - self.vt_accountid = f"{self.gateway_name}.{self.accountid}" + self.available: float = self.balance - self.frozen + self.vt_accountid: str = f"{self.gateway_name}.{self.accountid}" @dataclass @@ -219,9 +219,9 @@ class LogData(BaseData): msg: str level: int = INFO - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.time = datetime.now() + self.time: datetime = datetime.now() @dataclass @@ -250,9 +250,9 @@ class ContractData(BaseData): option_portfolio: str = "" option_index: str = "" # for identifying options with same strike price - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" @dataclass @@ -276,10 +276,10 @@ class QuoteData(BaseData): datetime: datetime = None reference: str = "" - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" - self.vt_quoteid = f"{self.gateway_name}.{self.quoteid}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" + self.vt_quoteid: str = f"{self.gateway_name}.{self.quoteid}" def is_active(self) -> bool: """ @@ -291,7 +291,7 @@ def create_cancel_request(self) -> "CancelRequest": """ Create cancel request object from quote. """ - req = CancelRequest( + req: CancelRequest = CancelRequest( orderid=self.quoteid, symbol=self.symbol, exchange=self.exchange ) return req @@ -306,9 +306,9 @@ class SubscribeRequest: symbol: str exchange: Exchange - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" @dataclass @@ -326,15 +326,15 @@ class OrderRequest: offset: Offset = Offset.NONE reference: str = "" - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" def create_order_data(self, orderid: str, gateway_name: str) -> OrderData: """ Create order data from request. """ - order = OrderData( + order: OrderData = OrderData( symbol=self.symbol, exchange=self.exchange, orderid=orderid, @@ -359,9 +359,9 @@ class CancelRequest: symbol: str exchange: Exchange - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" @dataclass @@ -376,9 +376,9 @@ class HistoryRequest: end: datetime = None interval: Interval = None - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" @dataclass @@ -397,15 +397,15 @@ class QuoteRequest: ask_offset: Offset = Offset.NONE reference: str = "" - def __post_init__(self): + def __post_init__(self) -> None: """""" - self.vt_symbol = f"{self.symbol}.{self.exchange.value}" + self.vt_symbol: str = f"{self.symbol}.{self.exchange.value}" def create_quote_data(self, quoteid: str, gateway_name: str) -> QuoteData: """ Create quote data from request. """ - quote = QuoteData( + quote: QuoteData = QuoteData( symbol=self.symbol, exchange=self.exchange, quoteid=quoteid, diff --git a/vnpy/trader/optimize.py b/vnpy/trader/optimize.py index 72bd4e97ab..4681fdc0be 100644 --- a/vnpy/trader/optimize.py +++ b/vnpy/trader/optimize.py @@ -1,6 +1,7 @@ from typing import Dict, List, Callable, Tuple from itertools import product from concurrent.futures import ProcessPoolExecutor +from _collections_abc import dict_keys, dict_values from random import random, choice from time import perf_counter from multiprocessing import Manager, Pool, get_context @@ -63,13 +64,13 @@ def set_target(self, target_name: str) -> None: def generate_settings(self) -> List[dict]: """""" - keys = self.params.keys() - values = self.params.values() - products = list(product(*values)) + keys: dict_keys = self.params.keys() + values: dict_values = self.params.values() + products: list = list(product(*values)) - settings = [] + settings: list = [] for p in products: - setting = dict(zip(keys, p)) + setting: dict = dict(zip(keys, p)) settings.append(setting) return settings @@ -140,8 +141,8 @@ def generate_parameter() -> list: def mutate_individual(individual: list, indpb: float) -> tuple: """""" - size = len(individual) - paramlist = generate_parameter() + size: int = len(individual) + paramlist: list = generate_parameter() for i in range(size): if random() < indpb: individual[i] = paramlist[i] @@ -153,7 +154,7 @@ def mutate_individual(individual: list, indpb: float) -> tuple: cache: Dict[Tuple, Tuple] = manager.dict() # Set up toolbox - toolbox = base.Toolbox() + toolbox: base.Toolbox = base.Toolbox() toolbox.register("individual", tools.initIterate, creator.Individual, generate_parameter) toolbox.register("population", tools.initRepeat, list, toolbox.individual) toolbox.register("mate", tools.cxTwoPoint) diff --git a/vnpy/trader/setting.py b/vnpy/trader/setting.py index 01ebeff64f..27111acd89 100644 --- a/vnpy/trader/setting.py +++ b/vnpy/trader/setting.py @@ -45,5 +45,5 @@ def get_settings(prefix: str = "") -> Dict[str, Any]: - prefix_length = len(prefix) + prefix_length: int = len(prefix) return {k[prefix_length:]: v for k, v in SETTINGS.items() if k.startswith(prefix)} diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index cc68d7247b..8e9f28a6b5 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -2,11 +2,12 @@ General utility functions. """ +from datetime import datetime import json import logging import sys from pathlib import Path -from typing import Callable, Dict, Tuple, Union, Optional +from typing import Callable, Dict, Tuple, Union, Optional, Any from decimal import Decimal from math import floor, ceil @@ -17,7 +18,7 @@ from .constant import Exchange, Interval -log_formatter = logging.Formatter('[%(asctime)s] %(message)s') +log_formatter: logging.Formatter = logging.Formatter('[%(asctime)s] %(message)s') def extract_vt_symbol(vt_symbol: str) -> Tuple[str, Exchange]: @@ -39,8 +40,8 @@ def _get_trader_dir(temp_name: str) -> Tuple[Path, Path]: """ Get path where trader is running in. """ - cwd = Path.cwd() - temp_path = cwd.joinpath(temp_name) + cwd: Path = Path.cwd() + temp_path: Path = cwd.joinpath(temp_name) # If .vntrader folder exists in current working directory, # then use it as trader running path. @@ -48,8 +49,8 @@ def _get_trader_dir(temp_name: str) -> Tuple[Path, Path]: return cwd, temp_path # Otherwise use home path of system. - home_path = Path.home() - temp_path = home_path.joinpath(temp_name) + home_path: Path = Path.home() + temp_path: Path = home_path.joinpath(temp_name) # Create .vntrader folder under home path if not exist. if not temp_path.exists(): @@ -73,7 +74,7 @@ def get_folder_path(folder_name: str) -> Path: """ Get path for temp folder with folder name. """ - folder_path = TEMP_DIR.joinpath(folder_name) + folder_path: Path = TEMP_DIR.joinpath(folder_name) if not folder_path.exists(): folder_path.mkdir() return folder_path @@ -83,8 +84,8 @@ def get_icon_path(filepath: str, ico_name: str) -> str: """ Get path for icon file with ico name. """ - ui_path = Path(filepath).parent - icon_path = ui_path.joinpath("ico", ico_name) + ui_path: Path = Path(filepath).parent + icon_path: Path = ui_path.joinpath("ico", ico_name) return str(icon_path) @@ -92,11 +93,11 @@ def load_json(filename: str) -> dict: """ Load data from json file in temp path. """ - filepath = get_file_path(filename) + filepath: Path = get_file_path(filename) if filepath.exists(): with open(filepath, mode="r", encoding="UTF-8") as f: - data = json.load(f) + data: Any = json.load(f) return data else: save_json(filename, {}) @@ -107,7 +108,7 @@ def save_json(filename: str, data: dict) -> None: """ Save data into json file in temp path. """ - filepath = get_file_path(filename) + filepath: Path = get_file_path(filename) with open(filepath, mode="w+", encoding="UTF-8") as f: json.dump( data, @@ -121,9 +122,9 @@ def round_to(value: float, target: float) -> float: """ Round price to price tick value. """ - value = Decimal(str(value)) - target = Decimal(str(target)) - rounded = float(int(round(value / target)) * target) + value: Decimal = Decimal(str(value)) + target: Decimal = Decimal(str(target)) + rounded: float = float(int(round(value / target)) * target) return rounded @@ -131,9 +132,9 @@ def floor_to(value: float, target: float) -> float: """ Similar to math.floor function, but to target float number. """ - value = Decimal(str(value)) - target = Decimal(str(target)) - result = float(int(floor(value / target)) * target) + value: Decimal = Decimal(str(value)) + target: Decimal = Decimal(str(target)) + result: float = float(int(floor(value / target)) * target) return result @@ -141,9 +142,9 @@ def ceil_to(value: float, target: float) -> float: """ Similar to math.ceil function, but to target float number. """ - value = Decimal(str(value)) - target = Decimal(str(target)) - result = float(int(ceil(value / target)) * target) + value: Decimal = Decimal(str(value)) + target: Decimal = Decimal(str(target)) + result: float = float(int(ceil(value / target)) * target) return result @@ -151,7 +152,7 @@ def get_digits(value: float) -> int: """ Get number of digits after decimal point. """ - value_str = str(value) + value_str: str = str(value) if "e-" in value_str: _, buf = value_str.split("e-") @@ -168,7 +169,6 @@ class BarGenerator: For: 1. generating 1 minute bar data from tick data 2. generating x minute bar/x hour bar data from 1 minute data - Notice: 1. for x minute bar, x must be able to divide 60: 2, 3, 5, 6, 10, 15, 20, 30 2. for x hour bar, x can be any number @@ -180,7 +180,7 @@ def __init__( window: int = 0, on_window_bar: Callable = None, interval: Interval = Interval.MINUTE - ): + ) -> None: """Constructor""" self.bar: BarData = None self.on_bar: Callable = on_bar @@ -200,7 +200,7 @@ def update_tick(self, tick: TickData) -> None: """ Update new tick data into generator. """ - new_minute = False + new_minute: bool = False # Filter tick data with 0 last price if not tick.last_price: @@ -250,10 +250,10 @@ def update_tick(self, tick: TickData) -> None: self.bar.datetime = tick.datetime if self.last_tick: - volume_change = tick.volume - self.last_tick.volume + volume_change: float = tick.volume - self.last_tick.volume self.bar.volume += max(volume_change, 0) - turnover_change = tick.turnover - self.last_tick.turnover + turnover_change: float = tick.turnover - self.last_tick.turnover self.bar.turnover += max(turnover_change, 0) self.last_tick = tick @@ -271,7 +271,7 @@ def update_bar_minute_window(self, bar: BarData) -> None: """""" # If not inited, create window bar object if not self.window_bar: - dt = bar.datetime.replace(second=0, microsecond=0) + dt: datetime = bar.datetime.replace(second=0, microsecond=0) self.window_bar = BarData( symbol=bar.symbol, exchange=bar.exchange, @@ -307,7 +307,7 @@ def update_bar_hour_window(self, bar: BarData) -> None: """""" # If not inited, create window bar object if not self.hour_bar: - dt = bar.datetime.replace(minute=0, second=0, microsecond=0) + dt: datetime = bar.datetime.replace(minute=0, second=0, microsecond=0) self.hour_bar = BarData( symbol=bar.symbol, exchange=bar.exchange, @@ -323,7 +323,7 @@ def update_bar_hour_window(self, bar: BarData) -> None: ) return - finished_bar = None + finished_bar: BarData = None # If minute is 59, update minute bar into window bar and push if bar.datetime.minute == 59: @@ -348,7 +348,7 @@ def update_bar_hour_window(self, bar: BarData) -> None: elif bar.datetime.hour != self.hour_bar.datetime.hour: finished_bar = self.hour_bar - dt = bar.datetime.replace(minute=0, second=0, microsecond=0) + dt: datetime = bar.datetime.replace(minute=0, second=0, microsecond=0) self.hour_bar = BarData( symbol=bar.symbol, exchange=bar.exchange, @@ -422,7 +422,7 @@ def generate(self) -> Optional[BarData]: """ Generate the bar data and call callback immediately. """ - bar = self.bar + bar: BarData = self.bar if self.bar: bar.datetime = bar.datetime.replace(second=0, microsecond=0) @@ -439,7 +439,7 @@ class ArrayManager(object): 2. calculating technical indicator value """ - def __init__(self, size: int = 100): + def __init__(self, size: int = 100) -> None: """Constructor""" self.count: int = 0 self.size: int = size @@ -530,7 +530,7 @@ def sma(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ Simple moving average. """ - result = talib.SMA(self.close, n) + result: np.ndarray = talib.SMA(self.close, n) if array: return result return result[-1] @@ -539,7 +539,7 @@ def ema(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ Exponential moving average. """ - result = talib.EMA(self.close, n) + result: np.ndarray = talib.EMA(self.close, n) if array: return result return result[-1] @@ -548,7 +548,7 @@ def kama(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ KAMA. """ - result = talib.KAMA(self.close, n) + result: np.ndarray = talib.KAMA(self.close, n) if array: return result return result[-1] @@ -557,7 +557,7 @@ def wma(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ WMA. """ - result = talib.WMA(self.close, n) + result: np.ndarray = talib.WMA(self.close, n) if array: return result return result[-1] @@ -572,7 +572,7 @@ def apo( """ APO. """ - result = talib.APO(self.close, fast_period, slow_period, matype) + result: np.ndarray = talib.APO(self.close, fast_period, slow_period, matype) if array: return result return result[-1] @@ -581,7 +581,7 @@ def cmo(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ CMO. """ - result = talib.CMO(self.close, n) + result: np.ndarray = talib.CMO(self.close, n) if array: return result return result[-1] @@ -590,7 +590,7 @@ def mom(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ MOM. """ - result = talib.MOM(self.close, n) + result: np.ndarray = talib.MOM(self.close, n) if array: return result return result[-1] @@ -605,7 +605,7 @@ def ppo( """ PPO. """ - result = talib.PPO(self.close, fast_period, slow_period, matype) + result: np.ndarray = talib.PPO(self.close, fast_period, slow_period, matype) if array: return result return result[-1] @@ -614,7 +614,7 @@ def roc(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ ROC. """ - result = talib.ROC(self.close, n) + result: np.ndarray = talib.ROC(self.close, n) if array: return result return result[-1] @@ -623,7 +623,7 @@ def rocr(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ ROCR. """ - result = talib.ROCR(self.close, n) + result: np.ndarray = talib.ROCR(self.close, n) if array: return result return result[-1] @@ -632,7 +632,7 @@ def rocp(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ ROCP. """ - result = talib.ROCP(self.close, n) + result: np.ndarray = talib.ROCP(self.close, n) if array: return result return result[-1] @@ -641,7 +641,7 @@ def rocr_100(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ ROCR100. """ - result = talib.ROCR100(self.close, n) + result: np.ndarray = talib.ROCR100(self.close, n) if array: return result return result[-1] @@ -650,7 +650,7 @@ def trix(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ TRIX. """ - result = talib.TRIX(self.close, n) + result: np.ndarray = talib.TRIX(self.close, n) if array: return result return result[-1] @@ -659,7 +659,7 @@ def std(self, n: int, nbdev: int = 1, array: bool = False) -> Union[float, np.nd """ Standard deviation. """ - result = talib.STDDEV(self.close, n, nbdev) + result: np.ndarray = talib.STDDEV(self.close, n, nbdev) if array: return result return result[-1] @@ -668,7 +668,7 @@ def obv(self, array: bool = False) -> Union[float, np.ndarray]: """ OBV. """ - result = talib.OBV(self.close, self.volume) + result: np.ndarray = talib.OBV(self.close, self.volume) if array: return result return result[-1] @@ -677,7 +677,7 @@ def cci(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ Commodity Channel Index (CCI). """ - result = talib.CCI(self.high, self.low, self.close, n) + result: np.ndarray = talib.CCI(self.high, self.low, self.close, n) if array: return result return result[-1] @@ -686,7 +686,7 @@ def atr(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ Average True Range (ATR). """ - result = talib.ATR(self.high, self.low, self.close, n) + result: np.ndarray = talib.ATR(self.high, self.low, self.close, n) if array: return result return result[-1] @@ -695,7 +695,7 @@ def natr(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ NATR. """ - result = talib.NATR(self.high, self.low, self.close, n) + result: np.ndarray = talib.NATR(self.high, self.low, self.close, n) if array: return result return result[-1] @@ -704,7 +704,7 @@ def rsi(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ Relative Strenght Index (RSI). """ - result = talib.RSI(self.close, n) + result: np.ndarray = talib.RSI(self.close, n) if array: return result return result[-1] @@ -733,7 +733,7 @@ def adx(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ ADX. """ - result = talib.ADX(self.high, self.low, self.close, n) + result: np.ndarray = talib.ADX(self.high, self.low, self.close, n) if array: return result return result[-1] @@ -742,7 +742,7 @@ def adxr(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ ADXR. """ - result = talib.ADXR(self.high, self.low, self.close, n) + result: np.ndarray = talib.ADXR(self.high, self.low, self.close, n) if array: return result return result[-1] @@ -751,7 +751,7 @@ def dx(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ DX. """ - result = talib.DX(self.high, self.low, self.close, n) + result: np.ndarray = talib.DX(self.high, self.low, self.close, n) if array: return result return result[-1] @@ -760,7 +760,7 @@ def minus_di(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ MINUS_DI. """ - result = talib.MINUS_DI(self.high, self.low, self.close, n) + result: np.ndarray = talib.MINUS_DI(self.high, self.low, self.close, n) if array: return result return result[-1] @@ -769,7 +769,7 @@ def plus_di(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ PLUS_DI. """ - result = talib.PLUS_DI(self.high, self.low, self.close, n) + result: np.ndarray = talib.PLUS_DI(self.high, self.low, self.close, n) if array: return result return result[-1] @@ -778,7 +778,7 @@ def willr(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ WILLR. """ - result = talib.WILLR(self.high, self.low, self.close, n) + result: np.ndarray = talib.WILLR(self.high, self.low, self.close, n) if array: return result return result[-1] @@ -793,7 +793,7 @@ def ultosc( """ Ultimate Oscillator. """ - result = talib.ULTOSC(self.high, self.low, self.close, time_period1, time_period2, time_period3) + result: np.ndarray = talib.ULTOSC(self.high, self.low, self.close, time_period1, time_period2, time_period3) if array: return result return result[-1] @@ -802,7 +802,7 @@ def trange(self, array: bool = False) -> Union[float, np.ndarray]: """ TRANGE. """ - result = talib.TRANGE(self.high, self.low, self.close) + result: np.ndarray = talib.TRANGE(self.high, self.low, self.close) if array: return result return result[-1] @@ -819,11 +819,11 @@ def boll( """ Bollinger Channel. """ - mid = self.sma(n, array) - std = self.std(n, 1, array) + mid: Union[float, np.ndarray] = self.sma(n, array) + std: Union[float, np.ndarray] = self.std(n, 1, array) - up = mid + std * dev - down = mid - std * dev + up: Union[float, np.ndarray] = mid + std * dev + down: Union[float, np.ndarray] = mid - std * dev return up, down @@ -839,11 +839,11 @@ def keltner( """ Keltner Channel. """ - mid = self.sma(n, array) - atr = self.atr(n, array) + mid: Union[float, np.ndarray] = self.sma(n, array) + atr: Union[float, np.ndarray] = self.atr(n, array) - up = mid + atr * dev - down = mid - atr * dev + up: Union[float, np.ndarray] = mid + atr * dev + down: Union[float, np.ndarray] = mid - atr * dev return up, down @@ -856,8 +856,8 @@ def donchian( """ Donchian Channel. """ - up = talib.MAX(self.high, n) - down = talib.MIN(self.low, n) + up: np.ndarray = talib.MAX(self.high, n) + down: np.ndarray = talib.MIN(self.low, n) if array: return up, down @@ -884,7 +884,7 @@ def aroonosc(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ Aroon Oscillator. """ - result = talib.AROONOSC(self.high, self.low, n) + result: np.ndarray = talib.AROONOSC(self.high, self.low, n) if array: return result @@ -894,7 +894,7 @@ def minus_dm(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ MINUS_DM. """ - result = talib.MINUS_DM(self.high, self.low, n) + result: np.ndarray = talib.MINUS_DM(self.high, self.low, n) if array: return result @@ -904,7 +904,7 @@ def plus_dm(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ PLUS_DM. """ - result = talib.PLUS_DM(self.high, self.low, n) + result: np.ndarray = talib.PLUS_DM(self.high, self.low, n) if array: return result @@ -914,7 +914,7 @@ def mfi(self, n: int, array: bool = False) -> Union[float, np.ndarray]: """ Money Flow Index. """ - result = talib.MFI(self.high, self.low, self.close, self.volume, n) + result: np.ndarray = talib.MFI(self.high, self.low, self.close, self.volume, n) if array: return result return result[-1] @@ -923,7 +923,7 @@ def ad(self, array: bool = False) -> Union[float, np.ndarray]: """ AD. """ - result = talib.AD(self.high, self.low, self.close, self.volume) + result: np.ndarray = talib.AD(self.high, self.low, self.close, self.volume) if array: return result return result[-1] @@ -937,7 +937,7 @@ def adosc( """ ADOSC. """ - result = talib.ADOSC(self.high, self.low, self.close, self.volume, fast_period, slow_period) + result: np.ndarray = talib.ADOSC(self.high, self.low, self.close, self.volume, fast_period, slow_period) if array: return result return result[-1] @@ -946,7 +946,7 @@ def bop(self, array: bool = False) -> Union[float, np.ndarray]: """ BOP. """ - result = talib.BOP(self.open, self.high, self.low, self.close) + result: np.ndarray = talib.BOP(self.open, self.high, self.low, self.close) if array: return result @@ -995,7 +995,7 @@ def virtual(func: Callable) -> Callable: def _get_file_logger_handler(filename: str) -> logging.FileHandler: - handler = file_handlers.get(filename, None) + handler: logging.FileHandler = file_handlers.get(filename, None) if handler is None: handler = logging.FileHandler(filename) file_handlers[filename] = handler # Am i need a lock? @@ -1006,8 +1006,8 @@ def get_file_logger(filename: str) -> logging.Logger: """ return a logger that writes records into a file. """ - logger = logging.getLogger(filename) - handler = _get_file_logger_handler(filename) # get singleton handler. + logger: logging.Logger = logging.getLogger(filename) + handler: logging.FileHandler = _get_file_logger_handler(filename) # get singleton handler. handler.setFormatter(log_formatter) logger.addHandler(handler) # each handler will be added only once. return logger