forked from BeastByteAI/scikit-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgpt_translator.py
49 lines (39 loc) · 1.37 KB
/
gpt_translator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Any, List, Optional, Union
import numpy as np
from numpy import ndarray
from pandas import Series
from skllm.openai.base_gpt import BaseZeroShotGPTTransformer as _BaseGPT
from skllm.prompts.builders import build_translation_prompt
class GPTTranslator(_BaseGPT):
"""A text translator."""
system_msg = "You are a text translator."
default_output = "Translation is unavailable."
def __init__(
self,
openai_key: Optional[str] = None,
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
output_language: str = "English",
) -> None:
self._set_keys(openai_key, openai_org)
self.openai_model = openai_model
self.output_language = output_language
def _get_prompt(self, X: str) -> str:
"""Generates the prompt for the given input.
Parameters
----------
X : str
sample to translate
Returns
-------
str
translated sample
"""
return build_translation_prompt(X, self.output_language)
def transform(self, X: Union[ndarray, Series, List[str]], **kwargs: Any) -> ndarray:
y = super().transform(X, **kwargs)
y = np.asarray(
[i.replace("[Translated text:]", "").replace("```", "").strip() for i in y],
dtype=object,
)
return y