mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 10:56:52 +08:00
improve: mordernizing validation by migrating pydantic from 1.x to 2.x (#4592)
This commit is contained in:
@@ -43,9 +43,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
self._init_react_state(query)
|
||||
|
||||
# check model mode
|
||||
if 'Observation' not in app_generate_entity.model_config.stop:
|
||||
if app_generate_entity.model_config.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_config.stop.append('Observation')
|
||||
if 'Observation' not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_conf.stop.append('Observation')
|
||||
|
||||
app_config = self.app_config
|
||||
|
||||
@@ -109,9 +109,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
# invoke model
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_config.parameters,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=[],
|
||||
stop=app_generate_entity.model_config.stop,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
@@ -141,8 +141,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
scratchpad.agent_response += json.dumps(chunk.dict())
|
||||
scratchpad.action_str = json.dumps(chunk.dict())
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
scratchpad.agent_response += chunk
|
||||
|
||||
@@ -84,9 +84,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_config.parameters,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=prompt_messages_tools,
|
||||
stop=app_generate_entity.model_config.stop,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=self.stream_tool_call,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
|
||||
@@ -58,7 +58,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
@@ -69,8 +69,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
|
||||
model=application_generate_entity.model_config.model
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
@@ -83,7 +83,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
# memory(optional)
|
||||
prompt_messages, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
@@ -152,7 +152,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
@@ -182,12 +182,12 @@ class AgentChatAppRunner(AppRunner):
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
|
||||
model=application_generate_entity.model_config.model
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
)
|
||||
prompt_message, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
@@ -225,7 +225,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
app_config=app_config,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
|
||||
@@ -100,7 +100,7 @@ class AppQueueManager:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
self._check_for_sqlalchemy_models(event.dict())
|
||||
self._check_for_sqlalchemy_models(event.model_dump())
|
||||
self._publish(event, pub_from)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -218,7 +218,7 @@ class AppRunner:
|
||||
index = 0
|
||||
for token in text:
|
||||
chunk = LLMResultChunk(
|
||||
model=app_generate_entity.model_config.model,
|
||||
model=app_generate_entity.model_conf.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
@@ -237,7 +237,7 @@ class AppRunner:
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=app_generate_entity.model_config.model,
|
||||
model=app_generate_entity.model_conf.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage()
|
||||
|
||||
@@ -54,7 +54,7 @@ class ChatAppRunner(AppRunner):
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
@@ -65,8 +65,8 @@ class ChatAppRunner(AppRunner):
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
|
||||
model=application_generate_entity.model_config.model
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
@@ -79,7 +79,7 @@ class ChatAppRunner(AppRunner):
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
@@ -159,7 +159,7 @@ class ChatAppRunner(AppRunner):
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
config=app_config.dataset,
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
@@ -173,7 +173,7 @@ class ChatAppRunner(AppRunner):
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
@@ -194,21 +194,21 @@ class ChatAppRunner(AppRunner):
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
|
||||
model=application_generate_entity.model_config.model
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=application_generate_entity.model_config.parameters,
|
||||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
|
||||
@@ -50,7 +50,7 @@ class CompletionAppRunner(AppRunner):
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
@@ -61,7 +61,7 @@ class CompletionAppRunner(AppRunner):
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
@@ -119,7 +119,7 @@ class CompletionAppRunner(AppRunner):
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
config=dataset_config,
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
@@ -132,7 +132,7 @@ class CompletionAppRunner(AppRunner):
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
@@ -152,21 +152,21 @@ class CompletionAppRunner(AppRunner):
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(
|
||||
model_config=application_generate_entity.model_config,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
|
||||
model=application_generate_entity.model_config.model
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=application_generate_entity.model_config.parameters,
|
||||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
|
||||
@@ -158,8 +158,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
model_id = None
|
||||
else:
|
||||
app_model_config_id = app_config.app_model_config_id
|
||||
model_provider = application_generate_entity.model_config.provider
|
||||
model_id = application_generate_entity.model_config.model
|
||||
model_provider = application_generate_entity.model_conf.provider
|
||||
model_id = application_generate_entity.model_conf.model
|
||||
override_model_configs = None
|
||||
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \
|
||||
and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
@@ -62,6 +62,9 @@ class ModelConfigWithCredentialsEntity(BaseModel):
|
||||
parameters: dict[str, Any] = {}
|
||||
stop: list[str] = []
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class AppGenerateEntity(BaseModel):
|
||||
"""
|
||||
@@ -93,10 +96,13 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
# app config
|
||||
app_config: EasyUIBasedAppConfig
|
||||
model_config: ModelConfigWithCredentialsEntity
|
||||
model_conf: ModelConfigWithCredentialsEntity
|
||||
|
||||
query: Optional[str] = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
|
||||
class QueueEvent(Enum):
|
||||
class QueueEvent(str, Enum):
|
||||
"""
|
||||
QueueEvent enum
|
||||
"""
|
||||
@@ -47,14 +47,14 @@ class QueueLLMChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueLLMChunkEvent entity
|
||||
"""
|
||||
event = QueueEvent.LLM_CHUNK
|
||||
event: QueueEvent = QueueEvent.LLM_CHUNK
|
||||
chunk: LLMResultChunk
|
||||
|
||||
class QueueIterationStartEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationStartEvent entity
|
||||
"""
|
||||
event = QueueEvent.ITERATION_START
|
||||
event: QueueEvent = QueueEvent.ITERATION_START
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
@@ -68,16 +68,17 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationNextEvent entity
|
||||
"""
|
||||
event = QueueEvent.ITERATION_NEXT
|
||||
event: QueueEvent = QueueEvent.ITERATION_NEXT
|
||||
|
||||
index: int
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
|
||||
node_run_index: int
|
||||
output: Optional[Any] # output for the current iteration
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
|
||||
@validator('output', pre=True, always=True)
|
||||
@classmethod
|
||||
@field_validator('output', mode='before')
|
||||
def set_output(cls, v):
|
||||
"""
|
||||
Set output
|
||||
@@ -92,7 +93,7 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationCompletedEvent entity
|
||||
"""
|
||||
event = QueueEvent.ITERATION_COMPLETED
|
||||
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
@@ -104,7 +105,7 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueTextChunkEvent entity
|
||||
"""
|
||||
event = QueueEvent.TEXT_CHUNK
|
||||
event: QueueEvent = QueueEvent.TEXT_CHUNK
|
||||
text: str
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@@ -113,7 +114,7 @@ class QueueAgentMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageEvent entity
|
||||
"""
|
||||
event = QueueEvent.AGENT_MESSAGE
|
||||
event: QueueEvent = QueueEvent.AGENT_MESSAGE
|
||||
chunk: LLMResultChunk
|
||||
|
||||
|
||||
@@ -121,7 +122,7 @@ class QueueMessageReplaceEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageReplaceEvent entity
|
||||
"""
|
||||
event = QueueEvent.MESSAGE_REPLACE
|
||||
event: QueueEvent = QueueEvent.MESSAGE_REPLACE
|
||||
text: str
|
||||
|
||||
|
||||
@@ -129,7 +130,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueRetrieverResourcesEvent entity
|
||||
"""
|
||||
event = QueueEvent.RETRIEVER_RESOURCES
|
||||
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: list[dict]
|
||||
|
||||
|
||||
@@ -137,7 +138,7 @@ class QueueAnnotationReplyEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAnnotationReplyEvent entity
|
||||
"""
|
||||
event = QueueEvent.ANNOTATION_REPLY
|
||||
event: QueueEvent = QueueEvent.ANNOTATION_REPLY
|
||||
message_annotation_id: str
|
||||
|
||||
|
||||
@@ -145,7 +146,7 @@ class QueueMessageEndEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageEndEvent entity
|
||||
"""
|
||||
event = QueueEvent.MESSAGE_END
|
||||
event: QueueEvent = QueueEvent.MESSAGE_END
|
||||
llm_result: Optional[LLMResult] = None
|
||||
|
||||
|
||||
@@ -153,28 +154,28 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAdvancedChatMessageEndEvent entity
|
||||
"""
|
||||
event = QueueEvent.ADVANCED_CHAT_MESSAGE_END
|
||||
event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END
|
||||
|
||||
|
||||
class QueueWorkflowStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowStartedEvent entity
|
||||
"""
|
||||
event = QueueEvent.WORKFLOW_STARTED
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
|
||||
|
||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowSucceededEvent entity
|
||||
"""
|
||||
event = QueueEvent.WORKFLOW_SUCCEEDED
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
||||
|
||||
|
||||
class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowFailedEvent entity
|
||||
"""
|
||||
event = QueueEvent.WORKFLOW_FAILED
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
|
||||
error: str
|
||||
|
||||
|
||||
@@ -182,7 +183,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeStartedEvent entity
|
||||
"""
|
||||
event = QueueEvent.NODE_STARTED
|
||||
event: QueueEvent = QueueEvent.NODE_STARTED
|
||||
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
@@ -195,7 +196,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeSucceededEvent entity
|
||||
"""
|
||||
event = QueueEvent.NODE_SUCCEEDED
|
||||
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
|
||||
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
@@ -213,7 +214,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeFailedEvent entity
|
||||
"""
|
||||
event = QueueEvent.NODE_FAILED
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
@@ -230,7 +231,7 @@ class QueueAgentThoughtEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAgentThoughtEvent entity
|
||||
"""
|
||||
event = QueueEvent.AGENT_THOUGHT
|
||||
event: QueueEvent = QueueEvent.AGENT_THOUGHT
|
||||
agent_thought_id: str
|
||||
|
||||
|
||||
@@ -238,7 +239,7 @@ class QueueMessageFileEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAgentThoughtEvent entity
|
||||
"""
|
||||
event = QueueEvent.MESSAGE_FILE
|
||||
event: QueueEvent = QueueEvent.MESSAGE_FILE
|
||||
message_file_id: str
|
||||
|
||||
|
||||
@@ -246,15 +247,15 @@ class QueueErrorEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueErrorEvent entity
|
||||
"""
|
||||
event = QueueEvent.ERROR
|
||||
error: Any
|
||||
event: QueueEvent = QueueEvent.ERROR
|
||||
error: Any = None
|
||||
|
||||
|
||||
class QueuePingEvent(AppQueueEvent):
|
||||
"""
|
||||
QueuePingEvent entity
|
||||
"""
|
||||
event = QueueEvent.PING
|
||||
event: QueueEvent = QueueEvent.PING
|
||||
|
||||
|
||||
class QueueStopEvent(AppQueueEvent):
|
||||
@@ -270,7 +271,7 @@ class QueueStopEvent(AppQueueEvent):
|
||||
OUTPUT_MODERATION = "output-moderation"
|
||||
INPUT_MODERATION = "input-moderation"
|
||||
|
||||
event = QueueEvent.STOP
|
||||
event: QueueEvent = QueueEvent.STOP
|
||||
stopped_by: StopBy
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -118,9 +118,7 @@ class ErrorStreamResponse(StreamResponse):
|
||||
"""
|
||||
event: StreamEvent = StreamEvent.ERROR
|
||||
err: Exception
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class MessageStreamResponse(StreamResponse):
|
||||
@@ -360,7 +358,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
pre_iteration_output: Optional[Any]
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
extras: dict = {}
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
@@ -379,12 +377,12 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Optional[dict]
|
||||
outputs: Optional[dict] = None
|
||||
created_at: int
|
||||
extras: dict = None
|
||||
inputs: dict = None
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: Optional[str]
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
finished_at: int
|
||||
|
||||
@@ -16,7 +16,7 @@ class HostingModerationFeature:
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
model_config = application_generate_entity.model_config
|
||||
model_config = application_generate_entity.model_conf
|
||||
|
||||
text = ""
|
||||
for prompt_message in prompt_messages:
|
||||
|
||||
@@ -85,7 +85,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
:param stream: stream
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
self._model_config = application_generate_entity.model_config
|
||||
self._model_config = application_generate_entity.model_conf
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def print_text(
|
||||
class DifyAgentCallbackHandler(BaseModel):
|
||||
"""Callback Handler that prints to std out."""
|
||||
color: Optional[str] = ''
|
||||
current_loop = 1
|
||||
current_loop: int = 1
|
||||
|
||||
def __init__(self, color: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -17,7 +17,7 @@ class PromptMessageFileType(enum.Enum):
|
||||
|
||||
class PromptMessageFile(BaseModel):
|
||||
type: PromptMessageFileType
|
||||
data: Any
|
||||
data: Any = None
|
||||
|
||||
|
||||
class ImagePromptMessageFile(PromptMessageFile):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType, ProviderModel
|
||||
@@ -77,3 +77,6 @@ class DefaultModelEntity(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
provider: DefaultModelProviderEntity
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@@ -6,7 +6,7 @@ from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import (
|
||||
@@ -54,6 +54,9 @@ class ProviderConfiguration(BaseModel):
|
||||
custom_configuration: CustomConfiguration
|
||||
model_settings: list[ModelSettings]
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -1019,7 +1022,6 @@ class ProviderModelBundle(BaseModel):
|
||||
provider_instance: ModelProvider
|
||||
model_type_instance: AIModel
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True,
|
||||
protected_namespaces=())
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderQuotaType
|
||||
@@ -27,6 +27,9 @@ class RestrictModel(BaseModel):
|
||||
base_model_name: Optional[str] = None
|
||||
model_type: ModelType
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class QuotaConfiguration(BaseModel):
|
||||
"""
|
||||
@@ -65,6 +68,9 @@ class CustomModelConfiguration(BaseModel):
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class CustomConfiguration(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -16,7 +16,7 @@ class ExtensionModule(enum.Enum):
|
||||
|
||||
|
||||
class ModuleExtension(BaseModel):
|
||||
extension_class: Any
|
||||
extension_class: Any = None
|
||||
name: str
|
||||
label: Optional[dict] = None
|
||||
form_schema: Optional[list] = None
|
||||
|
||||
@@ -28,8 +28,8 @@ class CodeExecutionException(Exception):
|
||||
|
||||
class CodeExecutionResponse(BaseModel):
|
||||
class Data(BaseModel):
|
||||
stdout: Optional[str]
|
||||
error: Optional[str]
|
||||
stdout: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
code: int
|
||||
message: str
|
||||
@@ -88,7 +88,7 @@ class CodeExecutor:
|
||||
}
|
||||
|
||||
if dependencies:
|
||||
data['dependencies'] = [dependency.dict() for dependency in dependencies]
|
||||
data['dependencies'] = [dependency.model_dump() for dependency in dependencies]
|
||||
|
||||
try:
|
||||
response = post(str(url), json=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
|
||||
|
||||
@@ -25,7 +25,7 @@ class CodeNodeProvider(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def get_default_available_packages(cls) -> list[dict]:
|
||||
return [p.dict() for p in CodeExecutor.list_dependencies(cls.get_language())]
|
||||
return [p.model_dump() for p in CodeExecutor.list_dependencies(cls.get_language())]
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls) -> dict:
|
||||
|
||||
@@ -4,12 +4,10 @@ from abc import ABC, abstractmethod
|
||||
from base64 import b64encode
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
|
||||
|
||||
class TemplateTransformer(ABC, BaseModel):
|
||||
class TemplateTransformer(ABC):
|
||||
_code_placeholder: str = '{{code}}'
|
||||
_inputs_placeholder: str = '{{inputs}}'
|
||||
_result_tag: str = '<<RESULT>>'
|
||||
|
||||
@@ -550,7 +550,7 @@ class IndexingRunner:
|
||||
document_qa_list = self.format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
|
||||
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.model_copy())
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result['question'])
|
||||
qa_document.metadata['answer'] = result['answer']
|
||||
|
||||
@@ -2,7 +2,7 @@ from abc import ABC
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
@@ -123,6 +123,13 @@ class AssistantPromptMessage(PromptMessage):
|
||||
type: str
|
||||
function: ToolCallFunction
|
||||
|
||||
@field_validator('id', mode='before')
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
@@ -148,9 +148,7 @@ class ProviderModel(BaseModel):
|
||||
fetch_from: FetchFrom
|
||||
model_properties: dict[ModelPropertyKey, Any]
|
||||
deprecated: bool = False
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ParameterRule(BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel
|
||||
@@ -122,8 +122,8 @@ class ProviderEntity(BaseModel):
|
||||
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def to_simple_provider(self) -> SimpleProviderEntity:
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,8 @@ import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.helper.position_helper import get_position_map, sort_by_position_map
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
@@ -28,6 +30,9 @@ class AIModel(ABC):
|
||||
model_schemas: list[AIModelEntity] = None
|
||||
started_at: float = 0
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
|
||||
@@ -6,6 +6,8 @@ from abc import abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@@ -34,6 +36,9 @@ class LargeLanguageModel(AIModel):
|
||||
"""
|
||||
model_type: ModelType = ModelType.LLM
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
|
||||
@@ -2,6 +2,8 @@ import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
@@ -12,6 +14,9 @@ class ModerationModel(AIModel):
|
||||
"""
|
||||
model_type: ModelType = ModelType.MODERATION
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
|
||||
@@ -2,6 +2,8 @@ import os
|
||||
from abc import abstractmethod
|
||||
from typing import IO, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
@@ -12,6 +14,9 @@ class Speech2TextModel(AIModel):
|
||||
"""
|
||||
model_type: ModelType = ModelType.SPEECH2TEXT
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from abc import abstractmethod
|
||||
from typing import IO, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
@@ -11,6 +13,9 @@ class Text2ImageModel(AIModel):
|
||||
"""
|
||||
model_type: ModelType = ModelType.TEXT2IMG
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, prompt: str,
|
||||
model_parameters: dict, user: Optional[str] = None) \
|
||||
-> list[IO[bytes]]:
|
||||
|
||||
@@ -2,6 +2,8 @@ import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
@@ -13,6 +15,9 @@ class TextEmbeddingModel(AIModel):
|
||||
"""
|
||||
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
|
||||
@@ -4,6 +4,8 @@ import uuid
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
@@ -15,6 +17,9 @@ class TTSModel(AIModel):
|
||||
"""
|
||||
model_type: ModelType = ModelType.TTS
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
|
||||
@@ -19,11 +19,7 @@ class ModelProviderExtension(BaseModel):
|
||||
provider_instance: ModelProvider
|
||||
name: str
|
||||
position: Optional[int] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
|
||||
@@ -12,9 +12,9 @@ from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.color import Color
|
||||
from pydantic.networks import AnyUrl, NameEmail
|
||||
from pydantic.types import SecretBytes, SecretStr
|
||||
from pydantic_extra_types.color import Color
|
||||
|
||||
from ._compat import PYDANTIC_V2, Url, _model_dump
|
||||
|
||||
|
||||
@@ -6,4 +6,4 @@ def dump_model(model: BaseModel) -> dict:
|
||||
if hasattr(pydantic, 'model_dump'):
|
||||
return pydantic.model_dump(model)
|
||||
else:
|
||||
return model.dict()
|
||||
return model.model_dump()
|
||||
|
||||
@@ -51,7 +51,7 @@ class ApiModeration(Moderation):
|
||||
query=query
|
||||
)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.dict())
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
|
||||
return ModerationInputsResult(**result)
|
||||
|
||||
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
@@ -66,7 +66,7 @@ class ApiModeration(Moderation):
|
||||
text=text
|
||||
)
|
||||
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.dict())
|
||||
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
|
||||
return ModerationOutputsResult(**result)
|
||||
|
||||
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
|
||||
@@ -4,7 +4,7 @@ import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageReplaceEvent
|
||||
@@ -33,9 +33,7 @@ class OutputModeration(BaseModel):
|
||||
buffer: str = ''
|
||||
is_final_chunk: bool = False
|
||||
final_output: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def should_direct_output(self):
|
||||
return self.final_output is not None
|
||||
|
||||
@@ -11,7 +11,7 @@ class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
edition_type: Optional[Literal['basic', 'jinja2']]
|
||||
edition_type: Optional[Literal['basic', 'jinja2']] = None
|
||||
|
||||
|
||||
class CompletionModelPromptTemplate(BaseModel):
|
||||
@@ -19,7 +19,7 @@ class CompletionModelPromptTemplate(BaseModel):
|
||||
Completion Model Prompt Template.
|
||||
"""
|
||||
text: str
|
||||
edition_type: Optional[Literal['basic', 'jinja2']]
|
||||
edition_type: Optional[Literal['basic', 'jinja2']] = None
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymilvus import MilvusClient, MilvusException, connections
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
@@ -28,7 +28,7 @@ class MilvusConfig(BaseModel):
|
||||
batch_size: int = 100
|
||||
database: str = "default"
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get('host'):
|
||||
raise ValueError("config MILVUS_HOST is required")
|
||||
|
||||
@@ -6,7 +6,7 @@ from uuid import UUID, uuid4
|
||||
from flask import current_app
|
||||
from numpy import ndarray
|
||||
from pgvecto_rs.sqlalchemy import Vector
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Float, String, create_engine, insert, select, text
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
@@ -31,7 +31,7 @@ class PgvectoRSConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config PGVECTO_RS_HOST is required")
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -24,7 +24,7 @@ class PGVectorConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTOR_HOST is required")
|
||||
|
||||
@@ -40,9 +40,9 @@ if TYPE_CHECKING:
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
api_key: Optional[str] = None
|
||||
timeout: float = 20
|
||||
root_path: Optional[str]
|
||||
root_path: Optional[str] = None
|
||||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||
@@ -33,7 +33,7 @@ class RelytConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config RELYT_HOST is required")
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
@@ -27,7 +27,7 @@ class TiDBVectorConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config TIDB_VECTOR_HOST is required")
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Optional
|
||||
import requests
|
||||
import weaviate
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@@ -19,10 +19,10 @@ from models.dataset import Dataset
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
api_key: Optional[str] = None
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
|
||||
@@ -14,7 +14,7 @@ from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
PathLike = Union[str, PurePath]
|
||||
|
||||
@@ -29,7 +29,7 @@ class Blob(BaseModel):
|
||||
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
|
||||
"""
|
||||
|
||||
data: Union[bytes, str, None] # Raw data
|
||||
data: Union[bytes, str, None] = None # Raw data
|
||||
mimetype: Optional[str] = None # Not to be confused with a file extension
|
||||
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
|
||||
# Location where the original content was found
|
||||
@@ -37,17 +37,15 @@ class Blob(BaseModel):
|
||||
# Useful for situations where downstream code assumes it must work with file paths
|
||||
# rather than in-memory content.
|
||||
path: Optional[PathLike] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
frozen = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
||||
|
||||
@property
|
||||
def source(self) -> Optional[str]:
|
||||
"""The source location of the blob as string if known otherwise none."""
|
||||
return str(self.path) if self.path else None
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Verify that either data or path is provided."""
|
||||
if "data" not in values and "path" not in values:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from models.dataset import Document
|
||||
from models.model import UploadFile
|
||||
@@ -13,9 +13,7 @@ class NotionInfo(BaseModel):
|
||||
notion_page_type: str
|
||||
document: Document = None
|
||||
tenant_id: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
@@ -29,9 +27,7 @@ class ExtractSetting(BaseModel):
|
||||
upload_file: UploadFile = None
|
||||
notion_info: NotionInfo = None
|
||||
document_model: str = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -13,7 +13,7 @@ class UserTool(BaseModel):
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]]
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
labels: list[str] = None
|
||||
|
||||
UserToolProviderTypeLiteral = Optional[Literal[
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
@@ -116,6 +116,14 @@ class ToolParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
@classmethod
|
||||
@field_validator('value', mode='before')
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if isinstance(value, bool):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
class ToolParameterType(str, Enum):
|
||||
@@ -278,7 +286,7 @@ class ToolRuntimeVariablePool(BaseModel):
|
||||
'conversation_id': self.conversation_id,
|
||||
'user_id': self.user_id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'pool': [variable.dict() for variable in self.pool],
|
||||
'pool': [variable.model_dump() for variable in self.pool],
|
||||
}
|
||||
|
||||
def set_text(self, tool_name: str, name: str, value: str) -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ from hmac import new as hmac_new
|
||||
from json import loads as json_loads
|
||||
from threading import Lock
|
||||
from time import sleep, time
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from httpx import get, post
|
||||
from requests import get as requests_get
|
||||
@@ -22,9 +22,9 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
|
||||
_api_base_url = URL('https://co.aippt.cn/api')
|
||||
_api_token_cache = {}
|
||||
_api_token_cache_lock = Lock()
|
||||
_api_token_cache_lock:Optional[Lock] = None
|
||||
_style_cache = {}
|
||||
_style_cache_lock = Lock()
|
||||
_style_cache_lock:Optional[Lock] = None
|
||||
|
||||
_task = {}
|
||||
_task_type_map = {
|
||||
@@ -32,6 +32,11 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
'markdown': 7,
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._api_token_cache_lock = Lock()
|
||||
self._style_cache_lock = Lock()
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the AIPPT generate tool with the given user ID and tool parameters.
|
||||
|
||||
@@ -44,14 +44,10 @@ class ArxivAPIWrapper(BaseModel):
|
||||
arxiv.run("tree of thought llm)
|
||||
"""
|
||||
|
||||
arxiv_search = arxiv.Search #: :meta private:
|
||||
arxiv_exceptions = (
|
||||
arxiv.ArxivError,
|
||||
arxiv.UnexpectedEmptyPageError,
|
||||
arxiv.HTTPError,
|
||||
) # :meta private:
|
||||
arxiv_search: type[arxiv.Search] = arxiv.Search #: :meta private:
|
||||
arxiv_http_error: tuple[type[Exception]] = (arxiv.ArxivError, arxiv.UnexpectedEmptyPageError, arxiv.HTTPError)
|
||||
top_k_results: int = 3
|
||||
ARXIV_MAX_QUERY_LENGTH = 300
|
||||
ARXIV_MAX_QUERY_LENGTH: int = 300
|
||||
load_max_docs: int = 100
|
||||
load_all_available_meta: bool = False
|
||||
doc_content_chars_max: Optional[int] = 4000
|
||||
@@ -73,7 +69,7 @@ class ArxivAPIWrapper(BaseModel):
|
||||
results = self.arxiv_search( # type: ignore
|
||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
|
||||
).results()
|
||||
except self.arxiv_exceptions as ex:
|
||||
except arxiv_http_error as ex:
|
||||
return f"Arxiv exception: {ex}"
|
||||
docs = [
|
||||
f"Published: {result.updated.date()}\n"
|
||||
|
||||
@@ -8,7 +8,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BingSearchTool(BuiltinTool):
|
||||
url = 'https://api.bing.microsoft.com/v7.0/search'
|
||||
url: str = 'https://api.bing.microsoft.com/v7.0/search'
|
||||
|
||||
def _invoke_bing(self,
|
||||
user_id: str,
|
||||
|
||||
@@ -15,7 +15,7 @@ class BraveSearchWrapper(BaseModel):
|
||||
"""The API key to use for the Brave search engine."""
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the search request."""
|
||||
base_url = "https://api.search.brave.com/res/v1/web/search"
|
||||
base_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
"""The base URL for the Brave search engine."""
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
@@ -58,8 +58,8 @@ class BraveSearchWrapper(BaseModel):
|
||||
class BraveSearch(BaseModel):
|
||||
"""Tool that queries the BraveSearch."""
|
||||
|
||||
name = "brave_search"
|
||||
description = (
|
||||
name: str = "brave_search"
|
||||
description: str = (
|
||||
"a search engine. "
|
||||
"useful for when you need to answer questions about current events."
|
||||
" input should be a search query."
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DuckDuckGoSearchAPIWrapper(BaseModel):
|
||||
"""Wrapper for DuckDuckGo Search API.
|
||||
|
||||
Free and does not require any setup.
|
||||
"""
|
||||
|
||||
region: Optional[str] = "wt-wt"
|
||||
safesearch: str = "moderate"
|
||||
time: Optional[str] = "y"
|
||||
max_results: int = 5
|
||||
|
||||
def get_snippets(self, query: str) -> list[str]:
|
||||
"""Run query through DuckDuckGo and return concatenated results."""
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
results = ddgs.text(
|
||||
query,
|
||||
region=self.region,
|
||||
safesearch=self.safesearch,
|
||||
timelimit=self.time,
|
||||
)
|
||||
if results is None:
|
||||
return ["No good DuckDuckGo Search Result was found"]
|
||||
snippets = []
|
||||
for i, res in enumerate(results, 1):
|
||||
if res is not None:
|
||||
snippets.append(res["body"])
|
||||
if len(snippets) == self.max_results:
|
||||
break
|
||||
return snippets
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
snippets = self.get_snippets(query)
|
||||
return " ".join(snippets)
|
||||
|
||||
def results(
|
||||
self, query: str, num_results: int, backend: str = "api"
|
||||
) -> list[dict[str, str]]:
|
||||
"""Run query through DuckDuckGo and return metadata.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
num_results: The number of results to return.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries with the following keys:
|
||||
snippet - The description of the result.
|
||||
title - The title of the result.
|
||||
link - The link to the result.
|
||||
"""
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
results = ddgs.text(
|
||||
query,
|
||||
region=self.region,
|
||||
safesearch=self.safesearch,
|
||||
timelimit=self.time,
|
||||
backend=backend,
|
||||
)
|
||||
if results is None:
|
||||
return [{"Result": "No good DuckDuckGo Search Result was found"}]
|
||||
|
||||
def to_metadata(result: dict) -> dict[str, str]:
|
||||
if backend == "news":
|
||||
return {
|
||||
"date": result["date"],
|
||||
"title": result["title"],
|
||||
"snippet": result["body"],
|
||||
"source": result["source"],
|
||||
"link": result["url"],
|
||||
}
|
||||
return {
|
||||
"snippet": result["body"],
|
||||
"title": result["title"],
|
||||
"link": result["href"],
|
||||
}
|
||||
|
||||
formatted_results = []
|
||||
for i, res in enumerate(results, 1):
|
||||
if res is not None:
|
||||
formatted_results.append(to_metadata(res))
|
||||
if len(formatted_results) == num_results:
|
||||
break
|
||||
return formatted_results
|
||||
|
||||
|
||||
class DuckDuckGoSearchRun(BaseModel):
|
||||
"""Tool that queries the DuckDuckGo search API."""
|
||||
|
||||
name: str = "duckduckgo_search"
|
||||
description: str = (
|
||||
"A wrapper around DuckDuckGo Search. "
|
||||
"Useful for when you need to answer questions about current events. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
|
||||
default_factory=DuckDuckGoSearchAPIWrapper
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
return self.api_wrapper.run(query)
|
||||
|
||||
|
||||
class DuckDuckGoSearchResults(BaseModel):
|
||||
"""Tool that queries the DuckDuckGo search API and gets back json."""
|
||||
|
||||
name: str = "DuckDuckGo Results JSON"
|
||||
description: str = (
|
||||
"A wrapper around Duck Duck Go Search. "
|
||||
"Useful for when you need to answer questions about current events. "
|
||||
"Input should be a search query. Output is a JSON array of the query results"
|
||||
)
|
||||
num_results: int = 4
|
||||
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
|
||||
default_factory=DuckDuckGoSearchAPIWrapper
|
||||
)
|
||||
backend: str = "api"
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
res = self.api_wrapper.results(query, self.num_results, backend=self.backend)
|
||||
res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res]
|
||||
return ", ".join([f"[{rs}]" for rs in res_strs])
|
||||
|
||||
class DuckDuckGoInput(BaseModel):
|
||||
query: str = Field(..., description="Search query.")
|
||||
|
||||
class DuckDuckGoSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a search using DuckDuckGo search engine.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invoke the DuckDuckGo search tool.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool invocation.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation.
|
||||
"""
|
||||
query = tool_parameters.get('query', '')
|
||||
|
||||
if not query:
|
||||
return self.create_text_message('Please input query')
|
||||
|
||||
tool = DuckDuckGoSearchRun(args_schema=DuckDuckGoInput)
|
||||
|
||||
result = tool._run(query)
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=result))
|
||||
|
||||
>>>>>>> 4c2ba442b (missing type in DuckDuckGoSearchAPIWrapper)
|
||||
@@ -69,10 +69,10 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
default: false
|
||||
|
||||
@@ -28,15 +28,15 @@ class PubMedAPIWrapper(BaseModel):
|
||||
if False: the `metadata` gets only the most informative fields.
|
||||
"""
|
||||
|
||||
base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
|
||||
base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
|
||||
max_retry = 5
|
||||
sleep_time = 0.2
|
||||
base_url_esearch: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
|
||||
base_url_efetch: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
|
||||
max_retry: int = 5
|
||||
sleep_time: float = 0.2
|
||||
|
||||
# Default values for the parameters
|
||||
top_k_results: int = 3
|
||||
load_max_docs: int = 25
|
||||
ARXIV_MAX_QUERY_LENGTH = 300
|
||||
ARXIV_MAX_QUERY_LENGTH: int = 300
|
||||
doc_content_chars_max: int = 2000
|
||||
load_all_available_meta: bool = False
|
||||
email: str = "your_email@example.com"
|
||||
@@ -160,8 +160,8 @@ class PubMedAPIWrapper(BaseModel):
|
||||
class PubmedQueryRun(BaseModel):
|
||||
"""Tool that searches the PubMed API."""
|
||||
|
||||
name = "PubMed"
|
||||
description = (
|
||||
name: str = "PubMed"
|
||||
description: str = (
|
||||
"A wrapper around PubMed.org "
|
||||
"Useful for when you need to answer questions about Physics, Mathematics, "
|
||||
"Computer Science, Quantitative Biology, Quantitative Finance, Statistics, "
|
||||
|
||||
@@ -12,7 +12,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class QRCodeGeneratorTool(BuiltinTool):
|
||||
error_correction_levels = {
|
||||
error_correction_levels: dict[str, int] = {
|
||||
'L': ERROR_CORRECT_L, # <=7%
|
||||
'M': ERROR_CORRECT_M, # <=15%
|
||||
'Q': ERROR_CORRECT_Q, # <=25%
|
||||
|
||||
@@ -24,21 +24,21 @@ class SearXNGSearchTool(BuiltinTool):
|
||||
Tool for performing a search using SearXNG engine.
|
||||
"""
|
||||
|
||||
SEARCH_TYPE = {
|
||||
SEARCH_TYPE: dict[str, str] = {
|
||||
"page": "general",
|
||||
"news": "news",
|
||||
"image": "images",
|
||||
# "video": "videos",
|
||||
# "file": "files"
|
||||
}
|
||||
LINK_FILED = {
|
||||
LINK_FILED: dict[str, str] = {
|
||||
"page": "url",
|
||||
"news": "url",
|
||||
"image": "img_src",
|
||||
# "video": "iframe_src",
|
||||
# "file": "magnetlink"
|
||||
}
|
||||
TEXT_FILED = {
|
||||
TEXT_FILED: dict[str, str] = {
|
||||
"page": "content",
|
||||
"news": "content",
|
||||
"image": "img_src",
|
||||
|
||||
@@ -11,7 +11,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
|
||||
"""
|
||||
This class is responsible for providing the stable diffusion tool.
|
||||
"""
|
||||
model_endpoint_map = {
|
||||
model_endpoint_map: dict[str, str] = {
|
||||
'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
||||
'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
||||
'core': 'https://api.stability.ai/v2beta/stable-image/generate/core',
|
||||
|
||||
@@ -98,11 +98,11 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
default: true
|
||||
- name: pagesize
|
||||
|
||||
@@ -64,14 +64,14 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
pt_BR: Yes
|
||||
pt_BR: 'Yes'
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
pt_BR: No
|
||||
pt_BR: 'No'
|
||||
default: false
|
||||
- name: include_answer
|
||||
type: boolean
|
||||
@@ -88,14 +88,14 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
pt_BR: Yes
|
||||
pt_BR: 'Yes'
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
pt_BR: No
|
||||
pt_BR: 'No'
|
||||
default: false
|
||||
- name: include_raw_content
|
||||
type: boolean
|
||||
@@ -112,14 +112,14 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
pt_BR: Yes
|
||||
pt_BR: 'Yes'
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
pt_BR: No
|
||||
pt_BR: 'No'
|
||||
default: false
|
||||
- name: max_results
|
||||
type: number
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
@@ -15,7 +15,7 @@ class TwilioAPIWrapper(BaseModel):
|
||||
named parameters to the constructor.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
client: Any = None #: :meta private:
|
||||
account_sid: Optional[str] = None
|
||||
"""Twilio account string identifier."""
|
||||
auth_token: Optional[str] = None
|
||||
@@ -32,7 +32,8 @@ class TwilioAPIWrapper(BaseModel):
|
||||
must be empty.
|
||||
"""
|
||||
|
||||
@validator("client", pre=True, always=True)
|
||||
@classmethod
|
||||
@field_validator('client', mode='before')
|
||||
def set_validator(cls, values: dict) -> dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
|
||||
@@ -51,10 +51,10 @@ parameters:
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
default: false
|
||||
|
||||
@@ -32,10 +32,10 @@ class ApiTool(Tool):
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
identity=self.identity.copy() if self.identity else None,
|
||||
identity=self.identity.model_copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
|
||||
description=self.description.model_copy() if self.description else None,
|
||||
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
|
||||
runtime=Tool.Runtime(**runtime)
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from abc import abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from msal_extensions.persistence import ABC
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
|
||||
@@ -17,9 +17,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
|
||||
@@ -3,7 +3,8 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileVar
|
||||
@@ -28,8 +29,12 @@ class Tool(BaseModel, ABC):
|
||||
description: ToolDescription = None
|
||||
is_team_authorization: bool = False
|
||||
|
||||
@validator('parameters', pre=True, always=True)
|
||||
def set_parameters(cls, v, values):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@classmethod
|
||||
@field_validator('parameters', mode='before')
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
||||
return v or []
|
||||
|
||||
class Runtime(BaseModel):
|
||||
@@ -65,9 +70,9 @@ class Tool(BaseModel, ABC):
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
identity=self.identity.copy() if self.identity else None,
|
||||
identity=self.identity.model_copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
description=self.description.model_copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
class Output(BaseModel):
|
||||
type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]']
|
||||
children: Optional[dict[str, 'Output']]
|
||||
children: Optional[dict[str, 'Output']] = None
|
||||
|
||||
variables: list[VariableSelector]
|
||||
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
@@ -14,15 +14,18 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
class Authorization(BaseModel):
|
||||
# TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually.
|
||||
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
|
||||
class Config(BaseModel):
|
||||
type: Literal[None, 'basic', 'bearer', 'custom']
|
||||
api_key: Union[None, str]
|
||||
header: Union[None, str]
|
||||
api_key: Union[None, str] = None
|
||||
header: Union[None, str] = None
|
||||
|
||||
type: Literal['no-auth', 'api-key']
|
||||
config: Optional[Config]
|
||||
|
||||
@validator('config', always=True, pre=True)
|
||||
@classmethod
|
||||
@field_validator('config', mode='before')
|
||||
def check_config(cls, v, values):
|
||||
"""
|
||||
Check config, if type is no-auth, config should be None, otherwise it should be a dict.
|
||||
@@ -37,7 +40,7 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
|
||||
class Body(BaseModel):
|
||||
type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
|
||||
data: Union[None, str]
|
||||
data: Union[None, str] = None
|
||||
|
||||
class Timeout(BaseModel):
|
||||
connect: Optional[int] = MAX_CONNECT_TIMEOUT
|
||||
@@ -50,5 +53,5 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
headers: str
|
||||
params: str
|
||||
body: Optional[Body]
|
||||
timeout: Optional[Timeout]
|
||||
timeout: Optional[Timeout] = None
|
||||
mask_authorization_header: Optional[bool] = True
|
||||
|
||||
@@ -39,7 +39,7 @@ class HttpRequestNode(BaseNode):
|
||||
"type": "none"
|
||||
},
|
||||
"timeout": {
|
||||
**HTTP_REQUEST_DEFAULT_TIMEOUT.dict(),
|
||||
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
|
||||
"max_connect_timeout": MAX_CONNECT_TIMEOUT,
|
||||
"max_read_timeout": MAX_READ_TIMEOUT,
|
||||
"max_write_timeout": MAX_WRITE_TIMEOUT,
|
||||
|
||||
@@ -7,7 +7,7 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
"""
|
||||
Iteration Node Data.
|
||||
"""
|
||||
parent_loop_id: Optional[str] # redundant field, not used currently
|
||||
parent_loop_id: Optional[str] = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class MultipleRetrievalConfig(BaseModel):
|
||||
Multiple Retrieval Config.
|
||||
"""
|
||||
top_k: int
|
||||
score_threshold: Optional[float]
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_model: RerankingModelConfig
|
||||
|
||||
|
||||
@@ -47,5 +47,5 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
query_variable_selector: list[str]
|
||||
dataset_ids: list[str]
|
||||
retrieval_mode: Literal['single', 'multiple']
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig]
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig]
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
@@ -21,12 +21,13 @@ class ParameterConfig(BaseModel):
|
||||
"""
|
||||
name: str
|
||||
type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]']
|
||||
options: Optional[list[str]]
|
||||
options: Optional[list[str]] = None
|
||||
description: str
|
||||
required: bool
|
||||
|
||||
@validator('name', pre=True, always=True)
|
||||
def validate_name(cls, value):
|
||||
@classmethod
|
||||
@field_validator('name', mode='before')
|
||||
def validate_name(cls, value) -> str:
|
||||
if not value:
|
||||
raise ValueError('Parameter name is required')
|
||||
if value in ['__reason', '__is_success']:
|
||||
@@ -40,12 +41,13 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
instruction: Optional[str]
|
||||
memory: Optional[MemoryConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
reasoning_mode: Literal['function_call', 'prompt']
|
||||
|
||||
@validator('reasoning_mode', pre=True, always=True)
|
||||
def set_reasoning_mode(cls, v):
|
||||
@classmethod
|
||||
@field_validator('reasoning_mode', mode='before')
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or 'function_call'
|
||||
|
||||
def get_parameter_json_schema(self) -> dict:
|
||||
|
||||
@@ -32,5 +32,5 @@ class QuestionClassifierNodeData(BaseNodeData):
|
||||
type: str = 'question-classifier'
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: Optional[str]
|
||||
memory: Optional[MemoryConfig]
|
||||
instruction: Optional[str] = None
|
||||
memory: Optional[MemoryConfig] = None
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
@@ -13,13 +14,14 @@ class ToolEntity(BaseModel):
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
|
||||
@validator('tool_configurations', pre=True, always=True)
|
||||
def validate_tool_configurations(cls, value, values):
|
||||
@classmethod
|
||||
@field_validator('tool_configurations', mode='before')
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo) -> dict[str, Any]:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError('tool_configurations must be a dictionary')
|
||||
|
||||
for key in values.get('tool_configurations', {}).keys():
|
||||
value = values.get('tool_configurations', {}).get(key)
|
||||
for key in values.data.get('tool_configurations', {}).keys():
|
||||
value = values.data.get('tool_configurations', {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f'{key} must be a string')
|
||||
|
||||
@@ -30,10 +32,11 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal['mixed', 'variable', 'constant']
|
||||
|
||||
@validator('type', pre=True, always=True)
|
||||
def check_type(cls, value, values):
|
||||
@classmethod
|
||||
@field_validator('type', mode='before')
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = values.get('value')
|
||||
value = validation_info.data.get('value')
|
||||
if typ == 'mixed' and not isinstance(value, str):
|
||||
raise ValueError('value must be a string')
|
||||
elif typ == 'variable':
|
||||
@@ -45,7 +48,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
elif typ == 'constant' and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError('value must be a string, int, float, or bool')
|
||||
return typ
|
||||
|
||||
|
||||
"""
|
||||
Tool Node Schema
|
||||
"""
|
||||
|
||||
@@ -30,4 +30,4 @@ class VariableAssignerNodeData(BaseNodeData):
|
||||
type: str = 'variable-assigner'
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
||||
advanced_settings: Optional[AdvancedSettings]
|
||||
advanced_settings: Optional[AdvancedSettings] = None
|
||||
@@ -592,7 +592,7 @@ class WorkflowEngineManager:
|
||||
node_data=current_iteration_node.node_data,
|
||||
inputs=workflow_run_state.current_iteration_state.inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=workflow_run_state.current_iteration_state.metadata.dict()
|
||||
metadata=workflow_run_state.current_iteration_state.metadata.model_dump()
|
||||
)
|
||||
|
||||
# add steps
|
||||
|
||||
Reference in New Issue
Block a user