Skip to content

Commit

Permalink
Update Eval Script (langchain-ai#164)
Browse files Browse the repository at this point in the history
Add a branch so everything is encapsulated in the runnable.
Expose single "get_chain()" function to make it easier to evaluate whatever is in main.py
  • Loading branch information
hinthornw authored Sep 26, 2023
1 parent f01d22b commit 0700eff
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 22 deletions.
63 changes: 63 additions & 0 deletions _scripts/evaluate_chat_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# TODO: Consolidate all these scripts into a single script
# This is ugly
import argparse
import functools
import json

from langchain import load as langchain_load
from langchain.chat_models import ChatAnthropic, ChatOpenAI
from langchain.smith import RunEvalConfig
from langsmith import Client, RunEvaluator
from langsmith.evaluation.evaluator import EvaluationResult
from langsmith.schemas import Example, Run

# Ugly. Requires PYTHONATH=$(PWD) to run
from main import create_chain, get_retriever

_PROVIDER_MAP = {
"openai": ChatOpenAI,
"anthropic": ChatAnthropic,
}

_MODEL_MAP = {
"openai": "gpt-3.5-turbo-16k",
"anthropic": "claude-2",
}


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset-name", default="Chat LangChain Complex Questions")
parser.add_argument("--model-provider", default="openai")
args = parser.parse_args()
client = Client()
# Check dataset exists
ds = client.read_dataset(dataset_name=args.dataset_name)
retriever = get_retriever()
llm = _PROVIDER_MAP[args.model_provider](
model=_MODEL_MAP[args.model_provider], temperature=0
)

# In app, we always pass in a chat history, but for evaluation we don't
# necessarily do that. Add that handling here.
def construct_eval_chain():
chain = create_chain(
retriever=retriever,
llm=llm,
)
return {
"question": lambda x: x["question"],
"chat_history": (lambda x: x.get("chat_history", [])),
} | chain

eval_config = RunEvalConfig(
evaluators=["qa"],
prediction_key="output",
)
results = client.run_on_dataset(
dataset_name=args.dataset_name,
llm_or_chain_factory=construct_eval_chain,
evaluation=eval_config,
tags=["simple_chain"],
verbose=True,
)
51 changes: 29 additions & 22 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from langchain.vectorstores import Weaviate
from langsmith import Client
from pydantic import BaseModel
from langchain.schema.runnable import RunnableBranch

from constants import WEAVIATE_DOCS_INDEX_NAME

Expand Down Expand Up @@ -73,27 +74,27 @@ def get_retriever():
return weaviate_client.as_retriever(search_kwargs=dict(k=6))


def create_retriever_chain(chat_history, llm, retriever: BaseRetriever):
def create_retriever_chain(llm, retriever: BaseRetriever):
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE)

if chat_history:
condense_question_chain = (
{
"question": itemgetter("question"),
"chat_history": itemgetter("chat_history"),
}
| CONDENSE_QUESTION_PROMPT
| llm
| StrOutputParser()
).with_config(
{
"run_name": "CondenseQuestion",
}
)
retriever_chain = condense_question_chain | retriever
else:
retriever_chain = (itemgetter("question")) | retriever
return retriever_chain
initial_chain = (itemgetter("question")) | retriever
condense_question_chain = (
{
"question": itemgetter("question"),
"chat_history": itemgetter("chat_history"),
}
| CONDENSE_QUESTION_PROMPT
| llm
| StrOutputParser()
).with_config(
run_name="CondenseQuestion",
)
conversation_chain = condense_question_chain | retriever

return RunnableBranch(
((lambda x: "chat_history" in x), conversation_chain),
initial_chain,
)


def format_docs(docs, max_tokens=200):
Expand All @@ -106,8 +107,11 @@ def format_docs(docs, max_tokens=200):

def create_chain(
llm,
retriever_chain,
retriever,
) -> Runnable:
retriever_chain = create_retriever_chain(llm, retriever).with_config(
run_name="FindDocs"
)
_context = RunnableMap(
{
"context": retriever_chain | format_docs,
Expand Down Expand Up @@ -189,8 +193,11 @@ async def chat_endpoint(request: ChatRequest):
streaming=True,
temperature=0,
)
docs_chain = create_retriever_chain(converted_chat_history, llm, get_retriever())
answer_chain = create_chain(llm, docs_chain.with_config(run_name="FindDocs"))
retriever = get_retriever()
answer_chain = create_chain(
llm,
retriever,
)
stream = answer_chain.astream_log(
{
"question": question,
Expand Down

0 comments on commit 0700eff

Please sign in to comment.