Skip to content

Commit

Permalink
Translator updated
Browse files Browse the repository at this point in the history
  • Loading branch information
iryna-kondr committed Jun 18, 2023
1 parent 54c70f4 commit 89c6f4b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 37 deletions.
48 changes: 23 additions & 25 deletions skllm/openai/base_gpt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, List, Optional, Union
from typing import Any

import numpy as np
import pandas as pd
Expand All @@ -14,24 +14,21 @@
from skllm.utils import to_numpy as _to_numpy


class BaseZeroShotGPTTransformer(_BaseEstimator, _TransformerMixin, _OAIMixin):

class BaseZeroShotGPTTransformer(_BaseEstimator, _TransformerMixin, _OAIMixin):
system_msg = "You are an scikit-learn transformer."
default_output = "Output is unavailable"

def _get_chat_completion(self, X):
"""
Gets the chat completion for the given input using open ai API.
"""Gets the chat completion for the given input using open ai API.
Parameters
----------
X : str
Input string
Returns
-------
str
"""
prompt = self._get_prompt(X)
msgs = []
Expand All @@ -45,10 +42,11 @@ def _get_chat_completion(self, X):
except Exception as e:
print(f"Skipping a sample due to the following error: {str(e)}")
return self.default_output

def fit(self, X: Any = None, y: Any = None, **kwargs: Any) -> BaseZeroShotGPTTransformer:
"""
Fits the model to the data.

def fit(
self, X: Any = None, y: Any = None, **kwargs: Any
) -> BaseZeroShotGPTTransformer:
"""Fits the model to the data.
Parameters
----------
Expand All @@ -60,12 +58,13 @@ def fit(self, X: Any = None, y: Any = None, **kwargs: Any) -> BaseZeroShotGPTTra
-------
self : BaseZeroShotGPTTransformer
"""

return self

def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwargs: Any) -> ndarray:
"""
Converts a list of strings using the open ai API and a predefined prompt.
def transform(
self, X: np.ndarray | pd.Series | list[str], **kwargs: Any
) -> ndarray:
"""Converts a list of strings using the open ai API and a predefined
prompt.
Parameters
----------
Expand All @@ -78,23 +77,22 @@ def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwar
X = _to_numpy(X)
transformed = []
for i in tqdm(range(len(X))):
transformed.append(
self._get_chat_completion(X[i])
)
transformed = np.asarray(transformed, dtype = object)
transformed.append(self._get_chat_completion(X[i]))
transformed = np.asarray(transformed, dtype=object)
return transformed

def fit_transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], y=None, **fit_params) -> ndarray:
"""
Fits and transforms a list of strings using the transform method.
This is modelled to function as the sklearn fit_transform method
def fit_transform(
self, X: np.ndarray | pd.Series | list[str], y=None, **fit_params
) -> ndarray:
"""Fits and transforms a list of strings using the transform method.
This is modelled to function as the sklearn fit_transform method.
Parameters
----------
X : np.ndarray, pd.Series, or list
Returns
-------
ndarray
ndarray
"""
return self.fit(X, y).transform(X)
return self.fit(X, y).transform(X)
14 changes: 13 additions & 1 deletion skllm/preprocessing/gpt_translator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Optional
from typing import Any, List, Optional, Union

import numpy as np
from numpy import ndarray
from pandas import Series

from skllm.openai.base_gpt import BaseZeroShotGPTTransformer as _BaseGPT
from skllm.prompts.builders import build_translation_prompt
Expand Down Expand Up @@ -35,3 +39,11 @@ def _get_prompt(self, X: str) -> str:
translated sample
"""
return build_translation_prompt(X, self.output_language)

def transform(self, X: Union[ndarray, Series, List[str]], **kwargs: Any) -> ndarray:
y = super().transform(X, **kwargs)
y = np.asarray(
[i.replace("[Translated text:]", "").replace("```", "").strip() for i in y],
dtype=object,
)
return y
15 changes: 4 additions & 11 deletions skllm/prompts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,9 @@
"""

TRANSLATION_PROMPT_TEMPLATE = """
You will be provided with an arbitrary text sample, delimited by triple backticks.
Your task is to translate this text to {output_language} language and output the translated text.
If the original text, delimited by triple backticks, is already in {output_language} language, output the original text.
Otherwise, translate the original text, delimited by triple backticks, to {output_language} language, and output the translated text only. Do not output any additional information except the translated text.
Perform the following actions:
1. Determine the language of the text sample.
2. If the language is not {output_language}, translate the text sample to {output_language} language.
3. Output the translated text.
If the text sample provided is not in a recognizable language, output "No translation available".
Do not output any additional information except the translated text.
Text sample: ```{x}```
Translated text:
Original text: ```{x}```
Output:
"""

0 comments on commit 89c6f4b

Please sign in to comment.