mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-25 10:43:01 +08:00
Feat/environment variables in workflow (#6515)
Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != 'true':
|
||||
from gevent import monkey
|
||||
|
||||
@@ -23,7 +21,9 @@ from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import contexts
|
||||
from commands import register_commands
|
||||
from configs import dify_config
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
@@ -181,7 +181,10 @@ def load_user_from_request(request_from_flask_login):
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get('user_id')
|
||||
|
||||
return AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
||||
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
||||
if account:
|
||||
contexts.tenant_id.set(account.current_tenant_id)
|
||||
return account
|
||||
|
||||
|
||||
@login_manager.unauthorized_handler
|
||||
|
||||
@@ -406,7 +406,6 @@ class DataSetConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
Workspace configs
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
# TODO: Update all string in code to use this constant
|
||||
HIDDEN_VALUE = '[__HIDDEN__]'
|
||||
3
api/contexts/__init__.py
Normal file
3
api/contexts/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar('tenant_id')
|
||||
@@ -212,7 +212,7 @@ class AppCopyApi(Resource):
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
data = AppDslService.export_dsl(app_model=app_model)
|
||||
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=data,
|
||||
@@ -234,8 +234,13 @@ class AppExportApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(app_model=app_model)
|
||||
"data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret'])
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import factory
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
@@ -41,7 +42,7 @@ class DraftWorkflowApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
@@ -64,13 +65,15 @@ class DraftWorkflowApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
content_type = request.headers.get('Content-Type')
|
||||
content_type = request.headers.get('Content-Type', '')
|
||||
|
||||
if 'application/json' in content_type:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('hash', type=str, required=False, location='json')
|
||||
# TODO: set this to required=True after frontend is updated
|
||||
parser.add_argument('environment_variables', type=list, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
elif 'text/plain' in content_type:
|
||||
try:
|
||||
@@ -84,7 +87,8 @@ class DraftWorkflowApi(Resource):
|
||||
args = {
|
||||
'graph': data.get('graph'),
|
||||
'features': data.get('features'),
|
||||
'hash': data.get('hash')
|
||||
'hash': data.get('hash'),
|
||||
'environment_variables': data.get('environment_variables')
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {'message': 'Invalid JSON data'}, 400
|
||||
@@ -94,12 +98,15 @@ class DraftWorkflowApi(Resource):
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get('environment_variables') or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args.get('graph'),
|
||||
features=args.get('features'),
|
||||
graph=args['graph'],
|
||||
features=args['features'],
|
||||
unique_hash=args.get('hash'),
|
||||
account=current_user
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.app_config.entities import AppAdditionalFeatures
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
||||
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
||||
@@ -10,37 +11,19 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
||||
SuggestedQuestionsAfterAnswerConfigManager,
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class BaseAppConfigManager:
|
||||
|
||||
@classmethod
|
||||
def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom,
|
||||
app_model_config: Union[AppModelConfig, dict],
|
||||
config_dict: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Convert app model config to config dict
|
||||
:param config_from: app model config from
|
||||
:param app_model_config: app model config
|
||||
:param config_dict: app model config dict
|
||||
:return:
|
||||
"""
|
||||
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
|
||||
return config_dict
|
||||
|
||||
@classmethod
|
||||
def convert_features(cls, config_dict: dict, app_mode: AppMode) -> AppAdditionalFeatures:
|
||||
def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> AppAdditionalFeatures:
|
||||
"""
|
||||
Convert app config to app model config
|
||||
|
||||
:param config_dict: app config
|
||||
:param app_mode: app mode
|
||||
"""
|
||||
config_dict = config_dict.copy()
|
||||
config_dict = dict(config_dict.items())
|
||||
|
||||
additional_features = AppAdditionalFeatures()
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict, is_vision: bool = True) -> Optional[FileExtraConfig]:
|
||||
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@ from core.app.app_config.entities import TextToSpeechEntity
|
||||
|
||||
class TextToSpeechConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
text_to_speech = False
|
||||
text_to_speech = None
|
||||
text_to_speech_dict = config.get('text_to_speech')
|
||||
if text_to_speech_dict:
|
||||
if text_to_speech_dict.get('enabled'):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -8,6 +9,7 @@ from typing import Union
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
@@ -107,6 +109,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
@@ -173,6 +176,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs=args['inputs']
|
||||
)
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
@@ -225,6 +229,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
'user': user,
|
||||
'context': contextvars.copy_context()
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
@@ -249,7 +255,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str) -> None:
|
||||
message_id: str,
|
||||
user: Account,
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
@@ -259,6 +267,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
runner = AdvancedChatAppRunner()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
@@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
@@ -87,7 +89,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
||||
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
@@ -161,7 +163,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
self, queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str
|
||||
) -> bool:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppBlockingResponse,
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@@ -18,12 +20,13 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
@@ -39,7 +42,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -53,8 +56,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -83,8 +85,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@@ -118,7 +118,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
def process(self):
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@@ -141,8 +141,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> ChatbotAppBlockingResponse:
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@@ -172,8 +171,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
|
||||
@@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
@@ -15,44 +15,41 @@ class AppGenerateResponseConverter(ABC):
|
||||
@classmethod
|
||||
def convert(cls, response: Union[
|
||||
AppBlockingResponse,
|
||||
Generator[AppStreamResponse, None, None]
|
||||
], invoke_from: InvokeFrom) -> Union[
|
||||
dict,
|
||||
Generator[str, None, None]
|
||||
]:
|
||||
Generator[AppStreamResponse, Any, None]
|
||||
], invoke_from: InvokeFrom):
|
||||
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
if isinstance(response, cls._blocking_response_type):
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
def _generate():
|
||||
def _generate_full_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_full_response(response):
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
|
||||
return _generate()
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, cls._blocking_response_type):
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
def _generate():
|
||||
def _generate_simple_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_simple_response(response):
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
|
||||
return _generate()
|
||||
return _generate_simple_response()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -68,7 +65,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _get_simple_metadata(cls, metadata: dict) -> dict:
|
||||
def _get_simple_metadata(cls, metadata: dict[str, Any]):
|
||||
"""
|
||||
Get simple metadata.
|
||||
:param metadata: metadata
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -8,6 +9,7 @@ from typing import Union
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
@@ -38,7 +40,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -86,6 +88,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
@@ -126,7 +129,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager
|
||||
'queue_manager': queue_manager,
|
||||
'context': contextvars.copy_context()
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
@@ -150,8 +154,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
stream: bool = True):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -193,6 +196,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
inputs=args['inputs']
|
||||
)
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
@@ -205,7 +209,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager) -> None:
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
@@ -213,6 +218,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param queue_manager: queue manager
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# workflow app
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
@@ -57,7 +58,7 @@ class WorkflowAppRunner:
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
||||
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
|
||||
@@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
@@ -15,7 +15,7 @@ _TEXT_COLOR_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
class WorkflowLoggingCallback(BaseWorkflowCallback):
|
||||
class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.current_node_id = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -76,7 +77,7 @@ class AppGenerateEntity(BaseModel):
|
||||
# app config
|
||||
app_config: AppConfig
|
||||
|
||||
inputs: dict[str, Any]
|
||||
inputs: Mapping[str, Any]
|
||||
files: list[FileVar] = []
|
||||
user_id: str
|
||||
|
||||
@@ -140,7 +141,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
query: Optional[str] = None
|
||||
query: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
||||
27
api/core/app/segments/__init__.py
Normal file
27
api/core/app/segments/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import Segment
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'IntegerVariable',
|
||||
'FloatVariable',
|
||||
'ObjectVariable',
|
||||
'SecretVariable',
|
||||
'FileVariable',
|
||||
'StringVariable',
|
||||
'ArrayVariable',
|
||||
'Variable',
|
||||
'SegmentType',
|
||||
'SegmentGroup',
|
||||
'Segment'
|
||||
]
|
||||
64
api/core/app/segments/factory.py
Normal file
64
api/core/app/segments/factory.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
from .segments import Segment, StringSegment
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
)
|
||||
|
||||
|
||||
def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
|
||||
if (value_type := m.get('value_type')) is None:
|
||||
raise ValueError('missing value type')
|
||||
if not m.get('name'):
|
||||
raise ValueError('missing name')
|
||||
if (value := m.get('value')) is None:
|
||||
raise ValueError('missing value')
|
||||
match value_type:
|
||||
case SegmentType.STRING:
|
||||
return StringVariable.model_validate(m)
|
||||
case SegmentType.NUMBER if isinstance(value, int):
|
||||
return IntegerVariable.model_validate(m)
|
||||
case SegmentType.NUMBER if isinstance(value, float):
|
||||
return FloatVariable.model_validate(m)
|
||||
case SegmentType.SECRET:
|
||||
return SecretVariable.model_validate(m)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise ValueError(f'invalid number value {value}')
|
||||
raise ValueError(f'not supported value type {value_type}')
|
||||
|
||||
|
||||
def build_anonymous_variable(value: Any, /) -> Variable:
|
||||
if isinstance(value, str):
|
||||
return StringVariable(name='anonymous', value=value)
|
||||
if isinstance(value, int):
|
||||
return IntegerVariable(name='anonymous', value=value)
|
||||
if isinstance(value, float):
|
||||
return FloatVariable(name='anonymous', value=value)
|
||||
if isinstance(value, dict):
|
||||
# TODO: Limit the depth of the object
|
||||
obj = {k: build_anonymous_variable(v) for k, v in value.items()}
|
||||
return ObjectVariable(name='anonymous', value=obj)
|
||||
if isinstance(value, list):
|
||||
# TODO: Limit the depth of the array
|
||||
elements = [build_anonymous_variable(v) for v in value]
|
||||
return ArrayVariable(name='anonymous', value=elements)
|
||||
if isinstance(value, FileVar):
|
||||
return FileVariable(name='anonymous', value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
if isinstance(value, str):
|
||||
return StringSegment(value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
||||
17
api/core/app/segments/parser.py
Normal file
17
api/core/app/segments/parser.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import re
|
||||
|
||||
from core.app.segments import SegmentGroup, factory
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
|
||||
|
||||
|
||||
def convert_template(*, template: str, variable_pool: VariablePool):
|
||||
parts = re.split(VARIABLE_PATTERN, template)
|
||||
segments = []
|
||||
for part in parts:
|
||||
if '.' in part and (value := variable_pool.get(part.split('.'))):
|
||||
segments.append(value)
|
||||
else:
|
||||
segments.append(factory.build_segment(part))
|
||||
return SegmentGroup(segments=segments)
|
||||
19
api/core/app/segments/segment_group.py
Normal file
19
api/core/app/segments/segment_group.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .segments import Segment
|
||||
|
||||
|
||||
class SegmentGroup(BaseModel):
|
||||
segments: list[Segment]
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return ''.join([segment.text for segment in self.segments])
|
||||
|
||||
@property
|
||||
def log(self):
|
||||
return ''.join([segment.log for segment in self.segments])
|
||||
|
||||
@property
|
||||
def markdown(self):
|
||||
return ''.join([segment.markdown for segment in self.segments])
|
||||
39
api/core/app/segments/segments.py
Normal file
39
api/core/app/segments/segments.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
class Segment(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
value_type: SegmentType
|
||||
value: Any
|
||||
|
||||
@field_validator('value_type')
|
||||
def validate_value_type(cls, value):
|
||||
"""
|
||||
This validator checks if the provided value is equal to the default value of the 'value_type' field.
|
||||
If the value is different, a ValueError is raised.
|
||||
"""
|
||||
if value != cls.model_fields['value_type'].default:
|
||||
raise ValueError("Cannot modify 'value_type'")
|
||||
return value
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class StringSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.STRING
|
||||
value: str
|
||||
17
api/core/app/segments/types.py
Normal file
17
api/core/app/segments/types.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SegmentType(str, Enum):
|
||||
STRING = 'string'
|
||||
NUMBER = 'number'
|
||||
FILE = 'file'
|
||||
|
||||
SECRET = 'secret'
|
||||
|
||||
OBJECT = 'object'
|
||||
|
||||
ARRAY = 'array'
|
||||
ARRAY_STRING = 'array[string]'
|
||||
ARRAY_NUMBER = 'array[number]'
|
||||
ARRAY_OBJECT = 'array[object]'
|
||||
ARRAY_FILE = 'array[file]'
|
||||
83
api/core/app/segments/variables.py
Normal file
83
api/core/app/segments/variables.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.helper import encrypter
|
||||
|
||||
from .segments import Segment, StringSegment
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
class Variable(Segment):
|
||||
"""
|
||||
A variable is a segment that has a name.
|
||||
"""
|
||||
|
||||
id: str = Field(
|
||||
default='',
|
||||
description="Unique identity for variable. It's only used by environment variables now.",
|
||||
)
|
||||
name: str
|
||||
|
||||
|
||||
class StringVariable(StringSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class FloatVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value: float
|
||||
|
||||
|
||||
class IntegerVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value: int
|
||||
|
||||
|
||||
class ObjectVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.OBJECT
|
||||
value: Mapping[str, Variable]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# TODO: Process variables.
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
# TODO: Process variables.
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
# TODO: Use markdown code block
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
class ArrayVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.ARRAY
|
||||
value: Sequence[Variable]
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return '\n'.join(['- ' + item.markdown for item in self.value])
|
||||
|
||||
|
||||
class FileVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.FILE
|
||||
# TODO: embed FileVar in this model.
|
||||
value: FileVar
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return self.value.to_markdown()
|
||||
|
||||
|
||||
class SecretVariable(StringVariable):
|
||||
value_type: SegmentType = SegmentType.SECRET
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return encrypter.obfuscated_token(self.value)
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, TextIO, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
@@ -43,7 +45,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
def on_tool_start(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: dict[str, Any],
|
||||
tool_inputs: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||
@@ -51,8 +53,8 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
def on_tool_end(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: dict[str, Any],
|
||||
tool_outputs: str,
|
||||
tool_inputs: Mapping[str, Any],
|
||||
tool_outputs: Sequence[ToolInvokeMessage],
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Union
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
@@ -16,7 +17,7 @@ class MessageFileParser:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: list[dict], file_extra_config: FileExtraConfig,
|
||||
def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig,
|
||||
user: Union[Account, EndUser]) -> list[FileVar]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
|
||||
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
|
||||
|
||||
CODE_EXECUTION_TIMEOUT= (10, 60)
|
||||
CODE_EXECUTION_TIMEOUT = (10, 60)
|
||||
|
||||
class CodeExecutionException(Exception):
|
||||
pass
|
||||
@@ -64,7 +64,7 @@ class CodeExecutor:
|
||||
|
||||
@classmethod
|
||||
def execute_code(cls,
|
||||
language: Literal['python3', 'javascript', 'jinja2'],
|
||||
language: CodeLanguage,
|
||||
preload: str,
|
||||
code: str,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> str:
|
||||
@@ -119,7 +119,7 @@ class CodeExecutor:
|
||||
return response.data.stdout
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
|
||||
@@ -6,11 +6,16 @@ from models.account import Tenant
|
||||
|
||||
|
||||
def obfuscated_token(token: str):
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
if not token:
|
||||
return token
|
||||
if len(token) <= 8:
|
||||
return '*' * 20
|
||||
return token[:6] + '*' * 12 + token[-2:]
|
||||
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f'Tenant with id {tenant_id} not found')
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
|
||||
@@ -413,6 +413,7 @@ class LBModelManager:
|
||||
for load_balancing_config in self._load_balancing_configs:
|
||||
if load_balancing_config.name == "__inherit__":
|
||||
if not managed_credentials:
|
||||
# FIXME: Mutation to loop iterable `self._load_balancing_configs` during iteration
|
||||
# remove __inherit__ if managed credentials is not provided
|
||||
self._load_balancing_configs.remove(load_balancing_config)
|
||||
else:
|
||||
|
||||
@@ -501,7 +501,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
# message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
|
||||
|
||||
@@ -42,6 +42,7 @@ class BaseKeyword(ABC):
|
||||
doc_id = text.metadata['doc_id']
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
# FIXME: Mutation to loop iterable `texts` during iteration
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
@@ -61,6 +61,7 @@ class BaseVector(ABC):
|
||||
doc_id = text.metadata['doc_id']
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
# FIXME: Mutation to loop iterable `texts` during iteration
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
@@ -157,6 +157,7 @@ class Vector:
|
||||
doc_id = text.metadata['doc_id']
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
# FIXME: Mutation to loop iterable `texts` during iteration
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
@@ -190,8 +191,9 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
return result
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]:
|
||||
# update tool_parameters
|
||||
# TODO: Fix type error.
|
||||
if self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
@@ -208,7 +210,7 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
return result
|
||||
|
||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform tool parameters type
|
||||
"""
|
||||
@@ -241,7 +243,7 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
:return: the runtime parameters
|
||||
"""
|
||||
return self.parameters
|
||||
return self.parameters or []
|
||||
|
||||
def get_all_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from mimetypes import guess_type
|
||||
@@ -46,7 +47,7 @@ class ToolEngine:
|
||||
if isinstance(tool_parameters, str):
|
||||
# check if this tool has only one parameter
|
||||
parameters = [
|
||||
parameter for parameter in tool.get_runtime_parameters()
|
||||
parameter for parameter in tool.get_runtime_parameters() or []
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM
|
||||
]
|
||||
if parameters and len(parameters) == 1:
|
||||
@@ -123,8 +124,8 @@ class ToolEngine:
|
||||
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
||||
|
||||
@staticmethod
|
||||
def workflow_invoke(tool: Tool, tool_parameters: dict,
|
||||
user_id: str, workflow_id: str,
|
||||
def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any],
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
) -> list[ToolInvokeMessage]:
|
||||
@@ -141,7 +142,9 @@ class ToolEngine:
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters)
|
||||
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_end(
|
||||
|
||||
@@ -6,7 +6,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
|
||||
class BaseWorkflowCallback(ABC):
|
||||
class WorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
@@ -78,7 +78,7 @@ class BaseWorkflowCallback(ABC):
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
inputs: Optional[dict] = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -82,9 +83,9 @@ class NodeRunResult(BaseModel):
|
||||
"""
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[dict] = None # node inputs
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict] = None # process data
|
||||
outputs: Optional[dict] = None # node outputs
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
@@ -1,101 +1,141 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
|
||||
class ValueType(Enum):
|
||||
"""
|
||||
Value Type Enum
|
||||
"""
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
FILE = "file"
|
||||
SYSTEM_VARIABLE_NODE_ID = 'sys'
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
|
||||
|
||||
|
||||
class VariablePool:
|
||||
|
||||
def __init__(self, system_variables: dict[SystemVariable, Any],
|
||||
user_inputs: dict) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
system_variables: Mapping[SystemVariable, Any],
|
||||
user_inputs: Mapping[str, Any],
|
||||
environment_variables: Sequence[Variable],
|
||||
) -> None:
|
||||
# system variables
|
||||
# for example:
|
||||
# {
|
||||
# 'query': 'abc',
|
||||
# 'files': []
|
||||
# }
|
||||
self.variables_mapping = {}
|
||||
|
||||
# Varaible dictionary is a dictionary for looking up variables by their selector.
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
self._variable_dictionary: dict[str, dict[int, Variable]] = defaultdict(dict)
|
||||
|
||||
# TODO: This user inputs is not used for pool.
|
||||
self.user_inputs = user_inputs
|
||||
|
||||
# Add system variables to the variable pool
|
||||
self.system_variables = system_variables
|
||||
for system_variable, value in system_variables.items():
|
||||
self.append_variable('sys', [system_variable.value], value)
|
||||
for key, value in system_variables.items():
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
|
||||
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
|
||||
# Add environment variables to the variable pool
|
||||
for var in environment_variables or []:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
Append variable
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list, like: ['result', 'text']
|
||||
:param value: value
|
||||
:return:
|
||||
Adds a variable to the variable pool.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): The selector for the variable.
|
||||
value (VariableValue): The value of the variable.
|
||||
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if node_id not in self.variables_mapping:
|
||||
self.variables_mapping[node_id] = {}
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
|
||||
variable_key_list_hash = hash(tuple(variable_key_list))
|
||||
if value is None:
|
||||
return
|
||||
|
||||
self.variables_mapping[node_id][variable_key_list_hash] = value
|
||||
if not isinstance(value, Variable):
|
||||
v = factory.build_anonymous_variable(value)
|
||||
else:
|
||||
v = value
|
||||
|
||||
def get_variable_value(self, variable_selector: list[str],
|
||||
target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]:
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self._variable_dictionary[selector[0]][hash_key] = v
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Variable | None:
|
||||
"""
|
||||
Get variable
|
||||
:param variable_selector: include node_id and variables
|
||||
:param target_value_type: target value type
|
||||
:return:
|
||||
Retrieves the value from the variable pool based on the given selector.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): The selector used to identify the variable.
|
||||
|
||||
Returns:
|
||||
Any: The value associated with the given selector.
|
||||
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(variable_selector) < 2:
|
||||
raise ValueError('Invalid value selector')
|
||||
|
||||
node_id = variable_selector[0]
|
||||
if node_id not in self.variables_mapping:
|
||||
return None
|
||||
|
||||
# fetch variable keys, pop node_id
|
||||
variable_key_list = variable_selector[1:]
|
||||
|
||||
variable_key_list_hash = hash(tuple(variable_key_list))
|
||||
|
||||
value = self.variables_mapping[node_id].get(variable_key_list_hash)
|
||||
|
||||
if target_value_type:
|
||||
if target_value_type == ValueType.STRING:
|
||||
return str(value)
|
||||
elif target_value_type == ValueType.NUMBER:
|
||||
return int(value)
|
||||
elif target_value_type == ValueType.OBJECT:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError('Invalid value type: object')
|
||||
elif target_value_type in [ValueType.ARRAY_STRING,
|
||||
ValueType.ARRAY_NUMBER,
|
||||
ValueType.ARRAY_OBJECT,
|
||||
ValueType.ARRAY_FILE]:
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f'Invalid value type: {target_value_type.value}')
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
return value
|
||||
|
||||
def clear_node_variables(self, node_id: str) -> None:
|
||||
@deprecated('This method is deprecated, use `get` instead.')
|
||||
def get_any(self, selector: Sequence[str], /) -> Any | None:
|
||||
"""
|
||||
Clear node variables
|
||||
:param node_id: node id
|
||||
:return:
|
||||
Retrieves the value from the variable pool based on the given selector.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): The selector used to identify the variable.
|
||||
|
||||
Returns:
|
||||
Any: The value associated with the given selector.
|
||||
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if node_id in self.variables_mapping:
|
||||
self.variables_mapping.pop(node_id)
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, ArrayVariable):
|
||||
return [element.value for element in value.value]
|
||||
if isinstance(value, ObjectVariable):
|
||||
return {k: v.value for k, v in value.value.items()}
|
||||
return value.value if value else None
|
||||
|
||||
def remove(self, selector: Sequence[str], /):
|
||||
"""
|
||||
Remove variables from the variable pool based on the given selector.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): A sequence of strings representing the selector.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not selector:
|
||||
return
|
||||
if len(selector) == 1:
|
||||
self._variable_dictionary[selector[0]] = {}
|
||||
return
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self._variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
@@ -19,7 +17,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
class AnswerNode(BaseNode):
|
||||
_node_data_cls = AnswerNodeData
|
||||
node_type = NodeType.ANSWER
|
||||
_node_type: NodeType = NodeType.ANSWER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
@@ -28,7 +26,7 @@ class AnswerNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
# generate routes
|
||||
generate_routes = self.extract_generate_route_from_node_data(node_data)
|
||||
@@ -38,31 +36,9 @@ class AnswerNode(BaseNode):
|
||||
if part.type == "var":
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=value_selector
|
||||
)
|
||||
|
||||
text = ''
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, dict):
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, FileVar):
|
||||
# convert file to markdown
|
||||
text = value.to_markdown()
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, FileVar):
|
||||
text += item.to_markdown() + ' '
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text and value:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
answer += text
|
||||
value = variable_pool.get(value_selector)
|
||||
if value:
|
||||
answer += value.markdown
|
||||
else:
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
@@ -82,7 +58,7 @@ class AnswerNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@@ -143,7 +119,7 @@ class AnswerNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@@ -46,7 +47,7 @@ class BaseNode(ABC):
|
||||
node_data: BaseNodeData
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
|
||||
callbacks: list[BaseWorkflowCallback]
|
||||
callbacks: Sequence[WorkflowCallback]
|
||||
|
||||
def __init__(self, tenant_id: str,
|
||||
app_id: str,
|
||||
@@ -54,8 +55,8 @@ class BaseNode(ABC):
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
config: dict,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
config: Mapping[str, Any],
|
||||
callbacks: Sequence[WorkflowCallback] | None = None,
|
||||
workflow_call_depth: int = 0) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
@@ -65,7 +66,8 @@ class BaseNode(ABC):
|
||||
self.invoke_from = invoke_from
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
|
||||
self.node_id = config.get("id")
|
||||
# TODO: May need to check if key exists.
|
||||
self.node_id = config["id"]
|
||||
if not self.node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
|
||||
@@ -59,11 +59,8 @@ class CodeNode(BaseNode):
|
||||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
variables[variable] = value
|
||||
value = variable_pool.get(variable_selector.value_selector)
|
||||
variables[variable] = value.value if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
|
||||
@@ -10,7 +10,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
class EndNode(BaseNode):
|
||||
_node_data_cls = EndNodeData
|
||||
node_type = NodeType.END
|
||||
_node_type = NodeType.END
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
@@ -19,16 +19,13 @@ class EndNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(EndNodeData, node_data)
|
||||
output_variables = node_data.outputs
|
||||
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
outputs[variable_selector.variable] = value
|
||||
value = variable_pool.get(variable_selector.value_selector)
|
||||
outputs[variable_selector.variable] = value.value if value else None
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -45,7 +42,7 @@ class EndNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
node_data = cast(EndNodeData, node_data)
|
||||
|
||||
return cls.extract_generate_nodes_from_node_data(graph, node_data)
|
||||
|
||||
@@ -57,7 +54,7 @@ class EndNode(BaseNode):
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
nodes = graph.get('nodes', [])
|
||||
node_mapping = {node.get('id'): node for node in nodes}
|
||||
|
||||
variable_selectors = node_data.outputs
|
||||
|
||||
@@ -9,7 +9,7 @@ import httpx
|
||||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from configs import dify_config
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeBody,
|
||||
@@ -212,13 +212,11 @@ class HttpExecutor:
|
||||
raise ValueError('self.authorization config is required')
|
||||
if authorization.config is None:
|
||||
raise ValueError('authorization config is required')
|
||||
if authorization.config.type != 'bearer' and authorization.config.header is None:
|
||||
raise ValueError('authorization config header is required')
|
||||
|
||||
if self.authorization.config.api_key is None:
|
||||
raise ValueError('api_key is required')
|
||||
|
||||
if not self.authorization.config.header:
|
||||
if not authorization.config.header:
|
||||
authorization.config.header = 'Authorization'
|
||||
|
||||
if self.authorization.config.type == 'bearer':
|
||||
@@ -335,16 +333,13 @@ class HttpExecutor:
|
||||
if variable_pool:
|
||||
variable_value_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector, target_value_type=ValueType.STRING
|
||||
)
|
||||
|
||||
if value is None:
|
||||
variable = variable_pool.get(variable_selector.value_selector)
|
||||
if variable is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
if escape_quotes and isinstance(value, str):
|
||||
value = value.replace('"', '\\"')
|
||||
|
||||
if escape_quotes and isinstance(variable.value, str):
|
||||
value = variable.value.replace('"', '\\"')
|
||||
else:
|
||||
value = variable.value
|
||||
variable_value_mapping[variable_selector.variable] = value
|
||||
|
||||
return variable_template_parser.format(variable_value_mapping), variable_selectors
|
||||
|
||||
@@ -3,6 +3,7 @@ from mimetypes import guess_extension
|
||||
from os import path
|
||||
from typing import cast
|
||||
|
||||
from core.app.segments import parser
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
@@ -51,6 +52,9 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
|
||||
# TODO: Switch to use segment directly
|
||||
if node_data.authorization.config and node_data.authorization.config.api_key:
|
||||
node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text
|
||||
|
||||
# init http executor
|
||||
http_executor = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
@@ -11,7 +12,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
class IfElseNode(BaseNode):
|
||||
_node_data_cls = IfElseNodeData
|
||||
node_type = NodeType.IF_ELSE
|
||||
_node_type = NodeType.IF_ELSE
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
@@ -20,7 +21,7 @@ class IfElseNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(IfElseNodeData, node_data)
|
||||
|
||||
node_inputs = {
|
||||
"conditions": []
|
||||
@@ -138,14 +139,12 @@ class IfElseNode(BaseNode):
|
||||
else:
|
||||
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
|
||||
|
||||
def process_conditions(self, variable_pool: VariablePool, conditions: list[Condition]):
|
||||
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
|
||||
input_conditions = []
|
||||
group_result = []
|
||||
|
||||
for condition in conditions:
|
||||
actual_value = variable_pool.get_variable_value(
|
||||
variable_selector=condition.variable_selector
|
||||
)
|
||||
actual_variable = variable_pool.get_any(condition.variable_selector)
|
||||
|
||||
if condition.value is not None:
|
||||
variable_template_parser = VariableTemplateParser(template=condition.value)
|
||||
@@ -153,9 +152,7 @@ class IfElseNode(BaseNode):
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
if variable_selectors:
|
||||
for variable_selector in variable_selectors:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
expected_value = variable_template_parser.format({variable_selector.variable: value})
|
||||
else:
|
||||
expected_value = condition.value
|
||||
@@ -165,13 +162,13 @@ class IfElseNode(BaseNode):
|
||||
comparison_operator = condition.comparison_operator
|
||||
input_conditions.append(
|
||||
{
|
||||
"actual_value": actual_value,
|
||||
"actual_value": actual_variable,
|
||||
"expected_value": expected_value,
|
||||
"comparison_operator": comparison_operator
|
||||
}
|
||||
)
|
||||
|
||||
result = self.evaluate_condition(actual_value, expected_value, comparison_operator)
|
||||
result = self.evaluate_condition(actual_variable, expected_value, comparison_operator)
|
||||
group_result.append(result)
|
||||
|
||||
return input_conditions, group_result
|
||||
|
||||
@@ -20,7 +20,8 @@ class IterationNode(BaseIterationNode):
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector)
|
||||
self.node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_any(self.node_data.iterator_selector)
|
||||
|
||||
if not isinstance(iterator, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
|
||||
@@ -63,15 +64,15 @@ class IterationNode(BaseIterationNode):
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
|
||||
variable_pool.append_variable(self.node_id, ['index'], state.index)
|
||||
variable_pool.add((self.node_id, 'index'), state.index)
|
||||
# get the iterator value
|
||||
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return
|
||||
|
||||
if state.index < len(iterator):
|
||||
variable_pool.append_variable(self.node_id, ['item'], iterator[state.index])
|
||||
variable_pool.add((self.node_id, 'item'), iterator[state.index])
|
||||
|
||||
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
@@ -87,7 +88,7 @@ class IterationNode(BaseIterationNode):
|
||||
:return: True if iteration limit is reached, False otherwise
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return True
|
||||
@@ -100,9 +101,9 @@ class IterationNode(BaseIterationNode):
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
output_selector = cast(IterationNodeData, self.node_data).output_selector
|
||||
output = variable_pool.get_variable_value(output_selector)
|
||||
output = variable_pool.get_any(output_selector)
|
||||
# clear the output for this iteration
|
||||
variable_pool.append_variable(self.node_id, output_selector[1:], None)
|
||||
variable_pool.remove([self.node_id] + output_selector[1:])
|
||||
state.current_output = output
|
||||
if output is not None:
|
||||
state.outputs.append(output)
|
||||
|
||||
@@ -41,7 +41,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
|
||||
|
||||
# extract variables
|
||||
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
||||
variable = variable_pool.get(node_data.query_variable_selector)
|
||||
query = variable.value if variable else None
|
||||
variables = {
|
||||
'query': query
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
_node_data_cls = LLMNodeData
|
||||
node_type = NodeType.LLM
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
@@ -90,7 +90,7 @@ class LLMNode(BaseNode):
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
|
||||
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
|
||||
if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
inputs=inputs,
|
||||
@@ -238,8 +238,8 @@ class LLMNode(BaseNode):
|
||||
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
value = variable_pool.get_any(
|
||||
variable_selector.value_selector
|
||||
)
|
||||
|
||||
def parse_dict(d: dict) -> str:
|
||||
@@ -302,7 +302,7 @@ class LLMNode(BaseNode):
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
for variable_selector in variable_selectors:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
@@ -313,7 +313,7 @@ class LLMNode(BaseNode):
|
||||
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||
.extract_variable_selectors())
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
variable_value = variable_pool.get_any(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
@@ -331,7 +331,7 @@ class LLMNode(BaseNode):
|
||||
if not node_data.vision.enabled:
|
||||
return []
|
||||
|
||||
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
|
||||
files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
@@ -350,7 +350,7 @@ class LLMNode(BaseNode):
|
||||
if not node_data.context.variable_selector:
|
||||
return None
|
||||
|
||||
context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
|
||||
context_value = variable_pool.get_any(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
return context_value
|
||||
@@ -496,7 +496,7 @@ class LLMNode(BaseNode):
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION_ID.value])
|
||||
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
|
||||
@@ -71,9 +71,10 @@ class ParameterExtractorNode(LLMNode):
|
||||
Run the node.
|
||||
"""
|
||||
node_data = cast(ParameterExtractorNodeData, self.node_data)
|
||||
query = variable_pool.get_variable_value(node_data.query)
|
||||
if not query:
|
||||
variable = variable_pool.get(node_data.query)
|
||||
if not variable:
|
||||
raise ValueError("Input variable content not found or is empty")
|
||||
query = variable.value
|
||||
|
||||
inputs = {
|
||||
'query': query,
|
||||
@@ -564,7 +565,8 @@ class ParameterExtractorNode(LLMNode):
|
||||
variable_template_parser = VariableTemplateParser(instruction)
|
||||
inputs = {}
|
||||
for selector in variable_template_parser.extract_variable_selectors():
|
||||
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
|
||||
variable = variable_pool.get(selector.value_selector)
|
||||
inputs[selector.variable] = variable.value if variable else None
|
||||
|
||||
return variable_template_parser.format(inputs)
|
||||
|
||||
|
||||
@@ -41,7 +41,8 @@ class QuestionClassifierNode(LLMNode):
|
||||
node_data = cast(QuestionClassifierNodeData, node_data)
|
||||
|
||||
# extract variables
|
||||
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
||||
variable = variable_pool.get(node_data.query_variable_selector)
|
||||
query = variable.value if variable else None
|
||||
variables = {
|
||||
'query': query
|
||||
}
|
||||
@@ -294,7 +295,8 @@ class QuestionClassifierNode(LLMNode):
|
||||
variable_template_parser = VariableTemplateParser(template=instruction)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
variable = variable_pool.get(variable_selector.value_selector)
|
||||
variable_value = variable.value if variable else None
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
class StartNode(BaseNode):
|
||||
_node_data_cls = StartNodeData
|
||||
node_type = NodeType.START
|
||||
_node_type = NodeType.START
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
@@ -18,7 +18,7 @@ class StartNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
# Get cleaned inputs
|
||||
cleaned_inputs = variable_pool.user_inputs
|
||||
cleaned_inputs = dict(variable_pool.user_inputs)
|
||||
|
||||
for var in variable_pool.system_variables:
|
||||
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
|
||||
|
||||
@@ -44,12 +44,9 @@ class TemplateTransformNode(BaseNode):
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
variables[variable] = value
|
||||
variable_name = variable_selector.variable
|
||||
value = variable_pool.get_any(variable_selector.value_selector)
|
||||
variables[variable_name] = value
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
|
||||
@@ -29,6 +29,7 @@ class ToolEntity(BaseModel):
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal['mixed', 'variable', 'constant']
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.segments import parser
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
@@ -20,6 +21,7 @@ class ToolNode(BaseNode):
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
|
||||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
@@ -50,23 +52,24 @@ class ToolNode(BaseNode):
|
||||
},
|
||||
error=f'Failed to get tool runtime: {str(e)}'
|
||||
)
|
||||
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_parameters(variable_pool, node_data, tool_runtime)
|
||||
tool_parameters = tool_runtime.get_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data)
|
||||
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True)
|
||||
|
||||
try:
|
||||
messages = ToolEngine.workflow_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
workflow_id=self.workflow_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters,
|
||||
inputs=parameters_for_log,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
@@ -86,21 +89,34 @@ class ToolNode(BaseNode):
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
inputs=parameters
|
||||
inputs=parameters_for_log
|
||||
)
|
||||
|
||||
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData, tool_runtime: Tool) -> dict:
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
tool_parameters: Sequence[ToolParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: ToolNodeData,
|
||||
for_log: bool = False,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate parameters
|
||||
"""
|
||||
tool_parameters = tool_runtime.get_all_runtime_parameters()
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
def fetch_parameter(name: str) -> Optional[ToolParameter]:
|
||||
return next((parameter for parameter in tool_parameters if parameter.name == name), None)
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
||||
|
||||
result = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
parameter = fetch_parameter(parameter_name)
|
||||
parameter = tool_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
continue
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
@@ -108,35 +124,21 @@ class ToolNode(BaseNode):
|
||||
v.to_dict() for v in self._fetch_files(variable_pool)
|
||||
]
|
||||
else:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
|
||||
elif input.type == 'variable':
|
||||
result[parameter_name] = variable_pool.get_variable_value(input.value)
|
||||
elif input.type == 'constant':
|
||||
result[parameter_name] = input.value
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
segment_group = parser.convert_template(
|
||||
template=str(tool_input.value),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
result[parameter_name] = segment_group.log if for_log else segment_group.text
|
||||
|
||||
return result
|
||||
|
||||
def _format_variable_template(self, template: str, variable_pool: VariablePool) -> str:
|
||||
"""
|
||||
Format variable template
|
||||
"""
|
||||
inputs = {}
|
||||
template_parser = VariableTemplateParser(template)
|
||||
for selector in template_parser.extract_variable_selectors():
|
||||
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
|
||||
|
||||
return template_parser.format(inputs)
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
return files
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
# FIXME: ensure this is a ArrayVariable contains FileVariable.
|
||||
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
|
||||
return [file_var.value for file_var in variable.value] if variable else []
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
|
||||
@@ -19,28 +19,27 @@ class VariableAggregatorNode(BaseNode):
|
||||
inputs = {}
|
||||
|
||||
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
|
||||
for variable in node_data.variables:
|
||||
value = variable_pool.get_variable_value(variable)
|
||||
|
||||
if value is not None:
|
||||
for selector in node_data.variables:
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is not None:
|
||||
outputs = {
|
||||
"output": value
|
||||
"output": variable.value
|
||||
}
|
||||
|
||||
inputs = {
|
||||
'.'.join(variable[1:]): value
|
||||
'.'.join(selector[1:]): variable.value
|
||||
}
|
||||
break
|
||||
else:
|
||||
for group in node_data.advanced_settings.groups:
|
||||
for variable in group.variables:
|
||||
value = variable_pool.get_variable_value(variable)
|
||||
for selector in group.variables:
|
||||
variable = variable_pool.get(selector)
|
||||
|
||||
if value is not None:
|
||||
if variable is not None:
|
||||
outputs[group.group_name] = {
|
||||
'output': value
|
||||
'output': variable.value
|
||||
}
|
||||
inputs['.'.join(variable[1:])] = value
|
||||
inputs['.'.join(selector[1:])] = variable.value
|
||||
break
|
||||
|
||||
return NodeRunResult(
|
||||
|
||||
@@ -3,12 +3,46 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
|
||||
REGEX = re.compile(r'\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}')
|
||||
|
||||
|
||||
def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str:
|
||||
"""
|
||||
This is an alternative to the VariableTemplateParser class,
|
||||
offering the same functionality but with better readability and ease of use.
|
||||
"""
|
||||
variable_keys = [match[0] for match in re.findall(REGEX, template)]
|
||||
variable_keys = list(set(variable_keys))
|
||||
|
||||
# This key_selector is a tuple of (key, selector) where selector is a list of keys
|
||||
# e.g. ('#node_id.query.name#', ['node_id', 'query', 'name'])
|
||||
key_selectors = filter(
|
||||
lambda t: len(t[1]) >= 2,
|
||||
((key, selector.replace('#', '').split('.')) for key, selector in zip(variable_keys, variable_keys)),
|
||||
)
|
||||
inputs = {key: variable_pool.get_any(selector) for key, selector in key_selectors}
|
||||
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
# return original matched string if key not found
|
||||
value = inputs.get(key, match.group(0))
|
||||
if value is None:
|
||||
value = ''
|
||||
value = str(value)
|
||||
# remove template variables if required
|
||||
return re.sub(REGEX, r'{\1}', value)
|
||||
|
||||
result = re.sub(REGEX, replacer, template)
|
||||
result = re.sub(r'<\|.*?\|>', '', result)
|
||||
return result
|
||||
|
||||
|
||||
class VariableTemplateParser:
|
||||
"""
|
||||
!NOTE: Consider to use the new `segments` module instead of this class.
|
||||
|
||||
A class for parsing and manipulating template variables in a string.
|
||||
|
||||
Rules:
|
||||
@@ -72,14 +106,11 @@ class VariableTemplateParser:
|
||||
if len(split_result) < 2:
|
||||
continue
|
||||
|
||||
variable_selectors.append(VariableSelector(
|
||||
variable=variable_key,
|
||||
value_selector=split_result
|
||||
))
|
||||
variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result))
|
||||
|
||||
return variable_selectors
|
||||
|
||||
def format(self, inputs: Mapping[str, Any], remove_template_variables: bool = True) -> str:
|
||||
def format(self, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Formats the template string by replacing the template variables with their corresponding values.
|
||||
|
||||
@@ -90,6 +121,7 @@ class VariableTemplateParser:
|
||||
Returns:
|
||||
The formatted string with template variables replaced by their values.
|
||||
"""
|
||||
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
@@ -99,11 +131,9 @@ class VariableTemplateParser:
|
||||
# convert the value to string
|
||||
if isinstance(value, list | dict | bool | int | float):
|
||||
value = str(value)
|
||||
|
||||
|
||||
# remove template variables if required
|
||||
if remove_template_variables:
|
||||
return VariableTemplateParser.remove_template_variables(value)
|
||||
return value
|
||||
return VariableTemplateParser.remove_template_variables(value)
|
||||
|
||||
prompt = re.sub(REGEX, replacer, self.template)
|
||||
return re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@@ -35,7 +36,7 @@ from models.workflow import (
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
|
||||
node_classes = {
|
||||
node_classes: Mapping[NodeType, type[BaseNode]] = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.ANSWER: AnswerNode,
|
||||
@@ -86,14 +87,14 @@ class WorkflowEngineManager:
|
||||
|
||||
return default_config
|
||||
|
||||
def run_workflow(self, workflow: Workflow,
|
||||
def run_workflow(self, *, workflow: Workflow,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
call_depth: Optional[int] = 0,
|
||||
user_inputs: Mapping[str, Any],
|
||||
system_inputs: Mapping[SystemVariable, Any],
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
call_depth: int = 0,
|
||||
variable_pool: Optional[VariablePool] = None) -> None:
|
||||
"""
|
||||
:param workflow: Workflow instance
|
||||
@@ -122,7 +123,8 @@ class WorkflowEngineManager:
|
||||
if not variable_pool:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=user_inputs
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
|
||||
@@ -154,7 +156,7 @@ class WorkflowEngineManager:
|
||||
|
||||
def _run_workflow(self, workflow: Workflow,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
start_at: Optional[str] = None,
|
||||
end_at: Optional[str] = None) -> None:
|
||||
"""
|
||||
@@ -173,8 +175,8 @@ class WorkflowEngineManager:
|
||||
graph = workflow.graph_dict
|
||||
|
||||
try:
|
||||
predecessor_node: BaseNode = None
|
||||
current_iteration_node: BaseIterationNode = None
|
||||
predecessor_node: BaseNode | None = None
|
||||
current_iteration_node: BaseIterationNode | None = None
|
||||
has_entry_node = False
|
||||
max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS
|
||||
max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
@@ -235,7 +237,7 @@ class WorkflowEngineManager:
|
||||
# move to next iteration
|
||||
next_node_id = next_iteration
|
||||
# get next id
|
||||
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
|
||||
next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks)
|
||||
|
||||
if not next_node:
|
||||
break
|
||||
@@ -295,7 +297,7 @@ class WorkflowEngineManager:
|
||||
workflow_run_state.current_iteration_state = None
|
||||
continue
|
||||
else:
|
||||
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
|
||||
next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks)
|
||||
|
||||
# run workflow, run multiple target nodes in the future
|
||||
self._run_workflow_node(
|
||||
@@ -381,7 +383,8 @@ class WorkflowEngineManager:
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={}
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
# variable selector to variable mapping
|
||||
@@ -419,7 +422,7 @@ class WorkflowEngineManager:
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run workflow node
|
||||
@@ -446,7 +449,8 @@ class WorkflowEngineManager:
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={}
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
# variable selector to variable mapping
|
||||
@@ -535,7 +539,7 @@ class WorkflowEngineManager:
|
||||
end_at=end_node_id
|
||||
)
|
||||
|
||||
def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
"""
|
||||
Workflow run success
|
||||
:param callbacks: workflow callbacks
|
||||
@@ -547,7 +551,7 @@ class WorkflowEngineManager:
|
||||
callback.on_workflow_run_succeeded()
|
||||
|
||||
def _workflow_run_failed(self, error: str,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
:param error: error message
|
||||
@@ -560,11 +564,11 @@ class WorkflowEngineManager:
|
||||
error=error
|
||||
)
|
||||
|
||||
def _workflow_iteration_started(self, graph: dict,
|
||||
def _workflow_iteration_started(self, *, graph: Mapping[str, Any],
|
||||
current_iteration_node: BaseIterationNode,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
"""
|
||||
Workflow iteration started
|
||||
:param current_iteration_node: current iteration node
|
||||
@@ -597,10 +601,10 @@ class WorkflowEngineManager:
|
||||
# add steps
|
||||
workflow_run_state.workflow_node_steps += 1
|
||||
|
||||
def _workflow_iteration_next(self, graph: dict,
|
||||
def _workflow_iteration_next(self, *, graph: Mapping[str, Any],
|
||||
current_iteration_node: BaseIterationNode,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
"""
|
||||
Workflow iteration next
|
||||
:param workflow_run_state: workflow run state
|
||||
@@ -627,11 +631,11 @@ class WorkflowEngineManager:
|
||||
nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id]
|
||||
|
||||
for node in nodes:
|
||||
workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id'))
|
||||
workflow_run_state.variable_pool.remove((node.get('id'),))
|
||||
|
||||
def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode,
|
||||
def _workflow_iteration_completed(self, *, current_iteration_node: BaseIterationNode,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
if callbacks:
|
||||
if isinstance(workflow_run_state.current_iteration_state, IterationState):
|
||||
for callback in callbacks:
|
||||
@@ -644,10 +648,10 @@ class WorkflowEngineManager:
|
||||
}
|
||||
)
|
||||
|
||||
def _get_next_overall_node(self, workflow_run_state: WorkflowRunState,
|
||||
graph: dict,
|
||||
def _get_next_overall_node(self, *, workflow_run_state: WorkflowRunState,
|
||||
graph: Mapping[str, Any],
|
||||
predecessor_node: Optional[BaseNode] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
start_at: Optional[str] = None,
|
||||
end_at: Optional[str] = None) -> Optional[BaseNode]:
|
||||
"""
|
||||
@@ -739,9 +743,9 @@ class WorkflowEngineManager:
|
||||
)
|
||||
|
||||
def _get_node(self, workflow_run_state: WorkflowRunState,
|
||||
graph: dict,
|
||||
graph: Mapping[str, Any],
|
||||
node_id: str,
|
||||
callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]:
|
||||
callbacks: Sequence[WorkflowCallback]):
|
||||
"""
|
||||
Get node from graph by node id
|
||||
"""
|
||||
@@ -752,7 +756,7 @@ class WorkflowEngineManager:
|
||||
for node_config in nodes:
|
||||
if node_config.get('id') == node_id:
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = node_classes[node_type]
|
||||
return node_cls(
|
||||
tenant_id=workflow_run_state.tenant_id,
|
||||
app_id=workflow_run_state.app_id,
|
||||
@@ -765,8 +769,6 @@ class WorkflowEngineManager:
|
||||
workflow_call_depth=workflow_run_state.workflow_call_depth
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
Check timeout
|
||||
@@ -785,10 +787,10 @@ class WorkflowEngineManager:
|
||||
if node_and_result.node_id == node_id
|
||||
])
|
||||
|
||||
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
|
||||
def _run_workflow_node(self, *, workflow_run_state: WorkflowRunState,
|
||||
node: BaseNode,
|
||||
predecessor_node: Optional[BaseNode] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: Sequence[WorkflowCallback]) -> None:
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_node_execute_started(
|
||||
@@ -894,10 +896,8 @@ class WorkflowEngineManager:
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
variable_pool.append_variable(
|
||||
node_id=node_id,
|
||||
variable_key_list=variable_key_list,
|
||||
value=variable_value
|
||||
variable_pool.add(
|
||||
[node_id] + variable_key_list, variable_value
|
||||
)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
@@ -946,7 +946,7 @@ class WorkflowEngineManager:
|
||||
tenant_id: str,
|
||||
node_instance: BaseNode):
|
||||
for variable_key, variable_selector in variable_mapping.items():
|
||||
if variable_key not in user_inputs:
|
||||
if variable_key not in user_inputs and not variable_pool.get(variable_selector):
|
||||
raise ValueError(f'Variable key {variable_key} not found in user inputs.')
|
||||
|
||||
# fetch variable node id from variable selector
|
||||
@@ -956,7 +956,7 @@ class WorkflowEngineManager:
|
||||
# get value
|
||||
value = user_inputs.get(variable_key)
|
||||
|
||||
# temp fix for image type
|
||||
# FIXME: temp fix for image type
|
||||
if node_instance.node_type == NodeType.LLM:
|
||||
new_value = []
|
||||
if isinstance(value, list):
|
||||
@@ -983,8 +983,4 @@ class WorkflowEngineManager:
|
||||
value = new_value
|
||||
|
||||
# append variable and value to variable pool
|
||||
variable_pool.append_variable(
|
||||
node_id=variable_node_id,
|
||||
variable_key_list=variable_key_list,
|
||||
value=value
|
||||
)
|
||||
variable_pool.add([variable_node_id]+variable_key_list, value)
|
||||
|
||||
@@ -1,8 +1,38 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from core.app.segments import SecretVariable, Variable
|
||||
from core.helper import encrypter
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
|
||||
class EnvironmentVariableField(fields.Raw):
|
||||
def format(self, value):
|
||||
# Mask secret variables values in environment_variables
|
||||
if isinstance(value, SecretVariable):
|
||||
return {
|
||||
'id': value.id,
|
||||
'name': value.name,
|
||||
'value': encrypter.obfuscated_token(value.value),
|
||||
'value_type': value.value_type.value,
|
||||
}
|
||||
elif isinstance(value, Variable):
|
||||
return {
|
||||
'id': value.id,
|
||||
'name': value.name,
|
||||
'value': value.value,
|
||||
'value_type': value.value_type.value,
|
||||
}
|
||||
return value
|
||||
|
||||
|
||||
environment_variable_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'value': fields.Raw,
|
||||
'value_type': fields.String(attribute='value_type.value'),
|
||||
}
|
||||
|
||||
workflow_fields = {
|
||||
'id': fields.String,
|
||||
'graph': fields.Raw(attribute='graph_dict'),
|
||||
@@ -13,4 +43,5 @@ workflow_fields = {
|
||||
'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True),
|
||||
'updated_at': TimestampField,
|
||||
'tool_published': fields.Boolean,
|
||||
'environment_variables': fields.List(EnvironmentVariableField()),
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add environment variable to workflow model
|
||||
|
||||
Revision ID: 8e5588e6412e
|
||||
Revises: 6e957a32015b
|
||||
Create Date: 2024-07-22 03:27:16.042533
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '8e5588e6412e'
|
||||
down_revision = '6e957a32015b'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('environment_variables', sa.Text(), server_default='{}', nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.drop_column('environment_variables')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -48,7 +48,7 @@ class Account(UserMixin, db.Model):
|
||||
return self._current_tenant
|
||||
|
||||
@current_tenant.setter
|
||||
def current_tenant(self, value):
|
||||
def current_tenant(self, value: "Tenant"):
|
||||
tenant = value
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=self.id).first()
|
||||
if ta:
|
||||
@@ -62,7 +62,7 @@ class Account(UserMixin, db.Model):
|
||||
return self._current_tenant.id
|
||||
|
||||
@current_tenant_id.setter
|
||||
def current_tenant_id(self, value):
|
||||
def current_tenant_id(self, value: str):
|
||||
try:
|
||||
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
||||
.filter(Tenant.id == value) \
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import contexts
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.app.segments import (
|
||||
SecretVariable,
|
||||
Variable,
|
||||
factory,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models import StringUUID
|
||||
@@ -112,6 +121,7 @@ class Workflow(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_by = db.Column(StringUUID)
|
||||
updated_at = db.Column(db.DateTime)
|
||||
_environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}')
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
@@ -122,11 +132,11 @@ class Workflow(db.Model):
|
||||
return db.session.get(Account, self.updated_by) if self.updated_by else None
|
||||
|
||||
@property
|
||||
def graph_dict(self):
|
||||
return json.loads(self.graph) if self.graph else None
|
||||
def graph_dict(self) -> Mapping[str, Any]:
|
||||
return json.loads(self.graph) if self.graph else {}
|
||||
|
||||
@property
|
||||
def features_dict(self):
|
||||
def features_dict(self) -> Mapping[str, Any]:
|
||||
return json.loads(self.features) if self.features else {}
|
||||
|
||||
def user_input_form(self, to_old_structure: bool = False) -> list:
|
||||
@@ -177,6 +187,72 @@ class Workflow(db.Model):
|
||||
WorkflowToolProvider.app_id == self.app_id
|
||||
).first() is not None
|
||||
|
||||
@property
|
||||
def environment_variables(self) -> Sequence[Variable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = '{}'
|
||||
|
||||
tenant_id = contexts.tenant_id.get()
|
||||
|
||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
|
||||
results = [factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()]
|
||||
|
||||
# decrypt secret variables value
|
||||
decrypt_func = (
|
||||
lambda var: var.model_copy(
|
||||
update={'value': encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}
|
||||
)
|
||||
if isinstance(var, SecretVariable)
|
||||
else var
|
||||
)
|
||||
results = list(map(decrypt_func, results))
|
||||
return results
|
||||
|
||||
@environment_variables.setter
|
||||
def environment_variables(self, value: Sequence[Variable]):
|
||||
tenant_id = contexts.tenant_id.get()
|
||||
|
||||
value = list(value)
|
||||
if any(var for var in value if not var.id):
|
||||
raise ValueError('environment variable require a unique id')
|
||||
|
||||
# Compare inputs and origin variables, if the value is HIDDEN_VALUE, use the origin variable value (only update `name`).
|
||||
origin_variables_dictionary = {var.id: var for var in self.environment_variables}
|
||||
for i, variable in enumerate(value):
|
||||
if variable.id in origin_variables_dictionary and variable.value == HIDDEN_VALUE:
|
||||
value[i] = origin_variables_dictionary[variable.id].model_copy(update={'name': variable.name})
|
||||
|
||||
# encrypt secret variables value
|
||||
encrypt_func = (
|
||||
lambda var: var.model_copy(
|
||||
update={'value': encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}
|
||||
)
|
||||
if isinstance(var, SecretVariable)
|
||||
else var
|
||||
)
|
||||
encrypted_vars = list(map(encrypt_func, value))
|
||||
environment_variables_json = json.dumps(
|
||||
{var.name: var.model_dump() for var in encrypted_vars},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
self._environment_variables = environment_variables_json
|
||||
|
||||
def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]:
|
||||
environment_variables = list(self.environment_variables)
|
||||
environment_variables = [
|
||||
v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={'value': ''})
|
||||
for v in environment_variables
|
||||
]
|
||||
|
||||
result = {
|
||||
'graph': self.graph_dict,
|
||||
'features': self.features_dict,
|
||||
'environment_variables': [var.model_dump(mode='json') for var in environment_variables],
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(Enum):
|
||||
"""
|
||||
Workflow Run Triggered From Enum
|
||||
|
||||
@@ -15,7 +15,7 @@ from models.dataset import Dataset, DatasetQuery, Document
|
||||
@app.celery.task(queue='dataset')
|
||||
def clean_unused_datasets_task():
|
||||
click.echo(click.style('Start clean unused datasets indexes.', fg='green'))
|
||||
clean_days = int(dify_config.CLEAN_DAY_SETTING)
|
||||
clean_days = dify_config.CLEAN_DAY_SETTING
|
||||
start_at = time.perf_counter()
|
||||
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
|
||||
page = 1
|
||||
|
||||
@@ -47,7 +47,7 @@ class AccountService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_user(user_id: str) -> Account:
|
||||
def load_user(user_id: str) -> None | Account:
|
||||
account = Account.query.filter_by(id=user_id).first()
|
||||
if not account:
|
||||
return None
|
||||
@@ -55,7 +55,7 @@ class AccountService:
|
||||
if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
|
||||
raise Unauthorized("Account is banned or closed.")
|
||||
|
||||
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
||||
current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
||||
if current_tenant:
|
||||
account.current_tenant_id = current_tenant.tenant_id
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import httpx
|
||||
import yaml # type: ignore
|
||||
|
||||
from core.app.segments import factory
|
||||
from events.app_event import app_model_config_was_updated, app_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
@@ -150,7 +151,7 @@ class AppDslService:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def export_dsl(cls, app_model: App) -> str:
|
||||
def export_dsl(cls, app_model: App, include_secret:bool = False) -> str:
|
||||
"""
|
||||
Export app
|
||||
:param app_model: App instance
|
||||
@@ -171,7 +172,7 @@ class AppDslService:
|
||||
}
|
||||
|
||||
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
cls._append_workflow_export_data(export_data, app_model)
|
||||
cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret)
|
||||
else:
|
||||
cls._append_model_config_export_data(export_data, app_model)
|
||||
|
||||
@@ -235,13 +236,16 @@ class AppDslService:
|
||||
)
|
||||
|
||||
# init draft workflow
|
||||
environment_variables_list = workflow_data.get('environment_variables') or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow_data.get('graph', {}),
|
||||
features=workflow_data.get('../core/app/features', {}),
|
||||
unique_hash=None,
|
||||
account=account
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
workflow_service.publish_workflow(
|
||||
app_model=app,
|
||||
@@ -276,12 +280,15 @@ class AppDslService:
|
||||
unique_hash = None
|
||||
|
||||
# sync draft workflow
|
||||
environment_variables_list = workflow_data.get('environment_variables') or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
draft_workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=workflow_data.get('graph', {}),
|
||||
features=workflow_data.get('features', {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
|
||||
return draft_workflow
|
||||
@@ -377,7 +384,7 @@ class AppDslService:
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
def _append_workflow_export_data(cls, export_data: dict, app_model: App) -> None:
|
||||
def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:
|
||||
"""
|
||||
Append workflow export data
|
||||
:param export_data: export data
|
||||
@@ -388,10 +395,7 @@ class AppDslService:
|
||||
if not workflow:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
|
||||
export_data['workflow'] = {
|
||||
"graph": workflow.graph_dict,
|
||||
"features": workflow.features_dict
|
||||
}
|
||||
export_data['workflow'] = workflow.to_dict(include_secret=include_secret)
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
|
||||
|
||||
@@ -133,6 +133,7 @@ class ModelLoadBalancingService:
|
||||
# move the inherit configuration to the first
|
||||
for i, load_balancing_config in enumerate(load_balancing_configs):
|
||||
if load_balancing_config.name == '__inherit__':
|
||||
# FIXME: Mutation to loop iterable `load_balancing_configs` during iteration
|
||||
inherit_config = load_balancing_configs.pop(i)
|
||||
load_balancing_configs.insert(0, inherit_config)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from os import path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
|
||||
@@ -199,7 +199,8 @@ class WorkflowConverter:
|
||||
version='draft',
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account_id
|
||||
created_by=account_id,
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
db.session.add(workflow)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.segments import Variable
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@@ -61,11 +63,16 @@ class WorkflowService:
|
||||
|
||||
return workflow
|
||||
|
||||
def sync_draft_workflow(self, app_model: App,
|
||||
graph: dict,
|
||||
features: dict,
|
||||
unique_hash: Optional[str],
|
||||
account: Account) -> Workflow:
|
||||
def sync_draft_workflow(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
graph: dict,
|
||||
features: dict,
|
||||
unique_hash: Optional[str],
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
:raises WorkflowHashNotEqualError
|
||||
@@ -73,10 +80,8 @@ class WorkflowService:
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if workflow:
|
||||
# validate unique hash
|
||||
if workflow.unique_hash != unique_hash:
|
||||
raise WorkflowHashNotEqualError()
|
||||
if workflow and workflow.unique_hash != unique_hash:
|
||||
raise WorkflowHashNotEqualError()
|
||||
|
||||
# validate features structure
|
||||
self.validate_features_structure(
|
||||
@@ -93,7 +98,8 @@ class WorkflowService:
|
||||
version='draft',
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account.id
|
||||
created_by=account.id,
|
||||
environment_variables=environment_variables
|
||||
)
|
||||
db.session.add(workflow)
|
||||
# update draft workflow if found
|
||||
@@ -102,6 +108,7 @@ class WorkflowService:
|
||||
workflow.features = json.dumps(features)
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow.environment_variables = environment_variables
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
@@ -137,7 +144,8 @@ class WorkflowService:
|
||||
version=str(datetime.now(timezone.utc).replace(tzinfo=None)),
|
||||
graph=draft_workflow.graph,
|
||||
features=draft_workflow.features,
|
||||
created_by=account.id
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables
|
||||
)
|
||||
|
||||
# commit db session changes
|
||||
|
||||
@@ -55,9 +55,9 @@ def test_execute_code(setup_code_executor_mock):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={}, user_inputs={})
|
||||
pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1)
|
||||
pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2)
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['1', '123', 'args1'], 1)
|
||||
pool.add(['1', '123', 'args2'], 2)
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
@@ -109,9 +109,9 @@ def test_execute_code_output_validator(setup_code_executor_mock):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={}, user_inputs={})
|
||||
pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1)
|
||||
pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2)
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['1', '123', 'args1'], 1)
|
||||
pool.add(['1', '123', 'args2'], 2)
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
|
||||
@@ -18,9 +18,9 @@ BASIC_NODE_DATA = {
|
||||
}
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={}, user_inputs={})
|
||||
pool.append_variable(node_id='a', variable_key_list=['b123', 'args1'], value=1)
|
||||
pool.append_variable(node_id='a', variable_key_list=['b123', 'args2'], value=2)
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['a', 'b123', 'args1'], 1)
|
||||
pool.add(['a', 'b123', 'args2'], 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True)
|
||||
|
||||
@@ -70,8 +70,8 @@ def test_execute_llm(setup_openai_mock):
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['abc', 'output'], 'sunny')
|
||||
|
||||
credentials = {
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
@@ -185,8 +185,8 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['abc', 'output'], 'sunny')
|
||||
|
||||
credentials = {
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
|
||||
@@ -123,7 +123,7 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
|
||||
@@ -181,7 +181,7 @@ def test_instructions(setup_openai_mock):
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
|
||||
@@ -247,7 +247,7 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
|
||||
@@ -311,7 +311,7 @@ def test_completion_parameter_extractor(setup_openai_mock):
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
|
||||
@@ -424,7 +424,7 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
|
||||
result = node.run(pool)
|
||||
|
||||
|
||||
@@ -38,9 +38,9 @@ def test_execute_code(setup_code_executor_mock):
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={}, user_inputs={})
|
||||
pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1)
|
||||
pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=3)
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['1', '123', 'args1'], 1)
|
||||
pool.add(['1', '123', 'args2'], 3)
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
|
||||
@@ -6,8 +6,8 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
def test_tool_variable_invoke():
|
||||
pool = VariablePool(system_variables={}, user_inputs={})
|
||||
pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value='1+1')
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['1', '123', 'args1'], '1+1')
|
||||
|
||||
node = ToolNode(
|
||||
tenant_id='1',
|
||||
@@ -45,8 +45,8 @@ def test_tool_variable_invoke():
|
||||
assert result.outputs['files'] == []
|
||||
|
||||
def test_tool_mixed_invoke():
|
||||
pool = VariablePool(system_variables={}, user_inputs={})
|
||||
pool.append_variable(node_id='1', variable_key_list=['args1'], value='1+1')
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['1', 'args1'], '1+1')
|
||||
|
||||
node = ToolNode(
|
||||
tenant_id='1',
|
||||
|
||||
53
api/tests/unit_tests/app/test_segment.py
Normal file
53
api/tests/unit_tests/app/test_segment.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from core.app.segments import SecretVariable, parser
|
||||
from core.helper import encrypter
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
def test_segment_group_to_text():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariable('user_id'): 'fake-user-id',
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
SecretVariable(name='secret_key', value='fake-secret-key'),
|
||||
],
|
||||
)
|
||||
variable_pool.add(('node_id', 'custom_query'), 'fake-user-query')
|
||||
template = (
|
||||
'Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}.'
|
||||
)
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
|
||||
assert segments_group.text == 'Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key.'
|
||||
assert (
|
||||
segments_group.log
|
||||
== f"Hello, fake-user-id! Your query is fake-user-query. And your key is {encrypter.obfuscated_token('fake-secret-key')}."
|
||||
)
|
||||
|
||||
|
||||
def test_convert_constant_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
template = 'Hello, world!'
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
assert segments_group.text == 'Hello, world!'
|
||||
assert segments_group.log == 'Hello, world!'
|
||||
|
||||
|
||||
def test_convert_variable_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariable('user_id'): 'fake-user-id',
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
template = '{{#sys.user_id#}}'
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
assert segments_group.text == 'fake-user-id'
|
||||
assert segments_group.log == 'fake-user-id'
|
||||
91
api/tests/unit_tests/app/test_variables.py
Normal file
91
api/tests/unit_tests/app/test_variables.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.segments import (
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
SecretVariable,
|
||||
SegmentType,
|
||||
StringVariable,
|
||||
factory,
|
||||
)
|
||||
|
||||
|
||||
def test_string_variable():
|
||||
test_data = {'value_type': 'string', 'name': 'test_text', 'value': 'Hello, World!'}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, StringVariable)
|
||||
|
||||
|
||||
def test_integer_variable():
|
||||
test_data = {'value_type': 'number', 'name': 'test_int', 'value': 42}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, IntegerVariable)
|
||||
|
||||
|
||||
def test_float_variable():
|
||||
test_data = {'value_type': 'number', 'name': 'test_float', 'value': 3.14}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, FloatVariable)
|
||||
|
||||
|
||||
def test_secret_variable():
|
||||
test_data = {'value_type': 'secret', 'name': 'test_secret', 'value': 'secret_value'}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, SecretVariable)
|
||||
|
||||
|
||||
def test_invalid_value_type():
|
||||
test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'}
|
||||
with pytest.raises(ValueError):
|
||||
factory.build_variable_from_mapping(test_data)
|
||||
|
||||
|
||||
def test_frozen_variables():
|
||||
var = StringVariable(name='text', value='text')
|
||||
with pytest.raises(ValidationError):
|
||||
var.value = 'new value'
|
||||
|
||||
int_var = IntegerVariable(name='integer', value=42)
|
||||
with pytest.raises(ValidationError):
|
||||
int_var.value = 100
|
||||
|
||||
float_var = FloatVariable(name='float', value=3.14)
|
||||
with pytest.raises(ValidationError):
|
||||
float_var.value = 2.718
|
||||
|
||||
secret_var = SecretVariable(name='secret', value='secret_value')
|
||||
with pytest.raises(ValidationError):
|
||||
secret_var.value = 'new_secret_value'
|
||||
|
||||
|
||||
def test_variable_value_type_immutable():
|
||||
with pytest.raises(ValidationError):
|
||||
StringVariable(value_type=SegmentType.ARRAY, name='text', value='text')
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
StringVariable.model_validate({'value_type': 'not text', 'name': 'text', 'value': 'text'})
|
||||
|
||||
var = IntegerVariable(name='integer', value=42)
|
||||
with pytest.raises(ValidationError):
|
||||
IntegerVariable(value_type=SegmentType.ARRAY, name=var.name, value=var.value)
|
||||
|
||||
var = FloatVariable(name='float', value=3.14)
|
||||
with pytest.raises(ValidationError):
|
||||
FloatVariable(value_type=SegmentType.ARRAY, name=var.name, value=var.value)
|
||||
|
||||
var = SecretVariable(name='secret', value='secret_value')
|
||||
with pytest.raises(ValidationError):
|
||||
SecretVariable(value_type=SegmentType.ARRAY, name=var.name, value=var.value)
|
||||
|
||||
|
||||
def test_build_a_blank_string():
|
||||
result = factory.build_variable_from_mapping(
|
||||
{
|
||||
'value_type': 'string',
|
||||
'name': 'blank',
|
||||
'value': '',
|
||||
}
|
||||
)
|
||||
assert isinstance(result, StringVariable)
|
||||
assert result.value == ''
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
@@ -48,7 +49,9 @@ def test_dify_config(example_env_file):
|
||||
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
|
||||
def test_flask_configs(example_env_file):
|
||||
flask_app = Flask('app')
|
||||
flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump())
|
||||
# clear system environment variables
|
||||
os.environ.clear()
|
||||
flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) # pyright: ignore
|
||||
config = flask_app.config
|
||||
|
||||
# configs read from pydantic-settings
|
||||
|
||||
@@ -31,9 +31,9 @@ def test_execute_answer():
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny')
|
||||
pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.')
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'weather'], 'sunny')
|
||||
pool.add(['llm', 'text'], 'You are a helpful AI.')
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
@@ -121,24 +121,24 @@ def test_execute_if_else_result_true():
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def'])
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def'])
|
||||
pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde')
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde')
|
||||
pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc')
|
||||
pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['is'], value='ab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['empty'], value='')
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa')
|
||||
pool.append_variable(node_id='start', variable_key_list=['equals'], value=22)
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23)
|
||||
pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23)
|
||||
pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21)
|
||||
pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22)
|
||||
pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21)
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212')
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
|
||||
pool.add(['start', 'contains'], 'cabcde')
|
||||
pool.add(['start', 'not_contains'], 'zacde')
|
||||
pool.add(['start', 'start_with'], 'abc')
|
||||
pool.add(['start', 'end_with'], 'zzab')
|
||||
pool.add(['start', 'is'], 'ab')
|
||||
pool.add(['start', 'is_not'], 'aab')
|
||||
pool.add(['start', 'empty'], '')
|
||||
pool.add(['start', 'not_empty'], 'aaa')
|
||||
pool.add(['start', 'equals'], 22)
|
||||
pool.add(['start', 'not_equals'], 23)
|
||||
pool.add(['start', 'greater_than'], 23)
|
||||
pool.add(['start', 'less_than'], 21)
|
||||
pool.add(['start', 'greater_than_or_equal'], 22)
|
||||
pool.add(['start', 'less_than_or_equal'], 21)
|
||||
pool.add(['start', 'not_null'], '1212')
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
@@ -184,9 +184,9 @@ def test_execute_if_else_result_false():
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['1ab', 'def'])
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ab', 'def'])
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['1ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
95
api/tests/unit_tests/models/test_workflow.py
Normal file
95
api/tests/unit_tests/models/test_workflow.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
import contexts
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.app.segments import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
def test_environment_variables():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow()
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
|
||||
variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())})
|
||||
variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())})
|
||||
variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())})
|
||||
|
||||
with (
|
||||
mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'),
|
||||
mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'),
|
||||
):
|
||||
# Set the environment_variables property of the Workflow instance
|
||||
variables = [variable1, variable2, variable3, variable4]
|
||||
workflow.environment_variables = variables
|
||||
|
||||
# Get the environment_variables property and assert its value
|
||||
assert workflow.environment_variables == variables
|
||||
|
||||
|
||||
def test_update_environment_variables():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow()
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
|
||||
variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())})
|
||||
variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())})
|
||||
variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())})
|
||||
|
||||
with (
|
||||
mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'),
|
||||
mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'),
|
||||
):
|
||||
variables = [variable1, variable2, variable3, variable4]
|
||||
|
||||
# Set the environment_variables property of the Workflow instance
|
||||
workflow.environment_variables = variables
|
||||
assert workflow.environment_variables == [variable1, variable2, variable3, variable4]
|
||||
|
||||
# Update the name of variable3 and keep the value as it is
|
||||
variables[2] = variable3.model_copy(
|
||||
update={
|
||||
'name': 'new name',
|
||||
'value': HIDDEN_VALUE,
|
||||
}
|
||||
)
|
||||
|
||||
workflow.environment_variables = variables
|
||||
assert workflow.environment_variables[2].name == 'new name'
|
||||
assert workflow.environment_variables[2].value == variable3.value
|
||||
|
||||
|
||||
def test_to_dict():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow()
|
||||
workflow.graph = '{}'
|
||||
workflow.features = '{}'
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
|
||||
with (
|
||||
mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'),
|
||||
mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'),
|
||||
):
|
||||
# Set the environment_variables property of the Workflow instance
|
||||
workflow.environment_variables = [
|
||||
SecretVariable.model_validate({'name': 'secret', 'value': 'secret', 'id': str(uuid4())}),
|
||||
StringVariable.model_validate({'name': 'text', 'value': 'text', 'id': str(uuid4())}),
|
||||
]
|
||||
|
||||
workflow_dict = workflow.to_dict()
|
||||
assert workflow_dict['environment_variables'][0]['value'] == ''
|
||||
assert workflow_dict['environment_variables'][1]['value'] == 'text'
|
||||
|
||||
workflow_dict = workflow.to_dict(include_secret=True)
|
||||
assert workflow_dict['environment_variables'][0]['value'] == 'secret'
|
||||
assert workflow_dict['environment_variables'][1]['value'] == 'text'
|
||||
Reference in New Issue
Block a user