Skip to content

Commit

Permalink
langchain: upgrade mypy (langchain-ai#19163)
Browse files Browse the repository at this point in the history
Update mypy in langchain
  • Loading branch information
eyurtsev authored Mar 15, 2024
1 parent aa785fa commit 745d247
Show file tree
Hide file tree
Showing 40 changed files with 205 additions and 243 deletions.
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def create_prompt(
]
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type]

@classmethod
def from_llm_and_tools(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def create_prompt(
HumanMessagePromptTemplate.from_template(final_prompt),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type]

def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/load_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def _get_eleven_labs_text2speech(**kwargs: Any) -> BaseTool:


def _get_memorize(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
return Memorize(llm=llm)
return Memorize(llm=llm) # type: ignore[arg-type]


def _get_google_cloud_texttospeech(**kwargs: Any) -> BaseTool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def create_prompt(
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
return ChatPromptTemplate(messages=messages)
return ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg]

@classmethod
def from_llm_and_tools(
Expand All @@ -220,7 +220,7 @@ def from_llm_and_tools(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
return cls( # type: ignore[call-arg]
llm=llm,
prompt=prompt,
tools=tools,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def create_prompt(
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
return ChatPromptTemplate(messages=messages)
return ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg]

@classmethod
def from_llm_and_tools(
Expand All @@ -298,7 +298,7 @@ def from_llm_and_tools(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
return cls( # type: ignore[call-arg]
llm=llm,
prompt=prompt,
tools=tools,
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/structured_chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def create_prompt(
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type]

@classmethod
def from_llm_and_tools(
Expand Down
5 changes: 3 additions & 2 deletions libs/langchain/langchain/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,11 @@ def from_llm(
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
)

qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs)
qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type]

cypher_generation_chain = LLMChain(
llm=cypher_llm or llm, **use_cypher_llm_kwargs
llm=cypher_llm or llm, # type: ignore[arg-type]
**use_cypher_llm_kwargs, # type: ignore[arg-type]
)

if exclude_types and include_types:
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/graph_qa/neptune_sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def from_llm(
)
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt)

return cls(
return cls( # type: ignore[call-arg]
qa_chain=qa_chain,
sparql_generation_chain=sparql_generation_chain,
examples=examples,
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/llm_checker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _load_question_to_checked_assertions_chain(
revised_answer_chain,
]
question_to_checked_assertions_chain = SequentialChain(
chains=chains,
chains=chains, # type: ignore[arg-type]
input_variables=["question"],
output_variables=["revised_statement"],
verbose=True,
Expand Down
46 changes: 24 additions & 22 deletions libs/langchain/langchain/chains/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def _load_hyde_chain(config: dict, **kwargs: Any) -> HypotheticalDocumentEmbedde
else:
raise ValueError("`embeddings` must be present.")
return HypotheticalDocumentEmbedder(
llm_chain=llm_chain, base_embeddings=embeddings, **config
llm_chain=llm_chain, # type: ignore[arg-type]
base_embeddings=embeddings,
**config, # type: ignore[arg-type]
)


Expand Down Expand Up @@ -125,7 +127,7 @@ def _load_map_reduce_documents_chain(

return MapReduceDocumentsChain(
llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain,
reduce_documents_chain=reduce_documents_chain, # type: ignore[arg-type]
**config,
)

Expand Down Expand Up @@ -207,7 +209,7 @@ def _load_llm_bash_chain(config: dict, **kwargs: Any) -> Any:
elif "prompt_path" in config:
prompt = load_prompt(config.pop("prompt_path"))
if llm_chain:
return LLMBashChain(llm_chain=llm_chain, prompt=prompt, **config)
return LLMBashChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
else:
return LLMBashChain(llm=llm, prompt=prompt, **config)

Expand Down Expand Up @@ -250,10 +252,10 @@ def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain:
revised_answer_prompt = load_prompt(config.pop("revised_answer_prompt_path"))
return LLMCheckerChain(
llm=llm,
create_draft_answer_prompt=create_draft_answer_prompt,
list_assertions_prompt=list_assertions_prompt,
check_assertions_prompt=check_assertions_prompt,
revised_answer_prompt=revised_answer_prompt,
create_draft_answer_prompt=create_draft_answer_prompt, # type: ignore[arg-type]
list_assertions_prompt=list_assertions_prompt, # type: ignore[arg-type]
check_assertions_prompt=check_assertions_prompt, # type: ignore[arg-type]
revised_answer_prompt=revised_answer_prompt, # type: ignore[arg-type]
**config,
)

Expand Down Expand Up @@ -281,7 +283,7 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
elif "prompt_path" in config:
prompt = load_prompt(config.pop("prompt_path"))
if llm_chain:
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config)
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
else:
return LLMMathChain(llm=llm, prompt=prompt, **config)

Expand All @@ -296,7 +298,7 @@ def _load_map_rerank_documents_chain(
llm_chain = load_chain(config.pop("llm_chain_path"))
else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
return MapRerankDocumentsChain(llm_chain=llm_chain, **config)
return MapRerankDocumentsChain(llm_chain=llm_chain, **config) # type: ignore[arg-type]


def _load_pal_chain(config: dict, **kwargs: Any) -> Any:
Expand All @@ -309,7 +311,7 @@ def _load_pal_chain(config: dict, **kwargs: Any) -> Any:
llm_chain = load_chain(config.pop("llm_chain_path"))
else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
return PALChain(llm_chain=llm_chain, **config)
return PALChain(llm_chain=llm_chain, **config) # type: ignore[arg-type]


def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocumentsChain:
Expand Down Expand Up @@ -337,8 +339,8 @@ def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocuments
elif "document_prompt_path" in config:
document_prompt = load_prompt(config.pop("document_prompt_path"))
return RefineDocumentsChain(
initial_llm_chain=initial_llm_chain,
refine_llm_chain=refine_llm_chain,
initial_llm_chain=initial_llm_chain, # type: ignore[arg-type]
refine_llm_chain=refine_llm_chain, # type: ignore[arg-type]
document_prompt=document_prompt,
**config,
)
Expand All @@ -355,7 +357,7 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha
"One of `combine_documents_chain` or "
"`combine_documents_chain_path` must be present."
)
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config)
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config) # type: ignore[arg-type]


def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
Expand All @@ -368,7 +370,7 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain")
chain = load_chain_from_config(llm_chain_config)
return SQLDatabaseChain(llm_chain=chain, database=database, **config)
return SQLDatabaseChain(llm_chain=chain, database=database, **config) # type: ignore[arg-type]
if "llm" in config:
llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config)
Expand Down Expand Up @@ -403,7 +405,7 @@ def _load_vector_db_qa_with_sources_chain(
"`combine_documents_chain_path` must be present."
)
return VectorDBQAWithSourcesChain(
combine_documents_chain=combine_documents_chain,
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
vectorstore=vectorstore,
**config,
)
Expand All @@ -425,7 +427,7 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA:
"`combine_documents_chain_path` must be present."
)
return RetrievalQA(
combine_documents_chain=combine_documents_chain,
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
retriever=retriever,
**config,
)
Expand All @@ -449,7 +451,7 @@ def _load_retrieval_qa_with_sources_chain(
"`combine_documents_chain_path` must be present."
)
return RetrievalQAWithSourcesChain(
combine_documents_chain=combine_documents_chain,
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
retriever=retriever,
**config,
)
Expand All @@ -471,7 +473,7 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
"`combine_documents_chain_path` must be present."
)
return VectorDBQA(
combine_documents_chain=combine_documents_chain,
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
vectorstore=vectorstore,
**config,
)
Expand All @@ -495,8 +497,8 @@ def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain:

return GraphCypherQAChain(
graph=graph,
cypher_generation_chain=cypher_generation_chain,
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain, # type: ignore[arg-type]
qa_chain=qa_chain, # type: ignore[arg-type]
**config,
)

Expand Down Expand Up @@ -525,8 +527,8 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
else:
raise ValueError("`requests_wrapper` must be present.")
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
api_request_chain=api_request_chain, # type: ignore[arg-type]
api_answer_chain=api_answer_chain, # type: ignore[arg-type]
requests_wrapper=requests_wrapper,
**config,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
)
),
]
prompt = ChatPromptTemplate(messages=messages)
prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg]

