feat: optimize override app model config convert (#874)

This commit is contained in:
takatost
2023-08-16 20:48:42 +08:00
committed by GitHub
parent cd11613952
commit cc2d71c253
12 changed files with 166 additions and 141 deletions

View File

@@ -112,7 +112,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2:
if len(intermediate_steps) >= 2 and self.summary_llm:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]

View File

@@ -65,7 +65,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_model_instance.client,
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
@@ -74,7 +75,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client,
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
@@ -83,7 +85,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client,
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:

View File

@@ -60,17 +60,7 @@ class ConversationMessageTask:
def init(self):
override_model_configs = None
if self.is_override:
override_model_configs = {
"model": self.app_model_config.model_dict,
"pre_prompt": self.app_model_config.pre_prompt,
"agent_mode": self.app_model_config.agent_mode_dict,
"opening_statement": self.app_model_config.opening_statement,
"suggested_questions": self.app_model_config.suggested_questions_list,
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
"more_like_this": self.app_model_config.more_like_this_dict,
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
"user_input_form": self.app_model_config.user_input_form_list,
}
override_model_configs = self.app_model_config.to_dict()
introduction = ''
system_instruction = ''

View File

@@ -2,7 +2,7 @@ import logging
from langchain.schema import OutputParserException
from core.model_providers.error import LLMError
from core.model_providers.error import LLMError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
@@ -108,13 +108,16 @@ class LLMGenerator:
_input = prompt.format_prompt(histories=histories)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=256,
temperature=0
try:
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=256,
temperature=0
)
)
)
except ProviderTokenNotInitError:
return []
prompts = [PromptMessage(content=_input.to_string())]

View File

@@ -14,6 +14,7 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -78,13 +79,16 @@ class OrchestratorRuleParser:
elif planning_strategy == PlanningStrategy.ROUTER:
planning_strategy = PlanningStrategy.REACT_ROUTER
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_kwargs=ModelKwargs(
temperature=0,
max_tokens=500
try:
summary_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_kwargs=ModelKwargs(
temperature=0,
max_tokens=500
)
)
)
except ProviderTokenNotInitError as e:
summary_model_instance = None
tools = self.to_tools(
tool_configs=tool_configs,

View File

@@ -1,7 +1,10 @@
import json
import re
from typing import Any
from langchain.schema import BaseOutputParser
from core.model_providers.error import LLMError
from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
@@ -12,5 +15,10 @@ class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
def parse(self, text: str) -> Any:
json_string = text.strip()
json_obj = json.loads(json_string)
action_match = re.search(r".*(\[\".+\"\]).*", json_string, re.DOTALL)
if action_match is not None:
json_obj = json.loads(action_match.group(1).strip(), strict=False)
else:
raise LLMError("Could not parse LLM output: {text}")
return json_obj