FEAT: NEW WORKFLOW ENGINE (#3160)

Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
takatost
2024-04-08 18:51:46 +08:00
committed by GitHub
parent 2fb9850af5
commit 7753ba2d37
1161 changed files with 103836 additions and 10327 deletions

View File

View File

@@ -0,0 +1,155 @@
import json
from typing import cast
from core.file.file_obj import FileVar
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.base_node import BaseNode
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
class AnswerNode(BaseNode):
_node_data_cls = AnswerNodeData
node_type = NodeType.ANSWER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
# generate routes
generate_routes = self.extract_generate_route_from_node_data(node_data)
answer = ''
for part in generate_routes:
if part.type == "var":
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = variable_pool.get_variable_value(
variable_selector=value_selector
)
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, dict):
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, list):
for item in value:
if isinstance(item, FileVar):
text += item.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
answer += text
else:
part = cast(TextGenerateRouteChunk, part)
answer += part.text
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"answer": answer
}
)
@classmethod
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(cls._node_data_cls, node_data)
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
generate_routes = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
return generate_routes
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
return variable_mapping

View File

@@ -0,0 +1,34 @@
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
class AnswerNodeData(BaseNodeData):
"""
Answer Node Data.
"""
answer: str
class GenerateRouteChunk(BaseModel):
"""
Generate Route Chunk.
"""
type: str
class VarGenerateRouteChunk(GenerateRouteChunk):
"""
Var Generate Route Chunk.
"""
type: str = "var"
value_selector: list[str]
class TextGenerateRouteChunk(GenerateRouteChunk):
"""
Text Generate Route Chunk.
"""
type: str = "text"
text: str

View File

