Skip to content

Commit

Permalink
Make stats.RollingQuantile comply with np.quantile(interpolation="lin…
Browse files Browse the repository at this point in the history
…ear") (online-ml#670)
  • Loading branch information
Saulo Martiello Mastelini authored Aug 19, 2021
1 parent cb797fa commit 67b6267
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 30 deletions.
3 changes: 3 additions & 0 deletions docs/releases/unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

- Added `rules.AMRules`

## stats
- Make `stats.RollingQuantile` match the default behavior of Numpy's `quantile` function.

## tree

- Unifed base class structure applied to all tree models.
Expand Down
24 changes: 12 additions & 12 deletions river/stats/iqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,24 @@ class RollingIQR(base.RollingUnivariate, utils.SortedWindow):
>>> rolling_iqr = stats.RollingIQR(
... q_inf=0.25,
... q_sup=0.75,
... window_size=100
... window_size=101
... )
>>> for i in range(0, 1001):
... rolling_iqr = rolling_iqr.update(i)
... if i % 100 == 0:
... print(rolling_iqr.get())
0
50
50
50
50
50
50
50
50
50
50
0.0
50.0
50.0
50.0
50.0
50.0
50.0
50.0
50.0
50.0
50.0
"""

Expand Down
52 changes: 34 additions & 18 deletions river/stats/quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,24 +181,24 @@ class RollingQuantile(base.RollingUnivariate, utils.SortedWindow):
>>> rolling_quantile = stats.RollingQuantile(
... q=.5,
... window_size=100,
... window_size=101,
... )
>>> for i in range(0, 1001):
>>> for i in range(1001):
... rolling_quantile = rolling_quantile.update(i)
... if i % 100 == 0:
... print(rolling_quantile.get())
0
50
150
250
350
450
550
650
750
850
950
0.0
50.0
150.0
250.0
350.0
450.0
550.0
650.0
750.0
850.0
950.0
References
----------
Expand All @@ -209,7 +209,25 @@ class RollingQuantile(base.RollingUnivariate, utils.SortedWindow):
def __init__(self, q, window_size):
super().__init__(size=window_size)
self.q = q
self.idx = int(round(self.q * self.size + 0.5)) - 1
idx = self.q * (self.size - 1)

self._lower = int(math.floor(idx))
self._higher = self._lower + 1
if self._higher > self.size - 1:
self._higher = self.size - 1
self._frac = idx - self._lower

def _prepare(self):
if len(self) < self.size:
idx = self.q * (len(self) - 1)
lower = int(math.floor(idx))
higher = lower + 1
if higher > len(self) - 1:
higher = len(self) - 1
frac = idx - lower
return lower, higher, frac

return self._lower, self._higher, self._frac

@property
def window_size(self):
Expand All @@ -220,7 +238,5 @@ def update(self, x):
return self

def get(self):
if len(self) < self.size:
idx = int(round(self.q * len(self) + 0.5)) - 1
return self[idx]
return self[self.idx]
lower, higher, frac = self._prepare()
return self[lower] + (self[higher] - self[lower]) * frac
20 changes: 20 additions & 0 deletions river/stats/test_.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,26 @@ def test_univariate(stat, func):
(stats.RollingMean(10), statistics.mean),
(stats.RollingVar(3, ddof=0), np.var),
(stats.RollingVar(10, ddof=0), np.var),
(
stats.RollingQuantile(0.0, 10),
functools.partial(np.quantile, q=0.0, interpolation="linear"),
),
(
stats.RollingQuantile(0.25, 10),
functools.partial(np.quantile, q=0.25, interpolation="linear"),
),
(
stats.RollingQuantile(0.5, 10),
functools.partial(np.quantile, q=0.5, interpolation="linear"),
),
(
stats.RollingQuantile(0.75, 10),
functools.partial(np.quantile, q=0.75, interpolation="linear"),
),
(
stats.RollingQuantile(1, 10),
functools.partial(np.quantile, q=1, interpolation="linear"),
),
],
)
def test_rolling_univariate(stat, func):
Expand Down

0 comments on commit 67b6267

Please sign in to comment.