improve: mordernizing validation by migrating pydantic from 1.x to 2.x (#4592)

This commit is contained in:
Bowen Liang
2024-06-14 01:05:37 +08:00
committed by GitHub
parent e8afc416dd
commit f976740b57
87 changed files with 697 additions and 300 deletions

View File

@@ -58,7 +58,7 @@ class AgentChatAppRunner(AppRunner):
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
@@ -69,8 +69,8 @@ class AgentChatAppRunner(AppRunner):
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
model=application_generate_entity.model_config.model
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
)
memory = TokenBufferMemory(
@@ -83,7 +83,7 @@ class AgentChatAppRunner(AppRunner):
# memory(optional)
prompt_messages, _ = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
@@ -152,7 +152,7 @@ class AgentChatAppRunner(AppRunner):
# memory(optional), external data, dataset context(optional)
prompt_messages, _ = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
@@ -182,12 +182,12 @@ class AgentChatAppRunner(AppRunner):
# init model instance
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
model=application_generate_entity.model_config.model
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
)
prompt_message, _ = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
@@ -225,7 +225,7 @@ class AgentChatAppRunner(AppRunner):
application_generate_entity=application_generate_entity,
conversation=conversation,
app_config=app_config,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
config=agent_entity,
queue_manager=queue_manager,
message=message,

View File

@@ -100,7 +100,7 @@ class AppQueueManager:
:param pub_from:
:return:
"""
self._check_for_sqlalchemy_models(event.dict())
self._check_for_sqlalchemy_models(event.model_dump())
self._publish(event, pub_from)
@abstractmethod

View File

@@ -218,7 +218,7 @@ class AppRunner:
index = 0
for token in text:
chunk = LLMResultChunk(
model=app_generate_entity.model_config.model,
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
@@ -237,7 +237,7 @@ class AppRunner:
queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=app_generate_entity.model_config.model,
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage()

View File

@@ -54,7 +54,7 @@ class ChatAppRunner(AppRunner):
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
@@ -65,8 +65,8 @@ class ChatAppRunner(AppRunner):
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
model=application_generate_entity.model_config.model
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
)
memory = TokenBufferMemory(
@@ -79,7 +79,7 @@ class ChatAppRunner(AppRunner):
# memory(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
@@ -159,7 +159,7 @@ class ChatAppRunner(AppRunner):
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
config=app_config.dataset,
query=query,
invoke_from=application_generate_entity.invoke_from,
@@ -173,7 +173,7 @@ class ChatAppRunner(AppRunner):
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
@@ -194,21 +194,21 @@ class ChatAppRunner(AppRunner):
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_messages=prompt_messages
)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
model=application_generate_entity.model_config.model
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
)
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_config.parameters,
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,

View File

@@ -50,7 +50,7 @@ class CompletionAppRunner(AppRunner):
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
@@ -61,7 +61,7 @@ class CompletionAppRunner(AppRunner):
# Include: prompt template, inputs, query(optional), files(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
@@ -119,7 +119,7 @@ class CompletionAppRunner(AppRunner):
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
config=dataset_config,
query=query,
invoke_from=application_generate_entity.invoke_from,
@@ -132,7 +132,7 @@ class CompletionAppRunner(AppRunner):
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
@@ -152,21 +152,21 @@ class CompletionAppRunner(AppRunner):
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(
model_config=application_generate_entity.model_config,
model_config=application_generate_entity.model_conf,
prompt_messages=prompt_messages
)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_config.provider_model_bundle,
model=application_generate_entity.model_config.model
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
)
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_config.parameters,
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,

View File

@@ -158,8 +158,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_id = None
else:
app_model_config_id = app_config.app_model_config_id
model_provider = application_generate_entity.model_config.provider
model_id = application_generate_entity.model_config.model
model_provider = application_generate_entity.model_conf.provider
model_id = application_generate_entity.model_conf.model
override_model_configs = None
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \
and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]:

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
@@ -62,6 +62,9 @@ class ModelConfigWithCredentialsEntity(BaseModel):
parameters: dict[str, Any] = {}
stop: list[str] = []
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class AppGenerateEntity(BaseModel):
"""
@@ -93,10 +96,13 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: EasyUIBasedAppConfig
model_config: ModelConfigWithCredentialsEntity
model_conf: ModelConfigWithCredentialsEntity
query: Optional[str] = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""

View File

