Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ankargren authored and iampelle committed Aug 29, 2022
1 parent 239eb1a commit 7fd9606
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,15 @@ def _sufficient_statistics(self) -> DataFrame:
)
.assign(
**{
ORIGINAL_POINT_ESTIMATE: lambda df: confidence_computers[ZTEST].point_estimate(
df, arg_dict
ORIGINAL_POINT_ESTIMATE: lambda df: df[self._point_estimate_column]
if self._point_estimate_column is not None
else (
confidence_computers[ZTEST].point_estimate(df, arg_dict)
if df[self._method_column].values[0] == ZTESTLINREG
else confidence_computers[df[self._method_column].values[0]].point_estimate(
df, arg_dict
)
)
if df[self._method_column].values[0] == ZTESTLINREG
else confidence_computers[df[self._method_column].values[0]].point_estimate(df, arg_dict)
}
)
.assign(
Expand All @@ -291,9 +295,13 @@ def _sufficient_statistics(self) -> DataFrame:
)
.assign(
**{
ORIGINAL_VARIANCE: lambda df: confidence_computers[ZTEST].variance(df, arg_dict)
if df[self._method_column].values[0] == ZTESTLINREG
else confidence_computers[df[self._method_column].values[0]].variance(df, arg_dict)
ORIGINAL_VARIANCE: lambda df: df[self._var_column]
if self._var_column is not None
else (
confidence_computers[ZTEST].variance(df, arg_dict)
if df[self._method_column].values[0] == ZTESTLINREG
else confidence_computers[df[self._method_column].values[0]].variance(df, arg_dict)
)
}
)
.pipe(
Expand Down Expand Up @@ -528,8 +536,7 @@ def join(df: DataFrame) -> DataFrame:
)
.pipe(
drop_and_rename_columns,
[NULL_HYPOTHESIS, ALTERNATIVE_HYPOTHESIS, f"current_total_{self._denominator}"]
+ ([ORIGINAL_POINT_ESTIMATE] if ORIGINAL_POINT_ESTIMATE in df.columns else []),
[NULL_HYPOTHESIS, ALTERNATIVE_HYPOTHESIS, f"current_total_{self._denominator}"],
)
.assign(**{PREFERENCE_TEST: lambda df: TWO_SIDED if self._correction_method == SPOT_1 else df[PREFERENCE]})
.assign(**{POWER: self._power})
Expand Down Expand Up @@ -1081,7 +1088,7 @@ def _powered_effect_and_required_sample_size_from_difference_df(df: DataFrame, a
z_power=z_power,
binary=binary,
non_inferiority=non_inferiority,
avg_column=ORIGINAL_POINT_ESTIMATE,
avg_column=ORIGINAL_POINT_ESTIMATE + SFX1,
var_column=VARIANCE + SFX1,
)

Expand All @@ -1093,7 +1100,7 @@ def _powered_effect_and_required_sample_size_from_difference_df(df: DataFrame, a
binary=binary,
non_inferiority=non_inferiority,
hypothetical_effect=df[ALTERNATIVE_HYPOTHESIS] - df[NULL_HYPOTHESIS],
control_avg=df[ORIGINAL_POINT_ESTIMATE],
control_avg=df[ORIGINAL_POINT_ESTIMATE + SFX1],
control_var=df[VARIANCE + SFX1],
kappa=kappa,
)
Expand All @@ -1106,7 +1113,7 @@ def _powered_effect_and_required_sample_size_from_difference_df(df: DataFrame, a
binary=binary,
non_inferiority=non_inferiority,
hypothetical_effect=df[ALTERNATIVE_HYPOTHESIS] - df[NULL_HYPOTHESIS],
control_avg=df[ORIGINAL_POINT_ESTIMATE],
control_avg=df[ORIGINAL_POINT_ESTIMATE + SFX1],
control_var=df[VARIANCE + SFX1],
kappa=kappa,
)
Expand Down
10 changes: 4 additions & 6 deletions tests/frequentist/test_ztest_linreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd

import spotify_confidence
from spotify_confidence.analysis.constants import REGRESSION_PARAM, DECREASE_PREFFERED, INCREASE_PREFFERED
from spotify_confidence.analysis.constants import REGRESSION_PARAM, DECREASE_PREFFERED


class TestUnivariateSingleMetric(object):
Expand Down Expand Up @@ -209,9 +209,7 @@ def setup(self):

def test_summary(self):
summary_ztest = self.ztest.summary(verbose=True).drop(["_method"], axis=1)
summary_ztestlinreg = self.ztestlinreg.summary(verbose=True).drop(
["_method", "original_variance", "original_point_estimate"], axis=1
)
summary_ztestlinreg = self.ztestlinreg.summary(verbose=True).drop(["_method"], axis=1)
pd.testing.assert_frame_equal(summary_ztest, summary_ztestlinreg)


Expand Down Expand Up @@ -453,8 +451,8 @@ def setup(self):
data = pd.DataFrame({"variation_name": list(map(str, d)), "metric_name": m, "y": y, "x": x})
data = (
data.assign(xy=y * x)
.assign(x2=x ** 2)
.assign(y2=y ** 2)
.assign(x2=x**2)
.assign(y2=y**2)
.groupby(["variation_name", "metric_name"])
.agg({"y": ["sum", "count"], "y2": "sum", "x": "sum", "x2": "sum", "xy": "sum"})
.reset_index()
Expand Down

0 comments on commit 7fd9606

Please sign in to comment.