Skip to content

Commit

Permalink
fix type checking/merging
Browse files Browse the repository at this point in the history
  • Loading branch information
alistairewj committed Nov 19, 2019
1 parent 4af86fe commit caec77c
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions pyroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(self, target, preds):

# initialize vars
self.K = len(self.preds)
self.n_obs = len(target)
self.n_pos = np.sum(target == 1)
self.n_obs = len(self.target)
self.n_pos = np.sum(self.target == 1)
self.n_neg = self.n_obs - self.n_pos

# First parse the predictions into matrices X and Y
Expand Down Expand Up @@ -108,22 +108,33 @@ def _parse_inputs(self, preds, target):
"""
if type(preds) is list:
# convert preds into a dictionary
predictors = [x for x in range(len(preds))]
preds = OrderedDict(enumerate(preds))
if len(preds) == len(target):
predictors = [0]
preds = OrderedDict([(0, np.asarray(preds))])
elif hasattr(preds[0], '__len__'):
# convert preds into a dictionary
predictors = [x for x in range(len(preds))]
preds = OrderedDict(
[[i, np.asarray(p)] for i, p in enumerate(preds)]
)
else:
raise TypeError(
'unable to parse preds list with element type %s',
type(preds[0])
)
elif type(preds) is pd.DataFrame:
# preds is a dict - convert to ordered
predictors = list(preds.columns)
preds = OrderedDict(zip(preds.columns, preds.T.values))
elif 'array' in str(type(preds)):
# convert preds into a dictionary
predictors = [0]
preds = OrderedDict([[0, preds]])
preds = OrderedDict([[0, np.asarray(preds)]])
elif type(preds) is dict:
# preds is a dict - convert to ordered
predictors = list(preds.keys())
predictors.sort()
preds = OrderedDict([[c, preds[c]] for c in predictors])
preds = OrderedDict([[c, np.asarray(preds[c])] for c in predictors])
elif type(preds) is not OrderedDict:
raise ValueError(
'Unrecognized type "%s" for predictions.', str(type(preds))
Expand All @@ -135,6 +146,8 @@ def _parse_inputs(self, preds, target):

if type(target) is pd.Series:
target = target.values
elif type(target) in (list, tuple):
target = np.asarray(target)
elif type(target) is not np.ndarray:
raise TypeError(
'target should be type np.ndarray, was %s', type(target)
Expand Down

0 comments on commit caec77c

Please sign in to comment.