chain = LLMChain(
llm=llm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def create_qa_with_structure_chain(
HumanMessagePromptTemplate.from_template("Question: {question}"),
HumanMessage(content="Tips: Make sure to answer in the correct format"),
]
prompt = prompt or ChatPromptTemplate(messages=messages)
prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg]

chain = LLMChain(
llm=llm,
Expand Down
22 changes: 11 additions & 11 deletions libs/langchain/langchain/chains/qa_with_sources/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def _load_stuff_chain(
verbose: Optional[bool] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type]
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
document_prompt=document_prompt,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
**kwargs,
)

Expand All @@ -83,14 +83,14 @@ def _load_map_reduce_chain(
token_max: int = 3000,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) # type: ignore[arg-type]
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
)
if collapse_prompt is None:
collapse_chain = None
Expand All @@ -105,7 +105,7 @@ def _load_map_reduce_chain(
llm_chain=LLMChain(
llm=_collapse_llm,
prompt=collapse_prompt,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
),
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
Expand All @@ -114,13 +114,13 @@ def _load_map_reduce_chain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
token_max=token_max,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
**kwargs,
)

Expand All @@ -136,16 +136,16 @@ def _load_refine_chain(
verbose: Optional[bool] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
_refine_llm = refine_llm or llm
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type]
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name=document_variable_name,
initial_response_name=initial_response_name,
document_prompt=document_prompt,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
**kwargs,
)

Expand Down
Loading

0 comments on commit 745d247

Please sign in to comment.