Skip to main content

Custom LLM Agent (with a ChatModel)

This notebook goes through how to create your own custom agent based on a chat model.

An LLM chat agent consists of three parts:

  • PromptTemplate: This is the prompt template that can be used to instruct the language model on what to do
  • ChatModel: This is the language model that powers the agent
  • stop sequence: Instructs the LLM to stop generating as soon as this string is found
  • OutputParser: This determines how to parse the LLMOutput into an AgentAction or AgentFinish object

The LLMAgent is used in an AgentExecutor. This AgentExecutor can largely be thought of as a loop that:

  1. Passes user input and any previous steps to the Agent (in this case, the LLMAgent)
  2. If the Agent returns an AgentFinish, then return that directly to the user
  3. If the Agent returns an AgentAction, then use that to call a tool and get an Observation
  4. Repeat, passing the AgentAction and Observation back to the Agent until an AgentFinish is emitted.

AgentAction is a response that consists of action and action_input. action refers to which tool to use, and action_input refers to the input to that tool. log can also be provided as more context (that can be used for logging, tracing, etc).

AgentFinish is a response that contains the final message to be sent back to the user. This should be used to end an agent run.

With LCEL

import { AgentExecutor } from "langchain/agents";
import { formatLogToString } from "langchain/agents/format_scratchpad/log";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { PromptTemplate } from "langchain/prompts";
import {
AgentAction,
AgentFinish,
AgentStep,
BaseMessage,
HumanMessage,
InputValues,
} from "langchain/schema";
import { RunnableSequence } from "langchain/schema/runnable";
import { SerpAPI } from "langchain/tools";
import { Calculator } from "langchain/tools/calculator";

/**
* Instantiate the chat model and bind the stop token
* @important The stop token must be set, if not the LLM will happily continue generating text forever.
*/
const model = new ChatOpenAI({ temperature: 0 }).bind({
stop: ["\nObservation"],
});
/** Define the tools */
const tools = [
new SerpAPI(process.env.SERPAPI_API_KEY, {
location: "Austin,Texas,United States",
hl: "en",
gl: "us",
}),
new Calculator(),
];
/** Create the prefix prompt */
const PREFIX = `Answer the following questions as best you can. You have access to the following tools:
{tools}`;
/** Create the tool instructions prompt */
const TOOL_INSTRUCTIONS_TEMPLATE = `Use the following format in your response:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question`;
/** Create the suffix prompt */
const SUFFIX = `Begin!

Question: {input}
Thought:`;

async function formatMessages(
values: InputValues
): Promise<Array<BaseMessage>> {
/** Check input and intermediate steps are both inside values */
if (!("input" in values) || !("intermediate_steps" in values)) {
throw new Error("Missing input or agent_scratchpad from values.");
}
/** Extract and case the intermediateSteps from values as Array<AgentStep> or an empty array if none are passed */
const intermediateSteps = values.intermediate_steps
? (values.intermediate_steps as Array<AgentStep>)
: [];
/** Call the helper `formatLogToString` which returns the steps as a string */
const agentScratchpad = formatLogToString(intermediateSteps);
/** Construct the tool strings */
const toolStrings = tools
.map((tool) => `${tool.name}: ${tool.description}`)
.join("\n");
const toolNames = tools.map((tool) => tool.name).join(",\n");
/** Create templates and format the instructions and suffix prompts */
const prefixTemplate = new PromptTemplate({
template: PREFIX,
inputVariables: ["tools"],
});
const instructionsTemplate = new PromptTemplate({
template: TOOL_INSTRUCTIONS_TEMPLATE,
inputVariables: ["tool_names"],
});
const suffixTemplate = new PromptTemplate({
template: SUFFIX,
inputVariables: ["input"],
});
/** Format both templates by passing in the input variables */
const formattedPrefix = await prefixTemplate.format({
tools: toolStrings,
});
const formattedInstructions = await instructionsTemplate.format({
tool_names: toolNames,
});
const formattedSuffix = await suffixTemplate.format({
input: values.input,
});
/** Construct the final prompt string */
const formatted = [
formattedPrefix,
formattedInstructions,
formattedSuffix,
agentScratchpad,
].join("\n");
/** Return the message as a HumanMessage. */
return [new HumanMessage(formatted)];
}

/** Define the custom output parser */
function customOutputParser(message: BaseMessage): AgentAction | AgentFinish {
const text = message.content;
if (typeof text !== "string") {
throw new Error(
`Message content is not a string. Received: ${JSON.stringify(
text,
null,
2
)}`
);
}
/** If the input includes "Final Answer" return as an instance of `AgentFinish` */
if (text.includes("Final Answer:")) {
const parts = text.split("Final Answer:");
const input = parts[parts.length - 1].trim();
const finalAnswers = { output: input };
return { log: text, returnValues: finalAnswers };
}
/** Use RegEx to extract any actions and their values */
const match = /Action: (.*)\nAction Input: (.*)/s.exec(text);
if (!match) {
throw new Error(`Could not parse LLM output: ${text}`);
}
/** Return as an instance of `AgentAction` */
return {
tool: match[1].trim(),
toolInput: match[2].trim().replace(/^"+|"+$/g, ""),
log: text,
};
}