@@ -0,0 +1,142 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")
class BaseNode(ABC):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
tenant_id: str
app_id: str
workflow_id: str
user_id: str
user_from: UserFrom
node_id: str
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
callbacks: list[BaseWorkflowCallback]
def __init__(self, tenant_id: str,
app_id: str,
workflow_id: str,
user_id: str,
user_from: UserFrom,
config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.user_id = user_id
self.user_from = user_from
self.node_id = config.get("id")
if not self.node_id:
raise ValueError("Node ID is required.")
self.node_data = self._node_data_cls(**config.get("data", {}))
self.callbacks = callbacks or []
@abstractmethod
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
raise NotImplementedError
def run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node entry
:param variable_pool: variable pool
:return:
"""
result = self._run(
variable_pool=variable_pool
)
self.node_run_result = result
return result
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
"""
Publish text chunk
:param text: chunk text
:param value_selector: value selector
:return:
"""
if self.callbacks:
for callback in self.callbacks:
callback.on_node_text_chunk(
node_id=self.node_id,
text=text,
metadata={
"node_type": self.node_type,
"value_selector": value_selector
}
)
@classmethod
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(node_data)
@classmethod
@abstractmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
raise NotImplementedError
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {}
@property
def node_type(self) -> NodeType:
"""
Get node type
:return:
"""
return self._node_type

View File

View File

@@ -0,0 +1,348 @@
import os
from typing import Optional, Union, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.code.entities import CodeNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_NUMBER = int(os.environ.get('CODE_MAX_NUMBER', '9223372036854775807'))
MIN_NUMBER = int(os.environ.get('CODE_MIN_NUMBER', '-9223372036854775808'))
MAX_PRECISION = 20
MAX_DEPTH = 5
MAX_STRING_LENGTH = int(os.environ.get('CODE_MAX_STRING_LENGTH', '80000'))
MAX_STRING_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_STRING_ARRAY_LENGTH', '30'))
MAX_OBJECT_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_OBJECT_ARRAY_LENGTH', '30'))
MAX_NUMBER_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_NUMBER_ARRAY_LENGTH', '1000'))
JAVASCRIPT_DEFAULT_CODE = """function main({arg1, arg2}) {
return {
result: arg1 + arg2
}
}"""
PYTHON_DEFAULT_CODE = """def main(arg1: int, arg2: int) -> dict:
return {
"result": arg1 + arg2,
}"""
class CodeNode(BaseNode):
_node_data_cls = CodeNodeData
node_type = NodeType.CODE
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
if filters and filters.get("code_language") == "javascript":
return {
"type": "code",
"config": {
"variables": [
{
"variable": "arg1",
"value_selector": []
},
{
"variable": "arg2",
"value_selector": []
}
],
"code_language": "javascript",
"code": JAVASCRIPT_DEFAULT_CODE,
"outputs": {
"result": {
"type": "string",
"children": None
}
}
}
}
return {
"type": "code",
"config": {
"variables": [
{
"variable": "arg1",
"value_selector": []
},
{
"variable": "arg2",
"value_selector": []
}
],
"code_language": "python3",
"code": PYTHON_DEFAULT_CODE,
"outputs": {
"result": {
"type": "string",
"children": None
}
}
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run code
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data: CodeNodeData = cast(self._node_data_cls, node_data)
# Get code language
code_language = node_data.code_language
code = node_data.code
# Get variables
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
variables[variable] = value
# Run code
try:
result = CodeExecutor.execute_code(
language=code_language,
code=code,
inputs=variables
)
# Transform result
result = self._transform_result(result, node_data.outputs)
except (CodeExecutionException, ValueError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e)
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs=result
)
def _check_string(self, value: str, variable: str) -> str:
"""
Check string
:param value: value
:param variable: variable
:return:
"""
if not isinstance(value, str):
raise ValueError(f"{variable} in output form must be a string")
if len(value) > MAX_STRING_LENGTH:
raise ValueError(f'{variable} in output form must be less than {MAX_STRING_LENGTH} characters')
return value.replace('\x00', '')
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
"""
Check number
:param value: value
:param variable: variable
:return:
"""
if not isinstance(value, int | float):
raise ValueError(f"{variable} in output form must be a number")
if value > MAX_NUMBER or value < MIN_NUMBER:
raise ValueError(f'{variable} in input form is out of range.')
if isinstance(value, float):
# raise error if precision is too high
if len(str(value).split('.')[1]) > MAX_PRECISION:
raise ValueError(f'{variable} in output form has too high precision.')
return value
def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]],
prefix: str = '',
depth: int = 1) -> dict:
"""
Transform result
:param result: result
:param output_schema: output schema
:return:
"""
if depth > MAX_DEPTH:
raise ValueError("Depth limit reached, object too deep.")
transformed_result = {}
if output_schema is None:
# validate output thought instance type
for output_name, output_value in result.items():
if isinstance(output_value, dict):
self._transform_result(
result=output_value,
output_schema=None,
prefix=f'{prefix}.{output_name}' if prefix else output_name,
depth=depth + 1
)
elif isinstance(output_value, int | float):
self._check_number(
value=output_value,
variable=f'{prefix}.{output_name}' if prefix else output_name
)
elif isinstance(output_value, str):
self._check_string(
value=output_value,
variable=f'{prefix}.{output_name}' if prefix else output_name
)
elif isinstance(output_value, list):
first_element = output_value[0] if len(output_value) > 0 else None
if first_element is not None:
if isinstance(first_element, int | float) and all(isinstance(value, int | float) for value in output_value):
for i, value in enumerate(output_value):
self._check_number(
value=value,
variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]'
)
elif isinstance(first_element, str) and all(isinstance(value, str) for value in output_value):
for i, value in enumerate(output_value):
self._check_string(
value=value,
variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]'
)
elif isinstance(first_element, dict) and all(isinstance(value, dict) for value in output_value):
for i, value in enumerate(output_value):
self._transform_result(
result=value,
output_schema=None,
prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]',
depth=depth + 1
)
else:
raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.')
else:
raise ValueError(f'Output {prefix}.{output_name} is not a valid type.')
return result
parameters_validated = {}
for output_name, output_config in output_schema.items():
dot = '.' if prefix else ''
if output_config.type == 'object':
# check if output is object
if not isinstance(result.get(output_name), dict):
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.'
)
transformed_result[output_name] = self._transform_result(
result=result[output_name],
output_schema=output_config.children,
prefix=f'{prefix}.{output_name}',
depth=depth + 1
)
elif output_config.type == 'number':
# check if number available
transformed_result[output_name] = self._check_number(
value=result[output_name],
variable=f'{prefix}{dot}{output_name}'
)
elif output_config.type == 'string':
# check if string available
transformed_result[output_name] = self._check_string(
value=result[output_name],
variable=f'{prefix}{dot}{output_name}',
)
elif output_config.type == 'array[number]':
# check if array of number available
if not isinstance(result[output_name], list):
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
)
if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH:
raise ValueError(
f'{prefix}{dot}{output_name} in output form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters.'
)
transformed_result[output_name] = [
self._check_number(
value=value,
variable=f'{prefix}{dot}{output_name}[{i}]'
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == 'array[string]':
# check if array of string available
if not isinstance(result[output_name], list):
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
)
if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH:
raise ValueError(
f'{prefix}{dot}{output_name} in output form must be less than {MAX_STRING_ARRAY_LENGTH} characters.'
)
transformed_result[output_name] = [
self._check_string(
value=value,
variable=f'{prefix}{dot}{output_name}[{i}]'
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == 'array[object]':
# check if array of object available
if not isinstance(result[output_name], list):
raise ValueError(
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
)
if len(result[output_name]) > MAX_OBJECT_ARRAY_LENGTH:
raise ValueError(
f'{prefix}{dot}{output_name} in output form must be less than {MAX_OBJECT_ARRAY_LENGTH} characters.'
)
for i, value in enumerate(result[output_name]):
if not isinstance(value, dict):
raise ValueError(
f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.'
)
transformed_result[output_name] = [
self._transform_result(
result=value,
output_schema=output_config.children,
prefix=f'{prefix}{dot}{output_name}[{i}]',
depth=depth + 1
)
for i, value in enumerate(result[output_name])
]
else:
raise ValueError(f'Output type {output_config.type} is not supported.')
parameters_validated[output_name] = True
# check if all output parameters are validated
if len(parameters_validated) != len(result):
raise ValueError('Not all output parameters are validated.')
return transformed_result
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
}

View File

@@ -0,0 +1,20 @@
from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class CodeNodeData(BaseNodeData):
"""
Code Node Data.
"""
class Output(BaseModel):
type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
children: Optional[dict[str, 'Output']]
variables: list[VariableSelector]
code_language: Literal['python3', 'javascript']
code: str
outputs: dict[str, Output]

View File

View File

@@ -0,0 +1,46 @@
from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.entities import EndNodeData
from models.workflow import WorkflowNodeExecutionStatus
class EndNode(BaseNode):
_node_data_cls = EndNodeData
node_type = NodeType.END
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
output_variables = node_data.outputs
outputs = {}
for variable_selector in output_variables:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
outputs[variable_selector.variable] = value
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs,
outputs=outputs
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}

View File

@@ -0,0 +1,9 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class EndNodeData(BaseNodeData):
"""
END Node Data.
"""
outputs: list[VariableSelector]

View File

@@ -0,0 +1,43 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, validator
from core.workflow.entities.base_node_data_entities import BaseNodeData
class HttpRequestNodeData(BaseNodeData):
"""
Code Node Data.
"""
class Authorization(BaseModel):
class Config(BaseModel):
type: Literal[None, 'basic', 'bearer', 'custom']
api_key: Union[None, str]
header: Union[None, str]
type: Literal['no-auth', 'api-key']
config: Optional[Config]
@validator('config', always=True, pre=True)
def check_config(cls, v, values):
"""
Check config, if type is no-auth, config should be None, otherwise it should be a dict.
"""
if values['type'] == 'no-auth':
return None
else:
if not v or not isinstance(v, dict):
raise ValueError('config should be a dict')
return v
class Body(BaseModel):
type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
data: Union[None, str]
method: Literal['get', 'post', 'put', 'patch', 'delete', 'head']
url: str
authorization: Authorization
headers: str
params: str
body: Optional[Body]

View File

@@ -0,0 +1,402 @@
import json
from copy import deepcopy
from random import randint
from typing import Any, Optional, Union
from urllib.parse import urlencode
import httpx
import requests
import core.helper.ssrf_proxy as ssrf_proxy
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.http_request.entities import HttpRequestNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
HTTP_REQUEST_DEFAULT_TIMEOUT = (10, 60)
MAX_BINARY_SIZE = 1024 * 1024 * 10 # 10MB
READABLE_MAX_BINARY_SIZE = '10MB'
MAX_TEXT_SIZE = 1024 * 1024 // 10 # 0.1MB
READABLE_MAX_TEXT_SIZE = '0.1MB'
class HttpExecutorResponse:
headers: dict[str, str]
response: Union[httpx.Response, requests.Response]
def __init__(self, response: Union[httpx.Response, requests.Response] = None):
"""
init
"""
headers = {}
if isinstance(response, httpx.Response):
for k, v in response.headers.items():
headers[k] = v
elif isinstance(response, requests.Response):
for k, v in response.headers.items():
headers[k] = v
self.headers = headers
self.response = response
@property
def is_file(self) -> bool:
"""
check if response is file
"""
content_type = self.get_content_type()
file_content_types = ['image', 'audio', 'video']
for v in file_content_types:
if v in content_type:
return True
return False
def get_content_type(self) -> str:
"""
get content type
"""
for key, val in self.headers.items():
if key.lower() == 'content-type':
return val
return ''
def extract_file(self) -> tuple[str, bytes]:
"""
extract file from response if content type is file related
"""
if self.is_file:
return self.get_content_type(), self.body
return '', b''
@property
def content(self) -> str:
"""
get content
"""
if isinstance(self.response, httpx.Response):
return self.response.text
elif isinstance(self.response, requests.Response):
return self.response.text
else:
raise ValueError(f'Invalid response type {type(self.response)}')
@property
def body(self) -> bytes:
"""
get body
"""
if isinstance(self.response, httpx.Response):
return self.response.content
elif isinstance(self.response, requests.Response):
return self.response.content
else:
raise ValueError(f'Invalid response type {type(self.response)}')
@property
def status_code(self) -> int:
"""
get status code
"""
if isinstance(self.response, httpx.Response):
return self.response.status_code
elif isinstance(self.response, requests.Response):
return self.response.status_code
else:
raise ValueError(f'Invalid response type {type(self.response)}')
@property
def size(self) -> int:
"""
get size
"""
return len(self.body)
@property
def readable_size(self) -> str:
"""
get readable size
"""
if self.size < 1024:
return f'{self.size} bytes'
elif self.size < 1024 * 1024:
return f'{(self.size / 1024):.2f} KB'
else:
return f'{(self.size / 1024 / 1024):.2f} MB'
class HttpExecutor:
server_url: str
method: str
authorization: HttpRequestNodeData.Authorization
params: dict[str, Any]
headers: dict[str, Any]
body: Union[None, str]
files: Union[None, dict[str, Any]]
boundary: str
variable_selectors: list[VariableSelector]
def __init__(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
"""
init
"""
self.server_url = node_data.url
self.method = node_data.method
self.authorization = node_data.authorization
self.params = {}
self.headers = {}
self.body = None
self.files = None
# init template
self.variable_selectors = []
self._init_template(node_data, variable_pool)
def _is_json_body(self, body: HttpRequestNodeData.Body):
"""
check if body is json
"""
if body and body.type == 'json':
try:
json.loads(body.data)
return True
except:
return False
return False
def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
"""
init template
"""
variable_selectors = []
# extract all template in url
self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
# extract all template in params
params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
# fill in params
kv_paris = params.split('\n')
for kv in kv_paris:
if not kv.strip():
continue
kv = kv.split(':')
if len(kv) == 2:
k, v = kv
elif len(kv) == 1:
k, v = kv[0], ''
else:
raise ValueError(f'Invalid params {kv}')
self.params[k.strip()] = v
# extract all template in headers
headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
# fill in headers
kv_paris = headers.split('\n')
for kv in kv_paris:
if not kv.strip():
continue
kv = kv.split(':')
if len(kv) == 2:
k, v = kv
elif len(kv) == 1:
k, v = kv[0], ''
else:
raise ValueError(f'Invalid headers {kv}')
self.headers[k.strip()] = v.strip()
# extract all template in body
body_data_variable_selectors = []
if node_data.body:
# check if it's a valid JSON
is_valid_json = self._is_json_body(node_data.body)
body_data = node_data.body.data or ''
if body_data:
body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json)
if node_data.body.type == 'json':
self.headers['Content-Type'] = 'application/json'
elif node_data.body.type == 'x-www-form-urlencoded':
self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
if node_data.body.type in ['form-data', 'x-www-form-urlencoded']:
body = {}
kv_paris = body_data.split('\n')
for kv in kv_paris:
if not kv.strip():
continue
kv = kv.split(':')
if len(kv) == 2:
body[kv[0].strip()] = kv[1]
elif len(kv) == 1:
body[kv[0].strip()] = ''
else:
raise ValueError(f'Invalid body {kv}')
if node_data.body.type == 'form-data':
self.files = {
k: ('', v) for k, v in body.items()
}
random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)])
self.boundary = f'----WebKitFormBoundary{random_str(16)}'
self.headers['Content-Type'] = f'multipart/form-data; boundary={self.boundary}'
else:
self.body = urlencode(body)
elif node_data.body.type in ['json', 'raw-text']:
self.body = body_data
elif node_data.body.type == 'none':
self.body = ''
self.variable_selectors = (server_url_variable_selectors + params_variable_selectors
+ headers_variable_selectors + body_data_variable_selectors)
def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.authorization)
headers = deepcopy(self.headers) or {}
if self.authorization.type == 'api-key':
if self.authorization.config.api_key is None:
raise ValueError('api_key is required')
if not self.authorization.config.header:
authorization.config.header = 'Authorization'
if self.authorization.config.type == 'bearer':
headers[authorization.config.header] = f'Bearer {authorization.config.api_key}'
elif self.authorization.config.type == 'basic':
headers[authorization.config.header] = f'Basic {authorization.config.api_key}'
elif self.authorization.config.type == 'custom':
headers[authorization.config.header] = authorization.config.api_key
return headers
def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse:
"""
validate the response
"""
if isinstance(response, httpx.Response | requests.Response):
executor_response = HttpExecutorResponse(response)
else:
raise ValueError(f'Invalid response type {type(response)}')
if executor_response.is_file:
if executor_response.size > MAX_BINARY_SIZE:
raise ValueError(f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.')
else:
if executor_response.size > MAX_TEXT_SIZE:
raise ValueError(f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.')
return executor_response
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
"""
do http request depending on api bundle
"""
# do http request
kwargs = {
'url': self.server_url,
'headers': headers,
'params': self.params,
'timeout': HTTP_REQUEST_DEFAULT_TIMEOUT,
'follow_redirects': True
}
if self.method == 'get':
response = ssrf_proxy.get(**kwargs)
elif self.method == 'post':
response = ssrf_proxy.post(data=self.body, files=self.files, **kwargs)
elif self.method == 'put':
response = ssrf_proxy.put(data=self.body, files=self.files, **kwargs)
elif self.method == 'delete':
response = ssrf_proxy.delete(data=self.body, files=self.files, **kwargs)
elif self.method == 'patch':
response = ssrf_proxy.patch(data=self.body, files=self.files, **kwargs)
elif self.method == 'head':
response = ssrf_proxy.head(**kwargs)
elif self.method == 'options':
response = ssrf_proxy.options(**kwargs)
else:
raise ValueError(f'Invalid http method {self.method}')
return response
def invoke(self) -> HttpExecutorResponse:
"""
invoke http request
"""
# assemble headers
headers = self._assembling_headers()
# do http request
response = self._do_http_request(headers)
# validate response
return self._validate_and_parse_response(response)
def to_raw_request(self) -> str:
"""
convert to raw request
"""
server_url = self.server_url
if self.params:
server_url += f'?{urlencode(self.params)}'
raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n'
headers = self._assembling_headers()
for k, v in headers.items():
raw_request += f'{k}: {v}\n'
raw_request += '\n'
# if files, use multipart/form-data with boundary
if self.files:
boundary = self.boundary
raw_request += f'--{boundary}'
for k, v in self.files.items():
raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n'
raw_request += f'{v[1]}\n'
raw_request += f'--{boundary}'
raw_request += '--'
else:
raw_request += self.body or ''
return raw_request
def _format_template(self, template: str, variable_pool: VariablePool, escape_quotes: bool = False) \
-> tuple[str, list[VariableSelector]]:
"""
format template
"""
variable_template_parser = VariableTemplateParser(template=template)
variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_pool:
variable_value_mapping = {}
for variable_selector in variable_selectors:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector,
target_value_type=ValueType.STRING
)
if value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
if escape_quotes:
value = value.replace('"', '\\"')
variable_value_mapping[variable_selector.variable] = value
return variable_template_parser.format(variable_value_mapping), variable_selectors
else:
return template, variable_selectors

View File

@@ -0,0 +1,112 @@
import logging
from mimetypes import guess_extension
from os import path
from typing import cast
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.http_request.entities import HttpRequestNodeData
from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
from models.workflow import WorkflowNodeExecutionStatus
class HttpRequestNode(BaseNode):
_node_data_cls = HttpRequestNodeData
node_type = NodeType.HTTP_REQUEST
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data)
# init http executor
http_executor = None
try:
http_executor = HttpExecutor(node_data=node_data, variable_pool=variable_pool)
# invoke http executor
response = http_executor.invoke()
except Exception as e:
process_data = {}
if http_executor:
process_data = {
'request': http_executor.to_raw_request(),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
process_data=process_data
)
files = self.extract_files(http_executor.server_url, response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'status_code': response.status_code,
'body': response.content if not files else '',
'headers': response.headers,
'files': files,
},
process_data={
'request': http_executor.to_raw_request(),
}
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
try:
http_executor = HttpExecutor(node_data=node_data)
variable_selectors = http_executor.variable_selectors
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
return variable_mapping
except Exception as e:
logging.exception(f"Failed to extract variable selector to variable mapping: {e}")
return {}
def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]:
"""
Extract files from response
"""
files = []
mimetype, file_binary = response.extract_file()
# if not image, return directly
if 'image' not in mimetype:
return files
if mimetype:
# extract filename from url
filename = path.basename(url)
# extract extension if possible
extension = guess_extension(mimetype) or '.bin'
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mimetype,
)
files.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file.id,
filename=filename,
extension=extension,
mime_type=mimetype,
))
return files

View File

@@ -0,0 +1,26 @@
from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
class IfElseNodeData(BaseNodeData):
"""
Answer Node Data.
"""
class Condition(BaseModel):
"""
Condition entity
"""
variable_selector: list[str]
comparison_operator: Literal[
# for string or array
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
# for number
"=", "", ">", "<", "", "", "null", "not null"
]
value: Optional[str] = None
logical_operator: Literal["and", "or"] = "and"
conditions: list[Condition]

View File

@@ -0,0 +1,398 @@
from typing import Optional, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.if_else.entities import IfElseNodeData
from models.workflow import WorkflowNodeExecutionStatus
class IfElseNode(BaseNode):
_node_data_cls = IfElseNodeData
node_type = NodeType.IF_ELSE
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
node_inputs = {
"conditions": []
}
process_datas = {
"condition_results": []
}
try:
logical_operator = node_data.logical_operator
input_conditions = []
for condition in node_data.conditions:
actual_value = variable_pool.get_variable_value(
variable_selector=condition.variable_selector
)
expected_value = condition.value
input_conditions.append({
"actual_value": actual_value,
"expected_value": expected_value,
"comparison_operator": condition.comparison_operator
})
node_inputs["conditions"] = input_conditions
for input_condition in input_conditions:
actual_value = input_condition["actual_value"]
expected_value = input_condition["expected_value"]
comparison_operator = input_condition["comparison_operator"]
if comparison_operator == "contains":
compare_result = self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
compare_result = self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
compare_result = self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
compare_result = self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
compare_result = self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
compare_result = self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
compare_result = self._assert_empty(actual_value)
elif comparison_operator == "not empty":
compare_result = self._assert_not_empty(actual_value)
elif comparison_operator == "=":
compare_result = self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
compare_result = self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
compare_result = self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
compare_result = self._assert_null(actual_value)
elif comparison_operator == "not null":
compare_result = self._assert_not_null(actual_value)
else:
continue
process_datas["condition_results"].append({
**input_condition,
"result": compare_result
})
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=node_inputs,
process_data=process_datas,
error=str(e)
)
if logical_operator == "and":
compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]]
else:
compare_result = True in [condition["result"] for condition in process_datas["condition_results"]]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_datas,
edge_source_handle="false" if not compare_result else "true",
outputs={
"result": compare_result
}
)
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value not in actual_value:
return False
return True
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert not contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return True
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value in actual_value:
return False
return True
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert start with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.startswith(expected_value):
return False
return True
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert end with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.endswith(expected_value):
return False
return True
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value != expected_value:
return False
return True
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is not
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value == expected_value:
return False
return True
def _assert_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if not actual_value:
return True
return False
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert not empty
:param actual_value: actual value
:return:
"""
if actual_value:
return True
return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value != expected_value:
return False
return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert not equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value == expected_value:
return False
return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value <= expected_value:
return False
return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value >= expected_value:
return False
return True
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value < expected_value:
return False
return True
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value > expected_value:
return False
return True
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert null
:param actual_value: actual value
:return:
"""
if actual_value is None:
return True
return False
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert not null
:param actual_value: actual value
:return:
"""
if actual_value is not None:
return True
return False
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}

