Skip to content

Commit

Permalink
CR Comments v2
Browse files Browse the repository at this point in the history
  • Loading branch information
Nadav-Barak committed Jun 4, 2023
1 parent 21484d4 commit be7b50f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
23 changes: 6 additions & 17 deletions skllm/models/gpt_zero_shot_clf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
from abc import ABC, abstractmethod
from collections import Counter
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Union, Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -34,8 +34,6 @@ class _BaseZeroShotGPTClassifier(ABC, BaseEstimator, ClassifierMixin, _OAIMixin)
default_label : Optional[Union[List[str], str]] , default : 'Random'
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
label will be chosen based on probabilities from the training set.
random_state : int, default 42
A seed to set for pseudo-random functions, primarily random selection.
"""

def __init__(
Expand All @@ -44,13 +42,10 @@ def __init__(
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
default_label: Optional[Union[List[str], str]] = 'Random',
random_state: int = 42,
):
self._set_keys(openai_key, openai_org)
self.openai_model = openai_model
self.default_label = default_label
random.seed(random_state)
np.random.seed(random_state)

def _to_np(self, X):
return _to_numpy(X)
Expand Down Expand Up @@ -121,8 +116,6 @@ class ZeroShotGPTClassifier(_BaseZeroShotGPTClassifier):
default_label : Optional[str] , default : 'Random'
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
label will be chosen based on probabilities from the training set.
random_state : int, default 42
A seed to set for pseudo-random functions, primarily random selection.
"""

def __init__(
Expand All @@ -131,9 +124,8 @@ def __init__(
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
default_label: Optional[str] = 'Random',
random_state: int = 42,
):
super().__init__(openai_key, openai_org, openai_model, default_label, random_state)
super().__init__(openai_key, openai_org, openai_model, default_label)

def _extract_labels(self, y: Any) -> List[str]:
if isinstance(y, (pd.Series, np.ndarray)):
Expand Down Expand Up @@ -198,24 +190,21 @@ class MultiLabelZeroShotGPTClassifier(_BaseZeroShotGPTClassifier):
openai_model : str , default : "gpt-3.5-turbo"
The OpenAI model to use. See https://beta.openai.com/docs/api-reference/available-models for a list of
available models.
default_label : Optional[Union[List[str], str]] , default : 'Random'
default_label : Optional[Union[List[str], Literal['Random']] , default : 'Random'
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
label will be chosen based on probabilities from the training set.
max_labels : int , default : 3
The maximum number of labels to predict for each sample.
random_state : int, default 42
A seed to set for pseudo-random functions, primarily random selection.
"""
def __init__(
self,
openai_key: Optional[str] = None,
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
default_label: Optional[Union[List[str], str]] = 'Random',
default_label: Optional[Union[List[str], Literal['Random']]] = 'Random',
max_labels: int = 3,
random_state: int = 42,
):
super().__init__(openai_key, openai_org, openai_model, default_label, random_state)
super().__init__(openai_key, openai_org, openai_model, default_label)
if max_labels < 2:
raise ValueError("max_labels should be at least 2")
if isinstance(default_label, str) and default_label != "Random":
Expand Down Expand Up @@ -259,7 +248,7 @@ def _predict_single(self, x):
if len(labels) == 0:
labels = self._get_default_label()
if labels is not None and len(labels) > self.max_labels:
labels = random.choices(labels, k=self.max_labels)
labels = labels[:self.max_labels - 1]
return labels

def fit(
Expand Down
10 changes: 8 additions & 2 deletions tests/test_gpt_zero_shot_clf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import random
import unittest
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -36,7 +37,9 @@ def test_fit_predict_unknown_label_set_default(self, mock_get_chat_completion):
mock_get_chat_completion.return_value.choices[0].message = {"content": json.dumps({"label": "new_class"})}

predictions = clf.predict(["text1", "text2", "text3"])
self.assertEqual(predictions, ["class1", "class1", "class1"]) # Random choice
random.seed(42)
np.random.seed(42)
self.assertEqual(predictions, ["class2", "class1", "class1"]) # Random choice

clf.default_label = "default_class"
predictions = clf.predict(["text1", "text2", "text3"])
Expand All @@ -50,7 +53,8 @@ def test_fit_predict_unknown_label_set_default(self, mock_get_chat_completion):
class TestMultiLabelZeroShotGPTClassifier(unittest.TestCase):

def get_mock_clf_model(self):
clf = MultiLabelZeroShotGPTClassifier(openai_key="mock_key", openai_org="mock_org") # Mock keys
clf = MultiLabelZeroShotGPTClassifier(openai_key="mock_key", openai_org="mock_org",
default_label='Random') # Mock keys
X = np.array(["text1", "text2", "text3"])
y = [["class1", "class2"], ["class1", "class2"],
["class1", "class2"]] # Adjusted y to ensure [0.5, 0.5] probability
Expand All @@ -75,6 +79,8 @@ def test_fit_predict_unknown_label_set_default(self, mock_get_chat_completion):
mock_get_chat_completion.return_value.choices[0].message = {"content": json.dumps({"label": "new_class"})}

predictions = clf.predict(["text1", "text2", "text3"])
random.seed(42)
np.random.seed(42)
self.assertEqual(predictions, [['class1'], [], ['class1', 'class2']]) # Random choice

clf.default_label = ["default_class"]
Expand Down

0 comments on commit be7b50f

Please sign in to comment.