@@ -1,14 +1,14 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class QueueEvent(Enum):
class QueueEvent(str, Enum):
"""
QueueEvent enum
"""
@@ -47,14 +47,14 @@ class QueueLLMChunkEvent(AppQueueEvent):
"""
QueueLLMChunkEvent entity
"""
event = QueueEvent.LLM_CHUNK
event: QueueEvent = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk
class QueueIterationStartEvent(AppQueueEvent):
"""
QueueIterationStartEvent entity
"""
event = QueueEvent.ITERATION_START
event: QueueEvent = QueueEvent.ITERATION_START
node_id: str
node_type: NodeType
node_data: BaseNodeData
@@ -68,16 +68,17 @@ class QueueIterationNextEvent(AppQueueEvent):
"""
QueueIterationNextEvent entity
"""
event = QueueEvent.ITERATION_NEXT
event: QueueEvent = QueueEvent.ITERATION_NEXT
index: int
node_id: str
node_type: NodeType
node_run_index: int
output: Optional[Any] # output for the current iteration
output: Optional[Any] = None # output for the current iteration
@validator('output', pre=True, always=True)
@classmethod
@field_validator('output', mode='before')
def set_output(cls, v):
"""
Set output
@@ -92,7 +93,7 @@ class QueueIterationCompletedEvent(AppQueueEvent):
"""
QueueIterationCompletedEvent entity
"""
event = QueueEvent.ITERATION_COMPLETED
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
node_id: str
node_type: NodeType
@@ -104,7 +105,7 @@ class QueueTextChunkEvent(AppQueueEvent):
"""
QueueTextChunkEvent entity
"""
event = QueueEvent.TEXT_CHUNK
event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str
metadata: Optional[dict] = None
@@ -113,7 +114,7 @@ class QueueAgentMessageEvent(AppQueueEvent):
"""
QueueMessageEvent entity
"""
event = QueueEvent.AGENT_MESSAGE
event: QueueEvent = QueueEvent.AGENT_MESSAGE
chunk: LLMResultChunk
@@ -121,7 +122,7 @@ class QueueMessageReplaceEvent(AppQueueEvent):
"""
QueueMessageReplaceEvent entity
"""
event = QueueEvent.MESSAGE_REPLACE
event: QueueEvent = QueueEvent.MESSAGE_REPLACE
text: str
@@ -129,7 +130,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
QueueRetrieverResourcesEvent entity
"""
event = QueueEvent.RETRIEVER_RESOURCES
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
@@ -137,7 +138,7 @@ class QueueAnnotationReplyEvent(AppQueueEvent):
"""
QueueAnnotationReplyEvent entity
"""
event = QueueEvent.ANNOTATION_REPLY
event: QueueEvent = QueueEvent.ANNOTATION_REPLY
message_annotation_id: str
@@ -145,7 +146,7 @@ class QueueMessageEndEvent(AppQueueEvent):
"""
QueueMessageEndEvent entity
"""
event = QueueEvent.MESSAGE_END
event: QueueEvent = QueueEvent.MESSAGE_END
llm_result: Optional[LLMResult] = None
@@ -153,28 +154,28 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
"""
QueueAdvancedChatMessageEndEvent entity
"""
event = QueueEvent.ADVANCED_CHAT_MESSAGE_END
event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END
class QueueWorkflowStartedEvent(AppQueueEvent):
"""
QueueWorkflowStartedEvent entity
"""
event = QueueEvent.WORKFLOW_STARTED
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
class QueueWorkflowSucceededEvent(AppQueueEvent):
"""
QueueWorkflowSucceededEvent entity
"""
event = QueueEvent.WORKFLOW_SUCCEEDED
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
class QueueWorkflowFailedEvent(AppQueueEvent):
"""
QueueWorkflowFailedEvent entity
"""
event = QueueEvent.WORKFLOW_FAILED
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
error: str
@@ -182,7 +183,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
"""
QueueNodeStartedEvent entity
"""
event = QueueEvent.NODE_STARTED
event: QueueEvent = QueueEvent.NODE_STARTED
node_id: str
node_type: NodeType
@@ -195,7 +196,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""
QueueNodeSucceededEvent entity
"""
event = QueueEvent.NODE_SUCCEEDED
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
node_id: str
node_type: NodeType
@@ -213,7 +214,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity
"""
event = QueueEvent.NODE_FAILED
event: QueueEvent = QueueEvent.NODE_FAILED
node_id: str
node_type: NodeType
@@ -230,7 +231,7 @@ class QueueAgentThoughtEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event = QueueEvent.AGENT_THOUGHT
event: QueueEvent = QueueEvent.AGENT_THOUGHT
agent_thought_id: str
@@ -238,7 +239,7 @@ class QueueMessageFileEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event = QueueEvent.MESSAGE_FILE
event: QueueEvent = QueueEvent.MESSAGE_FILE
message_file_id: str
@@ -246,15 +247,15 @@ class QueueErrorEvent(AppQueueEvent):
"""
QueueErrorEvent entity
"""
event = QueueEvent.ERROR
error: Any
event: QueueEvent = QueueEvent.ERROR
error: Any = None
class QueuePingEvent(AppQueueEvent):
"""
QueuePingEvent entity
"""
event = QueueEvent.PING
event: QueueEvent = QueueEvent.PING
class QueueStopEvent(AppQueueEvent):
@@ -270,7 +271,7 @@ class QueueStopEvent(AppQueueEvent):
OUTPUT_MODERATION = "output-moderation"
INPUT_MODERATION = "input-moderation"
event = QueueEvent.STOP
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -118,9 +118,7 @@ class ErrorStreamResponse(StreamResponse):
"""
event: StreamEvent = StreamEvent.ERROR
err: Exception
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
class MessageStreamResponse(StreamResponse):
@@ -360,7 +358,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
pre_iteration_output: Optional[Any]
pre_iteration_output: Optional[Any] = None
extras: dict = {}
event: StreamEvent = StreamEvent.ITERATION_NEXT
@@ -379,12 +377,12 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[dict]
outputs: Optional[dict] = None
created_at: int
extras: dict = None
inputs: dict = None
status: WorkflowNodeExecutionStatus
error: Optional[str]
error: Optional[str] = None
elapsed_time: float
total_tokens: int
finished_at: int

View File

@@ -16,7 +16,7 @@ class HostingModerationFeature:
:param prompt_messages: prompt messages
:return:
"""
model_config = application_generate_entity.model_config
model_config = application_generate_entity.model_conf
text = ""
for prompt_message in prompt_messages:

View File

@@ -85,7 +85,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:param stream: stream
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
self._model_config = application_generate_entity.model_config
self._model_config = application_generate_entity.model_conf
self._conversation = conversation
self._message = message