View File

@@ -0,0 +1,51 @@
from typing import Any, Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
provider: str
model: str
class MultipleRetrievalConfig(BaseModel):
"""
Multiple Retrieval Config.
"""
top_k: int
score_threshold: Optional[float]
reranking_model: RerankingModelConfig
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class SingleRetrievalConfig(BaseModel):
"""
Single Retrieval Config.
"""
model: ModelConfig
class KnowledgeRetrievalNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = 'knowledge-retrieval'
query_variable_selector: list[str]
dataset_ids: list[str]
retrieval_mode: Literal['single', 'multiple']
multiple_retrieval_config: Optional[MultipleRetrievalConfig]
single_retrieval_config: Optional[SingleRetrievalConfig]

View File

@@ -0,0 +1,443 @@
import threading
from typing import Any, cast
from flask import Flask, current_app
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalService
from core.rerank.rerank import RerankRunner
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter
from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
class KnowledgeRetrievalNode(BaseNode):
_node_data_cls = KnowledgeRetrievalNodeData
node_type = NodeType.KNOWLEDGE_RETRIEVAL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
# extract variables
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
variables = {
'query': query
}
if not query:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Query is required."
)
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(
node_data=node_data, query=query
)
outputs = {
'result': results
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs=outputs
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e)
)
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[
dict[str, Any]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param node_data: node data
:param query: query
"""
tools = []
available_datasets = []
dataset_ids = node_data.dataset_ids
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
).first()
# pass if dataset is not available
if not dataset:
continue
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
continue
available_datasets.append(dataset)
all_documents = []
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
all_documents = self._single_retrieve(available_datasets, node_data, query)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
all_documents = self._multiple_retrieve(available_datasets, node_data, query)
context_list = []
if all_documents:
document_score_list = {}
for item in all_documents:
if 'score' in item.metadata and item.metadata['score']:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(segments,
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
float('inf')))
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
else:
document_context_list.append(segment.content)
for segment in sorted_segments:
dataset = Dataset.query.filter_by(
id=segment.dataset_id
).first()
document = Document.query.filter(Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
resource_number = 1
if dataset and document:
source = {
'metadata': {
'_source': 'knowledge',
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'document_data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': 'workflow',
'score': document_score_list.get(segment.index_node_id, None),
'segment_hit_count': segment.hit_count,
'segment_word_count': segment.word_count,
'segment_position': segment.position,
'segment_index_node_hash': segment.index_node_hash,
},
'title': document.name
}
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
else:
source['content'] = segment.content
context_list.append(source)
resource_number += 1
return context_list
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
variable_mapping = {}
variable_mapping['query'] = node_data.query_variable_selector
return variable_mapping
def _single_retrieve(self, available_datasets, node_data, query):
tools = []
for dataset in available_datasets:
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
description = description.replace('\n', '').replace('\r', '')
message_tool = PromptMessageTool(
name=dataset.id,
description=description,
parameters={
"type": "object",
"properties": {},
"required": [],
}
)
tools.append(message_tool)
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data)
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model,
credentials=model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
dataset_id = None
if planning_strategy == PlanningStrategy.REACT_ROUTER:
react_multi_dataset_router = ReactMultiDatasetRouter()
dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance,
self.user_id, self.tenant_id)
elif planning_strategy == PlanningStrategy.ROUTER:
function_call_router = FunctionCallMultiDatasetRouter()
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if dataset:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get retrieval method
retrival_method = retrieval_model_config['search_method']
# get reranking model
reranking_model=retrieval_model_config['reranking_model'] \
if retrieval_model_config['reranking_enable'] else None
# get score threshold
score_threshold = .0
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
self._on_query(query, [dataset_id])
if results:
self._on_retrival_end(results)
return results
return []
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
:return:
"""
model_name = node_data.single_retrieval_config.model.name
provider_name = node_data.single_retrieval_config.model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=provider_name,
model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name,
model_type=ModelType.LLM
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data.single_retrieval_config.model.completion_params
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = node_data.single_retrieval_config.model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(
model_name,
model_credentials
)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _multiple_retrieve(self, available_datasets, node_data, query):
threads = []
all_documents = []
dataset_ids = [dataset.id for dataset in available_datasets]
for dataset in available_datasets:
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset.id,
'query': query,
'top_k': node_data.multiple_retrieval_config.top_k,
'all_documents': all_documents,
})
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=node_data.multiple_retrieval_config.reranking_model.provider,
model_type=ModelType.RERANK,
model=node_data.multiple_retrieval_config.reranking_model.model
)
rerank_runner = RerankRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents,
node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.top_k)
self._on_query(query, dataset_ids)
if all_documents:
self._on_retrival_end(all_documents)
return all_documents
def _on_retrival_end(self, documents: list[Document]) -> None:
"""Handle retrival end."""
for document in documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata['doc_id']
)
# if 'dataset_id' in document.metadata:
if 'dataset_id' in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
db.session.commit()
def _on_query(self, query: str, dataset_ids: list[str]) -> None:
"""
Handle query.
"""
if not query:
return
for dataset_id in dataset_ids:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source_app_id=self.app_id,
created_by_role=self.user_from.value,
created_by=self.user_id
)
db.session.add(dataset_query)
db.session.commit()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
return []
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(retrival_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=top_k
)
if documents:
all_documents.extend(documents)
else:
if top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None
)
all_documents.extend(documents)

