Skip to content

Commit

Permalink
add reset for token counts
Browse files Browse the repository at this point in the history
  • Loading branch information
kpister committed Oct 7, 2023
1 parent 772edd2 commit e3e1f0a
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 40 deletions.
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
[tool.poetry]
name = "sllim"
version = "0.1.25"
version = "0.1.26"
description = "Fixing the openai api"
authors = ["Kaiser Pister <[email protected]>"]
readme = "README.md"
homepage = "https://github.com/kpister/sllim"

[tool.poetry.dependencies]
python = "^3.9"
openai = "^0.27.8"
tiktoken = "^0.4.0"
pyspellchecker = "^0.7.2"
python = ">= 3.9"
openai = ">= 0.28"
tiktoken = ">= 0.4.0"
pyspellchecker = ">= 0.7.2"

[tool.poetry.scripts]
check = "sllim.check:run"
Expand Down
55 changes: 40 additions & 15 deletions sllim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@

prompt_tokens, completion_tokens = 0, 0


def load_template(filepath: str) -> tuple[str, str]:
with open(filepath, "r") as f:
description, text = f.read().split("\n", 1)
if not description.startswith("#"):
# Not a proper prompt file
logger.warning(f"File {filepath} does not start with a `# description line`.")
logger.warning(
f"File {filepath} does not start with a `# description line`."
)
text = description + "\n" + text
description = ""
return description, text.strip()
Expand All @@ -38,11 +41,13 @@ class Message(TypedDict):
role: str
content: str


class FunctionT(TypedDict):
name: str
description: str
parameters: dict[str, dict[str, str | dict]]


Prompt = TypeVar("Prompt", str, list[Message])

