Skip to content

Commit

Permalink
Add claude support
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Nov 27, 2023
1 parent b1e730d commit bd5afa5
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 6 deletions.
2 changes: 2 additions & 0 deletions gradio_utils/prompt_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def get_avatars(base_model, model_path_llama, inference_server=''):
bot_avatar = "models/openai.png"
elif 'hugging' in base_model.lower():
bot_avatar = "models/hf-logo.png"
elif 'claude' in base_model.lower():
bot_avatar = "models/anthropic.jpeg"
else:
bot_avatar = "models/h2oai.png"
human_avatar = human_avatar if os.path.isfile(human_avatar) else None
Expand Down
Binary file added models/anthropic.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added models/anthropic.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions reqs_optional/requirements_optional_langchain.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ sentence_transformers==2.2.2
# optional: for OpenAI endpoint or embeddings (requires key)
openai==1.3.5
replicate==0.20.0
anthropic==0.7.4

# local vector db
chromadb==0.4.18
Expand Down
15 changes: 15 additions & 0 deletions src/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class PromptType(Enum):
open_chat = 44
open_chat_correct = 44
open_chat_code = 44
anthropic = 21


class DocumentSubset(Enum):
Expand Down Expand Up @@ -157,6 +158,20 @@ class LangChainAgent(Enum):
"code-cushman-001": 2048,
})

anthropic_mapping = {
"claude-2.1": 200000,
"claude-2": 100000,
"claude-2.0": 100000,
"claude-instant-1.2": 100000
}

anthropic_mapping_outputs = {
"claude-2.1": 4096,
"claude-2": 4096,
"claude-2.0": 4096,
"claude-instant-1.2": 4096,
}

openai_supports_functiontools = ["gpt-4-0613", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613",
"gpt-4-1106-preview", "gpt-35-turbo-1106"]

