File: //proc/1233/root/usr/local/lib/python3.10/dist-packages/langchain/agents/mrkl/base.py
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf."""
from __future__ import annotations
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
from langchain_core._api import deprecated
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import BaseTool, Tool
from langchain_core.tools.render import render_text_description
from pydantic import Field
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.output_parser import MRKLOutputParser
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.utils import validate_tools_single_input
from langchain.chains import LLMChain
class ChainConfig(NamedTuple):
"""Configuration for a chain to use in MRKL system.
Parameters:
action_name: Name of the action.
action: Action function to call.
action_description: Description of the action.
"""
action_name: str
action: Callable
action_description: str
@deprecated(
"0.1.0",
message=AGENT_DEPRECATION_WARNING,
removal="1.0",
)
class ZeroShotAgent(Agent):
"""Agent for the MRKL chain.
Parameters:
output_parser: Output parser for the agent.
"""
output_parser: AgentOutputParser = Field(default_factory=MRKLOutputParser)
@classmethod
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return MRKLOutputParser()
@property
def _agent_type(self) -> str:
"""Return Identifier of agent type."""
return AgentType.ZERO_SHOT_REACT_DESCRIPTION
@property
def observation_prefix(self) -> str:
"""Prefix to append the observation with.
Returns:
"Observation: "
"""
return "Observation: "
@property
def llm_prefix(self) -> str:
"""Prefix to append the llm call with.
Returns:
"Thought: "
"""
return "Thought:"
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools. Defaults to PREFIX.
suffix: String to put after the list of tools. Defaults to SUFFIX.
format_instructions: Instructions on how to use the tools.
Defaults to FORMAT_INSTRUCTIONS
input_variables: List of input variables the final prompt will expect.
Defaults to None.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
tool_strings = render_text_description(list(tools))
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables:
return PromptTemplate(template=template, input_variables=input_variables)
return PromptTemplate.from_template(template)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools.
Args:
llm: The LLM to use as the agent LLM.
tools: The tools to use.
callback_manager: The callback manager to use. Defaults to None.
output_parser: The output parser to use. Defaults to None.
prefix: The prefix to use. Defaults to PREFIX.
suffix: The suffix to use. Defaults to SUFFIX.
format_instructions: The format instructions to use.
Defaults to FORMAT_INSTRUCTIONS.
input_variables: The input variables to use. Defaults to None.
kwargs: Additional parameters to pass to the agent.
"""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain( # type: ignore[misc]
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser or cls._get_default_output_parser()
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools)
if len(tools) == 0:
raise ValueError(
f"Got no tools for {cls.__name__}. At least one tool must be provided."
)
for tool in tools:
if tool.description is None:
raise ValueError(
f"Got a tool {tool.name} without a description. For this agent, "
f"a description must always be provided."
)
super()._validate_tools(tools)
@deprecated(
"0.1.0",
message=AGENT_DEPRECATION_WARNING,
removal="1.0",
)
class MRKLChain(AgentExecutor):
"""Chain that implements the MRKL system."""
@classmethod
def from_chains(
cls, llm: BaseLanguageModel, chains: List[ChainConfig], **kwargs: Any
) -> AgentExecutor:
"""User-friendly way to initialize the MRKL chain.
This is intended to be an easy way to get up and running with the
MRKL chain.
Args:
llm: The LLM to use as the agent LLM.
chains: The chains the MRKL system has access to.
**kwargs: parameters to be passed to initialization.
Returns:
An initialized MRKL chain.
"""
tools = [
Tool(
name=c.action_name,
func=c.action,
description=c.action_description,
)
for c in chains
]
agent = ZeroShotAgent.from_llm_and_tools(llm, tools)
return cls(agent=agent, tools=tools, **kwargs)