mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 19:06:51 +08:00
Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -2,30 +2,40 @@ import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
from typing import Any, Dict, List, Union, Optional, cast
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from models.model import MessageChain, MessageAgentThought, Message
|
||||
|
||||
|
||||
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
|
||||
def __init__(self, model_config: ModelConfigEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
message: Message,
|
||||
message_chain: MessageChain) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.model_instance = model_instance
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self.model_config = model_config
|
||||
self.queue_manager = queue_manager
|
||||
self.message = message
|
||||
self.message_chain = message_chain
|
||||
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
|
||||
self.model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
self.current_chain = None
|
||||
|
||||
@property
|
||||
def agent_loops(self) -> List[AgentLoop]:
|
||||
@@ -46,65 +56,60 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None:
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]),
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
|
||||
def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None:
|
||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||
self._current_loop.status = 'llm_end'
|
||||
if result.usage:
|
||||
self._current_loop.prompt_tokens = result.usage.prompt_tokens
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens(
|
||||
model=self.model_config.model,
|
||||
credentials=self.model_config.credentials,
|
||||
prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
|
||||
completion_message = result.message
|
||||
if completion_message.tool_calls:
|
||||
self._current_loop.completion \
|
||||
= json.dumps({'function_call': completion_message.tool_calls})
|
||||
else:
|
||||
self._current_loop.completion = completion_message.content
|
||||
|
||||
if result.usage:
|
||||
self._current_loop.completion_tokens = result.usage.completion_tokens
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens(
|
||||
model=self.model_config.model,
|
||||
credentials=self.model_config.credentials,
|
||||
prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt="\n".join([message.content for message in messages[0]]),
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
pass
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
# serialized={'name': 'OpenAI'}
|
||||
# prompts=['Answer the following questions...\nThought:']
|
||||
# kwargs={}
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt=prompts[0],
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
# kwargs={}
|
||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||
self._current_loop.status = 'llm_end'
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
completion_generation = response.generations[0][0]
|
||||
if isinstance(completion_generation, ChatGeneration):
|
||||
completion_message = completion_generation.message
|
||||
if 'function_call' in completion_message.additional_kwargs:
|
||||
self._current_loop.completion \
|
||||
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
|
||||
else:
|
||||
self._current_loop.completion = response.generations[0][0].text
|
||||
else:
|
||||
self._current_loop.completion = completion_generation.text
|
||||
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
@@ -150,10 +155,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
if completion is not None:
|
||||
self._current_loop.completion = completion
|
||||
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
self._message_agent_thought = self._init_agent_thought()
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
@@ -176,9 +178,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_instance, self._current_loop
|
||||
)
|
||||
self._complete_agent_thought(self._message_agent_thought)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
@@ -202,17 +202,62 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
self._current_loop.thought = '[DONE]'
|
||||
self._message_agent_thought = self.conversation_message_task.on_agent_start(
|
||||
self.current_chain,
|
||||
self._current_loop
|
||||
)
|
||||
self._message_agent_thought = self._init_agent_thought()
|
||||
|
||||
self.conversation_message_task.on_agent_end(
|
||||
self._message_agent_thought, self.model_instance, self._current_loop
|
||||
)
|
||||
self._complete_agent_thought(self._message_agent_thought)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
self._message_agent_thought = None
|
||||
elif not self._current_loop and self._agent_loops:
|
||||
self._agent_loops[-1].status = 'agent_finish'
|
||||
|
||||
def _init_agent_thought(self) -> MessageAgentThought:
|
||||
message_agent_thought = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=self.message_chain.id,
|
||||
position=self._current_loop.position,
|
||||
thought=self._current_loop.thought,
|
||||
tool=self._current_loop.tool_name,
|
||||
tool_input=self._current_loop.tool_input,
|
||||
message=self._current_loop.prompt,
|
||||
message_price_unit=0,
|
||||
answer=self._current_loop.completion,
|
||||
answer_price_unit=0,
|
||||
created_by_role=('account' if self.message.from_source == 'console' else 'end_user'),
|
||||
created_by=(self.message.from_account_id
|
||||
if self.message.from_source == 'console' else self.message.from_end_user_id)
|
||||
)
|
||||
|
||||
db.session.add(message_agent_thought)
|
||||
db.session.commit()
|
||||
|
||||
self.queue_manager.publish_agent_thought(message_agent_thought)
|
||||
|
||||
return message_agent_thought
|
||||
|
||||
def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
|
||||
loop_message_tokens = self._current_loop.prompt_tokens
|
||||
loop_answer_tokens = self._current_loop.completion_tokens
|
||||
|
||||
# transform usage
|
||||
llm_usage = self.model_type_instance._calc_response_usage(
|
||||
self.model_config.model,
|
||||
self.model_config.credentials,
|
||||
loop_message_tokens,
|
||||
loop_answer_tokens
|
||||
)
|
||||
|
||||
message_agent_thought.observation = self._current_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
message_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
message_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
message_agent_thought.latency = self._current_loop.latency
|
||||
message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens
|
||||
message_agent_thought.total_price = llm_usage.total_price
|
||||
message_agent_thought.currency = llm_usage.currency
|
||||
db.session.commit()
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.queries = []
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_name: str = serialized.get('name')
|
||||
dataset_id = tool_name.removeprefix('dataset-')
|
||||
|
||||
try:
|
||||
input_dict = json.loads(input_str.replace("'", "\""))
|
||||
query = input_dict.get('query')
|
||||
except JSONDecodeError:
|
||||
query = input_str
|
||||
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.debug("Dataset tool on_llm_error: %s", error)
|
||||
@@ -1,16 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChainResult(BaseModel):
|
||||
type: str = None
|
||||
prompt: dict = None
|
||||
completion: dict = None
|
||||
|
||||
status: str = 'chain_started'
|
||||
completed: bool = False
|
||||
|
||||
started_at: float = None
|
||||
completed_at: float = None
|
||||
|
||||
agent_result: dict = None
|
||||
"""only when type is 'AgentExecutor'"""
|
||||
@@ -1,6 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DatasetQueryObj(BaseModel):
|
||||
dataset_id: str = None
|
||||
query: str = None
|
||||
@@ -1,8 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
prompt: str = ''
|
||||
prompt_tokens: int = 0
|
||||
completion: str = ''
|
||||
completion_tokens: int = 0
|
||||
@@ -1,17 +1,44 @@
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment
|
||||
from models.dataset import DocumentSegment, DatasetQuery
|
||||
from models.model import DatasetRetrieverResource
|
||||
|
||||
|
||||
class DatasetIndexToolCallbackHandler:
|
||||
"""Callback handler for dataset tool."""
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
self.conversation_message_task = conversation_message_task
|
||||
def __init__(self, queue_manager: ApplicationQueueManager,
|
||||
app_id: str,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> None:
|
||||
self._queue_manager = queue_manager
|
||||
self._app_id = app_id
|
||||
self._message_id = message_id
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
|
||||
def on_query(self, query: str, dataset_id: str) -> None:
|
||||
"""
|
||||
Handle query.
|
||||
"""
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source='app',
|
||||
source_app_id=self._app_id,
|
||||
created_by_role=('account'
|
||||
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||
created_by=self._user_id
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def on_tool_end(self, documents: List[Document]) -> None:
|
||||
"""Handle tool end."""
|
||||
@@ -30,4 +57,27 @@ class DatasetIndexToolCallbackHandler:
|
||||
|
||||
def return_retriever_resource_info(self, resource: List):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
self.conversation_message_task.on_dataset_query_finish(resource)
|
||||
if resource and len(resource) > 0:
|
||||
for item in resource:
|
||||
dataset_retriever_resource = DatasetRetrieverResource(
|
||||
message_id=self._message_id,
|
||||
position=item.get('position'),
|
||||
dataset_id=item.get('dataset_id'),
|
||||
dataset_name=item.get('dataset_name'),
|
||||
document_id=item.get('document_id'),
|
||||
document_name=item.get('document_name'),
|
||||
data_source_type=item.get('data_source_type'),
|
||||
segment_id=item.get('segment_id'),
|
||||
score=item.get('score') if 'score' in item else None,
|
||||
hit_count=item.get('hit_count') if 'hit_count' else None,
|
||||
word_count=item.get('word_count') if 'word_count' in item else None,
|
||||
segment_position=item.get('segment_position') if 'segment_position' in item else None,
|
||||
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
|
||||
content=item.get('content'),
|
||||
retriever_from=item.get('retriever_from'),
|
||||
created_by=self._user_id
|
||||
)
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.commit()
|
||||
|
||||
self._queue_manager.publish_retriever_resources(resource)
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult, BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
|
||||
ImagePromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.moderation.base import ModerationOutputsResult, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class LLMCallbackHandler(BaseCallbackHandler):
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, model_instance: BaseLLM,
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
self.model_instance = model_instance
|
||||
self.llm_message = LLMMessage()
|
||||
self.start_at = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
self.output_moderation_handler = None
|
||||
self.init_output_moderation()
|
||||
|
||||
def init_output_moderation(self):
|
||||
app_model_config = self.conversation_message_task.app_model_config
|
||||
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
|
||||
|
||||
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
|
||||
self.output_moderation_handler = OutputModerationHandler(
|
||||
tenant_id=self.conversation_message_task.tenant_id,
|
||||
app_id=self.conversation_message_task.app.id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config")
|
||||
),
|
||||
on_message_replace_func=self.conversation_message_task.on_message_replace
|
||||
)
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
real_prompts = []
|
||||
for message in messages[0]:
|
||||
if message.type == 'human':
|
||||
role = 'user'
|
||||
elif message.type == 'ai':
|
||||
role = 'assistant'
|
||||
else:
|
||||
role = 'system'
|
||||
|
||||
real_prompts.append({
|
||||
"role": role,
|
||||
"text": message.content,
|
||||
"files": [{
|
||||
"type": file.type.value,
|
||||
"data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
|
||||
"detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
|
||||
} for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
|
||||
})
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
self.llm_message.prompt = [{
|
||||
"role": 'user',
|
||||
"text": prompts[0]
|
||||
}]
|
||||
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
|
||||
completion=response.generations[0][0].text,
|
||||
public_event=True if self.conversation_message_task.streaming else False
|
||||
)
|
||||
else:
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(self.llm_message.completion)
|
||||
|
||||
if response.llm_output and 'token_usage' in response.llm_output:
|
||||
if 'prompt_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
|
||||
if 'completion_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)])
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)])
|
||||
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
ex = ConversationTaskInterruptException()
|
||||
self.on_llm_error(error=ex)
|
||||
raise ex
|
||||
|
||||
try:
|
||||
self.conversation_message_task.append_message_text(token)
|
||||
self.llm_message.completion += token
|
||||
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.append_new_token(token)
|
||||
except ConversationTaskStoppedException as ex:
|
||||
self.on_llm_error(error=ex)
|
||||
raise ex
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
if isinstance(error, ConversationTaskStoppedException):
|
||||
if self.conversation_message_task.streaming:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
||||
if isinstance(error, ConversationTaskInterruptException):
|
||||
self.llm_message.completion = self.output_moderation_handler.get_final_output()
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message)
|
||||
else:
|
||||
logging.debug("on_llm_error: %s", error)
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
DEFAULT_BUFFER_SIZE: int = 300
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
|
||||
rule: ModerationRule
|
||||
on_message_replace_func: Any
|
||||
|
||||
thread: Optional[threading.Thread] = None
|
||||
thread_running: bool = True
|
||||
buffer: str = ''
|
||||
is_final_chunk: bool = False
|
||||
final_output: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_direct_output(self):
|
||||
return self.final_output is not None
|
||||
|
||||
def get_final_output(self):
|
||||
return self.final_output
|
||||
|
||||
def append_new_token(self, token: str):
|
||||
self.buffer += token
|
||||
|
||||
if not self.thread:
|
||||
self.thread = self.start_thread()
|
||||
|
||||
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
|
||||
self.buffer = completion
|
||||
self.is_final_chunk = True
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=completion
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
return completion
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
else:
|
||||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
return final_output
|
||||
|
||||
def start_thread(self) -> threading.Thread:
|
||||
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
|
||||
thread = threading.Thread(target=self.worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
|
||||
})
|
||||
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
||||
def stop_thread(self):
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread_running = False
|
||||
|
||||
def worker(self, flask_app: Flask, buffer_size: int):
|
||||
with flask_app.app_context():
|
||||
current_length = 0
|
||||
while self.thread_running:
|
||||
moderation_buffer = self.buffer
|
||||
buffer_length = len(moderation_buffer)
|
||||
if not self.is_final_chunk:
|
||||
chunk_length = buffer_length - current_length
|
||||
if 0 <= chunk_length < buffer_size:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
current_length = buffer_length
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=moderation_buffer
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
continue
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
self.final_output = final_output
|
||||
else:
|
||||
final_output = result.text + self.buffer[len(moderation_buffer):]
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
||||
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
||||
try:
|
||||
moderation_factory = ModerationFactory(
|
||||
name=self.rule.type,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
config=self.rule.config
|
||||
)
|
||||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error("Moderation Output error: %s", e)
|
||||
|
||||
return None
|
||||
@@ -1,76 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self.agent_callback = None
|
||||
|
||||
def clear_chain_results(self) -> None:
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return True
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
if not self._current_chain_result:
|
||||
chain_type = serialized['id'][-1]
|
||||
if chain_type:
|
||||
self._current_chain_result = ChainResult(
|
||||
type=chain_type,
|
||||
prompt=inputs,
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
||||
if self.agent_callback:
|
||||
self.agent_callback.current_chain = self._current_chain_message
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
if self._current_chain_result and self._current_chain_result.status == 'chain_started':
|
||||
self._current_chain_result.status = 'chain_ended'
|
||||
self._current_chain_result.completion = outputs
|
||||
self._current_chain_result.completed = True
|
||||
self._current_chain_result.completed_at = time.perf_counter()
|
||||
|
||||
self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result)
|
||||
|
||||
self.clear_chain_results()
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.debug("Dataset tool on_chain_error: %s", error)
|
||||
self.clear_chain_results()
|
||||
@@ -79,8 +79,11 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Run on agent action."""
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
try:
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
except ValueError:
|
||||
thought = ''
|
||||
|
||||
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
|
||||
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
|
||||
|
||||
Reference in New Issue
Block a user