mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 10:56:52 +08:00
feat: optimize override app model config convert (#874)
This commit is contained in:
@@ -112,7 +112,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2:
|
||||
if len(intermediate_steps) >= 2 and self.summary_llm:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
|
||||
@@ -65,7 +65,8 @@ class AgentExecutor:
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_llm=self.configuration.summary_model_instance.client,
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
@@ -74,7 +75,8 @@ class AgentExecutor:
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_model_instance.client,
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
|
||||
@@ -83,7 +85,8 @@ class AgentExecutor:
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_model_instance.client,
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
|
||||
@@ -60,17 +60,7 @@ class ConversationMessageTask:
|
||||
def init(self):
|
||||
override_model_configs = None
|
||||
if self.is_override:
|
||||
override_model_configs = {
|
||||
"model": self.app_model_config.model_dict,
|
||||
"pre_prompt": self.app_model_config.pre_prompt,
|
||||
"agent_mode": self.app_model_config.agent_mode_dict,
|
||||
"opening_statement": self.app_model_config.opening_statement,
|
||||
"suggested_questions": self.app_model_config.suggested_questions_list,
|
||||
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
|
||||
"more_like_this": self.app_model_config.more_like_this_dict,
|
||||
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
|
||||
"user_input_form": self.app_model_config.user_input_form_list,
|
||||
}
|
||||
override_model_configs = self.app_model_config.to_dict()
|
||||
|
||||
introduction = ''
|
||||
system_instruction = ''
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from langchain.schema import OutputParserException
|
||||
|
||||
from core.model_providers.error import LLMError
|
||||
from core.model_providers.error import LLMError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
@@ -108,13 +108,16 @@ class LLMGenerator:
|
||||
|
||||
_input = prompt.format_prompt(histories=histories)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=256,
|
||||
temperature=0
|
||||
try:
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=256,
|
||||
temperature=0
|
||||
)
|
||||
)
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
return []
|
||||
|
||||
prompts = [PromptMessage(content=_input.to_string())]
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
@@ -78,13 +79,16 @@ class OrchestratorRuleParser:
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
|
||||
summary_model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=0,
|
||||
max_tokens=500
|
||||
try:
|
||||
summary_model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=0,
|
||||
max_tokens=500
|
||||
)
|
||||
)
|
||||
)
|
||||
except ProviderTokenNotInitError as e:
|
||||
summary_model_instance = None
|
||||
|
||||
tools = self.to_tools(
|
||||
tool_configs=tool_configs,
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
from core.model_providers.error import LLMError
|
||||
from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
|
||||
@@ -12,5 +15,10 @@ class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
|
||||
|
||||
def parse(self, text: str) -> Any:
|
||||
json_string = text.strip()
|
||||
json_obj = json.loads(json_string)
|
||||
action_match = re.search(r".*(\[\".+\"\]).*", json_string, re.DOTALL)
|
||||
if action_match is not None:
|
||||
json_obj = json.loads(action_match.group(1).strip(), strict=False)
|
||||
else:
|
||||
raise LLMError("Could not parse LLM output: {text}")
|
||||
|
||||
return json_obj
|
||||
|
||||
Reference in New Issue
Block a user