mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-12 04:16:54 +08:00
feat: universal chat in explore (#649)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Union, Tuple
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
@@ -8,20 +9,21 @@ from langchain.llms import BaseLLM
|
||||
from langchain.schema import BaseMessage, HumanMessage
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.constant import llm_constant
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
|
||||
DifyStdOutCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.llm.error import LLMBadRequestError
|
||||
from core.llm.fake import FakeLLM
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.chain.main_chain_builder import MainChainBuilder
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBStringBufferSharedMemory
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
@@ -69,18 +71,33 @@ class Completion:
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
# build main chain include agent
|
||||
main_chain = MainChainBuilder.to_langchain_components(
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
|
||||
# init orchestrator rule parser
|
||||
orchestrator_rule_parser = OrchestratorRuleParser(
|
||||
tenant_id=app.tenant_id,
|
||||
agent_mode=app_model_config.agent_mode_dict,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
|
||||
conversation_message_task=conversation_message_task
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
|
||||
chain_output = ''
|
||||
if main_chain:
|
||||
chain_output = main_chain.run(query)
|
||||
# parse sensitive_word_avoidance_chain
|
||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
|
||||
if sensitive_word_avoidance_chain:
|
||||
query = sensitive_word_avoidance_chain.run(query)
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
chain_callback=chain_callback
|
||||
)
|
||||
|
||||
# run agent executor
|
||||
agent_execute_result = None
|
||||
if agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if should_use_agent:
|
||||
agent_execute_result = agent_executor.run(query)
|
||||
|
||||
# run the final llm
|
||||
try:
|
||||
@@ -90,7 +107,7 @@ class Completion:
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=chain_output,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
streaming=streaming
|
||||
@@ -105,9 +122,20 @@ class Completion:
|
||||
|
||||
@classmethod
|
||||
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
chain_output: str,
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
||||
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
|
||||
final_llm = FakeLLM(response=agent_execute_result.output,
|
||||
origin_llm=agent_execute_result.configuration.llm,
|
||||
streaming=streaming)
|
||||
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
|
||||
response = final_llm.generate([[HumanMessage(content=query)]])
|
||||
return response
|
||||
|
||||
final_llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=tenant_id,
|
||||
model=app_model_config.model_dict,
|
||||
@@ -122,7 +150,7 @@ class Completion:
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=chain_output,
|
||||
agent_execute_result=agent_execute_result,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
@@ -142,16 +170,9 @@ class Completion:
|
||||
@classmethod
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
|
||||
pre_prompt: str, query: str, inputs: dict,
|
||||
chain_output: Optional[str],
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
||||
# disable template string in query
|
||||
# query_params = JinjaPromptTemplate.from_template(template=query).input_variables
|
||||
# if query_params:
|
||||
# for query_param in query_params:
|
||||
# if query_param not in inputs:
|
||||
# inputs[query_param] = '{{' + query_param + '}}'
|
||||
|
||||
if mode == 'completion':
|
||||
prompt_template = JinjaPromptTemplate.from_template(
|
||||
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
@@ -165,18 +186,13 @@ When answer to user:
|
||||
- If you don't know when you are not sure, ask for clarification.
|
||||
Avoid mentioning that you obtained the information from the context.
|
||||
And answer according to the language of the user's question.
|
||||
""" if chain_output else "")
|
||||
""" if agent_execute_result else "")
|
||||
+ (pre_prompt + "\n" if pre_prompt else "")
|
||||
+ "{{query}}\n"
|
||||
)
|
||||
|
||||
if chain_output:
|
||||
inputs['context'] = chain_output
|
||||
# context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
|
||||
# if context_params:
|
||||
# for context_param in context_params:
|
||||
# if context_param not in inputs:
|
||||
# inputs[context_param] = '{{' + context_param + '}}'
|
||||
if agent_execute_result:
|
||||
inputs['context'] = agent_execute_result.output
|
||||
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
prompt_content = prompt_template.format(
|
||||
@@ -206,8 +222,8 @@ And answer according to the language of the user's question.
|
||||
if pre_prompt_inputs:
|
||||
human_inputs.update(pre_prompt_inputs)
|
||||
|
||||
if chain_output:
|
||||
human_inputs['context'] = chain_output
|
||||
if agent_execute_result:
|
||||
human_inputs['context'] = agent_execute_result.output
|
||||
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||
|
||||
<context>
|
||||
@@ -240,18 +256,10 @@ And answer according to the language of the user's question.
|
||||
- max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
|
||||
# disable template string in query
|
||||
# histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
|
||||
# if histories_params:
|
||||
# for histories_param in histories_params:
|
||||
# if histories_param not in human_inputs:
|
||||
# human_inputs[histories_param] = '{{' + histories_param + '}}'
|
||||
|
||||
human_message_prompt += "\n\n" if human_message_prompt else ""
|
||||
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
||||
"inside <histories></histories> XML tags.\n\n<histories>"
|
||||
human_message_prompt += histories + "</histories>"
|
||||
"inside <histories></histories> XML tags.\n\n<histories>\n"
|
||||
human_message_prompt += histories + "\n</histories>"
|
||||
|
||||
human_message_prompt += query_prompt
|
||||
|
||||
@@ -263,10 +271,13 @@ And answer according to the language of the user's question.
|
||||
|
||||
messages.append(human_message)
|
||||
|
||||
return messages, ['\nHuman:']
|
||||
for message in messages:
|
||||
message.content = re.sub(r'<\|.*?\|>', '', message.content)
|
||||
|
||||
return messages, ['\nHuman:', '</histories>']
|
||||
|
||||
@classmethod
|
||||
def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
def get_llm_callbacks(cls, llm: BaseLanguageModel,
|
||||
streaming: bool,
|
||||
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
|
||||
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
||||
@@ -277,8 +288,7 @@ And answer according to the language of the user's question.
|
||||
|
||||
@classmethod
|
||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||
max_token_limit: int) -> \
|
||||
str:
|
||||
max_token_limit: int) -> str:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory_key = memory.memory_variables[0]
|
||||
@@ -329,7 +339,7 @@ And answer according to the language of the user's question.
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=None,
|
||||
agent_execute_result=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
@@ -379,6 +389,7 @@ And answer according to the language of the user's question.
|
||||
query=message.query,
|
||||
inputs=message.inputs,
|
||||
chain_output=None,
|
||||
agent_execute_result=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user