Fix/type-error (#11240)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2024-12-02 10:24:21 +08:00
committed by GitHub
parent 9601102885
commit c34bdb74e6
11 changed files with 117 additions and 161 deletions

View File

@@ -1,5 +1,6 @@
import uuid
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional
from core.agent.entities import AgentEntity
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
@@ -85,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict:
"""
Validate for agent chat app model config

View File

@@ -1,8 +1,8 @@
import logging
import threading
import uuid
from collections.abc import Generator
from typing import Any, Literal, Union, overload
from collections.abc import Generator, Mapping
from typing import Any, Union
from flask import Flask, current_app
from pydantic import ValidationError
@@ -28,34 +28,15 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[True] = True,
) -> Generator[dict, None, None]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[dict, None, None]]:
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
"""
Generate App response.
@@ -65,7 +46,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
if not stream:
if not streaming:
raise ValueError("Agent Chat App does not support blocking mode")
if not args.get("query"):
@@ -96,7 +77,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config")
tenant_id=app_model.tenant_id,
config=args["model_config"],
)
# always enable retriever resource in debugger mode
@@ -141,7 +123,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=stream,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
call_depth=0,
@@ -182,7 +164,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream,
stream=streaming,
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)