View File

@@ -0,0 +1,47 @@
from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
class FunctionCallMultiDatasetRouter:
def invoke(
self,
query: str,
dataset_tools: list[PromptMessageTool],
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
) -> Union[str, None]:
"""Given input, decided what to do.
Returns:
Action specifying what tool to use.
"""
if len(dataset_tools) == 0:
return None
elif len(dataset_tools) == 1:
return dataset_tools[0].name
try:
prompt_messages = [
SystemPromptMessage(content='You are a helpful AI assistant.'),
UserPromptMessage(content=query)
]
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
if result.message.tool_calls:
# get retrieval model config
return result.message.tool_calls[0].function.name
return None
except Exception as e:
return None

View File

@@ -0,0 +1,254 @@
from collections.abc import Generator, Sequence
from typing import Optional, Union
from langchain import PromptTemplate
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from langchain.schema import AgentAction
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.llm.llm_node import LLMNode
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{
"action": "Final Answer",
"action_input": "Final response to human"
}}
```"""
class ReactMultiDatasetRouter:
def invoke(
self,
query: str,
dataset_tools: list[PromptMessageTool],
node_data: KnowledgeRetrievalNodeData,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
user_id: str,
tenant_id: str,
) -> Union[str, None]:
"""Given input, decided what to do.
Returns:
Action specifying what tool to use.
"""
if len(dataset_tools) == 0:
return None
elif len(dataset_tools) == 1:
return dataset_tools[0].name
try:
return self._react_invoke(query=query, node_data=node_data, model_config=model_config, model_instance=model_instance,
tools=dataset_tools, user_id=user_id, tenant_id=tenant_id)
except Exception as e:
return None
def _react_invoke(
self,
query: str,
node_data: KnowledgeRetrievalNodeData,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
tools: Sequence[PromptMessageTool],
user_id: str,
tenant_id: str,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]:
if model_config.mode == "chat":
prompt = self.create_chat_prompt(
query=query,
tools=tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
)
else:
prompt = self.create_completion_prompt(
tools=tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=None
)
stop = ['Observation:']
# handle invoke result
prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt,
inputs={},
query='',
files=[],
context='',
memory_config=None,
memory=None,
model_config=model_config
)
result_text, usage = self._invoke_llm(
node_data=node_data,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=user_id,
tenant_id=tenant_id
)
output_parser = StructuredChatOutputParser()
agent_decision = output_parser.parse(result_text)
if isinstance(agent_decision, AgentAction):
return agent_decision.tool
return None
def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str], user_id: str, tenant_id: str) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param node_data: node data
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data.single_retrieval_config.model.completion_params,
stop=stop,
stream=True,
user=user_id,
)
# handle invoke result
text, usage = self._handle_invoke_result(
invoke_result=invoke_result
)
# deduct quota
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
return text, usage
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
model = None
prompt_messages = []
full_text = ''
usage = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
if not model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
usage = result.delta.usage
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage
def create_chat_prompt(
self,
query: str,
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> list[ChatModelMessage]:
tool_strings = []
for tool in tools:
tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
formatted_tools = "\n".join(tool_strings)
unique_tool_names = set(tool.name for tool in tools)
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
prompt_messages = []
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=template
)
prompt_messages.append(system_prompt_messages)
user_prompt_message = ChatModelMessage(
role=PromptMessageRole.USER,
text=query
)
prompt_messages.append(user_prompt_message)
return prompt_messages
def create_completion_prompt(
self,
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)

View File

View File

@@ -0,0 +1,49 @@
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class ContextConfig(BaseModel):
"""
Context Config.
"""
enabled: bool
variable_selector: Optional[list[str]] = None
class VisionConfig(BaseModel):
"""
Vision Config.
"""
class Configs(BaseModel):
"""
Configs.
"""
detail: Literal['low', 'high']
enabled: bool
configs: Optional[Configs] = None
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
model: ModelConfig
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig

View File

@@ -0,0 +1,554 @@
from collections.abc import Generator
from typing import Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus
class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
node_type = NodeType.LLM
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
node_inputs = None
process_data = None
try:
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)
node_inputs = {}
# fetch files
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
if files:
node_inputs['#files#'] = [file.to_dict() for file in files]
# fetch context value
context = self._fetch_context(node_data, variable_pool)
if context:
node_inputs['#context#'] = context
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
if node_data.memory else None,
inputs=inputs,
files=files,
context=context,
memory=memory,
model_config=model_config
)
process_data = {
'model_mode': model_config.mode,
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode,
prompt_messages=prompt_messages
)
}
# handle invoke result
result_text, usage = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data
)
outputs = {
'text': result_text,
'usage': jsonable_encoder(usage)
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
)
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str]) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
stop=stop,
stream=True,
user=self.user_id,
)
# handle invoke result
text, usage = self._handle_invoke_result(
invoke_result=invoke_result
)
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
model = None
prompt_messages = []
full_text = ''
usage = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
if not model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
usage = result.delta.usage
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
inputs = {}
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
elif isinstance(prompt_template, CompletionModelPromptTemplate):
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
inputs[variable_selector.variable] = variable_value
return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
"""
Fetch files
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if not node_data.vision.enabled:
return []
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
if not files:
return []
return files
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
"""
Fetch context
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if not node_data.context.enabled:
return None
if not node_data.context.variable_selector:
return None
context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
return context_value
elif isinstance(context_value, list):
context_str = ''
original_retriever_resource = []
for item in context_value:
if 'content' not in item:
raise ValueError(f'Invalid context structure: {item}')
context_str += item['content'] + '\n'
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
if self.callbacks:
for callback in self.callbacks:
callback.on_event(
event=QueueRetrieverResourcesEvent(
retriever_resources=original_retriever_resource
)
)
return context_str.strip()
return None
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
"""
Convert to original retriever resource, temp.
:param context_dict: context dict
:return:
"""
if ('metadata' in context_dict and '_source' in context_dict['metadata']
and context_dict['metadata']['_source'] == 'knowledge'):
metadata = context_dict.get('metadata', {})
source = {
'position': metadata.get('position'),
'dataset_id': metadata.get('dataset_id'),
'dataset_name': metadata.get('dataset_name'),
'document_id': metadata.get('document_id'),
'document_name': metadata.get('document_name'),
'data_source_type': metadata.get('document_data_source_type'),
'segment_id': metadata.get('segment_id'),
'retriever_from': metadata.get('retriever_from'),
'score': metadata.get('score'),
'hit_count': metadata.get('segment_hit_count'),
'word_count': metadata.get('segment_word_count'),
'segment_position': metadata.get('segment_position'),
'index_node_hash': metadata.get('segment_index_node_hash'),
'content': context_dict.get('content'),
}
return source
return None
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data_model: node data model
:return:
"""
model_name = node_data_model.name
provider_name = node_data_model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=provider_name,
model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name,
model_type=ModelType.LLM
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data_model.completion_params
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(
model_name,
model_credentials
)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _fetch_memory(self, node_data_memory: Optional[MemoryConfig],
variable_pool: VariablePool,
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
"""
Fetch memory
:param node_data_memory: node data memory
:param variable_pool: variable pool
:return:
"""
if not node_data_memory:
return None
# get conversation id
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value])
if conversation_id is None:
return None
# get conversation
conversation = db.session.query(Conversation).filter(
Conversation.app_id == self.app_id,
Conversation.id == conversation_id
).first()
if not conversation:
return None
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
return memory
def _fetch_prompt_messages(self, node_data: LLMNodeData,
query: Optional[str],
inputs: dict[str, str],
files: list[FileVar],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt messages
:param node_data: node data
:param query: query
:param inputs: inputs
:param files: files
:param context: context
:param memory: memory
:param model_config: model config
:return:
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=node_data.prompt_template,
inputs=inputs,
query=query if query else '',
files=files,
context=context,
memory_config=node_data.memory,
memory=memory,
model_config=model_config
)
stop = model_config.stop
return prompt_messages, stop
@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
"""
Deduct LLM quota
:param tenant_id: tenant id
:param model_instance: model instance
:param usage: usage
:return:
"""
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1
if 'gpt-4' in model_instance.model:
used_quota = 20
else:
used_quota = 1
if used_quota is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_instance.provider,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + used_quota})
db.session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else:
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping['#context#'] = node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "llm",
"config": {
"prompt_templates": {
"chat_model": {
"prompts": [
{
"role": "system",
"text": "You are a helpful AI assistant."
}
]
},
"completion_model": {
"conversation_histories_role": {
"user_prefix": "Human",
"assistant_prefix": "Assistant"
},
"prompt": {
"text": "Here is the chat histories between human and assistant, inside "
"<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
},
"stop": ["Human:"]
}
}
}
}

View File

@@ -0,0 +1,36 @@
from typing import Any, Optional
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class ClassConfig(BaseModel):
"""
Class Config.
"""
id: str
name: str
class QuestionClassifierNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
query_variable_selector: list[str]
type: str = 'question-classifier'
model: ModelConfig
classes: list[ClassConfig]
instruction: Optional[str]
memory: Optional[MemoryConfig]

View File

@@ -0,0 +1,253 @@
import json
import logging
from typing import Optional, Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.question_classifier.template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
QUESTION_CLASSIFIER_COMPLETION_PROMPT,
QUESTION_CLASSIFIER_SYSTEM_PROMPT,
QUESTION_CLASSIFIER_USER_PROMPT_1,
QUESTION_CLASSIFIER_USER_PROMPT_2,
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
from models.workflow import WorkflowNodeExecutionStatus
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData
node_type = NodeType.QUESTION_CLASSIFIER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
node_data = cast(QuestionClassifierNodeData, node_data)
# extract variables
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
variables = {
'query': query
}
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt(
node_data=node_data,
context='',
query=query,
memory=memory,
model_config=model_config
)
# handle invoke result
result_text, usage = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
)
categories = [_class.name for _class in node_data.classes]
try:
result_text_json = json.loads(result_text.strip('```JSON\n'))
categories = result_text_json.get('categories', [])
except Exception:
logging.error(f"Failed to parse result text: {result_text}")
try:
process_data = {
'model_mode': model_config.mode,
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode,
prompt_messages=prompt_messages
),
'usage': jsonable_encoder(usage),
'topics': categories[0] if categories else ''
}
outputs = {
'class_name': categories[0] if categories else ''
}
classes = node_data.classes
classes_map = {class_.name: class_.id for class_ in classes}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=process_data,
outputs=outputs,
edge_source_handle=classes_map.get(categories[0], None)
)
except ValueError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
variable_mapping = {'query': node_data.query_variable_selector}
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "question-classifier",
"config": {
"instructions": ""
}
}
def _fetch_prompt(self, node_data: QuestionClassifierNodeData,
query: str,
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt
:param node_data: node data
:param query: inputs
:param context: context
:param memory: memory
:param model_config: model config
:return:
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, model_config, context)
prompt_template = self._get_prompt_template(node_data, query, memory, rest_token)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config
)
stop = model_config.stop
return prompt_messages, stop
def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str]) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages
)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
model_mode = ModelMode.value_of(node_data.model.mode)
classes = node_data.classes
class_names = [class_.name for class_ in classes]
class_names_str = ','.join(class_names)
instruction = node_data.instruction if node_data.instruction else ''
input_text = query
memory_str = ''
if memory:
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size)
prompt_messages = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = ChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT,
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = ChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT,
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = ChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text,
categories=class_names_str,
classification_instructions=instruction)
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
input_text=input_text,
categories=class_names_str,
classification_instructions=instruction)
)
else:
raise ValueError(f"Model mode {model_mode} not support.")

