mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 11:26:52 +08:00
Feat/blocking function call (#2247)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import json
|
||||
|
||||
from typing import Optional, List, Tuple, Union
|
||||
from typing import Optional, List, Tuple, Union, cast
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension
|
||||
|
||||
@@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \
|
||||
AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_manager import ModelInstance
|
||||
from core.file.message_file_parser import FileTransferMethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
prompt_messages: Optional[List[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
@@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
self.history_prompt_messages = prompt_messages
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
self.agent_callback = DifyAgentCallbackHandler()
|
||||
@@ -95,6 +100,14 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
|
||||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
|
||||
self.stream_tool_call = True
|
||||
else:
|
||||
self.stream_tool_call = False
|
||||
|
||||
def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
|
||||
"""
|
||||
Repacket app orchestration config
|
||||
|
||||
Reference in New Issue
Block a user