feat: universal chat in explore (#649)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
John Wang
2023-07-27 13:08:57 +08:00
committed by GitHub
parent 94b54b7ca9
commit 4fdb37771a
64 changed files with 3186 additions and 858 deletions

View File

@@ -1,4 +1,5 @@
import logging
import re
from typing import Optional, List, Union, Tuple
from langchain.base_language import BaseLanguageModel
@@ -8,20 +9,21 @@ from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError
from core.llm.fake import FakeLLM
from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
ReadOnlyConversationTokenDBStringBufferSharedMemory
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
@@ -69,18 +71,33 @@ class Completion:
streaming=streaming
)
# build main chain include agent
main_chain = MainChainBuilder.to_langchain_components(
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
# init orchestrator rule parser
orchestrator_rule_parser = OrchestratorRuleParser(
tenant_id=app.tenant_id,
agent_mode=app_model_config.agent_mode_dict,
rest_tokens=rest_tokens_for_context_and_memory,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task
app_model_config=app_model_config
)
chain_output = ''
if main_chain:
chain_output = main_chain.run(query)
# parse sensitive_word_avoidance_chain
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback
)
# run agent executor
agent_execute_result = None
if agent_executor:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
agent_execute_result = agent_executor.run(query)
# run the final llm
try:
@@ -90,7 +107,7 @@ class Completion:
app_model_config=app_model_config,
query=query,
inputs=inputs,
chain_output=chain_output,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory,
streaming=streaming
@@ -105,9 +122,20 @@ class Completion:
@classmethod
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
chain_output: str,
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
final_llm = FakeLLM(response=agent_execute_result.output,
origin_llm=agent_execute_result.configuration.llm,
streaming=streaming)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
response = final_llm.generate([[HumanMessage(content=query)]])
return response
final_llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict,
@@ -122,7 +150,7 @@ class Completion:
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
chain_output=chain_output,
agent_execute_result=agent_execute_result,
memory=memory
)
@@ -142,16 +170,9 @@ class Completion:
@classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str],
agent_execute_result: Optional[AgentExecuteResult],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
# disable template string in query
# query_params = JinjaPromptTemplate.from_template(template=query).input_variables
# if query_params:
# for query_param in query_params:
# if query_param not in inputs:
# inputs[query_param] = '{{' + query_param + '}}'
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
@@ -165,18 +186,13 @@ When answer to user:
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
""" if chain_output else "")
""" if agent_execute_result else "")
+ (pre_prompt + "\n" if pre_prompt else "")
+ "{{query}}\n"
)
if chain_output:
inputs['context'] = chain_output
# context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
# if context_params:
# for context_param in context_params:
# if context_param not in inputs:
# inputs[context_param] = '{{' + context_param + '}}'
if agent_execute_result:
inputs['context'] = agent_execute_result.output
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_content = prompt_template.format(
@@ -206,8 +222,8 @@ And answer according to the language of the user's question.
if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs)
if chain_output:
human_inputs['context'] = chain_output
if agent_execute_result:
human_inputs['context'] = agent_execute_result.output
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
@@ -240,18 +256,10 @@ And answer according to the language of the user's question.
- max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
# disable template string in query
# histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
# if histories_params:
# for histories_param in histories_params:
# if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}'
human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
"inside <histories></histories> XML tags.\n\n<histories>"
human_message_prompt += histories + "</histories>"
"inside <histories></histories> XML tags.\n\n<histories>\n"
human_message_prompt += histories + "\n</histories>"
human_message_prompt += query_prompt
@@ -263,10 +271,13 @@ And answer according to the language of the user's question.
messages.append(human_message)
return messages, ['\nHuman:']
for message in messages:
message.content = re.sub(r'<\|.*?\|>', '', message.content)
return messages, ['\nHuman:', '</histories>']
@classmethod
def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
def get_llm_callbacks(cls, llm: BaseLanguageModel,
streaming: bool,
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
@@ -277,8 +288,7 @@ And answer according to the language of the user's question.
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> \
str:
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
@@ -329,7 +339,7 @@ And answer according to the language of the user's question.
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
chain_output=None,
agent_execute_result=None,
memory=None
)
@@ -379,6 +389,7 @@ And answer according to the language of the user's question.
query=message.query,
inputs=message.inputs,
chain_output=None,
agent_execute_result=None,
memory=None
)