Fix/workflow tool incorrect parameter configurations (#3402)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
Yeuoly
2024-04-12 15:46:34 +08:00
committed by GitHub
parent f7f8ef257c
commit 64e395d6cf
3 changed files with 72 additions and 13 deletions

View File

@@ -243,8 +243,21 @@ class Tool(BaseModel, ABC):
tool_parameters[parameter.name] = float(tool_parameters[parameter.name])
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter.name], bool):
tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
# check if it is a string
if isinstance(tool_parameters[parameter.name], str):
# check true false
if tool_parameters[parameter.name].lower() in ['true', 'false']:
tool_parameters[parameter.name] = tool_parameters[parameter.name].lower() == 'true'
# check 1 0
elif tool_parameters[parameter.name] in ['1', '0']:
tool_parameters[parameter.name] = tool_parameters[parameter.name] == '1'
else:
tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
elif isinstance(tool_parameters[parameter.name], int | float):
tool_parameters[parameter.name] = tool_parameters[parameter.name] != 0
else:
tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
return tool_parameters
@abstractmethod

View File

@@ -1,10 +1,9 @@
from typing import Literal, Union
from typing import Any, Literal, Union
from pydantic import BaseModel, validator
from core.workflow.entities.base_node_data_entities import BaseNodeData
ToolParameterValue = Union[str, int, float, bool]
class ToolEntity(BaseModel):
provider_id: str
@@ -12,11 +11,23 @@ class ToolEntity(BaseModel):
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, ToolParameterValue]
tool_configurations: dict[str, Any]
@validator('tool_configurations', pre=True, always=True)
def validate_tool_configurations(cls, value, values):
if not isinstance(value, dict):
raise ValueError('tool_configurations must be a dictionary')
for key in values.get('tool_configurations', {}).keys():
value = values.get('tool_configurations', {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f'{key} must be a string')
return value
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
value: Union[ToolParameterValue, list[str]]
value: Union[Any, list[str]]
type: Literal['mixed', 'variable', 'constant']
@validator('type', pre=True, always=True)
@@ -25,12 +36,16 @@ class ToolNodeData(BaseNodeData, ToolEntity):
value = values.get('value')
if typ == 'mixed' and not isinstance(value, str):
raise ValueError('value must be a string')
elif typ == 'variable' and not isinstance(value, list):
raise ValueError('value must be a list')
elif typ == 'constant' and not isinstance(value, ToolParameterValue):
elif typ == 'variable':
if not isinstance(value, list):
raise ValueError('value must be a list')
for val in value:
if not isinstance(val, str):
raise ValueError('value must be a list of strings')
elif typ == 'constant' and not isinstance(value, str | int | float | bool):
raise ValueError('value must be a string, int, float, or bool')
return typ
"""
Tool Node Schema
"""