API_PARAMS = dict(
Expand Down Expand Up @@ -76,6 +81,12 @@ def get_token_counts():
return prompt_tokens, completion_tokens


def reset_token_counts():
global prompt_tokens, completion_tokens

prompt_tokens, completion_tokens = 0, 0


def try_make(folder_name):
try:
os.makedirs(folder_name, exist_ok=True)
Expand All @@ -93,6 +104,7 @@ def close(self, *args, **kwargs):
def read(self, *args, **kwargs):
return "{}"


@contextmanager
def try_open(filename, mode="r"):
try:
Expand All @@ -107,7 +119,6 @@ def try_open(filename, mode="r"):
f.close()



def cache(fn):
folder_name = ".cache"
cache_file = os.path.join(folder_name, f"{fn.__name__}.json")
Expand Down Expand Up @@ -202,7 +213,9 @@ def chat(
}
max_tokens_str = "infinity" if max_tokens is None else str(max_tokens)
model_str = model if model else deployment_id
logger.info(f"Calling {model_str} using at most {max_tokens_str} with messages: {messages}")
logger.info(
f"Calling {model_str} using at most {max_tokens_str} with messages: {messages}"
)

response = openai.ChatCompletion.create(
messages=messages,
Expand Down Expand Up @@ -263,12 +276,13 @@ def call(
message = response.choices[0].message
prompt_tokens += response.usage.prompt_tokens
completion_tokens += response.usage.completion_tokens

if message.get("function_call"):
return message.function_call

raise Exception("No function call found in response. %s" % str(message))


@catch
@cache
def complete(
Expand Down Expand Up @@ -311,19 +325,23 @@ def complete(

@catch
@cache
def embed(text: str | list[str], engine="text-embedding-ada-002", deployment_id: str=None) -> list[float] | list[list[float]]:
def embed(
text: str | list[str], engine="text-embedding-ada-002", deployment_id: str = None
) -> list[float] | list[list[float]]:
if deployment_id:
engine = None
kwargs = {k: v for k, v in locals().items() if k in ["engine", "deployment_id"] and v is not None}
kwargs = {
k: v
for k, v in locals().items()
if k in ["engine", "deployment_id"] and v is not None
}
if isinstance(text, list):
text = [t.replace("\n", " ") for t in text]
response = openai.Embedding.create(input=text, **kwargs)["data"]
return [x["embedding"] for x in response]
else:
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], **kwargs)["data"][0][
"embedding"
]
return openai.Embedding.create(input=[text], **kwargs)["data"][0]["embedding"]


def estimate(messages_or_prompt: Prompt, model="gpt-3.5-turbo"):
Expand Down Expand Up @@ -465,6 +483,7 @@ def map_reduce(template: Prompt, n=8, **kwargs):
finally:
collate_caches(params["__function"])


def thread_map_reduce(template: Prompt, n=8, **kwargs):
params = {}
for key, value in kwargs.items():
Expand All @@ -490,9 +509,12 @@ def thread_map_reduce(template: Prompt, n=8, **kwargs):
iterator = to_slices(template, iters, constants)

results = ["" for _ in range(max_len)]

def thread_call(message, idx):
try:
results[idx] = fn(message, **{k: v for k, v in params.items()}, nocache=True)
results[idx] = fn(
message, **{k: v for k, v in params.items()}, nocache=True
)
except RateLimitError:
results[idx] = ""

Expand All @@ -509,7 +531,7 @@ def thread_call(message, idx):

for thread in threads:
thread.join()

return results


Expand All @@ -534,16 +556,17 @@ def collate_caches(function_name):
with try_open(cache_file, "w") as w:
json.dump(cache, w)


def to_type_name(_type: str):
if "list" in _type or "tuple" in _type:
return "array", {"items": {"type": "string"}}


return {
"str": "string",
"int": "number",
}.get(_type, _type), {}


def parse_doc(doc: str):
if not doc:
return "", {}, []
Expand Down Expand Up @@ -576,6 +599,7 @@ def parse_doc(doc: str):

return fn_description, properties, required


def create_function_call(fn: Callable) -> FunctionT:
description, properties, required = parse_doc(fn.__doc__)
return {
Expand All @@ -585,10 +609,11 @@ def create_function_call(fn: Callable) -> FunctionT:
"type": "object",
"properties": properties,
"required": required,
}
},
}


def format(s: str, **kwargs):
for key, value in kwargs.items():
s = s.replace("{" + key + "}", value)
return s
return s
75 changes: 55 additions & 20 deletions sllim/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = logging.getLogger(__name__)

spell = SpellChecker()
spell.word_frequency.load_words(['schema', 'bulleted'])
spell.word_frequency.load_words(["schema", "bulleted"])


@dataclass
Expand All @@ -25,6 +25,7 @@ class PromptFile:
extra_variables: list[str] = field(default_factory=list)
missing_variables: list[str] = field(default_factory=list)


@dataclass
class PythonFile:
filename: str
Expand Down Expand Up @@ -53,28 +54,38 @@ def gather(directory, ignore=None, ext=".txt"):

return out_files


def check_spelling(pfile: PromptFile):
results = spell.unknown(spell.split_words(pfile.text))
pfile.typos = list(filter(lambda x: len(x) > 3 and "_" not in x and spell.candidates(x), results))
pfile.typos = list(
filter(lambda x: len(x) > 3 and "_" not in x and spell.candidates(x), results)
)


def load_variables(pfile: PromptFile):
pfile.variables = set(re.findall(r"\{(\w+)\}", pfile.text))


def check_usage(tfile: PromptFile, pyfiles: list[PythonFile]):
for pyfile in pyfiles:
for _, (filename, usages) in pyfile.uses.items():
if filename not in tfile.filename:
continue

for variables in usages:
for variable in (set(variables) - set(tfile.variables)):
logger.warning(f"{os.path.relpath(pyfile.filename)}: Variable {variable} not defined in {os.path.relpath(tfile.filename)}.")
for variable in set(variables) - set(tfile.variables):
logger.warning(
f"{os.path.relpath(pyfile.filename)}: Variable {variable} not defined in {os.path.relpath(tfile.filename)}."
)
tfile.extra_variables.append(variable)

for variable in (set(tfile.variables) - set(variables)):
logger.warning(f"{os.path.relpath(pyfile.filename)}: Variable {variable} not used in {os.path.relpath(tfile.filename)}.")

for variable in set(tfile.variables) - set(variables):
logger.warning(
f"{os.path.relpath(pyfile.filename)}: Variable {variable} not used in {os.path.relpath(tfile.filename)}."
)
tfile.missing_variables.append(variable)


def check_file(filepath, pyfiles: list[PythonFile]):
description, prompt = load_template(filepath)

Expand All @@ -84,17 +95,18 @@ def check_file(filepath, pyfiles: list[PythonFile]):
check_usage(pfile, pyfiles)
return pfile


def load_python_file(filepath):
try:
with open(filepath, "r") as f:
text = f.read()
except UnicodeDecodeError:
logger.warning(f"Unable to read {filepath}.")
return None

if "sllim" not in text or "load_template" not in text:
return None

track_ids = walk(ast.parse(text))
python_file = PythonFile(filename=filepath, text=text, uses=track_ids)
return python_file
Expand All @@ -112,23 +124,34 @@ def get_arg(node: ast.Call):
elif isinstance(node.keywords[0], ast.Name):
return node.keywords[0].id
return None



def handle_load_template(node: ast.AST):
if isinstance(node, ast.Call) and hasattr(node.func, "id") and node.func.id == "load_template":
if (
isinstance(node, ast.Call)
and hasattr(node.func, "id")
and node.func.id == "load_template"
):
return get_arg(node)

return None


def handle_use_template(ancestor: ast.AST, node: ast.AST, track_ids):
if isinstance(node, ast.Name) and node.id in track_ids:
if isinstance(ancestor, ast.Call) and hasattr(ancestor.func, "id") and ancestor.func.id == "format":
if (
isinstance(ancestor, ast.Call)
and hasattr(ancestor.func, "id")
and ancestor.func.id == "format"
):
return get_arg(ancestor), ancestor.keywords
elif isinstance(node, ast.Attribute) and node.attr == "format":
if isinstance(ancestor, ast.Call):
return ancestor.func.value.id, ancestor.keywords

return None, None


def walk(tree: ast.AST):
track_ids = {}
for ancestor in ast.walk(tree):
Expand All @@ -139,13 +162,14 @@ def walk(tree: ast.AST):
else:
tid = ancestor.targets[0].id
track_ids[tid] = (value, [])
else:
else:
key, keywords = handle_use_template(ancestor, child, track_ids)
if key and keywords and key in track_ids:
track_ids[key][1].append([x.arg for x in keywords])

return track_ids


def run_checks(text_files, python_files):
tfiles = []
pyfiles = []
Expand All @@ -155,14 +179,14 @@ def run_checks(text_files, python_files):

for filepath in text_files:
tfiles.append(check_file(filepath, pyfiles))

spelling_errors = False
variable_errors = False
for tfile in tfiles:
if tfile.typos:
logger.warning(f"Possible typos in {tfile.filename}: {tfile.typos}")
spelling_errors = True

if tfile.extra_variables or tfile.missing_variables:
variable_errors = True

Expand All @@ -171,17 +195,28 @@ def run_checks(text_files, python_files):

if not variable_errors:
logger.info("No variable errors found.")


def run():
"""Run check command."""
# All the logic of argparse goes in this function
parser = argparse.ArgumentParser(description='Basic prompt checking.')
parser.add_argument('directory', type=str, help='the name of the directory to search for prompts', default=".")
parser.add_argument('--ignore', type=str, help='a directory or file to exclude', default=None, required=False)
parser = argparse.ArgumentParser(description="Basic prompt checking.")
parser.add_argument(
"directory",
type=str,
help="the name of the directory to search for prompts",
default=".",
)
parser.add_argument(
"--ignore",
type=str,
help="a directory or file to exclude",
default=None,
required=False,
)

args = parser.parse_args()
prompt_files = gather(args.directory, args.ignore, ext=".txt")
prompt_files += gather(args.directory, args.ignore, ext=".prompt")
python_files = gather(args.directory, args.ignore, ext=".py")
run_checks(prompt_files, python_files)
run_checks(prompt_files, python_files)

0 comments on commit e3e1f0a

Please sign in to comment.