Skip to content

Commit

Permalink
Updated testcases to use sequence like operations on dataseries. Impr…
Browse files Browse the repository at this point in the history
…oved __getitem__ on dataseries.
  • Loading branch information
Gabriel Becedillas authored and Gabriel Becedillas committed Dec 18, 2012
1 parent bfa831a commit 5ebf4e8
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 32 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Version 0.10 (TBD)
. [NEW] Returns analysis (pyalgotrade.stratanalyzer.returns.Returns).
. [NEW] Trades analysis (pyalgotrade.stratanalyzer.trades.Trades).
. [NEW] Support for bar timezones (pyalgotrade.marketsession).
. [NEW] Support for sequence like operations in dataseries.
. [NEW] Support for sequence like operations in dataseries (getValue and getValueAbsolute will soon get deprecated).
. [FIX] Fixed returns calculations for short positions. Thanks John Fawcett for explaining this.
. [FIX] Fixed a bug in the Strategy class and the backtesting broker when dealing with multiple instruments. Was not handling absence of bars appropriately. Thanks Fabian Braennstroem for reporting this.

Expand Down
6 changes: 4 additions & 2 deletions pyalgotrade/dataseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

import bar

class DataSeries:
# It is important to inherit object to get __getitem__ to work properly.
# Check http://code.activestate.com/lists/python-list/621258/
class DataSeries(object):
"""Base class for data series. A data series is an abstraction used to manage historical data.
.. note::
Expand All @@ -39,7 +41,7 @@ def __getitem__(self, key):
elif isinstance(key, int) :
if key < 0:
key += self.getLength()
if key >= self.getLength():
if key >= self.getLength() or key < 0:
raise IndexError("Index out of range")
return self.getValueAbsolute(key)
else:
Expand Down
18 changes: 14 additions & 4 deletions testcases/dataseries_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ def testEmpty(self):
ds = dataseries.SequenceDataSeries([])
self.assertTrue(ds.getFirstValidPos() == 0)
self.assertTrue(ds.getLength() == 0)
self.assertTrue(ds.getValue() == None)
self.assertTrue(ds.getValue(1) == None)
with self.assertRaises(IndexError):
ds[-1]
with self.assertRaises(IndexError):
ds[-2]
with self.assertRaises(IndexError):
ds[0]
with self.assertRaises(IndexError):
Expand All @@ -40,8 +42,8 @@ def testNonEmpty(self):
ds = dataseries.SequenceDataSeries(range(10))
self.assertTrue(ds.getFirstValidPos() == 0)
self.assertTrue(ds.getLength() == 10)
self.assertTrue(ds.getValue() == 9)
self.assertTrue(ds.getValue(1) == 8)
self.assertTrue(ds[-1] == 9)
self.assertTrue(ds[-2] == 8)
self.assertTrue(ds[0] == 0)
self.assertTrue(ds[1] == 1)

Expand Down Expand Up @@ -77,6 +79,13 @@ def testSeqLikeOps(self):
sl = slice(0,-1,1)
self.assertEqual(ds[sl], seq[sl])

for i in xrange(-100, 100):
self.assertEqual(ds[i:], seq[i:])

for step in xrange(1, 10):
for i in xrange(-100, 100):
self.assertEqual(ds[i::step], seq[i::step])

class TestBarDataSeries(unittest.TestCase):
def testEmpty(self):
ds = dataseries.BarDataSeries()
Expand Down Expand Up @@ -140,6 +149,7 @@ def testSeqLikeOps(self):

def getTestCases():
ret = []

ret.append(TestSequenceDataSeries("testEmpty"))
ret.append(TestSequenceDataSeries("testNonEmpty"))
ret.append(TestSequenceDataSeries("testSeqLikeOps"))
Expand Down
6 changes: 3 additions & 3 deletions testcases/multi_instrument_strategy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ def onExitOk(self, position):

def __calculatePosSize(self):
cash = self.getBroker().getCash()
lastPrice = self.getFeed().getDataSeries(self.__lag).getValue().getClose()
lastPrice = self.getFeed().getDataSeries(self.__lag)[-1].getClose()
ret = cash / lastPrice
return int(ret)

def onBars(self, bars):
if bars.getBar(self.__lead):
if self.__crossAbove.getValue() == 1 and self.__pos == None:
if self.__crossAbove[-1] == 1 and self.__pos == None:
shares = self.__calculatePosSize()
if shares:
self.__pos = self.enterLong(self.__lag, shares)
elif self.__crossBelow.getValue() == 1 and self.__pos != None:
elif self.__crossBelow[-1] == 1 and self.__pos != None:
self.exitPosition(self.__pos)

class TestCase(unittest.TestCase):
Expand Down
6 changes: 3 additions & 3 deletions testcases/smacrossover_strategy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ def onBars(self, bars):
self.printDebug("%s: O=%s H=%s L=%s C=%s" % (bar.getDateTime(), bar.getOpen(), bar.getHigh(), bar.getLow(), bar.getClose()))

# Wait for enough bars to be available.
if self.__crossAbove.getValue() is None or self.__crossBelow.getValue() is None:
if self.__crossAbove[-1] is None or self.__crossBelow[-1] is None:
return

if self.__crossAbove.getValue() == 1:
if self.__crossAbove[-1] == 1:
if self.__shortPos:
self.exitShortPosition(bars, self.__shortPos)
assert(self.__longPos == None)
self.__longPos = self.enterLongPosition(bars)
elif self.__crossBelow.getValue() == 1:
elif self.__crossBelow[-1] == 1:
if self.__longPos:
self.exitLongPosition(bars, self.__longPos)
assert(self.__shortPos == None)
Expand Down
14 changes: 7 additions & 7 deletions testcases/technical_cross_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def testCrossAboveOnce(self):

# Check for all values.
crs = self.__buildCrossTechnical(cross.CrossAbove, values1, values2, 100)
self.assertTrue(crs.getValue() == 1)
self.assertTrue(crs[-1] == 1)

def testCrossAboveMany(self):
count = 100
Expand All @@ -80,7 +80,7 @@ def testCrossAboveMany(self):

# Check for all values.
crs = self.__buildCrossTechnical(cross.CrossAbove, values1, values2, 100)
self.assertTrue(crs.getValue() == count / 2)
self.assertTrue(crs[-1] == count / 2)

def testCrossBelowOnce(self):
values1 = [1, 1, 1, 10, 1, 1, 1]
Expand All @@ -105,7 +105,7 @@ def testCrossBelowOnce(self):

# Check for all values.
crs = self.__buildCrossTechnical(cross.CrossBelow, values1, values2, 100)
self.assertTrue(crs.getValue() == 1)
self.assertTrue(crs[-1] == 1)

def testCrossBelowMany(self):
count = 100
Expand All @@ -131,7 +131,7 @@ def testCrossBelowMany(self):

# Check for all values.
crs = self.__buildCrossTechnical(cross.CrossBelow, values1, values2, 100)
self.assertTrue(crs.getValue() == count / 2 - 1)
self.assertTrue(crs[-1] == count / 2 - 1)

def testWithSMAs(self):
ds1 = dataseries.SequenceDataSeries()
Expand All @@ -141,11 +141,11 @@ def testWithSMAs(self):
ds1.appendValue(i)
ds2.appendValue(50)
if i < 24:
self.assertTrue(crs.getValue() == None)
self.assertTrue(crs[-1] == None)
elif i == 58:
self.assertTrue(crs.getValue() == 1)
self.assertTrue(crs[-1] == 1)
else:
self.assertTrue(crs.getValue() == 0)
self.assertTrue(crs[-1] == 0)

def getTestCases():
ret = []
Expand Down
43 changes: 38 additions & 5 deletions testcases/technical_ma_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ def __buildSMA(self, period, values):

def testPeriod1(self):
sma = self.__buildSMA(1, [10, 20])

self.assertTrue(sma[0] == 10)
self.assertTrue(sma[1] == 20)
self.assertTrue(sma[-1] == 20)
self.assertTrue(sma[-2] == 10)
with self.assertRaises(IndexError):
sma[2]

self.assertTrue(sma.getValue(-1) == None)
self.assertTrue(sma.getValue() == 20)
self.assertTrue(sma.getValue(1) == 10)
self.assertTrue(sma.getValue(2) == None)
with self.assertRaises(IndexError):
sma[-3]

def testPeriod2(self):
sma = self.__buildSMA(2, [0, 1, 2])
Expand All @@ -50,7 +50,7 @@ def testPeriod2(self):

self.assertTrue(sma[2] == sma.getValue())
self.assertTrue(sma[1] == sma.getValue(1))
self.assertTrue(sma[0] == sma.getValue(2))
self.assertTrue(sma[0] == sma.getValue(2) == None)

def testMultipleValues(self):
period = 5
Expand All @@ -76,6 +76,37 @@ def testStockChartsSMA(self):
def testNinjaTraderSMA(self):
common.test_from_csv(self, "nt-sma-15.csv", lambda inputDS: ma.SMA(inputDS, 15), 3)

def testSeqLikeOps(self):
# ds and seq should be the same.
seq = [1.0 for i in xrange(10)]
ds = self.__buildSMA(1, seq)

# Test length and every item.
self.assertEqual(len(ds), len(seq))
for i in xrange(len(seq)):
self.assertEqual(ds[i], seq[i])

# Test negative indices
self.assertEqual(ds[-1], seq[-1])
self.assertEqual(ds[-2], seq[-2])
self.assertEqual(ds[-9], seq[-9])

# Test slices
sl = slice(0,1,2)
self.assertEqual(ds[sl], seq[sl])
sl = slice(0,9,2)
self.assertEqual(ds[sl], seq[sl])
sl = slice(0,-1,1)
self.assertEqual(ds[sl], seq[sl])

for i in xrange(-100, 100):
self.assertEqual(ds[i:], seq[i:])

for step in xrange(1, 10):
for i in xrange(-100, 100):
self.assertEqual(ds[i::step], seq[i::step])


class WMATestCase(unittest.TestCase):
def __buildWMA(self, weights, values):
from pyalgotrade import dataseries
Expand Down Expand Up @@ -107,12 +138,14 @@ def testStockChartsEMA_Reverse(self):

def getTestCases():
ret = []

ret.append(SMATestCase("testPeriod1"))
ret.append(SMATestCase("testPeriod2"))
ret.append(SMATestCase("testMultipleValues"))
ret.append(SMATestCase("testStockChartsSMA"))
ret.append(SMATestCase("testMultipleValuesSkippingOne"))
ret.append(SMATestCase("testNinjaTraderSMA"))
ret.append(SMATestCase("testSeqLikeOps"))

ret.append(WMATestCase("testPeriod1"))
ret.append(WMATestCase("testPeriod2"))
Expand Down
8 changes: 4 additions & 4 deletions testcases/technical_ratio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def testSimple(self):
with self.assertRaises(IndexError):
ratio[3]

self.assertTrue(ratio.getValue(1) == ratio[1])
self.assertTrue(ratio.getValue() == ratio[2])
self.assertTrue(ratio[-2] == ratio[1])
self.assertTrue(ratio[-1] == ratio[2])

def testNegativeValues(self):
ratio = self.__buildRatio([-1, -2, -1])
Expand All @@ -47,8 +47,8 @@ def testNegativeValues(self):
with self.assertRaises(IndexError):
ratio[3]

self.assertTrue(ratio.getValue(1) == ratio[1])
self.assertTrue(ratio.getValue() == ratio[2])
self.assertTrue(ratio[-2] == ratio[1])
self.assertTrue(ratio[-1] == ratio[2])

def getTestCases():
ret = []
Expand Down
6 changes: 3 additions & 3 deletions testcases/technical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def testInvalidPosNotCached(self):
values.append(None) # Interleave Nones.

testFilter = DataSeriesFilterTest.TestFilter(ds)
self.assertTrue(testFilter.getValue() == None)
self.assertTrue(testFilter.getValue(1) == 9)
self.assertTrue(testFilter.getValue(3) == 8) # We go 3 instead of 2 because we need to skip the interleaved None values.
self.assertTrue(testFilter[-1] == None)
self.assertTrue(testFilter[-2] == 9)
self.assertTrue(testFilter[-4] == 8) # We go 3 instead of 2 because we need to skip the interleaved None values.

self.assertTrue(testFilter[18] == 9)
self.assertTrue(testFilter[19] == None)
Expand Down

0 comments on commit 5ebf4e8

Please sign in to comment.