Created
April 28, 2023 06:47
-
-
Save didmar/112b5886993cea00a6d88c3df39771b6 to your computer and use it in GitHub Desktop.
IterativeAgentExecutor (LangChain)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Modified langchain.agents.AgentExecutor that runs one step at a time. (Based on langchain 0.0.147) | |
See the example usage in main() at the bottom of this file. | |
""" | |
import time | |
from typing import Dict, Any, List, Union, Tuple, Iterator | |
from langchain.agents import AgentExecutor, ZeroShotAgent | |
from langchain.input import get_color_mapping | |
from langchain.llms.fake import FakeListLLM | |
from langchain.schema import AgentAction, AgentFinish | |
class IterativeAgentExecutor(AgentExecutor): | |
""" | |
Modified AgentExecutor that runs one step at a time. | |
""" | |
def __call__( | |
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False | |
) -> Iterator[Union[Dict[str, str], List[Tuple[AgentAction, str]]]]: | |
"""Run the logic of this chain and add to output if desired. | |
Args: | |
inputs: Dictionary of inputs, or single input if chain expects | |
only one param. | |
return_only_outputs: boolean for whether to return only outputs in the | |
response. If True, only new keys generated by this chain will be | |
returned. If False, both input keys and new keys generated by this | |
chain will be returned. Defaults to False. | |
""" | |
inputs = self.prep_inputs(inputs) | |
self.callback_manager.on_chain_start( | |
{"name": self.__class__.__name__}, | |
inputs, | |
verbose=self.verbose, | |
) | |
outputs = None | |
try: | |
for step in self._call(inputs): | |
if isinstance(step, dict): | |
outputs = step | |
break | |
yield step | |
except (KeyboardInterrupt, Exception) as e: | |
self.callback_manager.on_chain_error(e, verbose=self.verbose) | |
raise e | |
if outputs is None: | |
raise ValueError( | |
"Chain did not return any outputs (dict)" | |
) | |
self.callback_manager.on_chain_end(outputs, verbose=self.verbose) | |
yield self.prep_outputs(inputs, outputs, return_only_outputs) | |
return | |
def _call(self, inputs: Dict[str, str]) -> Iterator[Union[Dict[str, Any], List[Tuple[AgentAction, str]]]]: | |
""" | |
Run text through and get agent response. | |
Yields the intermediate steps and returns the final output. | |
""" | |
# Construct a mapping of tool name to tool for easy lookup | |
name_to_tool_map = {tool.name: tool for tool in self.tools} | |
# We construct a mapping from each tool to a color, used for logging. | |
color_mapping = get_color_mapping( | |
[tool.name for tool in self.tools], excluded_colors=["green"] | |
) | |
intermediate_steps: List[Tuple[AgentAction, str]] = [] | |
# Let's start tracking the number of iterations and time elapsed | |
iterations = 0 | |
time_elapsed = 0.0 | |
start_time = time.time() | |
# We now enter the agent loop (until it returns something). | |
while self._should_continue(iterations, time_elapsed): | |
next_step_output = self._take_next_step( | |
name_to_tool_map, color_mapping, inputs, intermediate_steps | |
) | |
if isinstance(next_step_output, AgentFinish): | |
yield self._return(next_step_output, intermediate_steps) | |
return | |
else: | |
# Yield the output so that the caller can see what's going on | |
yield next_step_output | |
intermediate_steps.extend(next_step_output) | |
if len(next_step_output) == 1: | |
next_step_action = next_step_output[0] | |
# See if tool should return directly | |
tool_return = self._get_tool_return(next_step_action) | |
if tool_return is not None: | |
yield self._return(tool_return, intermediate_steps) | |
return | |
iterations += 1 | |
time_elapsed = time.time() - start_time | |
output = self.agent.return_stopped_response( | |
self.early_stopping_method, intermediate_steps, **inputs | |
) | |
yield self._return(output, intermediate_steps) | |
def main(): | |
""" | |
Example usage of IterativeAgentExecutor. | |
Output will be: | |
> Entering new IterativeAgentExecutor chain... | |
Action: A | |
Action Input: B | |
Observation: A is not a valid tool, try another one. | |
Thought: | |
STEP: [(AgentAction(tool='A', tool_input='B', log='Action: A\nAction Input: B'), 'A is not a valid tool, try another one.')] | |
Press Enter to continue... | |
Action: X | |
Action Input: Y | |
Observation: X is not a valid tool, try another one. | |
Thought: | |
STEP: [(AgentAction(tool='X', tool_input='Y', log='Action: X\nAction Input: Y'), 'X is not a valid tool, try another one.')] | |
Press Enter to continue... | |
Final Answer: Some stuff. | |
> Finished chain. | |
STEP: {'input': 'This is a test.', 'output': 'Some stuff.'} | |
Press Enter to continue... | |
""" | |
responses = [ | |
'Action: A\nAction Input: B', | |
'Action: X\nAction Input: Y', | |
"Final Answer: Some stuff.", | |
] | |
llm = FakeListLLM(responses=responses) | |
tools = [] | |
agent_obj = ZeroShotAgent.from_llm_and_tools( | |
llm, | |
tools, | |
) | |
agent = IterativeAgentExecutor.from_agent_and_tools( | |
agent=agent_obj, | |
tools=tools, | |
verbose=True, | |
) | |
query = "This is a test." | |
for x in agent({"input": query}): | |
print(f"\nSTEP: {x}") | |
input("\nPress Enter to continue...") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment