diff --git a/pyproject.toml b/pyproject.toml index 2aaa1c0..c4fbbca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,16 @@ [tool.poetry] name = "sllim" -version = "0.1.25" +version = "0.1.26" description = "Fixing the openai api" authors = ["Kaiser Pister "] readme = "README.md" homepage = "https://github.com/kpister/sllim" [tool.poetry.dependencies] -python = "^3.9" -openai = "^0.27.8" -tiktoken = "^0.4.0" -pyspellchecker = "^0.7.2" +python = ">= 3.9" +openai = ">= 0.28" +tiktoken = ">= 0.4.0" +pyspellchecker = ">= 0.7.2" [tool.poetry.scripts] check = "sllim.check:run" diff --git a/sllim/__init__.py b/sllim/__init__.py index 22794f0..c330625 100644 --- a/sllim/__init__.py +++ b/sllim/__init__.py @@ -23,12 +23,15 @@ prompt_tokens, completion_tokens = 0, 0 + def load_template(filepath: str) -> tuple[str, str]: with open(filepath, "r") as f: description, text = f.read().split("\n", 1) if not description.startswith("#"): # Not a proper prompt file - logger.warning(f"File {filepath} does not start with a `# description line`.") + logger.warning( + f"File {filepath} does not start with a `# description line`." + ) text = description + "\n" + text description = "" return description, text.strip() @@ -38,11 +41,13 @@ class Message(TypedDict): role: str content: str + class FunctionT(TypedDict): name: str description: str parameters: dict[str, dict[str, str | dict]] + Prompt = TypeVar("Prompt", str, list[Message]) API_PARAMS = dict( @@ -76,6 +81,12 @@ def get_token_counts(): return prompt_tokens, completion_tokens +def reset_token_counts(): + global prompt_tokens, completion_tokens + + prompt_tokens, completion_tokens = 0, 0 + + def try_make(folder_name): try: os.makedirs(folder_name, exist_ok=True) @@ -93,6 +104,7 @@ def close(self, *args, **kwargs): def read(self, *args, **kwargs): return "{}" + @contextmanager def try_open(filename, mode="r"): try: @@ -107,7 +119,6 @@ def try_open(filename, mode="r"): f.close() - def cache(fn): folder_name = ".cache" cache_file = os.path.join(folder_name, f"{fn.__name__}.json") @@ -202,7 +213,9 @@ def chat( } max_tokens_str = "infinity" if max_tokens is None else str(max_tokens) model_str = model if model else deployment_id - logger.info(f"Calling {model_str} using at most {max_tokens_str} with messages: {messages}") + logger.info( + f"Calling {model_str} using at most {max_tokens_str} with messages: {messages}" + ) response = openai.ChatCompletion.create( messages=messages, @@ -263,12 +276,13 @@ def call( message = response.choices[0].message prompt_tokens += response.usage.prompt_tokens completion_tokens += response.usage.completion_tokens - + if message.get("function_call"): return message.function_call - + raise Exception("No function call found in response. %s" % str(message)) + @catch @cache def complete( @@ -311,19 +325,23 @@ def complete( @catch @cache -def embed(text: str | list[str], engine="text-embedding-ada-002", deployment_id: str=None) -> list[float] | list[list[float]]: +def embed( + text: str | list[str], engine="text-embedding-ada-002", deployment_id: str = None +) -> list[float] | list[list[float]]: if deployment_id: engine = None - kwargs = {k: v for k, v in locals().items() if k in ["engine", "deployment_id"] and v is not None} + kwargs = { + k: v + for k, v in locals().items() + if k in ["engine", "deployment_id"] and v is not None + } if isinstance(text, list): text = [t.replace("\n", " ") for t in text] response = openai.Embedding.create(input=text, **kwargs)["data"] return [x["embedding"] for x in response] else: text = text.replace("\n", " ") - return openai.Embedding.create(input=[text], **kwargs)["data"][0][ - "embedding" - ] + return openai.Embedding.create(input=[text], **kwargs)["data"][0]["embedding"] def estimate(messages_or_prompt: Prompt, model="gpt-3.5-turbo"): @@ -465,6 +483,7 @@ def map_reduce(template: Prompt, n=8, **kwargs): finally: collate_caches(params["__function"]) + def thread_map_reduce(template: Prompt, n=8, **kwargs): params = {} for key, value in kwargs.items(): @@ -490,9 +509,12 @@ def thread_map_reduce(template: Prompt, n=8, **kwargs): iterator = to_slices(template, iters, constants) results = ["" for _ in range(max_len)] + def thread_call(message, idx): try: - results[idx] = fn(message, **{k: v for k, v in params.items()}, nocache=True) + results[idx] = fn( + message, **{k: v for k, v in params.items()}, nocache=True + ) except RateLimitError: results[idx] = "" @@ -509,7 +531,7 @@ def thread_call(message, idx): for thread in threads: thread.join() - + return results @@ -534,16 +556,17 @@ def collate_caches(function_name): with try_open(cache_file, "w") as w: json.dump(cache, w) + def to_type_name(_type: str): if "list" in _type or "tuple" in _type: return "array", {"items": {"type": "string"}} - return { "str": "string", "int": "number", }.get(_type, _type), {} + def parse_doc(doc: str): if not doc: return "", {}, [] @@ -576,6 +599,7 @@ def parse_doc(doc: str): return fn_description, properties, required + def create_function_call(fn: Callable) -> FunctionT: description, properties, required = parse_doc(fn.__doc__) return { @@ -585,10 +609,11 @@ def create_function_call(fn: Callable) -> FunctionT: "type": "object", "properties": properties, "required": required, - } + }, } + def format(s: str, **kwargs): for key, value in kwargs.items(): s = s.replace("{" + key + "}", value) - return s \ No newline at end of file + return s diff --git a/sllim/check.py b/sllim/check.py index b6bcd9c..099e386 100644 --- a/sllim/check.py +++ b/sllim/check.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) spell = SpellChecker() -spell.word_frequency.load_words(['schema', 'bulleted']) +spell.word_frequency.load_words(["schema", "bulleted"]) @dataclass @@ -25,6 +25,7 @@ class PromptFile: extra_variables: list[str] = field(default_factory=list) missing_variables: list[str] = field(default_factory=list) + @dataclass class PythonFile: filename: str @@ -53,13 +54,18 @@ def gather(directory, ignore=None, ext=".txt"): return out_files + def check_spelling(pfile: PromptFile): results = spell.unknown(spell.split_words(pfile.text)) - pfile.typos = list(filter(lambda x: len(x) > 3 and "_" not in x and spell.candidates(x), results)) + pfile.typos = list( + filter(lambda x: len(x) > 3 and "_" not in x and spell.candidates(x), results) + ) + def load_variables(pfile: PromptFile): pfile.variables = set(re.findall(r"\{(\w+)\}", pfile.text)) + def check_usage(tfile: PromptFile, pyfiles: list[PythonFile]): for pyfile in pyfiles: for _, (filename, usages) in pyfile.uses.items(): @@ -67,14 +73,19 @@ def check_usage(tfile: PromptFile, pyfiles: list[PythonFile]): continue for variables in usages: - for variable in (set(variables) - set(tfile.variables)): - logger.warning(f"{os.path.relpath(pyfile.filename)}: Variable {variable} not defined in {os.path.relpath(tfile.filename)}.") + for variable in set(variables) - set(tfile.variables): + logger.warning( + f"{os.path.relpath(pyfile.filename)}: Variable {variable} not defined in {os.path.relpath(tfile.filename)}." + ) tfile.extra_variables.append(variable) - - for variable in (set(tfile.variables) - set(variables)): - logger.warning(f"{os.path.relpath(pyfile.filename)}: Variable {variable} not used in {os.path.relpath(tfile.filename)}.") + + for variable in set(tfile.variables) - set(variables): + logger.warning( + f"{os.path.relpath(pyfile.filename)}: Variable {variable} not used in {os.path.relpath(tfile.filename)}." + ) tfile.missing_variables.append(variable) + def check_file(filepath, pyfiles: list[PythonFile]): description, prompt = load_template(filepath) @@ -84,6 +95,7 @@ def check_file(filepath, pyfiles: list[PythonFile]): check_usage(pfile, pyfiles) return pfile + def load_python_file(filepath): try: with open(filepath, "r") as f: @@ -91,10 +103,10 @@ def load_python_file(filepath): except UnicodeDecodeError: logger.warning(f"Unable to read {filepath}.") return None - + if "sllim" not in text or "load_template" not in text: return None - + track_ids = walk(ast.parse(text)) python_file = PythonFile(filename=filepath, text=text, uses=track_ids) return python_file @@ -112,16 +124,26 @@ def get_arg(node: ast.Call): elif isinstance(node.keywords[0], ast.Name): return node.keywords[0].id return None - + + def handle_load_template(node: ast.AST): - if isinstance(node, ast.Call) and hasattr(node.func, "id") and node.func.id == "load_template": + if ( + isinstance(node, ast.Call) + and hasattr(node.func, "id") + and node.func.id == "load_template" + ): return get_arg(node) return None + def handle_use_template(ancestor: ast.AST, node: ast.AST, track_ids): if isinstance(node, ast.Name) and node.id in track_ids: - if isinstance(ancestor, ast.Call) and hasattr(ancestor.func, "id") and ancestor.func.id == "format": + if ( + isinstance(ancestor, ast.Call) + and hasattr(ancestor.func, "id") + and ancestor.func.id == "format" + ): return get_arg(ancestor), ancestor.keywords elif isinstance(node, ast.Attribute) and node.attr == "format": if isinstance(ancestor, ast.Call): @@ -129,6 +151,7 @@ def handle_use_template(ancestor: ast.AST, node: ast.AST, track_ids): return None, None + def walk(tree: ast.AST): track_ids = {} for ancestor in ast.walk(tree): @@ -139,13 +162,14 @@ def walk(tree: ast.AST): else: tid = ancestor.targets[0].id track_ids[tid] = (value, []) - else: + else: key, keywords = handle_use_template(ancestor, child, track_ids) if key and keywords and key in track_ids: track_ids[key][1].append([x.arg for x in keywords]) return track_ids + def run_checks(text_files, python_files): tfiles = [] pyfiles = [] @@ -155,14 +179,14 @@ def run_checks(text_files, python_files): for filepath in text_files: tfiles.append(check_file(filepath, pyfiles)) - + spelling_errors = False variable_errors = False for tfile in tfiles: if tfile.typos: logger.warning(f"Possible typos in {tfile.filename}: {tfile.typos}") spelling_errors = True - + if tfile.extra_variables or tfile.missing_variables: variable_errors = True @@ -171,17 +195,28 @@ def run_checks(text_files, python_files): if not variable_errors: logger.info("No variable errors found.") - + def run(): """Run check command.""" # All the logic of argparse goes in this function - parser = argparse.ArgumentParser(description='Basic prompt checking.') - parser.add_argument('directory', type=str, help='the name of the directory to search for prompts', default=".") - parser.add_argument('--ignore', type=str, help='a directory or file to exclude', default=None, required=False) + parser = argparse.ArgumentParser(description="Basic prompt checking.") + parser.add_argument( + "directory", + type=str, + help="the name of the directory to search for prompts", + default=".", + ) + parser.add_argument( + "--ignore", + type=str, + help="a directory or file to exclude", + default=None, + required=False, + ) args = parser.parse_args() prompt_files = gather(args.directory, args.ignore, ext=".txt") prompt_files += gather(args.directory, args.ignore, ext=".prompt") python_files = gather(args.directory, args.ignore, ext=".py") - run_checks(prompt_files, python_files) \ No newline at end of file + run_checks(prompt_files, python_files)