feat: server multi models support (#799)

This commit is contained in:
takatost
2023-08-12 00:57:00 +08:00
committed by GitHub
parent d8b712b325
commit 5fa2161b05
213 changed files with 10556 additions and 2579 deletions

View File

@@ -1,20 +1,17 @@
from typing import cast, List
from typing import List
from langchain import OpenAI
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import BaseMessage
from core.constant import llm_constant
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
class CalcTokenMixin:
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
llm = cast(ChatOpenAI, llm)
return llm.get_num_tokens_from_messages(messages)
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
return model_instance.get_num_tokens(to_prompt_messages(messages))
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
@@ -22,10 +19,9 @@ class CalcTokenMixin:
:param messages:
:return:
"""
llm = cast(ChatOpenAI, llm)
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
completion_max_tokens = llm.max_tokens
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
llm_max_tokens = model_instance.model_rules.max_tokens.max
completion_max_tokens = model_instance.model_kwargs.max_tokens
used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
return rest_tokens

View File

@@ -4,9 +4,11 @@ from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -14,6 +16,12 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""

View File

@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@@ -84,7 +85,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))

View File

@@ -3,20 +3,28 @@ from typing import cast, List
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from pydantic import BaseModel
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
model_instance: BaseLLM
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages

View File

@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFuncti
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
@@ -84,7 +85,7 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))

View File

@@ -0,0 +1,162 @@
import re
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
model_instance: BaseLLM
dataset_tools: Sequence[BaseTool]
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.dataset_tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
try:
return self.output_parser.parse(full_output)
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
unique_tool_names = set(tool.name for tool in tools)
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
dataset_tools=tools,
**kwargs,
)

View File

@@ -14,7 +14,7 @@ from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.model_providers.models.llm.base import BaseLLM
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
@@ -53,6 +53,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
@@ -89,7 +95,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)

View File

@@ -3,7 +3,6 @@ import logging
from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool
@@ -13,14 +12,17 @@ from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'
@@ -28,10 +30,9 @@ class PlanningStrategy(str, enum.Enum):
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
llm: BaseLanguageModel
model_instance: BaseLLM
tools: list[BaseTool]
summary_llm: BaseLanguageModel
dataset_llm: BaseLanguageModel
summary_model_instance: BaseLLM
memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
@@ -60,36 +61,49 @@ class AgentExecutor:
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
llm=self.configuration.dataset_llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True
)
else:
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")