Skip to content

Commit

Permalink
Merge pull request vnpy#3285 from noranhe/add-type-declaration-trader
Browse files Browse the repository at this point in the history
[Add] type declaration - vnpy.trader
  • Loading branch information
vnpy authored May 1, 2022
2 parents dfa1f60 + fce3f99 commit 6d9486a
Show file tree
Hide file tree
Showing 10 changed files with 244 additions and 251 deletions.
6 changes: 3 additions & 3 deletions vnpy/trader/app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
""""""

from abc import ABC
from pathlib import Path

from .engine import BaseEngine


class BaseApp(ABC):
"""
Expand All @@ -13,6 +13,6 @@ class BaseApp(ABC):
app_module: str = "" # App module string used in import_module
app_path: Path = "" # Absolute path of app folder
display_name: str = "" # Name for display on the menu.
engine_class = None # App engine class
engine_class: BaseEngine = None # App engine class
widget_name: str = "" # Class name of app widget
icon_name: str = "" # Icon file name of app widget
101 changes: 50 additions & 51 deletions vnpy/trader/converter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
""""""
from copy import copy
from typing import Dict, List

Expand All @@ -16,7 +15,7 @@
class OffsetConverter:
""""""

def __init__(self, main_engine: MainEngine):
def __init__(self, main_engine: MainEngine) -> None:
""""""
self.main_engine: MainEngine = main_engine
self.holdings: Dict[str, "PositionHolding"] = {}
Expand All @@ -26,38 +25,38 @@ def update_position(self, position: PositionData) -> None:
if not self.is_convert_required(position.vt_symbol):
return

holding = self.get_position_holding(position.vt_symbol)
holding: PositionHolding = self.get_position_holding(position.vt_symbol)
holding.update_position(position)

def update_trade(self, trade: TradeData) -> None:
""""""
if not self.is_convert_required(trade.vt_symbol):
return

holding = self.get_position_holding(trade.vt_symbol)
holding: PositionHolding = self.get_position_holding(trade.vt_symbol)
holding.update_trade(trade)

def update_order(self, order: OrderData) -> None:
""""""
if not self.is_convert_required(order.vt_symbol):
return

holding = self.get_position_holding(order.vt_symbol)
holding: PositionHolding = self.get_position_holding(order.vt_symbol)
holding.update_order(order)

def update_order_request(self, req: OrderRequest, vt_orderid: str) -> None:
""""""
if not self.is_convert_required(req.vt_symbol):
return

holding = self.get_position_holding(req.vt_symbol)
holding: PositionHolding = self.get_position_holding(req.vt_symbol)
holding.update_order_request(req, vt_orderid)

def get_position_holding(self, vt_symbol: str) -> "PositionHolding":
""""""
holding = self.holdings.get(vt_symbol, None)
holding: PositionHolding = self.holdings.get(vt_symbol, None)
if not holding:
contract = self.main_engine.get_contract(vt_symbol)
contract: ContractData = self.main_engine.get_contract(vt_symbol)
holding = PositionHolding(contract)
self.holdings[vt_symbol] = holding
return holding
Expand All @@ -72,7 +71,7 @@ def convert_order_request(
if not self.is_convert_required(req.vt_symbol):
return [req]

holding = self.get_position_holding(req.vt_symbol)
holding: PositionHolding = self.get_position_holding(req.vt_symbol)

if lock:
return holding.convert_order_request_lock(req)
Expand All @@ -87,7 +86,7 @@ def is_convert_required(self, vt_symbol: str) -> bool:
"""
Check if the contract needs offset convert.
"""
contract = self.main_engine.get_contract(vt_symbol)
contract: ContractData = self.main_engine.get_contract(vt_symbol)

# Only contracts with long-short position mode requires convert
if not contract:
Expand All @@ -101,7 +100,7 @@ def is_convert_required(self, vt_symbol: str) -> bool:
class PositionHolding:
""""""

def __init__(self, contract: ContractData):
def __init__(self, contract: ContractData) -> None:
""""""
self.vt_symbol: str = contract.vt_symbol
self.exchange: Exchange = contract.exchange
Expand Down Expand Up @@ -149,7 +148,7 @@ def update_order_request(self, req: OrderRequest, vt_orderid: str) -> None:
""""""
gateway_name, orderid = vt_orderid.split(".")

order = req.create_order_data(orderid, gateway_name)
order: OrderData = req.create_order_data(orderid, gateway_name)
self.update_order(order)

def update_trade(self, trade: TradeData) -> None:
Expand Down Expand Up @@ -208,7 +207,7 @@ def calculate_frozen(self) -> None:
if order.offset == Offset.OPEN:
continue

frozen = order.volume - order.traded
frozen: float = order.volume - order.traded

if order.direction == Direction.LONG:
if order.offset == Offset.CLOSETODAY:
Expand Down Expand Up @@ -255,28 +254,28 @@ def convert_order_request_shfe(self, req: OrderRequest) -> List[OrderRequest]:
return [req]

if req.direction == Direction.LONG:
pos_available = self.short_pos - self.short_pos_frozen
td_available = self.short_td - self.short_td_frozen
pos_available: float = self.short_pos - self.short_pos_frozen
td_available: float = self.short_td - self.short_td_frozen
else:
pos_available = self.long_pos - self.long_pos_frozen
td_available = self.long_td - self.long_td_frozen
pos_available: float = self.long_pos - self.long_pos_frozen
td_available: float = self.long_td - self.long_td_frozen

if req.volume > pos_available:
return []
elif req.volume <= td_available:
req_td = copy(req)
req_td: OrderRequest = copy(req)
req_td.offset = Offset.CLOSETODAY
return [req_td]
else:
req_list = []
req_list: List[OrderRequest] = []

if td_available > 0:
req_td = copy(req)
req_td: OrderRequest = copy(req)
req_td.offset = Offset.CLOSETODAY
req_td.volume = td_available
req_list.append(req_td)

req_yd = copy(req)
req_yd: OrderRequest = copy(req)
req_yd.offset = Offset.CLOSEYESTERDAY
req_yd.volume = req.volume - td_available
req_list.append(req_yd)
Expand All @@ -286,26 +285,26 @@ def convert_order_request_shfe(self, req: OrderRequest) -> List[OrderRequest]:
def convert_order_request_lock(self, req: OrderRequest) -> List[OrderRequest]:
""""""
if req.direction == Direction.LONG:
td_volume = self.short_td
yd_available = self.short_yd - self.short_yd_frozen
td_volume: float = self.short_td
yd_available: float = self.short_yd - self.short_yd_frozen
else:
td_volume = self.long_td
yd_available = self.long_yd - self.long_yd_frozen
td_volume: float = self.long_td
yd_available: float = self.long_yd - self.long_yd_frozen

# If there is td_volume, we can only lock position
if td_volume:
req_open = copy(req)
req_open: OrderRequest = copy(req)
req_open.offset = Offset.OPEN
return [req_open]
# If no td_volume, we close opposite yd position first
# then open new position
else:
close_volume = min(req.volume, yd_available)
open_volume = max(0, req.volume - yd_available)
req_list = []
close_volume: float = min(req.volume, yd_available)
open_volume: float = max(0, req.volume - yd_available)
req_list: List[OrderRequest] = []

if yd_available:
req_yd = copy(req)
req_yd: OrderRequest = copy(req)
if self.exchange in [Exchange.SHFE, Exchange.INE]:
req_yd.offset = Offset.CLOSEYESTERDAY
else:
Expand All @@ -314,7 +313,7 @@ def convert_order_request_lock(self, req: OrderRequest) -> List[OrderRequest]:
req_list.append(req_yd)

if open_volume:
req_open = copy(req)
req_open: OrderRequest = copy(req)
req_open.offset = Offset.OPEN
req_open.volume = open_volume
req_list.append(req_open)
Expand All @@ -324,64 +323,64 @@ def convert_order_request_lock(self, req: OrderRequest) -> List[OrderRequest]:
def convert_order_request_net(self, req: OrderRequest) -> List[OrderRequest]:
""""""
if req.direction == Direction.LONG:
pos_available = self.short_pos - self.short_pos_frozen
td_available = self.short_td - self.short_td_frozen
yd_available = self.short_yd - self.short_yd_frozen
pos_available: float = self.short_pos - self.short_pos_frozen
td_available: float = self.short_td - self.short_td_frozen
yd_available: float = self.short_yd - self.short_yd_frozen
else:
pos_available = self.long_pos - self.long_pos_frozen
td_available = self.long_td - self.long_td_frozen
yd_available = self.long_yd - self.long_yd_frozen
pos_available: float = self.long_pos - self.long_pos_frozen
td_available: float = self.long_td - self.long_td_frozen
yd_available: float = self.long_yd - self.long_yd_frozen

# Split close order to close today/yesterday for SHFE/INE exchange
if req.exchange in {Exchange.SHFE, Exchange.INE}:
reqs = []
volume_left = req.volume
reqs: List[OrderRequest] = []
volume_left: float = req.volume

if td_available:
td_volume = min(td_available, volume_left)
td_volume: float = min(td_available, volume_left)
volume_left -= td_volume

td_req = copy(req)
td_req: OrderRequest = copy(req)
td_req.offset = Offset.CLOSETODAY
td_req.volume = td_volume
reqs.append(td_req)

if volume_left and yd_available:
yd_volume = min(yd_available, volume_left)
yd_volume: float = min(yd_available, volume_left)
volume_left -= yd_volume

yd_req = copy(req)
yd_req: OrderRequest = copy(req)
yd_req.offset = Offset.CLOSEYESTERDAY
yd_req.volume = yd_volume
reqs.append(yd_req)

if volume_left > 0:
open_volume = volume_left
open_volume: float = volume_left

open_req = copy(req)
open_req: OrderRequest = copy(req)
open_req.offset = Offset.OPEN
open_req.volume = open_volume
reqs.append(open_req)

return reqs
# Just use close for other exchanges
else:
reqs = []
volume_left = req.volume
reqs: List[OrderRequest] = []
volume_left: float = req.volume

if pos_available:
close_volume = min(pos_available, volume_left)
close_volume: float = min(pos_available, volume_left)
volume_left -= pos_available

close_req = copy(req)
close_req: OrderRequest = copy(req)
close_req.offset = Offset.CLOSE
close_req.volume = close_volume
reqs.append(close_req)

if volume_left > 0:
open_volume = volume_left
open_volume: float = volume_left

open_req = copy(req)
open_req: OrderRequest = copy(req)
open_req.offset = Offset.OPEN
open_req.volume = open_volume
reqs.append(open_req)
Expand Down
7 changes: 4 additions & 3 deletions vnpy/trader/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from datetime import datetime
from types import ModuleType
from typing import List
from pytz import timezone
from dataclasses import dataclass
Expand All @@ -17,7 +18,7 @@ def convert_tz(dt: datetime) -> datetime:
"""
Convert timezone of datetime object to DB_TZ.
"""
dt = dt.astimezone(DB_TZ)
dt: datetime = dt.astimezone(DB_TZ)
return dt.replace(tzinfo=None)


Expand Down Expand Up @@ -128,10 +129,10 @@ def get_database() -> BaseDatabase:

# Try to import database module
try:
module = import_module(module_name)
module: ModuleType = import_module(module_name)
except ModuleNotFoundError:
print(f"找不到数据库驱动{module_name},使用默认的SQLite数据库")
module = import_module("vnpy_sqlite")
module: ModuleType = import_module("vnpy_sqlite")

# Create database object from module
database = module.Database()
Expand Down
5 changes: 3 additions & 2 deletions vnpy/trader/datafeed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
from types import ModuleType
from typing import Optional, List
from importlib import import_module

Expand Down Expand Up @@ -46,10 +47,10 @@ def get_datafeed() -> BaseDatafeed:

# Try to import datafeed module
try:
module = import_module(module_name)
module: ModuleType = import_module(module_name)
except ModuleNotFoundError:
print(f"找不到数据服务驱动{module_name},使用默认的RQData数据服务")
module = import_module("vnpy_rqdata")
module: ModuleType = import_module("vnpy_rqdata")

# Create datafeed object from module
datafeed = module.Datafeed()
Expand Down
Loading

0 comments on commit 6d9486a

Please sign in to comment.