Expand Down
26 changes: 22 additions & 4 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
super_source_postfix, t5_type, get_langchain_prompts, gr_to_lg, invalid_key_msg, docs_joiner_default, \
docs_ordering_types_default, docs_token_handling_default, max_input_tokens_public, max_total_input_tokens_public, \
max_top_k_docs_public, max_top_k_docs_default, max_total_input_tokens_public_api, max_top_k_docs_public_api, \
max_input_tokens_public_api, model_token_mapping_outputs
max_input_tokens_public_api, model_token_mapping_outputs, anthropic_mapping, anthropic_mapping_outputs
from loaders import get_loaders
from utils import set_seed, clear_torch_cache, NullContext, wrapped_partial, EThread, get_githash, \
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, \
Expand Down Expand Up @@ -472,6 +472,10 @@ def main(
Use: "sagemaker_chat:<endpoint name>" for chat models that AWS sets up as dialog
Use: "sagemaker:<endpoint name>" for foundation models that AWS only text as inputs
Or Address can be for Anthropic Claude. Ensure key is set in env ANTHROPIC_API_KEY
Use: "anthropic:<model name>"
E.g. "anthropic:claude-2"
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
:param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
:param system_prompt: Universal system prompt to use if model supports, like LLaMa2, regardless of prompt_type definition.
Expand Down Expand Up @@ -2221,7 +2225,8 @@ def get_model(
inference_server.startswith('openai') or
inference_server.startswith('vllm') or
inference_server.startswith('replicate') or
inference_server.startswith('sagemaker')
inference_server.startswith('sagemaker') or
inference_server.startswith('anthropic')
):
max_output_len = None
if inference_server.startswith('openai'):
Expand All @@ -2236,6 +2241,18 @@ def get_model(
max_output_len = model_token_mapping_outputs[base_model]
else:
max_output_len = None
if inference_server.startswith('anthropic'):
assert os.getenv('ANTHROPIC_API_KEY'), "Set environment for ANTHROPIC_API_KEY"
# Don't return None, None for model, tokenizer so triggers
# include small token cushion
if base_model in anthropic_mapping:
max_seq_len = anthropic_mapping[base_model]
else:
raise ValueError("Invalid base_model=%s for inference_server=%s" % (base_model, inference_server))
if base_model in anthropic_mapping_outputs:
max_output_len = anthropic_mapping_outputs[base_model]
else:
max_output_len = None
if inference_server.startswith('replicate'):
assert len(inference_server.split(':')) >= 3, "Expected replicate:model string, got %s" % inference_server
assert os.getenv('REPLICATE_API_TOKEN'), "Set environment for REPLICATE_API_TOKEN"
Expand All @@ -2255,7 +2272,7 @@ def get_model(
assert os.getenv('AWS_SECRET_ACCESS_KEY'), "Set environment for AWS_SECRET_ACCESS_KEY"
# Don't return None, None for model, tokenizer so triggers
# include small token cushion
if inference_server.startswith('openai') or tokenizer is None:
if inference_server.startswith('openai') or tokenizer is None or inference_server.startswith('anthropic'):
# don't use fake (tiktoken) tokenizer for vLLM//replicate if know actual model with actual tokenizer
assert max_seq_len is not None, "Please pass --max_seq_len=<max_seq_len> for unknown or non-HF model %s" % base_model
tokenizer = FakeTokenizer(model_max_length=max_seq_len - 50, is_openai=True)
Expand Down Expand Up @@ -3078,7 +3095,8 @@ def evaluate(
inference_server.startswith('replicate') or \
inference_server.startswith('sagemaker') or \
inference_server.startswith('openai_azure_chat') or \
inference_server.startswith('openai_azure')
inference_server.startswith('openai_azure') or \
inference_server.startswith('anthropic')
do_langchain_path = langchain_mode not in [False, 'Disabled', 'LLM'] or \
langchain_only_model or \
force_langchain_evaluate or \
Expand Down
55 changes: 54 additions & 1 deletion src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ def get_token_ids(self, text: str) -> List[int]:
# return _get_token_ids_default_method(text)


from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
from langchain.llms import OpenAI, AzureOpenAI, Replicate


Expand Down Expand Up @@ -1261,6 +1261,37 @@ async def agenerate_prompt(
)


class H2OChatAnthropic(ChatAnthropic, ExtraChat):
system_prompt: Any = None
chat_conversation: Any = []

# max_new_tokens0: Any = None # FIXME: Doesn't seem to have same max_tokens == -1 for prompts==1

def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_messages = self.get_messages(prompts)
# prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)

async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_messages = self.get_messages(prompts)
# prompt_messages = [p.to_messages() for p in prompts]
return await self.agenerate(
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)


class H2OAzureOpenAI(AzureOpenAI):
max_new_tokens0: Any = None # FIXME: Doesn't seem to have same max_tokens == -1 for prompts==1

Expand Down Expand Up @@ -1522,6 +1553,28 @@ def get_llm(use_openai_model=False,
else:
# vllm goes here
prompt_type = prompt_type or 'plain'
elif inference_server.startswith('anthropic'):
cls = H2OChatAnthropic

# Langchain oddly passes some things directly and rest via model_kwargs
model_kwargs = dict()
kwargs_extra = {}
kwargs_extra.update(dict(system_prompt=system_prompt, chat_conversation=chat_conversation))

callbacks = [StreamingGradioCallbackHandler(max_time=max_time, verbose=verbose)]
llm = cls(model=model_name,
anthropic_api_key=os.getenv('ANTHROPIC_API_KEY'),
top_p=top_p if do_sample else 1,
top_k=top_k,
temperature=temperature if do_sample else 0,
callbacks=callbacks if stream_output else None,
streaming=stream_output,
default_request_timeout=max_time,
model_kwargs=model_kwargs,
**kwargs_extra
)
streamer = callbacks[0] if stream_output else None
prompt_type = inference_server
elif inference_server and inference_server.startswith('sagemaker'):
callbacks = [StreamingGradioCallbackHandler(max_time=max_time, verbose=verbose)] # FIXME
streamer = None
Expand Down
4 changes: 3 additions & 1 deletion src/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,9 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context,
humanstr = PreInstruct
botstr = PreResponse
elif prompt_type in [PromptType.openai_chat.value, str(PromptType.openai_chat.value),
PromptType.openai_chat.name]:
PromptType.openai_chat.name] or \
[PromptType.anthropic.value, str(PromptType.anthropic.value),
PromptType.anthropic.name]:
# prompting and termination all handled by endpoint
preprompt = """"""
start = ''
Expand Down

0 comments on commit bd5afa5

Please sign in to comment.