Feat/blocking function call (#2247)

This commit is contained in:
Yeuoly
2024-01-30 15:25:37 +08:00
committed by GitHub
parent 1ea18a2922
commit 6d5b386394
33 changed files with 429 additions and 94 deletions

View File

@@ -1,7 +1,7 @@
import logging
import json
from typing import Optional, List, Tuple, Union
from typing import Optional, List, Tuple, Union, cast
from datetime import datetime
from mimetypes import guess_extension
@@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \
AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_manager import ModelInstance
from core.file.message_file_parser import FileTransferMethod
logger = logging.getLogger(__name__)
@@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner):
prompt_messages: Optional[List[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
) -> None:
"""
Agent runner
@@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self.history_prompt_messages = prompt_messages
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance
# init callback
self.agent_callback = DifyAgentCallbackHandler()
@@ -95,6 +100,14 @@ class BaseAssistantApplicationRunner(AppRunner):
MessageAgentThought.message_id == self.message.id,
).count()
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
self.stream_tool_call = True
else:
self.stream_tool_call = False
def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
"""
Repacket app orchestration config