Custom LLM Agent (with a ChatModel)
This notebook goes through how to create your own custom agent based on a chat model.
An LLM chat agent consists of three parts:
- PromptTemplate: This is the prompt template that can be used to instruct the language model on what to do
- ChatModel: This is the language model that powers the agent
stop
sequence: Instructs the LLM to stop generating as soon as this string is found- OutputParser: This determines how to parse the LLMOutput into an AgentAction or AgentFinish object
The LLMAgent is used in an AgentExecutor. This AgentExecutor can largely be thought of as a loop that:
- Passes user input and any previous steps to the Agent (in this case, the LLMAgent)
- If the Agent returns an
AgentFinish
, then return that directly to the user - If the Agent returns an
AgentAction
, then use that to call a tool and get anObservation
- Repeat, passing the
AgentAction
andObservation
back to the Agent until anAgentFinish
is emitted.
AgentAction
is a response that consists of action
and action_input
. action
refers to which tool to use, and action_input
refers to the input to that tool. log
can also be provided as more context (that can be used for logging, tracing, etc).
AgentFinish
is a response that contains the final message to be sent back to the user. This should be used to end an agent run.
With LCEL
import { AgentExecutor } from "langchain/agents";
import { formatLogToString } from "langchain/agents/format_scratchpad/log";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { PromptTemplate } from "langchain/prompts";
import {
AgentAction,
AgentFinish,
AgentStep,
BaseMessage,
HumanMessage,
InputValues,
} from "langchain/schema";
import { RunnableSequence } from "langchain/schema/runnable";
import { SerpAPI } from "langchain/tools";
import { Calculator } from "langchain/tools/calculator";
/**
* Instantiate the chat model and bind the stop token
* @important The stop token must be set, if not the LLM will happily continue generating text forever.
*/
const model = new ChatOpenAI({ temperature: 0 }).bind({
stop: ["\nObservation"],
});
/** Define the tools */
const tools = [
new SerpAPI(process.env.SERPAPI_API_KEY, {
location: "Austin,Texas,United States",
hl: "en",
gl: "us",
}),
new Calculator(),
];
/** Create the prefix prompt */
const PREFIX = `Answer the following questions as best you can. You have access to the following tools:
{tools}`;
/** Create the tool instructions prompt */
const TOOL_INSTRUCTIONS_TEMPLATE = `Use the following format in your response:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question`;
/** Create the suffix prompt */
const SUFFIX = `Begin!
Question: {input}
Thought:`;
async function formatMessages(
values: InputValues
): Promise<Array<BaseMessage>> {
/** Check input and intermediate steps are both inside values */
if (!("input" in values) || !("intermediate_steps" in values)) {
throw new Error("Missing input or agent_scratchpad from values.");
}
/** Extract and case the intermediateSteps from values as Array<AgentStep> or an empty array if none are passed */
const intermediateSteps = values.intermediate_steps
? (values.intermediate_steps as Array<AgentStep>)
: [];
/** Call the helper `formatLogToString` which returns the steps as a string */
const agentScratchpad = formatLogToString(intermediateSteps);
/** Construct the tool strings */
const toolStrings = tools
.map((tool) => `${tool.name}: ${tool.description}`)
.join("\n");
const toolNames = tools.map((tool) => tool.name).join(",\n");
/** Create templates and format the instructions and suffix prompts */
const prefixTemplate = new PromptTemplate({
template: PREFIX,
inputVariables: ["tools"],
});
const instructionsTemplate = new PromptTemplate({
template: TOOL_INSTRUCTIONS_TEMPLATE,
inputVariables: ["tool_names"],
});
const suffixTemplate = new PromptTemplate({
template: SUFFIX,
inputVariables: ["input"],
});
/** Format both templates by passing in the input variables */
const formattedPrefix = await prefixTemplate.format({
tools: toolStrings,
});
const formattedInstructions = await instructionsTemplate.format({
tool_names: toolNames,
});
const formattedSuffix = await suffixTemplate.format({
input: values.input,
});
/** Construct the final prompt string */
const formatted = [
formattedPrefix,
formattedInstructions,
formattedSuffix,
agentScratchpad,
].join("\n");
/** Return the message as a HumanMessage. */
return [new HumanMessage(formatted)];
}
/** Define the custom output parser */
function customOutputParser(message: BaseMessage): AgentAction | AgentFinish {
const text = message.content;
if (typeof text !== "string") {
throw new Error(
`Message content is not a string. Received: ${JSON.stringify(
text,
null,
2
)}`
);
}
/** If the input includes "Final Answer" return as an instance of `AgentFinish` */
if (text.includes("Final Answer:")) {
const parts = text.split("Final Answer:");
const input = parts[parts.length - 1].trim();
const finalAnswers = { output: input };
return { log: text, returnValues: finalAnswers };
}
/** Use RegEx to extract any actions and their values */
const match = /Action: (.*)\nAction Input: (.*)/s.exec(text);
if (!match) {
throw new Error(`Could not parse LLM output: ${text}`);
}
/** Return as an instance of `AgentAction` */
return {
tool: match[1].trim(),
toolInput: match[2].trim().replace(/^"+|"+$/g, ""),
log: text,
};
}
/** Define the Runnable with LCEL */
const runnable = RunnableSequence.from([
{
input: (values: InputValues) => values.input,
intermediate_steps: (values: InputValues) => values.steps,
},
formatMessages,
model,
customOutputParser,
]);
/** Pass the runnable to the `AgentExecutor` class as the agent */
const executor = new AgentExecutor({
agent: runnable,
tools,
});
console.log("Loaded agent.");
const input = `Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?`;
console.log(`Executing with input "${input}"...`);
const result = await executor.invoke({ input });
console.log(`Got output ${result.output}`);
/**
* Got output Harry Styles' current age raised to the 0.23 power is approximately 2.1156502324195268.
*/
API Reference:
- AgentExecutor from
langchain/agents
- formatLogToString from
langchain/agents/format_scratchpad/log
- ChatOpenAI from
langchain/chat_models/openai
- PromptTemplate from
langchain/prompts
- AgentAction from
langchain/schema
- AgentFinish from
langchain/schema
- AgentStep from
langchain/schema
- BaseMessage from
langchain/schema
- HumanMessage from
langchain/schema
- InputValues from
langchain/schema
- RunnableSequence from
langchain/schema/runnable
- SerpAPI from
langchain/tools
- Calculator from
langchain/tools/calculator
With LLMChain
import {
AgentActionOutputParser,
AgentExecutor,
LLMSingleActionAgent,
} from "langchain/agents";
import { LLMChain } from "langchain/chains";
import { ChatOpenAI } from "langchain/chat_models/openai";
import {
BaseChatPromptTemplate,
SerializedBasePromptTemplate,
renderTemplate,
} from "langchain/prompts";
import {
AgentAction,
AgentFinish,
AgentStep,
BaseMessage,
HumanMessage,
InputValues,
PartialValues,
} from "langchain/schema";
import { SerpAPI, Tool } from "langchain/tools";
import { Calculator } from "langchain/tools/calculator";
const PREFIX = `Answer the following questions as best you can. You have access to the following tools:`;
const formatInstructions = (
toolNames: string
) => `Use the following format in your response:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [${toolNames}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question`;
const SUFFIX = `Begin!
Question: {input}
Thought:{agent_scratchpad}`;
class CustomPromptTemplate extends BaseChatPromptTemplate {
tools: Tool[];
constructor(args: { tools: Tool[]; inputVariables: string[] }) {
super({ inputVariables: args.inputVariables });
this.tools = args.tools;
}
_getPromptType(): string {
return "chat";
}
async formatMessages(values: InputValues): Promise<BaseMessage[]> {
/** Construct the final template */
const toolStrings = this.tools
.map((tool) => `${tool.name}: ${tool.description}`)
.join("\n");
const toolNames = this.tools.map((tool) => tool.name).join("\n");
const instructions = formatInstructions(toolNames);
const template = [PREFIX, toolStrings, instructions, SUFFIX].join("\n\n");
/** Construct the agent_scratchpad */
const intermediateSteps = values.intermediate_steps as AgentStep[];
const agentScratchpad = intermediateSteps.reduce(
(thoughts, { action, observation }) =>
thoughts +
[action.log, `\nObservation: ${observation}`, "Thought:"].join("\n"),
""
);
const newInput = { agent_scratchpad: agentScratchpad, ...values };
/** Format the template. */
const formatted = renderTemplate(template, "f-string", newInput);
return [new HumanMessage(formatted)];
}
partial(_values: PartialValues): Promise<BaseChatPromptTemplate> {
throw new Error("Not implemented");
}
serialize(): SerializedBasePromptTemplate {
throw new Error("Not implemented");
}
}
class CustomOutputParser extends AgentActionOutputParser {
lc_namespace = ["langchain", "agents", "custom_llm_agent_chat"];
async parse(text: string): Promise<AgentAction | AgentFinish> {
if (text.includes("Final Answer:")) {
const parts = text.split("Final Answer:");
const input = parts[parts.length - 1].trim();
const finalAnswers = { output: input };
return { log: text, returnValues: finalAnswers };
}
const match = /Action: (.*)\nAction Input: (.*)/s.exec(text);
if (!match) {
throw new Error(`Could not parse LLM output: ${text}`);
}
return {
tool: match[1].trim(),
toolInput: match[2].trim().replace(/^"+|"+$/g, ""),
log: text,
};
}
getFormatInstructions(): string {
throw new Error("Not implemented");
}
}
export const run = async () => {
const model = new ChatOpenAI({ temperature: 0 });
const tools = [
new SerpAPI(process.env.SERPAPI_API_KEY, {
location: "Austin,Texas,United States",
hl: "en",
gl: "us",
}),
new Calculator(),
];
const llmChain = new LLMChain({
prompt: new CustomPromptTemplate({
tools,
inputVariables: ["input", "agent_scratchpad"],
}),
llm: model,
});
const agent = new LLMSingleActionAgent({
llmChain,
outputParser: new CustomOutputParser(),
stop: ["\nObservation"],
});
const executor = new AgentExecutor({
agent,
tools,
});
console.log("Loaded agent.");
const input = `Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?`;
console.log(`Executing with input "${input}"...`);
const result = await executor.invoke({ input });
console.log(`Got output ${result.output}`);
};
run();
API Reference:
- AgentActionOutputParser from
langchain/agents
- AgentExecutor from
langchain/agents
- LLMSingleActionAgent from
langchain/agents
- LLMChain from
langchain/chains
- ChatOpenAI from
langchain/chat_models/openai
- BaseChatPromptTemplate from
langchain/prompts
- SerializedBasePromptTemplate from
langchain/prompts
- renderTemplate from
langchain/prompts
- AgentAction from
langchain/schema
- AgentFinish from
langchain/schema
- AgentStep from
langchain/schema
- BaseMessage from
langchain/schema
- HumanMessage from
langchain/schema
- InputValues from
langchain/schema
- PartialValues from
langchain/schema
- SerpAPI from
langchain/tools
- Tool from
langchain/tools
- Calculator from
langchain/tools/calculator