View File

@@ -0,0 +1,72 @@
QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
### Job Description',
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
### Task
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output.Additionally, you need to extract the key words from the text that are related to the classification.
### Format
The input text is in the variable text_field.Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination.Classification instructions may be included to improve the classification accuracy.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
"""
QUESTION_CLASSIFIER_USER_PROMPT_1 = """
{ "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],
"categories": ["Customer Service, Satisfaction, Sales, Product"],
"classification_instructions": ["classify the text based on the feedback provided by customer"]}```JSON
"""
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],
"categories": ["Customer Service"]}```
"""
QUESTION_CLASSIFIER_USER_PROMPT_2 = """
{"input_text": ["bad service, slow to bring the food"],
"categories": ["Food Quality, Experience, Price" ],
"classification_instructions": []}```JSON
"""
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],
"categories": ["Experience""]}```
"""
QUESTION_CLASSIFIER_USER_PROMPT_3 = """
'{{"input_text": ["{input_text}"],',
'"categories": ["{categories}" ], ',
'"classification_instructions": ["{classification_instructions}"]}}```JSON'
"""
QUESTION_CLASSIFIER_COMPLETION_PROMPT = """
### Job Description
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
### Task
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
### Format
The input text is in the variable text_field. Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination. Classification instructions may be included to improve the classification accuracy.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],"categories": ["Customer Service, Satisfaction, Sales, Product"], "classification_instructions": ["classify the text based on the feedback provided by customer"]}}
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"categories": ["Customer Service"]}}
User:{{"input_text": ["bad service, slow to bring the food"],"categories": ["Food Quality, Experience, Price" ], "classification_instructions": []}}
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"categories": ["Customer Service"]}}{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"categories": ["Experience""]}}
</example>
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### User Input
{{"input_text" : ["{input_text}"], "categories" : ["{categories}"],"classification_instruction" : ["{classification_instructions}"]}}
### Assistant Output
"""

View File

@@ -0,0 +1,9 @@
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
class StartNodeData(BaseNodeData):
"""
Start Node Data
"""
variables: list[VariableEntity] = []

View File

@@ -0,0 +1,84 @@
from typing import cast
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class StartNode(BaseNode):
_node_data_cls = StartNodeData
node_type = NodeType.START
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
variables = node_data.variables
# Get cleaned inputs
cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs)
for var in variable_pool.system_variables:
if var == SystemVariable.CONVERSATION:
continue
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs,
outputs=cleaned_inputs
)
def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
for variable_config in variables:
variable = variable_config.variable
if variable not in user_inputs or not user_inputs[variable]:
if variable_config.required:
raise ValueError(f"Input form variable {variable} is required")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value:
if not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}

View File

@@ -0,0 +1,12 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class TemplateTransformNodeData(BaseNodeData):
"""
Code Node Data.
"""
variables: list[VariableSelector]
template: str

View File

@@ -0,0 +1,91 @@
import os
from typing import Optional, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000'))
class TemplateTransformNode(BaseNode):
_node_data_cls = TemplateTransformNodeData
_node_type = NodeType.TEMPLATE_TRANSFORM
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "template-transform",
"config": {
"variables": [
{
"variable": "arg1",
"value_selector": []
}
],
"template": "{{ arg1 }}"
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
"""
node_data = self.node_data
node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data)
# Get variables
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
variables[variable] = value
# Run code
try:
result = CodeExecutor.execute_code(
language='jinja2',
code=node_data.template,
inputs=variables
)
except CodeExecutionException as e:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters"
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs={
'output': result['result']
}
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
}

View File

View File

@@ -0,0 +1,37 @@
from typing import Literal, Union
from pydantic import BaseModel, validator
from core.workflow.entities.base_node_data_entities import BaseNodeData
ToolParameterValue = Union[str, int, float, bool]
class ToolEntity(BaseModel):
provider_id: str
provider_type: Literal['builtin', 'api']
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, ToolParameterValue]
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
value: Union[ToolParameterValue, list[str]]
type: Literal['mixed', 'variable', 'constant']
@validator('type', pre=True, always=True)
def check_type(cls, value, values):
typ = value
value = values.get('value')
if typ == 'mixed' and not isinstance(value, str):
raise ValueError('value must be a string')
elif typ == 'variable' and not isinstance(value, list):
raise ValueError('value must be a list')
elif typ == 'constant' and not isinstance(value, ToolParameterValue):
raise ValueError('value must be a string, int, float, or bool')
return typ
"""
Tool Node Schema
"""
tool_parameters: dict[str, ToolInput]

View File

@@ -0,0 +1,201 @@
from os import path
from typing import cast
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
class ToolNode(BaseNode):
"""
Tool Node
"""
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run the tool node
"""
node_data = cast(ToolNodeData, self.node_data)
# fetch tool icon
tool_info = {
'provider_type': node_data.provider_type,
'provider_id': node_data.provider_id
}
# get parameters
parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to get tool runtime: {str(e)}'
)
try:
messages = ToolEngine.workflow_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
workflow_id=self.workflow_id,
workflow_tool_callback=DifyWorkflowCallbackHandler()
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to invoke tool: {str(e)}',
)
# convert tool messages
plain_text, files = self._convert_tool_messages(messages)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'text': plain_text,
'files': files
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
inputs=parameters
)
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
"""
Generate parameters
"""
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
elif input.type == 'variable':
result[parameter_name] = variable_pool.get_variable_value(input.value)
elif input.type == 'constant':
result[parameter_name] = input.value
return result
def _format_variable_template(self, template: str, variable_pool: VariablePool) -> str:
"""
Format variable template
"""
inputs = {}
template_parser = VariableTemplateParser(template)
for selector in template_parser.extract_variable_selectors():
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
return template_parser.format(inputs)
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
# extract plain text and files
files = self._extract_tool_response_binary(messages)
plain_text = self._extract_tool_response_text(messages)
return plain_text, files
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
url = response.message
ext = path.splitext(url)[1]
mimetype = response.meta.get('mime_type', 'image/jpeg')
filename = response.save_as or url.split('/')[-1]
# get tool file id
tool_file_id = url.split('/')[-1].split('.')[0]
result.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=filename,
extension=ext,
mime_type=mimetype,
))
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = response.message.split('/')[-1].split('.')[0]
result.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=response.save_as,
extension=path.splitext(response.save_as)[1],
mime_type=response.meta.get('mime_type', 'application/octet-stream'),
))
elif response.type == ToolInvokeMessage.MessageType.LINK:
pass # TODO:
return result
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
"""
Extract tool response text
"""
return ''.join([
f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
for message in tool_response
])
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == 'variable':
result[parameter_name] = input.value
elif input.type == 'constant':
pass
return result

View File

@@ -0,0 +1,12 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
class VariableAssignerNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = 'variable-assigner'
output_type: str
variables: list[list[str]]

View File

@@ -0,0 +1,41 @@
from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.variable_assigner.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus
class VariableAssignerNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: VariableAssignerNodeData = cast(self._node_data_cls, self.node_data)
# Get variables
outputs = {}
inputs = {}
for variable in node_data.variables:
value = variable_pool.get_variable_value(variable)
if value is not None:
outputs = {
"output": value
}
inputs = {
'.'.join(variable[1:]): value
}
break
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
inputs=inputs
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
return {}