Skip to content

Commit 5051b61

Browse files
committed
REF: Slightly more robust input handling in places
1 parent daa1da9 commit 5051b61

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

backtesting/backtesting.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,15 @@ def init():
127127
value = func(*args, **kwargs)
128128

129129
try:
130+
if isinstance(value, pd.DataFrame):
131+
value = value.values.T
130132
value = np.asarray(value)
131133
except Exception:
132134
raise ValueError('Indicators must return array-like sequences of values')
133135
if value.shape[-1] != len(self._data.Close):
134-
raise ValueError('Indicators must be arrays of same length as `data`')
136+
raise ValueError('Indicators must be (a tuple of) arrays of same length as `data`'
137+
'(data: {}, indicator "{}": {})'.format(len(self._data.Close),
138+
name, value.shape))
135139

136140
if plot and overlay is None:
137141
x = value / self._data.Close
@@ -682,7 +686,7 @@ def run(self, **kwargs) -> pd.Series:
682686

683687
# Skip first few candles where indicators are still "warming up"
684688
# +1 to have at least two entries available
685-
start = 1 + max((np.isnan(indicator).argmin()
689+
start = 1 + max((np.isnan(indicator.astype(float)).argmin()
686690
for _, indicator in indicator_attrs), default=0)
687691

688692
# Disable "invalid value encountered in ..." warnings. Comparison

backtesting/lib.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,10 @@ def SMA(series, n):
197197
self.sma = self.I(SMA, daily, 10, plot=False)
198198
199199
"""
200-
if not isinstance(series, pd.Series):
200+
if not isinstance(series, (pd.Series, pd.DataFrame)):
201201
assert isinstance(series, _Array), \
202-
'resample_apply() takes either a `pd.Series` or a `Strategy.data.*` array'
202+
'resample_apply() takes either a `pd.Series`, `pd.DataFrame`, ' \
203+
'or a `Strategy.data.*` array'
203204
series = series.to_series()
204205

205206
resampled = series.resample(rule, label='right').agg('last').dropna()

0 commit comments

Comments
 (0)