mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 22:06:52 +08:00
FEAT: NEW WORKFLOW ENGINE (#3160)
Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
0
api/core/workflow/nodes/__init__.py
Normal file
0
api/core/workflow/nodes/__init__.py
Normal file
0
api/core/workflow/nodes/answer/__init__.py
Normal file
0
api/core/workflow/nodes/answer/__init__.py
Normal file
155
api/core/workflow/nodes/answer/answer_node.py
Normal file
155
api/core/workflow/nodes/answer/answer_node.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnswerNode(BaseNode):
|
||||
_node_data_cls = AnswerNodeData
|
||||
node_type = NodeType.ANSWER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
|
||||
# generate routes
|
||||
generate_routes = self.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
answer = ''
|
||||
for part in generate_routes:
|
||||
if part.type == "var":
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=value_selector
|
||||
)
|
||||
|
||||
text = ''
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, dict):
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, FileVar):
|
||||
# convert file to markdown
|
||||
text = value.to_markdown()
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, FileVar):
|
||||
text += item.to_markdown() + ' '
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text and value:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
answer += text
|
||||
else:
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"answer": answer
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
generate_routes = []
|
||||
for part in template.split('Ω'):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
34
api/core/workflow/nodes/answer/entities.py
Normal file
34
api/core/workflow/nodes/answer/entities.py
Normal file
@@ -0,0 +1,34 @@
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class AnswerNodeData(BaseNodeData):
|
||||
"""
|
||||
Answer Node Data.
|
||||
"""
|
||||
answer: str
|
||||
|
||||
|
||||
class GenerateRouteChunk(BaseModel):
|
||||
"""
|
||||
Generate Route Chunk.
|
||||
"""
|
||||
type: str
|
||||
|
||||
|
||||
class VarGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Var Generate Route Chunk.
|
||||
"""
|
||||
type: str = "var"
|
||||
value_selector: list[str]
|
||||
|
||||
|
||||
class TextGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Text Generate Route Chunk.
|
||||
"""
|
||||
type: str = "text"
|
||||
text: str
|
||||
142
api/core/workflow/nodes/base_node.py
Normal file
142
api/core/workflow/nodes/base_node.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
"""
|
||||
User from
|
||||
"""
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "UserFrom":
|
||||
"""
|
||||
Value of
|
||||
:param value: value
|
||||
:return:
|
||||
"""
|
||||
for item in cls:
|
||||
if item.value == value:
|
||||
return item
|
||||
raise ValueError(f"Invalid value: {value}")
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
_node_data_cls: type[BaseNodeData]
|
||||
_node_type: NodeType
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
|
||||
node_id: str
|
||||
node_data: BaseNodeData
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
|
||||
callbacks: list[BaseWorkflowCallback]
|
||||
|
||||
def __init__(self, tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
config: dict,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
self.workflow_id = workflow_id
|
||||
self.user_id = user_id
|
||||
self.user_from = user_from
|
||||
|
||||
self.node_id = config.get("id")
|
||||
if not self.node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self.node_data = self._node_data_cls(**config.get("data", {}))
|
||||
self.callbacks = callbacks or []
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node entry
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
result = self._run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
|
||||
self.node_run_result = result
|
||||
return result
|
||||
|
||||
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
:param text: chunk text
|
||||
:param value_selector: value selector
|
||||
:return:
|
||||
"""
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_node_text_chunk(
|
||||
node_id=self.node_id,
|
||||
text=text,
|
||||
metadata={
|
||||
"node_type": self.node_type,
|
||||
"value_selector": value_selector
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
return cls._extract_variable_selector_to_variable_mapping(node_data)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def node_type(self) -> NodeType:
|
||||
"""
|
||||
Get node type
|
||||
:return:
|
||||
"""
|
||||
return self._node_type
|
||||
0
api/core/workflow/nodes/code/__init__.py
Normal file
0
api/core/workflow/nodes/code/__init__.py
Normal file
348
api/core/workflow/nodes/code/code_node.py
Normal file
348
api/core/workflow/nodes/code/code_node.py
Normal file
@@ -0,0 +1,348 @@
|
||||
import os
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_NUMBER = int(os.environ.get('CODE_MAX_NUMBER', '9223372036854775807'))
|
||||
MIN_NUMBER = int(os.environ.get('CODE_MIN_NUMBER', '-9223372036854775808'))
|
||||
MAX_PRECISION = 20
|
||||
MAX_DEPTH = 5
|
||||
MAX_STRING_LENGTH = int(os.environ.get('CODE_MAX_STRING_LENGTH', '80000'))
|
||||
MAX_STRING_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_STRING_ARRAY_LENGTH', '30'))
|
||||
MAX_OBJECT_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_OBJECT_ARRAY_LENGTH', '30'))
|
||||
MAX_NUMBER_ARRAY_LENGTH = int(os.environ.get('CODE_MAX_NUMBER_ARRAY_LENGTH', '1000'))
|
||||
|
||||
JAVASCRIPT_DEFAULT_CODE = """function main({arg1, arg2}) {
|
||||
return {
|
||||
result: arg1 + arg2
|
||||
}
|
||||
}"""
|
||||
|
||||
PYTHON_DEFAULT_CODE = """def main(arg1: int, arg2: int) -> dict:
|
||||
return {
|
||||
"result": arg1 + arg2,
|
||||
}"""
|
||||
|
||||
class CodeNode(BaseNode):
|
||||
_node_data_cls = CodeNodeData
|
||||
node_type = NodeType.CODE
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
if filters and filters.get("code_language") == "javascript":
|
||||
return {
|
||||
"type": "code",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"variable": "arg1",
|
||||
"value_selector": []
|
||||
},
|
||||
{
|
||||
"variable": "arg2",
|
||||
"value_selector": []
|
||||
}
|
||||
],
|
||||
"code_language": "javascript",
|
||||
"code": JAVASCRIPT_DEFAULT_CODE,
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string",
|
||||
"children": None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"type": "code",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"variable": "arg1",
|
||||
"value_selector": []
|
||||
},
|
||||
{
|
||||
"variable": "arg2",
|
||||
"value_selector": []
|
||||
}
|
||||
],
|
||||
"code_language": "python3",
|
||||
"code": PYTHON_DEFAULT_CODE,
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string",
|
||||
"children": None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run code
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data: CodeNodeData = cast(self._node_data_cls, node_data)
|
||||
|
||||
# Get code language
|
||||
code_language = node_data.code_language
|
||||
code = node_data.code
|
||||
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
variables[variable] = value
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_code(
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables
|
||||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result, node_data.outputs)
|
||||
except (CodeExecutionException, ValueError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
outputs=result
|
||||
)
|
||||
|
||||
def _check_string(self, value: str, variable: str) -> str:
|
||||
"""
|
||||
Check string
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable} in output form must be a string")
|
||||
|
||||
if len(value) > MAX_STRING_LENGTH:
|
||||
raise ValueError(f'{variable} in output form must be less than {MAX_STRING_LENGTH} characters')
|
||||
|
||||
return value.replace('\x00', '')
|
||||
|
||||
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
|
||||
"""
|
||||
Check number
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(value, int | float):
|
||||
raise ValueError(f"{variable} in output form must be a number")
|
||||
|
||||
if value > MAX_NUMBER or value < MIN_NUMBER:
|
||||
raise ValueError(f'{variable} in input form is out of range.')
|
||||
|
||||
if isinstance(value, float):
|
||||
# raise error if precision is too high
|
||||
if len(str(value).split('.')[1]) > MAX_PRECISION:
|
||||
raise ValueError(f'{variable} in output form has too high precision.')
|
||||
|
||||
return value
|
||||
|
||||
def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]],
|
||||
prefix: str = '',
|
||||
depth: int = 1) -> dict:
|
||||
"""
|
||||
Transform result
|
||||
:param result: result
|
||||
:param output_schema: output schema
|
||||
:return:
|
||||
"""
|
||||
if depth > MAX_DEPTH:
|
||||
raise ValueError("Depth limit reached, object too deep.")
|
||||
|
||||
transformed_result = {}
|
||||
if output_schema is None:
|
||||
# validate output thought instance type
|
||||
for output_name, output_value in result.items():
|
||||
if isinstance(output_value, dict):
|
||||
self._transform_result(
|
||||
result=output_value,
|
||||
output_schema=None,
|
||||
prefix=f'{prefix}.{output_name}' if prefix else output_name,
|
||||
depth=depth + 1
|
||||
)
|
||||
elif isinstance(output_value, int | float):
|
||||
self._check_number(
|
||||
value=output_value,
|
||||
variable=f'{prefix}.{output_name}' if prefix else output_name
|
||||
)
|
||||
elif isinstance(output_value, str):
|
||||
self._check_string(
|
||||
value=output_value,
|
||||
variable=f'{prefix}.{output_name}' if prefix else output_name
|
||||
)
|
||||
elif isinstance(output_value, list):
|
||||
first_element = output_value[0] if len(output_value) > 0 else None
|
||||
if first_element is not None:
|
||||
if isinstance(first_element, int | float) and all(isinstance(value, int | float) for value in output_value):
|
||||
for i, value in enumerate(output_value):
|
||||
self._check_number(
|
||||
value=value,
|
||||
variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]'
|
||||
)
|
||||
elif isinstance(first_element, str) and all(isinstance(value, str) for value in output_value):
|
||||
for i, value in enumerate(output_value):
|
||||
self._check_string(
|
||||
value=value,
|
||||
variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]'
|
||||
)
|
||||
elif isinstance(first_element, dict) and all(isinstance(value, dict) for value in output_value):
|
||||
for i, value in enumerate(output_value):
|
||||
self._transform_result(
|
||||
result=value,
|
||||
output_schema=None,
|
||||
prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]',
|
||||
depth=depth + 1
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.')
|
||||
else:
|
||||
raise ValueError(f'Output {prefix}.{output_name} is not a valid type.')
|
||||
|
||||
return result
|
||||
|
||||
parameters_validated = {}
|
||||
for output_name, output_config in output_schema.items():
|
||||
dot = '.' if prefix else ''
|
||||
if output_config.type == 'object':
|
||||
# check if output is object
|
||||
if not isinstance(result.get(output_name), dict):
|
||||
raise ValueError(
|
||||
f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.'
|
||||
)
|
||||
|
||||
transformed_result[output_name] = self._transform_result(
|
||||
result=result[output_name],
|
||||
output_schema=output_config.children,
|
||||
prefix=f'{prefix}.{output_name}',
|
||||
depth=depth + 1
|
||||
)
|
||||
elif output_config.type == 'number':
|
||||
# check if number available
|
||||
transformed_result[output_name] = self._check_number(
|
||||
value=result[output_name],
|
||||
variable=f'{prefix}{dot}{output_name}'
|
||||
)
|
||||
elif output_config.type == 'string':
|
||||
# check if string available
|
||||
transformed_result[output_name] = self._check_string(
|
||||
value=result[output_name],
|
||||
variable=f'{prefix}{dot}{output_name}',
|
||||
)
|
||||
elif output_config.type == 'array[number]':
|
||||
# check if array of number available
|
||||
if not isinstance(result[output_name], list):
|
||||
raise ValueError(
|
||||
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
|
||||
)
|
||||
|
||||
if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'{prefix}{dot}{output_name} in output form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters.'
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_number(
|
||||
value=value,
|
||||
variable=f'{prefix}{dot}{output_name}[{i}]'
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == 'array[string]':
|
||||
# check if array of string available
|
||||
if not isinstance(result[output_name], list):
|
||||
raise ValueError(
|
||||
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
|
||||
)
|
||||
|
||||
if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'{prefix}{dot}{output_name} in output form must be less than {MAX_STRING_ARRAY_LENGTH} characters.'
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_string(
|
||||
value=value,
|
||||
variable=f'{prefix}{dot}{output_name}[{i}]'
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == 'array[object]':
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
raise ValueError(
|
||||
f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.'
|
||||
)
|
||||
|
||||
if len(result[output_name]) > MAX_OBJECT_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'{prefix}{dot}{output_name} in output form must be less than {MAX_OBJECT_ARRAY_LENGTH} characters.'
|
||||
)
|
||||
|
||||
for i, value in enumerate(result[output_name]):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(
|
||||
f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.'
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._transform_result(
|
||||
result=value,
|
||||
output_schema=output_config.children,
|
||||
prefix=f'{prefix}{dot}{output_name}[{i}]',
|
||||
depth=depth + 1
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
else:
|
||||
raise ValueError(f'Output type {output_config.type} is not supported.')
|
||||
|
||||
parameters_validated[output_name] = True
|
||||
|
||||
# check if all output parameters are validated
|
||||
if len(parameters_validated) != len(result):
|
||||
raise ValueError('Not all output parameters are validated.')
|
||||
|
||||
return transformed_result
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
||||
return {
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
}
|
||||
20
api/core/workflow/nodes/code/entities.py
Normal file
20
api/core/workflow/nodes/code/entities.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
Code Node Data.
|
||||
"""
|
||||
class Output(BaseModel):
|
||||
type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
|
||||
children: Optional[dict[str, 'Output']]
|
||||
|
||||
variables: list[VariableSelector]
|
||||
code_language: Literal['python3', 'javascript']
|
||||
code: str
|
||||
outputs: dict[str, Output]
|
||||
0
api/core/workflow/nodes/end/__init__.py
Normal file
0
api/core/workflow/nodes/end/__init__.py
Normal file
46
api/core/workflow/nodes/end/end_node.py
Normal file
46
api/core/workflow/nodes/end/end_node.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class EndNode(BaseNode):
|
||||
_node_data_cls = EndNodeData
|
||||
node_type = NodeType.END
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
output_variables = node_data.outputs
|
||||
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
outputs[variable_selector.variable] = value
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=outputs,
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
9
api/core/workflow/nodes/end/entities.py
Normal file
9
api/core/workflow/nodes/end/entities.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
"""
|
||||
END Node Data.
|
||||
"""
|
||||
outputs: list[VariableSelector]
|
||||
0
api/core/workflow/nodes/http_request/__init__.py
Normal file
0
api/core/workflow/nodes/http_request/__init__.py
Normal file
43
api/core/workflow/nodes/http_request/entities.py
Normal file
43
api/core/workflow/nodes/http_request/entities.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class HttpRequestNodeData(BaseNodeData):
|
||||
"""
|
||||
Code Node Data.
|
||||
"""
|
||||
class Authorization(BaseModel):
|
||||
class Config(BaseModel):
|
||||
type: Literal[None, 'basic', 'bearer', 'custom']
|
||||
api_key: Union[None, str]
|
||||
header: Union[None, str]
|
||||
|
||||
type: Literal['no-auth', 'api-key']
|
||||
config: Optional[Config]
|
||||
|
||||
@validator('config', always=True, pre=True)
|
||||
def check_config(cls, v, values):
|
||||
"""
|
||||
Check config, if type is no-auth, config should be None, otherwise it should be a dict.
|
||||
"""
|
||||
if values['type'] == 'no-auth':
|
||||
return None
|
||||
else:
|
||||
if not v or not isinstance(v, dict):
|
||||
raise ValueError('config should be a dict')
|
||||
|
||||
return v
|
||||
|
||||
class Body(BaseModel):
|
||||
type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
|
||||
data: Union[None, str]
|
||||
|
||||
method: Literal['get', 'post', 'put', 'patch', 'delete', 'head']
|
||||
url: str
|
||||
authorization: Authorization
|
||||
headers: str
|
||||
params: str
|
||||
body: Optional[Body]
|
||||
402
api/core/workflow/nodes/http_request/http_executor.py
Normal file
402
api/core/workflow/nodes/http_request/http_executor.py
Normal file
@@ -0,0 +1,402 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from random import randint
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
HTTP_REQUEST_DEFAULT_TIMEOUT = (10, 60)
|
||||
MAX_BINARY_SIZE = 1024 * 1024 * 10 # 10MB
|
||||
READABLE_MAX_BINARY_SIZE = '10MB'
|
||||
MAX_TEXT_SIZE = 1024 * 1024 // 10 # 0.1MB
|
||||
READABLE_MAX_TEXT_SIZE = '0.1MB'
|
||||
|
||||
|
||||
class HttpExecutorResponse:
|
||||
headers: dict[str, str]
|
||||
response: Union[httpx.Response, requests.Response]
|
||||
|
||||
def __init__(self, response: Union[httpx.Response, requests.Response] = None):
|
||||
"""
|
||||
init
|
||||
"""
|
||||
headers = {}
|
||||
if isinstance(response, httpx.Response):
|
||||
for k, v in response.headers.items():
|
||||
headers[k] = v
|
||||
elif isinstance(response, requests.Response):
|
||||
for k, v in response.headers.items():
|
||||
headers[k] = v
|
||||
|
||||
self.headers = headers
|
||||
self.response = response
|
||||
|
||||
@property
|
||||
def is_file(self) -> bool:
|
||||
"""
|
||||
check if response is file
|
||||
"""
|
||||
content_type = self.get_content_type()
|
||||
file_content_types = ['image', 'audio', 'video']
|
||||
for v in file_content_types:
|
||||
if v in content_type:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_content_type(self) -> str:
|
||||
"""
|
||||
get content type
|
||||
"""
|
||||
for key, val in self.headers.items():
|
||||
if key.lower() == 'content-type':
|
||||
return val
|
||||
|
||||
return ''
|
||||
|
||||
def extract_file(self) -> tuple[str, bytes]:
|
||||
"""
|
||||
extract file from response if content type is file related
|
||||
"""
|
||||
if self.is_file:
|
||||
return self.get_content_type(), self.body
|
||||
|
||||
return '', b''
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
"""
|
||||
get content
|
||||
"""
|
||||
if isinstance(self.response, httpx.Response):
|
||||
return self.response.text
|
||||
elif isinstance(self.response, requests.Response):
|
||||
return self.response.text
|
||||
else:
|
||||
raise ValueError(f'Invalid response type {type(self.response)}')
|
||||
|
||||
@property
|
||||
def body(self) -> bytes:
|
||||
"""
|
||||
get body
|
||||
"""
|
||||
if isinstance(self.response, httpx.Response):
|
||||
return self.response.content
|
||||
elif isinstance(self.response, requests.Response):
|
||||
return self.response.content
|
||||
else:
|
||||
raise ValueError(f'Invalid response type {type(self.response)}')
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
"""
|
||||
get status code
|
||||
"""
|
||||
if isinstance(self.response, httpx.Response):
|
||||
return self.response.status_code
|
||||
elif isinstance(self.response, requests.Response):
|
||||
return self.response.status_code
|
||||
else:
|
||||
raise ValueError(f'Invalid response type {type(self.response)}')
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""
|
||||
get size
|
||||
"""
|
||||
return len(self.body)
|
||||
|
||||
@property
|
||||
def readable_size(self) -> str:
|
||||
"""
|
||||
get readable size
|
||||
"""
|
||||
if self.size < 1024:
|
||||
return f'{self.size} bytes'
|
||||
elif self.size < 1024 * 1024:
|
||||
return f'{(self.size / 1024):.2f} KB'
|
||||
else:
|
||||
return f'{(self.size / 1024 / 1024):.2f} MB'
|
||||
|
||||
|
||||
class HttpExecutor:
|
||||
server_url: str
|
||||
method: str
|
||||
authorization: HttpRequestNodeData.Authorization
|
||||
params: dict[str, Any]
|
||||
headers: dict[str, Any]
|
||||
body: Union[None, str]
|
||||
files: Union[None, dict[str, Any]]
|
||||
boundary: str
|
||||
variable_selectors: list[VariableSelector]
|
||||
|
||||
def __init__(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
|
||||
"""
|
||||
init
|
||||
"""
|
||||
self.server_url = node_data.url
|
||||
self.method = node_data.method
|
||||
self.authorization = node_data.authorization
|
||||
self.params = {}
|
||||
self.headers = {}
|
||||
self.body = None
|
||||
self.files = None
|
||||
|
||||
# init template
|
||||
self.variable_selectors = []
|
||||
self._init_template(node_data, variable_pool)
|
||||
|
||||
def _is_json_body(self, body: HttpRequestNodeData.Body):
|
||||
"""
|
||||
check if body is json
|
||||
"""
|
||||
if body and body.type == 'json':
|
||||
try:
|
||||
json.loads(body.data)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
|
||||
"""
|
||||
init template
|
||||
"""
|
||||
variable_selectors = []
|
||||
|
||||
# extract all template in url
|
||||
self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
|
||||
|
||||
# extract all template in params
|
||||
params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
|
||||
|
||||
# fill in params
|
||||
kv_paris = params.split('\n')
|
||||
for kv in kv_paris:
|
||||
if not kv.strip():
|
||||
continue
|
||||
|
||||
kv = kv.split(':')
|
||||
if len(kv) == 2:
|
||||
k, v = kv
|
||||
elif len(kv) == 1:
|
||||
k, v = kv[0], ''
|
||||
else:
|
||||
raise ValueError(f'Invalid params {kv}')
|
||||
|
||||
self.params[k.strip()] = v
|
||||
|
||||
# extract all template in headers
|
||||
headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
|
||||
|
||||
# fill in headers
|
||||
kv_paris = headers.split('\n')
|
||||
for kv in kv_paris:
|
||||
if not kv.strip():
|
||||
continue
|
||||
|
||||
kv = kv.split(':')
|
||||
if len(kv) == 2:
|
||||
k, v = kv
|
||||
elif len(kv) == 1:
|
||||
k, v = kv[0], ''
|
||||
else:
|
||||
raise ValueError(f'Invalid headers {kv}')
|
||||
|
||||
self.headers[k.strip()] = v.strip()
|
||||
|
||||
# extract all template in body
|
||||
body_data_variable_selectors = []
|
||||
if node_data.body:
|
||||
# check if it's a valid JSON
|
||||
is_valid_json = self._is_json_body(node_data.body)
|
||||
|
||||
body_data = node_data.body.data or ''
|
||||
if body_data:
|
||||
body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json)
|
||||
|
||||
if node_data.body.type == 'json':
|
||||
self.headers['Content-Type'] = 'application/json'
|
||||
elif node_data.body.type == 'x-www-form-urlencoded':
|
||||
self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
|
||||
|
||||
if node_data.body.type in ['form-data', 'x-www-form-urlencoded']:
|
||||
body = {}
|
||||
kv_paris = body_data.split('\n')
|
||||
for kv in kv_paris:
|
||||
if not kv.strip():
|
||||
continue
|
||||
kv = kv.split(':')
|
||||
if len(kv) == 2:
|
||||
body[kv[0].strip()] = kv[1]
|
||||
elif len(kv) == 1:
|
||||
body[kv[0].strip()] = ''
|
||||
else:
|
||||
raise ValueError(f'Invalid body {kv}')
|
||||
|
||||
if node_data.body.type == 'form-data':
|
||||
self.files = {
|
||||
k: ('', v) for k, v in body.items()
|
||||
}
|
||||
random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)])
|
||||
self.boundary = f'----WebKitFormBoundary{random_str(16)}'
|
||||
|
||||
self.headers['Content-Type'] = f'multipart/form-data; boundary={self.boundary}'
|
||||
else:
|
||||
self.body = urlencode(body)
|
||||
elif node_data.body.type in ['json', 'raw-text']:
|
||||
self.body = body_data
|
||||
elif node_data.body.type == 'none':
|
||||
self.body = ''
|
||||
|
||||
self.variable_selectors = (server_url_variable_selectors + params_variable_selectors
|
||||
+ headers_variable_selectors + body_data_variable_selectors)
|
||||
|
||||
def _assembling_headers(self) -> dict[str, Any]:
|
||||
authorization = deepcopy(self.authorization)
|
||||
headers = deepcopy(self.headers) or {}
|
||||
if self.authorization.type == 'api-key':
|
||||
if self.authorization.config.api_key is None:
|
||||
raise ValueError('api_key is required')
|
||||
|
||||
if not self.authorization.config.header:
|
||||
authorization.config.header = 'Authorization'
|
||||
|
||||
if self.authorization.config.type == 'bearer':
|
||||
headers[authorization.config.header] = f'Bearer {authorization.config.api_key}'
|
||||
elif self.authorization.config.type == 'basic':
|
||||
headers[authorization.config.header] = f'Basic {authorization.config.api_key}'
|
||||
elif self.authorization.config.type == 'custom':
|
||||
headers[authorization.config.header] = authorization.config.api_key
|
||||
|
||||
return headers
|
||||
|
||||
def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse:
|
||||
"""
|
||||
validate the response
|
||||
"""
|
||||
if isinstance(response, httpx.Response | requests.Response):
|
||||
executor_response = HttpExecutorResponse(response)
|
||||
else:
|
||||
raise ValueError(f'Invalid response type {type(response)}')
|
||||
|
||||
if executor_response.is_file:
|
||||
if executor_response.size > MAX_BINARY_SIZE:
|
||||
raise ValueError(f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.')
|
||||
else:
|
||||
if executor_response.size > MAX_TEXT_SIZE:
|
||||
raise ValueError(f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.')
|
||||
|
||||
return executor_response
|
||||
|
||||
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
# do http request
|
||||
kwargs = {
|
||||
'url': self.server_url,
|
||||
'headers': headers,
|
||||
'params': self.params,
|
||||
'timeout': HTTP_REQUEST_DEFAULT_TIMEOUT,
|
||||
'follow_redirects': True
|
||||
}
|
||||
|
||||
if self.method == 'get':
|
||||
response = ssrf_proxy.get(**kwargs)
|
||||
elif self.method == 'post':
|
||||
response = ssrf_proxy.post(data=self.body, files=self.files, **kwargs)
|
||||
elif self.method == 'put':
|
||||
response = ssrf_proxy.put(data=self.body, files=self.files, **kwargs)
|
||||
elif self.method == 'delete':
|
||||
response = ssrf_proxy.delete(data=self.body, files=self.files, **kwargs)
|
||||
elif self.method == 'patch':
|
||||
response = ssrf_proxy.patch(data=self.body, files=self.files, **kwargs)
|
||||
elif self.method == 'head':
|
||||
response = ssrf_proxy.head(**kwargs)
|
||||
elif self.method == 'options':
|
||||
response = ssrf_proxy.options(**kwargs)
|
||||
else:
|
||||
raise ValueError(f'Invalid http method {self.method}')
|
||||
|
||||
return response
|
||||
|
||||
def invoke(self) -> HttpExecutorResponse:
|
||||
"""
|
||||
invoke http request
|
||||
"""
|
||||
# assemble headers
|
||||
headers = self._assembling_headers()
|
||||
|
||||
# do http request
|
||||
response = self._do_http_request(headers)
|
||||
|
||||
# validate response
|
||||
return self._validate_and_parse_response(response)
|
||||
|
||||
def to_raw_request(self) -> str:
|
||||
"""
|
||||
convert to raw request
|
||||
"""
|
||||
server_url = self.server_url
|
||||
if self.params:
|
||||
server_url += f'?{urlencode(self.params)}'
|
||||
|
||||
raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n'
|
||||
|
||||
headers = self._assembling_headers()
|
||||
for k, v in headers.items():
|
||||
raw_request += f'{k}: {v}\n'
|
||||
|
||||
raw_request += '\n'
|
||||
|
||||
# if files, use multipart/form-data with boundary
|
||||
if self.files:
|
||||
boundary = self.boundary
|
||||
raw_request += f'--{boundary}'
|
||||
for k, v in self.files.items():
|
||||
raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n'
|
||||
raw_request += f'{v[1]}\n'
|
||||
raw_request += f'--{boundary}'
|
||||
raw_request += '--'
|
||||
else:
|
||||
raw_request += self.body or ''
|
||||
|
||||
return raw_request
|
||||
|
||||
def _format_template(self, template: str, variable_pool: VariablePool, escape_quotes: bool = False) \
|
||||
-> tuple[str, list[VariableSelector]]:
|
||||
"""
|
||||
format template
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=template)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
if variable_pool:
|
||||
variable_value_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector,
|
||||
target_value_type=ValueType.STRING
|
||||
)
|
||||
|
||||
if value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
if escape_quotes:
|
||||
value = value.replace('"', '\\"')
|
||||
|
||||
variable_value_mapping[variable_selector.variable] = value
|
||||
|
||||
return variable_template_parser.format(variable_value_mapping), variable_selectors
|
||||
else:
|
||||
return template, variable_selectors
|
||||
112
api/core/workflow/nodes/http_request/http_request_node.py
Normal file
112
api/core/workflow/nodes/http_request/http_request_node.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import logging
|
||||
from mimetypes import guess_extension
|
||||
from os import path
|
||||
from typing import cast
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeData
|
||||
from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class HttpRequestNode(BaseNode):
|
||||
_node_data_cls = HttpRequestNodeData
|
||||
node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data)
|
||||
|
||||
# init http executor
|
||||
http_executor = None
|
||||
try:
|
||||
http_executor = HttpExecutor(node_data=node_data, variable_pool=variable_pool)
|
||||
|
||||
# invoke http executor
|
||||
response = http_executor.invoke()
|
||||
except Exception as e:
|
||||
process_data = {}
|
||||
if http_executor:
|
||||
process_data = {
|
||||
'request': http_executor.to_raw_request(),
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
process_data=process_data
|
||||
)
|
||||
|
||||
files = self.extract_files(http_executor.server_url, response)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'status_code': response.status_code,
|
||||
'body': response.content if not files else '',
|
||||
'headers': response.headers,
|
||||
'files': files,
|
||||
},
|
||||
process_data={
|
||||
'request': http_executor.to_raw_request(),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
http_executor = HttpExecutor(node_data=node_data)
|
||||
|
||||
variable_selectors = http_executor.variable_selectors
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to extract variable selector to variable mapping: {e}")
|
||||
return {}
|
||||
|
||||
def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]:
|
||||
"""
|
||||
Extract files from response
|
||||
"""
|
||||
files = []
|
||||
mimetype, file_binary = response.extract_file()
|
||||
# if not image, return directly
|
||||
if 'image' not in mimetype:
|
||||
return files
|
||||
|
||||
if mimetype:
|
||||
# extract filename from url
|
||||
filename = path.basename(url)
|
||||
# extract extension if possible
|
||||
extension = guess_extension(mimetype) or '.bin'
|
||||
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_binary,
|
||||
mimetype=mimetype,
|
||||
)
|
||||
|
||||
files.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file.id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
))
|
||||
|
||||
return files
|
||||
0
api/core/workflow/nodes/if_else/__init__.py
Normal file
0
api/core/workflow/nodes/if_else/__init__.py
Normal file
26
api/core/workflow/nodes/if_else/entities.py
Normal file
26
api/core/workflow/nodes/if_else/entities.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class IfElseNodeData(BaseNodeData):
|
||||
"""
|
||||
Answer Node Data.
|
||||
"""
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition entity
|
||||
"""
|
||||
variable_selector: list[str]
|
||||
comparison_operator: Literal[
|
||||
# for string or array
|
||||
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
|
||||
# for number
|
||||
"=", "≠", ">", "<", "≥", "≤", "null", "not null"
|
||||
]
|
||||
value: Optional[str] = None
|
||||
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
conditions: list[Condition]
|
||||
398
api/core/workflow/nodes/if_else/if_else_node.py
Normal file
398
api/core/workflow/nodes/if_else/if_else_node.py
Normal file
@@ -0,0 +1,398 @@
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IfElseNode(BaseNode):
|
||||
_node_data_cls = IfElseNodeData
|
||||
node_type = NodeType.IF_ELSE
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
|
||||
node_inputs = {
|
||||
"conditions": []
|
||||
}
|
||||
|
||||
process_datas = {
|
||||
"condition_results": []
|
||||
}
|
||||
|
||||
try:
|
||||
logical_operator = node_data.logical_operator
|
||||
input_conditions = []
|
||||
for condition in node_data.conditions:
|
||||
actual_value = variable_pool.get_variable_value(
|
||||
variable_selector=condition.variable_selector
|
||||
)
|
||||
|
||||
expected_value = condition.value
|
||||
|
||||
input_conditions.append({
|
||||
"actual_value": actual_value,
|
||||
"expected_value": expected_value,
|
||||
"comparison_operator": condition.comparison_operator
|
||||
})
|
||||
|
||||
node_inputs["conditions"] = input_conditions
|
||||
|
||||
for input_condition in input_conditions:
|
||||
actual_value = input_condition["actual_value"]
|
||||
expected_value = input_condition["expected_value"]
|
||||
comparison_operator = input_condition["comparison_operator"]
|
||||
|
||||
if comparison_operator == "contains":
|
||||
compare_result = self._assert_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "not contains":
|
||||
compare_result = self._assert_not_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "start with":
|
||||
compare_result = self._assert_start_with(actual_value, expected_value)
|
||||
elif comparison_operator == "end with":
|
||||
compare_result = self._assert_end_with(actual_value, expected_value)
|
||||
elif comparison_operator == "is":
|
||||
compare_result = self._assert_is(actual_value, expected_value)
|
||||
elif comparison_operator == "is not":
|
||||
compare_result = self._assert_is_not(actual_value, expected_value)
|
||||
elif comparison_operator == "empty":
|
||||
compare_result = self._assert_empty(actual_value)
|
||||
elif comparison_operator == "not empty":
|
||||
compare_result = self._assert_not_empty(actual_value)
|
||||
elif comparison_operator == "=":
|
||||
compare_result = self._assert_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≠":
|
||||
compare_result = self._assert_not_equal(actual_value, expected_value)
|
||||
elif comparison_operator == ">":
|
||||
compare_result = self._assert_greater_than(actual_value, expected_value)
|
||||
elif comparison_operator == "<":
|
||||
compare_result = self._assert_less_than(actual_value, expected_value)
|
||||
elif comparison_operator == "≥":
|
||||
compare_result = self._assert_greater_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≤":
|
||||
compare_result = self._assert_less_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "null":
|
||||
compare_result = self._assert_null(actual_value)
|
||||
elif comparison_operator == "not null":
|
||||
compare_result = self._assert_not_null(actual_value)
|
||||
else:
|
||||
continue
|
||||
|
||||
process_datas["condition_results"].append({
|
||||
**input_condition,
|
||||
"result": compare_result
|
||||
})
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_datas,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if logical_operator == "and":
|
||||
compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]]
|
||||
else:
|
||||
compare_result = True in [condition["result"] for condition in process_datas["condition_results"]]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_datas,
|
||||
edge_source_handle="false" if not compare_result else "true",
|
||||
outputs={
|
||||
"result": compare_result
|
||||
}
|
||||
)
|
||||
|
||||
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value not in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert not contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert start with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.startswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert end with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.endswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is not
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert not empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert not equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert greater than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value <= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert less than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value >= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert greater than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value < expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert less than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value > expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert not null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
51
api/core/workflow/nodes/knowledge_retrieval/entities.py
Normal file
51
api/core/workflow/nodes/knowledge_retrieval/entities.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
class MultipleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Multiple Retrieval Config.
|
||||
"""
|
||||
top_k: int
|
||||
score_threshold: Optional[float]
|
||||
reranking_model: RerankingModelConfig
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = {}
|
||||
|
||||
|
||||
class SingleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Single Retrieval Config.
|
||||
"""
|
||||
model: ModelConfig
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
type: str = 'knowledge-retrieval'
|
||||
query_variable_selector: list[str]
|
||||
dataset_ids: list[str]
|
||||
retrieval_mode: Literal['single', 'multiple']
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig]
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig]
|
||||
@@ -0,0 +1,443 @@
|
||||
import threading
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
_node_data_cls = KnowledgeRetrievalNodeData
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
|
||||
|
||||
# extract variables
|
||||
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
||||
variables = {
|
||||
'query': query
|
||||
}
|
||||
if not query:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Query is required."
|
||||
)
|
||||
# retrieve knowledge
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(
|
||||
node_data=node_data, query=query
|
||||
)
|
||||
outputs = {
|
||||
'result': results
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=None,
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[
|
||||
dict[str, Any]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param node_data: node data
|
||||
:param query: query
|
||||
"""
|
||||
tools = []
|
||||
available_datasets = []
|
||||
dataset_ids = node_data.dataset_ids
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
all_documents = []
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
all_documents = self._single_retrieve(available_datasets, node_data, query)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
all_documents = self._multiple_retrieve(available_datasets, node_data, query)
|
||||
|
||||
context_list = []
|
||||
if all_documents:
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if 'score' in item.metadata and item.metadata['score']:
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(segments,
|
||||
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
|
||||
float('inf')))
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
||||
else:
|
||||
document_context_list.append(segment.content)
|
||||
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=segment.dataset_id
|
||||
).first()
|
||||
document = Document.query.filter(Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
resource_number = 1
|
||||
if dataset and document:
|
||||
|
||||
source = {
|
||||
'metadata': {
|
||||
'_source': 'knowledge',
|
||||
'position': resource_number,
|
||||
'dataset_id': dataset.id,
|
||||
'dataset_name': dataset.name,
|
||||
'document_id': document.id,
|
||||
'document_name': document.name,
|
||||
'document_data_source_type': document.data_source_type,
|
||||
'segment_id': segment.id,
|
||||
'retriever_from': 'workflow',
|
||||
'score': document_score_list.get(segment.index_node_id, None),
|
||||
'segment_hit_count': segment.hit_count,
|
||||
'segment_word_count': segment.word_count,
|
||||
'segment_position': segment.position,
|
||||
'segment_index_node_hash': segment.index_node_hash,
|
||||
},
|
||||
'title': document.name
|
||||
}
|
||||
if segment.answer:
|
||||
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
||||
else:
|
||||
source['content'] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
return context_list
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
variable_mapping = {}
|
||||
variable_mapping['query'] = node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def _single_retrieve(self, available_datasets, node_data, query):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
description = description.replace('\n', '').replace('\r', '')
|
||||
message_tool = PromptMessageTool(
|
||||
name=dataset.id,
|
||||
description=description,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
tools.append(message_tool)
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data)
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
dataset_id = None
|
||||
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
||||
react_multi_dataset_router = ReactMultiDatasetRouter()
|
||||
dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance,
|
||||
self.user_id, self.tenant_id)
|
||||
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
function_call_router = FunctionCallMultiDatasetRouter()
|
||||
dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if dataset:
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
# get retrieval method
|
||||
retrival_method = retrieval_model_config['search_method']
|
||||
# get reranking model
|
||||
reranking_model=retrieval_model_config['reranking_model'] \
|
||||
if retrieval_model_config['reranking_enable'] else None
|
||||
# get score threshold
|
||||
score_threshold = .0
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
self._on_query(query, [dataset_id])
|
||||
if results:
|
||||
self._on_retrival_end(results)
|
||||
return results
|
||||
return []
|
||||
|
||||
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
|
||||
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
model_name = node_data.single_retrieval_config.model.name
|
||||
provider_name = node_data.single_retrieval_config.model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider_name,
|
||||
model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = node_data.single_retrieval_config.model.completion_params
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data.single_retrieval_config.model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _multiple_retrieve(self, available_datasets, node_data, query):
|
||||
threads = []
|
||||
all_documents = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
for dataset in available_datasets:
|
||||
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset.id,
|
||||
'query': query,
|
||||
'top_k': node_data.multiple_retrieval_config.top_k,
|
||||
'all_documents': all_documents,
|
||||
})
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
model_type=ModelType.RERANK,
|
||||
model=node_data.multiple_retrieval_config.reranking_model.model
|
||||
)
|
||||
|
||||
rerank_runner = RerankRunner(rerank_model_instance)
|
||||
all_documents = rerank_runner.run(query, all_documents,
|
||||
node_data.multiple_retrieval_config.score_threshold,
|
||||
node_data.multiple_retrieval_config.top_k)
|
||||
self._on_query(query, dataset_ids)
|
||||
if all_documents:
|
||||
self._on_retrival_end(all_documents)
|
||||
return all_documents
|
||||
|
||||
def _on_retrival_end(self, documents: list[Document]) -> None:
|
||||
"""Handle retrival end."""
|
||||
for document in documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if 'dataset_id' in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _on_query(self, query: str, dataset_ids: list[str]) -> None:
|
||||
"""
|
||||
Handle query.
|
||||
"""
|
||||
if not query:
|
||||
return
|
||||
for dataset_id in dataset_ids:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source='app',
|
||||
source_app_id=self.app_id,
|
||||
created_by_role=self.user_from.value,
|
||||
created_by=self.user_id
|
||||
)
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(retrival_method='keyword_search',
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
if retrieval_model['reranking_enable'] else None
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
|
||||
|
||||
|
||||
class FunctionCallMultiDatasetRouter:
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
query: str,
|
||||
dataset_tools: list[PromptMessageTool],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
|
||||
) -> Union[str, None]:
|
||||
"""Given input, decided what to do.
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(dataset_tools) == 0:
|
||||
return None
|
||||
elif len(dataset_tools) == 1:
|
||||
return dataset_tools[0].name
|
||||
|
||||
try:
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content='You are a helpful AI assistant.'),
|
||||
UserPromptMessage(content=query)
|
||||
]
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
if result.message.tool_calls:
|
||||
# get retrieval model config
|
||||
return result.message.tool_calls[0].function.name
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
@@ -0,0 +1,254 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
|
||||
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}
|
||||
```"""
|
||||
|
||||
|
||||
class ReactMultiDatasetRouter:
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
query: str,
|
||||
dataset_tools: list[PromptMessageTool],
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
|
||||
) -> Union[str, None]:
|
||||
"""Given input, decided what to do.
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(dataset_tools) == 0:
|
||||
return None
|
||||
elif len(dataset_tools) == 1:
|
||||
return dataset_tools[0].name
|
||||
|
||||
try:
|
||||
return self._react_invoke(query=query, node_data=node_data, model_config=model_config, model_instance=model_instance,
|
||||
tools=dataset_tools, user_id=user_id, tenant_id=tenant_id)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def _react_invoke(
|
||||
self,
|
||||
query: str,
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> Union[str, None]:
|
||||
if model_config.mode == "chat":
|
||||
prompt = self.create_chat_prompt(
|
||||
query=query,
|
||||
tools=tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
)
|
||||
else:
|
||||
prompt = self.create_completion_prompt(
|
||||
tools=tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=None
|
||||
)
|
||||
stop = ['Observation:']
|
||||
# handle invoke result
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
context='',
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
)
|
||||
result_text, usage = self._invoke_llm(
|
||||
node_data=node_data,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
output_parser = StructuredChatOutputParser()
|
||||
agent_decision = output_parser.parse(result_text)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
return agent_decision.tool
|
||||
return None
|
||||
|
||||
def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str], user_id: str, tenant_id: str) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data: node data
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data.single_retrieval_config.model.completion_params,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
full_text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
|
||||
def create_chat_prompt(
|
||||
self,
|
||||
query: str,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> list[ChatModelMessage]:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
unique_tool_names = set(tool.name for tool in tools)
|
||||
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
prompt_messages = []
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=template
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=query
|
||||
)
|
||||
prompt_messages.append(user_prompt_message)
|
||||
return prompt_messages
|
||||
|
||||
def create_completion_prompt(
|
||||
self,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
0
api/core/workflow/nodes/llm/__init__.py
Normal file
0
api/core/workflow/nodes/llm/__init__.py
Normal file
49
api/core/workflow/nodes/llm/entities.py
Normal file
49
api/core/workflow/nodes/llm/entities.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = {}
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
"""
|
||||
Context Config.
|
||||
"""
|
||||
enabled: bool
|
||||
variable_selector: Optional[list[str]] = None
|
||||
|
||||
|
||||
class VisionConfig(BaseModel):
|
||||
"""
|
||||
Vision Config.
|
||||
"""
|
||||
class Configs(BaseModel):
|
||||
"""
|
||||
Configs.
|
||||
"""
|
||||
detail: Literal['low', 'high']
|
||||
|
||||
enabled: bool
|
||||
configs: Optional[Configs] = None
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
"""
|
||||
LLM Node Data.
|
||||
"""
|
||||
model: ModelConfig
|
||||
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig
|
||||
554
api/core/workflow/nodes/llm/llm_node.py
Normal file
554
api/core/workflow/nodes/llm/llm_node.py
Normal file
@@ -0,0 +1,554 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
_node_data_cls = LLMNodeData
|
||||
node_type = NodeType.LLM
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
|
||||
node_inputs = None
|
||||
process_data = None
|
||||
|
||||
try:
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data, variable_pool)
|
||||
|
||||
node_inputs = {}
|
||||
|
||||
# fetch files
|
||||
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
|
||||
|
||||
if files:
|
||||
node_inputs['#files#'] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
context = self._fetch_context(node_data, variable_pool)
|
||||
|
||||
if context:
|
||||
node_inputs['#context#'] = context
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
|
||||
if node_data.memory else None,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
process_data = {
|
||||
'model_mode': model_config.mode,
|
||||
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
}
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data
|
||||
)
|
||||
|
||||
outputs = {
|
||||
'text': result_text,
|
||||
'usage': jsonable_encoder(usage)
|
||||
}
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
}
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str]) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
full_text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
|
||||
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Fetch inputs
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
inputs = {}
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
elif isinstance(prompt_template, CompletionModelPromptTemplate):
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
for variable_selector in variable_selectors:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
return inputs
|
||||
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
||||
"""
|
||||
Fetch files
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data.vision.enabled:
|
||||
return []
|
||||
|
||||
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
return files
|
||||
|
||||
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
|
||||
"""
|
||||
Fetch context
|
||||
:param node_data: node data
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data.context.enabled:
|
||||
return None
|
||||
|
||||
if not node_data.context.variable_selector:
|
||||
return None
|
||||
|
||||
context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
return context_value
|
||||
elif isinstance(context_value, list):
|
||||
context_str = ''
|
||||
original_retriever_resource = []
|
||||
for item in context_value:
|
||||
if 'content' not in item:
|
||||
raise ValueError(f'Invalid context structure: {item}')
|
||||
|
||||
context_str += item['content'] + '\n'
|
||||
|
||||
retriever_resource = self._convert_to_original_retriever_resource(item)
|
||||
if retriever_resource:
|
||||
original_retriever_resource.append(retriever_resource)
|
||||
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_event(
|
||||
event=QueueRetrieverResourcesEvent(
|
||||
retriever_resources=original_retriever_resource
|
||||
)
|
||||
)
|
||||
|
||||
return context_str.strip()
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
"""
|
||||
Convert to original retriever resource, temp.
|
||||
:param context_dict: context dict
|
||||
:return:
|
||||
"""
|
||||
if ('metadata' in context_dict and '_source' in context_dict['metadata']
|
||||
and context_dict['metadata']['_source'] == 'knowledge'):
|
||||
metadata = context_dict.get('metadata', {})
|
||||
source = {
|
||||
'position': metadata.get('position'),
|
||||
'dataset_id': metadata.get('dataset_id'),
|
||||
'dataset_name': metadata.get('dataset_name'),
|
||||
'document_id': metadata.get('document_id'),
|
||||
'document_name': metadata.get('document_name'),
|
||||
'data_source_type': metadata.get('document_data_source_type'),
|
||||
'segment_id': metadata.get('segment_id'),
|
||||
'retriever_from': metadata.get('retriever_from'),
|
||||
'score': metadata.get('score'),
|
||||
'hit_count': metadata.get('segment_hit_count'),
|
||||
'word_count': metadata.get('segment_word_count'),
|
||||
'segment_position': metadata.get('segment_position'),
|
||||
'index_node_hash': metadata.get('segment_index_node_hash'),
|
||||
'content': context_dict.get('content'),
|
||||
}
|
||||
|
||||
return source
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data_model: node data model
|
||||
:return:
|
||||
"""
|
||||
model_name = node_data_model.name
|
||||
provider_name = node_data_model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider_name,
|
||||
model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = node_data_model.completion_params
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data_model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _fetch_memory(self, node_data_memory: Optional[MemoryConfig],
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
"""
|
||||
Fetch memory
|
||||
:param node_data_memory: node data memory
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
# get conversation
|
||||
conversation = db.session.query(Conversation).filter(
|
||||
Conversation.app_id == self.app_id,
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_prompt_messages(self, node_data: LLMNodeData,
|
||||
query: Optional[str],
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Fetch prompt messages
|
||||
:param node_data: node data
|
||||
:param query: query
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param context: context
|
||||
:param memory: memory
|
||||
:param model_config: model config
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=node_data.prompt_template,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
@classmethod
|
||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
"""
|
||||
Deduct LLM quota
|
||||
:param tenant_id: tenant id
|
||||
:param model_instance: model instance
|
||||
:param usage: usage
|
||||
:return:
|
||||
"""
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = 1
|
||||
|
||||
if 'gpt-4' in model_instance.model:
|
||||
used_quota = 20
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_instance.provider,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + used_quota})
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
else:
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if node_data.context.enabled:
|
||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"prompt_templates": {
|
||||
"chat_model": {
|
||||
"prompts": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "You are a helpful AI assistant."
|
||||
}
|
||||
]
|
||||
},
|
||||
"completion_model": {
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "Human",
|
||||
"assistant_prefix": "Assistant"
|
||||
},
|
||||
"prompt": {
|
||||
"text": "Here is the chat histories between human and assistant, inside "
|
||||
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
||||
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
36
api/core/workflow/nodes/question_classifier/entities.py
Normal file
36
api/core/workflow/nodes/question_classifier/entities.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = {}
|
||||
|
||||
|
||||
class ClassConfig(BaseModel):
|
||||
"""
|
||||
Class Config.
|
||||
"""
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class QuestionClassifierNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
query_variable_selector: list[str]
|
||||
type: str = 'question-classifier'
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: Optional[str]
|
||||
memory: Optional[MemoryConfig]
|
||||
@@ -0,0 +1,253 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from core.workflow.nodes.question_classifier.template_prompts import (
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
|
||||
QUESTION_CLASSIFIER_COMPLETION_PROMPT,
|
||||
QUESTION_CLASSIFIER_SYSTEM_PROMPT,
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_1,
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_2,
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_3,
|
||||
)
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class QuestionClassifierNode(LLMNode):
|
||||
_node_data_cls = QuestionClassifierNodeData
|
||||
node_type = NodeType.QUESTION_CLASSIFIER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
|
||||
node_data = cast(QuestionClassifierNodeData, node_data)
|
||||
|
||||
# extract variables
|
||||
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
||||
variables = {
|
||||
'query': query
|
||||
}
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt(
|
||||
node_data=node_data,
|
||||
context='',
|
||||
query=query,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
)
|
||||
categories = [_class.name for _class in node_data.classes]
|
||||
try:
|
||||
result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||
categories = result_text_json.get('categories', [])
|
||||
except Exception:
|
||||
logging.error(f"Failed to parse result text: {result_text}")
|
||||
try:
|
||||
process_data = {
|
||||
'model_mode': model_config.mode,
|
||||
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode,
|
||||
prompt_messages=prompt_messages
|
||||
),
|
||||
'usage': jsonable_encoder(usage),
|
||||
'topics': categories[0] if categories else ''
|
||||
}
|
||||
outputs = {
|
||||
'class_name': categories[0] if categories else ''
|
||||
}
|
||||
classes = node_data.classes
|
||||
classes_map = {class_.name: class_.id for class_ in classes}
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
edge_source_handle=classes_map.get(categories[0], None)
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
variable_mapping = {'query': node_data.query_variable_selector}
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"type": "question-classifier",
|
||||
"config": {
|
||||
"instructions": ""
|
||||
}
|
||||
}
|
||||
|
||||
def _fetch_prompt(self, node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Fetch prompt
|
||||
:param node_data: node data
|
||||
:param query: inputs
|
||||
:param context: context
|
||||
:param memory: memory
|
||||
:param model_config: model config
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, model_config, context)
|
||||
prompt_template = self._get_prompt_template(node_data, query, memory, rest_token)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str]) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000) \
|
||||
-> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
classes = node_data.classes
|
||||
class_names = [class_.name for class_ in classes]
|
||||
class_names_str = ','.join(class_names)
|
||||
instruction = node_data.instruction if node_data.instruction else ''
|
||||
input_text = query
|
||||
memory_str = ''
|
||||
if memory:
|
||||
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size)
|
||||
prompt_messages = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message_1 = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=QUESTION_CLASSIFIER_USER_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_1)
|
||||
assistant_prompt_message_1 = ChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT,
|
||||
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = ChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT,
|
||||
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_2)
|
||||
user_prompt_message_3 = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text,
|
||||
categories=class_names_str,
|
||||
classification_instructions=instruction)
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return CompletionModelPromptTemplate(
|
||||
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
|
||||
input_text=input_text,
|
||||
categories=class_names_str,
|
||||
classification_instructions=instruction)
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
@@ -0,0 +1,72 @@
|
||||
|
||||
|
||||
QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
|
||||
### Job Description',
|
||||
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
|
||||
### Task
|
||||
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output.Additionally, you need to extract the key words from the text that are related to the classification.
|
||||
### Format
|
||||
The input text is in the variable text_field.Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination.Classification instructions may be included to improve the classification accuracy.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Memory
|
||||
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_1 = """
|
||||
{ "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],
|
||||
"categories": ["Customer Service, Satisfaction, Sales, Product"],
|
||||
"classification_instructions": ["classify the text based on the feedback provided by customer"]}```JSON
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
|
||||
{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],
|
||||
"categories": ["Customer Service"]}```
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_2 = """
|
||||
{"input_text": ["bad service, slow to bring the food"],
|
||||
"categories": ["Food Quality, Experience, Price" ],
|
||||
"classification_instructions": []}```JSON
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
|
||||
{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],
|
||||
"categories": ["Experience""]}```
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_3 = """
|
||||
'{{"input_text": ["{input_text}"],',
|
||||
'"categories": ["{categories}" ], ',
|
||||
'"classification_instructions": ["{classification_instructions}"]}}```JSON'
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_COMPLETION_PROMPT = """
|
||||
### Job Description
|
||||
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
|
||||
### Task
|
||||
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
|
||||
### Format
|
||||
The input text is in the variable text_field. Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination. Classification instructions may be included to improve the classification accuracy.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Example
|
||||
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||
<example>
|
||||
User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],"categories": ["Customer Service, Satisfaction, Sales, Product"], "classification_instructions": ["classify the text based on the feedback provided by customer"]}}
|
||||
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"categories": ["Customer Service"]}}
|
||||
User:{{"input_text": ["bad service, slow to bring the food"],"categories": ["Food Quality, Experience, Price" ], "classification_instructions": []}}
|
||||
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"categories": ["Customer Service"]}}{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"categories": ["Experience""]}}
|
||||
</example>
|
||||
### Memory
|
||||
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
### User Input
|
||||
{{"input_text" : ["{input_text}"], "categories" : ["{categories}"],"classification_instruction" : ["{classification_instructions}"]}}
|
||||
### Assistant Output
|
||||
"""
|
||||
0
api/core/workflow/nodes/start/__init__.py
Normal file
0
api/core/workflow/nodes/start/__init__.py
Normal file
9
api/core/workflow/nodes/start/entities.py
Normal file
9
api/core/workflow/nodes/start/entities.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class StartNodeData(BaseNodeData):
|
||||
"""
|
||||
Start Node Data
|
||||
"""
|
||||
variables: list[VariableEntity] = []
|
||||
84
api/core/workflow/nodes/start/start_node.py
Normal file
84
api/core/workflow/nodes/start/start_node.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import cast
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class StartNode(BaseNode):
|
||||
_node_data_cls = StartNodeData
|
||||
node_type = NodeType.START
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
variables = node_data.variables
|
||||
|
||||
# Get cleaned inputs
|
||||
cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs)
|
||||
|
||||
for var in variable_pool.system_variables:
|
||||
if var == SystemVariable.CONVERSATION:
|
||||
continue
|
||||
|
||||
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=cleaned_inputs,
|
||||
outputs=cleaned_inputs
|
||||
)
|
||||
|
||||
def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
|
||||
if user_inputs is None:
|
||||
user_inputs = {}
|
||||
|
||||
filtered_inputs = {}
|
||||
|
||||
for variable_config in variables:
|
||||
variable = variable_config.variable
|
||||
|
||||
if variable not in user_inputs or not user_inputs[variable]:
|
||||
if variable_config.required:
|
||||
raise ValueError(f"Input form variable {variable} is required")
|
||||
else:
|
||||
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
|
||||
continue
|
||||
|
||||
value = user_inputs[variable]
|
||||
|
||||
if value:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable} in input form must be a string")
|
||||
|
||||
if variable_config.type == VariableEntity.Type.SELECT:
|
||||
options = variable_config.options if variable_config.options is not None else []
|
||||
if value not in options:
|
||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
||||
else:
|
||||
if variable_config.max_length is not None:
|
||||
max_length = variable_config.max_length
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||
|
||||
filtered_inputs[variable] = value.replace('\x00', '') if value else None
|
||||
|
||||
return filtered_inputs
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
12
api/core/workflow/nodes/template_transform/entities.py
Normal file
12
api/core/workflow/nodes/template_transform/entities.py
Normal file
@@ -0,0 +1,12 @@
|
||||
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class TemplateTransformNodeData(BaseNodeData):
|
||||
"""
|
||||
Code Node Data.
|
||||
"""
|
||||
variables: list[VariableSelector]
|
||||
template: str
|
||||
@@ -0,0 +1,91 @@
|
||||
import os
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000'))
|
||||
|
||||
class TemplateTransformNode(BaseNode):
|
||||
_node_data_cls = TemplateTransformNodeData
|
||||
_node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"type": "template-transform",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"variable": "arg1",
|
||||
"value_selector": []
|
||||
}
|
||||
],
|
||||
"template": "{{ arg1 }}"
|
||||
}
|
||||
}
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data)
|
||||
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable = variable_selector.variable
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
|
||||
variables[variable] = value
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_code(
|
||||
language='jinja2',
|
||||
code=node_data.template,
|
||||
inputs=variables
|
||||
)
|
||||
except CodeExecutionException as e:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters"
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
outputs={
|
||||
'output': result['result']
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
}
|
||||
0
api/core/workflow/nodes/tool/__init__.py
Normal file
0
api/core/workflow/nodes/tool/__init__.py
Normal file
37
api/core/workflow/nodes/tool/entities.py
Normal file
37
api/core/workflow/nodes/tool/entities.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
ToolParameterValue = Union[str, int, float, bool]
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: Literal['builtin', 'api']
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, ToolParameterValue]
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
value: Union[ToolParameterValue, list[str]]
|
||||
type: Literal['mixed', 'variable', 'constant']
|
||||
|
||||
@validator('type', pre=True, always=True)
|
||||
def check_type(cls, value, values):
|
||||
typ = value
|
||||
value = values.get('value')
|
||||
if typ == 'mixed' and not isinstance(value, str):
|
||||
raise ValueError('value must be a string')
|
||||
elif typ == 'variable' and not isinstance(value, list):
|
||||
raise ValueError('value must be a list')
|
||||
elif typ == 'constant' and not isinstance(value, ToolParameterValue):
|
||||
raise ValueError('value must be a string, int, float, or bool')
|
||||
return typ
|
||||
|
||||
"""
|
||||
Tool Node Schema
|
||||
"""
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
201
api/core/workflow/nodes/tool/tool_node.py
Normal file
201
api/core/workflow/nodes/tool/tool_node.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from os import path
|
||||
from typing import cast
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
|
||||
node_data = cast(ToolNodeData, self.node_data)
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
'provider_type': node_data.provider_type,
|
||||
'provider_id': node_data.provider_id
|
||||
}
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_parameters(variable_pool, node_data)
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to get tool runtime: {str(e)}'
|
||||
)
|
||||
|
||||
try:
|
||||
messages = ToolEngine.workflow_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
workflow_id=self.workflow_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler()
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to invoke tool: {str(e)}',
|
||||
)
|
||||
|
||||
# convert tool messages
|
||||
plain_text, files = self._convert_tool_messages(messages)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'text': plain_text,
|
||||
'files': files
|
||||
},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
inputs=parameters
|
||||
)
|
||||
|
||||
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
|
||||
"""
|
||||
Generate parameters
|
||||
"""
|
||||
result = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
|
||||
elif input.type == 'variable':
|
||||
result[parameter_name] = variable_pool.get_variable_value(input.value)
|
||||
elif input.type == 'constant':
|
||||
result[parameter_name] = input.value
|
||||
|
||||
return result
|
||||
|
||||
def _format_variable_template(self, template: str, variable_pool: VariablePool) -> str:
|
||||
"""
|
||||
Format variable template
|
||||
"""
|
||||
inputs = {}
|
||||
template_parser = VariableTemplateParser(template)
|
||||
for selector in template_parser.extract_variable_selectors():
|
||||
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
|
||||
|
||||
return template_parser.format(inputs)
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
# extract plain text and files
|
||||
files = self._extract_tool_response_binary(messages)
|
||||
plain_text = self._extract_tool_response_text(messages)
|
||||
|
||||
return plain_text, files
|
||||
|
||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
url = response.message
|
||||
ext = path.splitext(url)[1]
|
||||
mimetype = response.meta.get('mime_type', 'image/jpeg')
|
||||
filename = response.save_as or url.split('/')[-1]
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = url.split('/')[-1].split('.')[0]
|
||||
result.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=filename,
|
||||
extension=ext,
|
||||
mime_type=mimetype,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
tool_file_id = response.message.split('/')[-1].split('.')[0]
|
||||
result.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=response.save_as,
|
||||
extension=path.splitext(response.save_as)[1],
|
||||
mime_type=response.meta.get('mime_type', 'application/octet-stream'),
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
pass # TODO:
|
||||
|
||||
return result
|
||||
|
||||
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Extract tool response text
|
||||
"""
|
||||
return ''.join([
|
||||
f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else
|
||||
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
||||
for message in tool_response
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
result = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == 'variable':
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == 'constant':
|
||||
pass
|
||||
|
||||
return result
|
||||
12
api/core/workflow/nodes/variable_assigner/entities.py
Normal file
12
api/core/workflow/nodes/variable_assigner/entities.py
Normal file
@@ -0,0 +1,12 @@
|
||||
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
type: str = 'variable-assigner'
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
||||
@@ -0,0 +1,41 @@
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.variable_assigner.entities import VariableAssignerNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: VariableAssignerNodeData = cast(self._node_data_cls, self.node_data)
|
||||
# Get variables
|
||||
outputs = {}
|
||||
inputs = {}
|
||||
for variable in node_data.variables:
|
||||
value = variable_pool.get_variable_value(variable)
|
||||
|
||||
if value is not None:
|
||||
outputs = {
|
||||
"output": value
|
||||
}
|
||||
|
||||
inputs = {
|
||||
'.'.join(variable[1:]): value
|
||||
}
|
||||
break
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
inputs=inputs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
return {}
|
||||
Reference in New Issue
Block a user