/** Define the Runnable with LCEL */
const runnable = RunnableSequence.from([
{
input: (values: InputValues) => values.input,
intermediate_steps: (values: InputValues) => values.steps,
},
formatMessages,
model,
customOutputParser,
]);
/** Pass the runnable to the `AgentExecutor` class as the agent */
const executor = new AgentExecutor({
agent: runnable,
tools,
});
console.log("Loaded agent.");

const input = `Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?`;

console.log(`Executing with input "${input}"...`);

const result = await executor.invoke({ input });

console.log(`Got output ${result.output}`);
/**
* Got output Harry Styles' current age raised to the 0.23 power is approximately 2.1156502324195268.
*/

API Reference:

With LLMChain

import {
AgentActionOutputParser,
AgentExecutor,
LLMSingleActionAgent,
} from "langchain/agents";
import { LLMChain } from "langchain/chains";
import { ChatOpenAI } from "langchain/chat_models/openai";
import {
BaseChatPromptTemplate,
SerializedBasePromptTemplate,
renderTemplate,
} from "langchain/prompts";
import {
AgentAction,
AgentFinish,
AgentStep,
BaseMessage,
HumanMessage,
InputValues,
PartialValues,
} from "langchain/schema";
import { SerpAPI, Tool } from "langchain/tools";
import { Calculator } from "langchain/tools/calculator";

const PREFIX = `Answer the following questions as best you can. You have access to the following tools:`;
const formatInstructions = (
toolNames: string
) => `Use the following format in your response:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [${toolNames}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question`;
const SUFFIX = `Begin!

Question: {input}
Thought:{agent_scratchpad}`;

class CustomPromptTemplate extends BaseChatPromptTemplate {
tools: Tool[];

constructor(args: { tools: Tool[]; inputVariables: string[] }) {
super({ inputVariables: args.inputVariables });
this.tools = args.tools;
}

_getPromptType(): string {
return "chat";
}

async formatMessages(values: InputValues): Promise<BaseMessage[]> {
/** Construct the final template */
const toolStrings = this.tools
.map((tool) => `${tool.name}: ${tool.description}`)
.join("\n");
const toolNames = this.tools.map((tool) => tool.name).join("\n");
const instructions = formatInstructions(toolNames);
const template = [PREFIX, toolStrings, instructions, SUFFIX].join("\n\n");
/** Construct the agent_scratchpad */
const intermediateSteps = values.intermediate_steps as AgentStep[];
const agentScratchpad = intermediateSteps.reduce(
(thoughts, { action, observation }) =>
thoughts +
[action.log, `\nObservation: ${observation}`, "Thought:"].join("\n"),
""
);
const newInput = { agent_scratchpad: agentScratchpad, ...values };
/** Format the template. */
const formatted = renderTemplate(template, "f-string", newInput);
return [new HumanMessage(formatted)];
}

partial(_values: PartialValues): Promise<BaseChatPromptTemplate> {
throw new Error("Not implemented");
}

serialize(): SerializedBasePromptTemplate {
throw new Error("Not implemented");
}
}

class CustomOutputParser extends AgentActionOutputParser {
lc_namespace = ["langchain", "agents", "custom_llm_agent_chat"];

async parse(text: string): Promise<AgentAction | AgentFinish> {
if (text.includes("Final Answer:")) {
const parts = text.split("Final Answer:");
const input = parts[parts.length - 1].trim();
const finalAnswers = { output: input };
return { log: text, returnValues: finalAnswers };
}

const match = /Action: (.*)\nAction Input: (.*)/s.exec(text);
if (!match) {
throw new Error(`Could not parse LLM output: ${text}`);
}

return {
tool: match[1].trim(),
toolInput: match[2].trim().replace(/^"+|"+$/g, ""),
log: text,
};
}

getFormatInstructions(): string {
throw new Error("Not implemented");
}
}

export const run = async () => {
const model = new ChatOpenAI({ temperature: 0 });
const tools = [
new SerpAPI(process.env.SERPAPI_API_KEY, {
location: "Austin,Texas,United States",
hl: "en",
gl: "us",
}),
new Calculator(),
];

const llmChain = new LLMChain({
prompt: new CustomPromptTemplate({
tools,
inputVariables: ["input", "agent_scratchpad"],
}),
llm: model,
});

const agent = new LLMSingleActionAgent({
llmChain,
outputParser: new CustomOutputParser(),
stop: ["\nObservation"],
});
const executor = new AgentExecutor({
agent,
tools,
});
console.log("Loaded agent.");

const input = `Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?`;

console.log(`Executing with input "${input}"...`);

const result = await executor.invoke({ input });

console.log(`Got output ${result.output}`);
};
run();

API Reference: