Skip to content

Commit

Permalink
corrected logic of threshold function in multiclass_models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tmfreiberg committed Nov 27, 2024
1 parent d122b18 commit cd595a6
Showing 1 changed file with 78 additions and 18 deletions.
96 changes: 78 additions & 18 deletions scripts/multiclass_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,26 +862,86 @@ def get_argmax(row: pd.DataFrame,
return inverse_label_codes[dx] # Return the label if label_codes are provided
else:
return dx # Otherwise, return the code itself

def threshold(probabilities: pd.Series,
"""
START: REPEAIRING INCORRECT LOGIC OF THRESHOLD FUNCTION BELOW. MODELS WILL NEED TO BE RE-EVALUATED.
"""

# New version of the threshold function.
# The logic is correct but the function has not been tested and models have not been evaluated using this new version.
# Most likely the code will break if the new version of threshold is used, but we'll deal with it then.
# Hopefully this comment will help.

def threshold(probabilities: Union[pd.Series, np.ndarray], # This is the correct version, but has not actually been used with the rest of the codebase.
threshold_dict_help: Union[None, OrderedDict[str, float]],
threshold_dict_hinder: Union[None, OrderedDict[str, float]],
prefix: Union[None, str] = None) -> pd.Series:

if prefix is None:
prefix = 'prob_'

if isinstance(threshold_dict_help, OrderedDict):
for dx, thres in threshold_dict_help.items():
if prefix + dx in probabilities.index and probabilities[prefix + dx] > thres:
probabilities[prefix + dx] = 1
break
if isinstance(threshold_dict_hinder, OrderedDict):
for dx, thres in threshold_dict_hinder.items():
if prefix + dx in probabilities.index and probabilities[prefix + dx] < thres:
probabilities[prefix + dx] = 0
break
return probabilities
prefix: Union[None, str] = None) -> pd.Series:
"""
Adjusts probabilities based on help and hinder thresholds using vectorized operations.
Parameters:
probabilities (Union[pd.Series, np.ndarray]): Probabilities can be either a pandas Series or a numpy array.
threshold_dict_help (OrderedDict): Ordered dictionary of thresholds to promote probabilities.
threshold_dict_hinder (OrderedDict): Ordered dictionary of thresholds to demote probabilities.
prefix (str, optional): Prefix to match the Series columns (if input is a Series).
Returns:
pd.Series: Adjusted probabilities in the form of a pandas Series.
"""
if isinstance(probabilities, pd.Series):
original_index = probabilities.index # Preserve the index for later use
probabilities = probabilities.values # Convert to ndarray for processing
else:
original_index = None # No need to preserve index if it's already ndarray

adjusted_probabilities = probabilities.copy()
num_rows, num_classes = adjusted_probabilities.shape

# Process threshold_dict_help (promoting probabilities)
if threshold_dict_help is not None:
for class_, thres in threshold_dict_help.items():
class_idx = int(class_)
# Modify all rows where the condition is met
condition = adjusted_probabilities[:, class_idx] > thres
adjusted_probabilities[condition, class_idx] = 1

# Process threshold_dict_hinder (demoting probabilities)
if threshold_dict_hinder is not None:
for class_, thres in threshold_dict_hinder.items():
class_idx = int(class_)
# Modify all rows where the condition is met
condition = adjusted_probabilities[:, class_idx] < thres
adjusted_probabilities[condition, class_idx] = 0

# Convert the ndarray back to pd.Series with the original index
if original_index is not None:
adjusted_probabilities = pd.Series(adjusted_probabilities, index=original_index)

return adjusted_probabilities

# Old version of threshold function. It worked with the rest of the codebase but the logic is incorrect.
# def threshold(probabilities: pd.Series,
# threshold_dict_help: Union[None, OrderedDict[str, float]],
# threshold_dict_hinder: Union[None, OrderedDict[str, float]],
# prefix: Union[None, str] = None) -> pd.Series:
#
# if prefix is None:
# prefix = 'prob_'
#
# if isinstance(threshold_dict_help, OrderedDict):
# for dx, thres in threshold_dict_help.items():
# if prefix + dx in probabilities.index and probabilities[prefix + dx] > thres:
# probabilities[prefix + dx] = 1
# break
# if isinstance(threshold_dict_hinder, OrderedDict):
# for dx, thres in threshold_dict_hinder.items():
# if prefix + dx in probabilities.index and probabilities[prefix + dx] < thres:
# probabilities[prefix + dx] = 0
# break
# return probabilities

"""
END: REPEAIRING INCORRECT LOGIC OF THRESHOLD FUNCTION ABOVE. MODELS WILL NEED TO BE RE-EVALUATED.
"""

def aggregate_predictions(df: pd.DataFrame,
label_codes: Dict[int, str],
Expand Down

0 comments on commit cd595a6

Please sign in to comment.