Skip to content

Commit

Permalink
Update llama.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed Apr 11, 2023
1 parent 213cc5c commit 9f1e565
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
47 changes: 39 additions & 8 deletions llama_cpp/llama_cpp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import sys
import os
import ctypes
from ctypes import c_int, c_float, c_char_p, c_void_p, c_bool, POINTER, Structure, Array, c_uint8, c_size_t
from ctypes import (
c_int,
c_float,
c_char_p,
c_void_p,
c_bool,
POINTER,
Structure,
Array,
c_uint8,
c_size_t,
)
import pathlib


# Load the library
def _load_shared_library(lib_base_name):
# Determine the file extension based on the platform
Expand All @@ -22,10 +34,10 @@ def _load_shared_library(lib_base_name):
# for llamacpp) and "llama" (default name for this repo)
_lib_paths = [
_base_path / f"lib{lib_base_name}{lib_ext}",
_base_path / f"{lib_base_name}{lib_ext}"
_base_path / f"{lib_base_name}{lib_ext}",
]

if ("LLAMA_CPP_LIB" in os.environ):
if "LLAMA_CPP_LIB" in os.environ:
lib_base_name = os.environ["LLAMA_CPP_LIB"]
_lib = pathlib.Path(lib_base_name)
_base_path = _lib.parent.resolve()
Expand All @@ -43,7 +55,10 @@ def _load_shared_library(lib_base_name):
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")

raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found")
raise FileNotFoundError(
f"Shared library with base name '{lib_base_name}' not found"
)


# Specify the base name of the shared library to load
_lib_base_name = "llama"
Expand Down Expand Up @@ -95,6 +110,10 @@ class llama_context_params(Structure):

llama_context_params_p = POINTER(llama_context_params)

LLAMA_FTYPE_ALL_F32 = ctypes.c_int(0)
LLAMA_FTYPE_MOSTLY_F16 = ctypes.c_int(1) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes.c_int(2) # except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes.c_int(3) # except 1d tensors

# Functions

Expand All @@ -106,18 +125,23 @@ def llama_context_default_params() -> llama_context_params:
_lib.llama_context_default_params.argtypes = []
_lib.llama_context_default_params.restype = llama_context_params


def llama_mmap_supported() -> c_bool:
return _lib.llama_mmap_supported()


_lib.llama_mmap_supported.argtypes = []
_lib.llama_mmap_supported.restype = c_bool


def llama_mlock_supported() -> c_bool:
return _lib.llama_mlock_supported()


_lib.llama_mlock_supported.argtypes = []
_lib.llama_mlock_supported.restype = c_bool


# Various functions for loading a ggml llama model.
# Allocate (almost) all memory needed for the model.
# Return NULL on failure
Expand All @@ -142,42 +166,49 @@ def llama_free(ctx: llama_context_p):

# TODO: not great API - very likely to change
# Returns 0 on success
def llama_model_quantize(
fname_inp: bytes, fname_out: bytes, itype: c_int
) -> c_int:
def llama_model_quantize(fname_inp: bytes, fname_out: bytes, itype: c_int) -> c_int:
return _lib.llama_model_quantize(fname_inp, fname_out, itype)


_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int]
_lib.llama_model_quantize.restype = c_int


# Returns the KV cache that will contain the context for the
# ongoing prediction with the model.
def llama_get_kv_cache(ctx: llama_context_p):
return _lib.llama_get_kv_cache(ctx)


_lib.llama_get_kv_cache.argtypes = [llama_context_p]
_lib.llama_get_kv_cache.restype = POINTER(c_uint8)


# Returns the size of the KV cache
def llama_get_kv_cache_size(ctx: llama_context_p) -> c_size_t:
return _lib.llama_get_kv_cache_size(ctx)


_lib.llama_get_kv_cache_size.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_size.restype = c_size_t


# Returns the number of tokens in the KV cache
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int:
return _lib.llama_get_kv_cache_token_count(ctx)


_lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_token_count.restype = c_int


# Sets the KV cache containing the current context for the model
def llama_set_kv_cache(ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int):
def llama_set_kv_cache(
ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int
):
return _lib.llama_set_kv_cache(ctx, kv_cache, n_size, n_token_count)


_lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, c_int]
_lib.llama_set_kv_cache.restype = None

Expand Down
2 changes: 1 addition & 1 deletion vendor/llama.cpp

0 comments on commit 9f1e565

Please sign in to comment.