mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
Introduce Plugins (#13836)
Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: xhe <xw897002528@gmail.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: takatost <takatost@gmail.com> Co-authored-by: kurokobo <kuro664@gmail.com> Co-authored-by: Novice Lee <novicelee@NoviPro.local> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: AkaraChen <akarachen@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com> Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com> Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Novice <857526207@qq.com> Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com> Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com> Co-authored-by: eux <euxuuu@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: lotsik <lotsik@mail.ru> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com> Co-authored-by: CN-P5 <heibai2006@gmail.com> Co-authored-by: CN-P5 <heibai2006@qq.com> Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Boris Feld <lothiraldan@gmail.com> Co-authored-by: mbo <himabo@gmail.com> Co-authored-by: mabo <mabo@aeyes.ai> Co-authored-by: Warren Chen <warren.chen830@gmail.com> Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com> Co-authored-by: jiandanfeng <chenjh3@wangsu.com> Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com> Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com> Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com> Co-authored-by: Xu Song <xusong.vip@gmail.com> Co-authored-by: rayshaw001 <396301947@163.com> Co-authored-by: Ding Jiatong <dingjiatong@gmail.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: JasonVV <jasonwangiii@outlook.com> Co-authored-by: le0zh <newlight@qq.com> Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com> Co-authored-by: k-zaku <zaku99@outlook.jp> Co-authored-by: luckylhb90 <luckylhb90@gmail.com> Co-authored-by: hobo.l <hobo.l@binance.com> Co-authored-by: jiangbo721 <365065261@qq.com> Co-authored-by: 刘江波 <jiangbo721@163.com> Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com> Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: sino <sino2322@gmail.com> Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com> Co-authored-by: lowell <lowell.hu@zkteco.in> Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com> Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com> Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com> Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com> Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com> Co-authored-by: Jason <ggbbddjm@gmail.com> Co-authored-by: Xin Zhang <sjhpzx@gmail.com> Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com> Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com> Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com> Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com> Co-authored-by: Yingchun Lai <laiyingchun@apache.org> Co-authored-by: Hash Brown <hi@xzd.me> Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com> Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com> Co-authored-by: aplio <ryo.091219@gmail.com> Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com> Co-authored-by: Nam Vu <zuzoovn@gmail.com> Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com> Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com> Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com> Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com> Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp> Co-authored-by: HQidea <HQidea@users.noreply.github.com> Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com> Co-authored-by: xhe <xw897002528@gmail.com> Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com> Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com> Co-authored-by: engchina <12236799+engchina@users.noreply.github.com> Co-authored-by: engchina <atjapan2015@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: 呆萌闷油瓶 <253605712@qq.com> Co-authored-by: Kemal <kemalmeler@outlook.com> Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com> Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com> Co-authored-by: steven <sunzwj@digitalchina.com> Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com> Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com> Co-authored-by: 胡春东 <gycm520@gmail.com> Co-authored-by: Junjie.M <118170653@qq.com> Co-authored-by: MuYu <mr.muzea@gmail.com> Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com> Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com> Co-authored-by: Fei He <droxer.he@gmail.com> Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: douxc <7553076+douxc@users.noreply.github.com> Co-authored-by: liuzhenghua <1090179900@qq.com> Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com> Co-authored-by: AugNSo <song.tiankai@icloud.com> Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com> Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com> Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com> Co-authored-by: Hundredwz <1808096180@qq.com> Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
@@ -32,19 +31,16 @@ from core.model_runtime.entities import (
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,11 +58,9 @@ class BaseAgentRunner(AppRunner):
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
model_instance: ModelInstance,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
@@ -79,8 +73,6 @@ class BaseAgentRunner(AppRunner):
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
@@ -141,11 +133,10 @@ class BaseAgentRunner(AppRunner):
|
||||
agent_tool=tool,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
assert tool_entity.entity.description
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.description.llm if tool_entity.description else "",
|
||||
description=tool_entity.entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@@ -153,7 +144,7 @@ class BaseAgentRunner(AppRunner):
|
||||
},
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
@@ -186,9 +177,11 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
assert tool.entity.description
|
||||
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.identity.name if tool.identity else "unknown",
|
||||
description=tool.description.llm if tool.description else "",
|
||||
name=tool.entity.identity.name,
|
||||
description=tool.entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@@ -234,8 +227,7 @@ class BaseAgentRunner(AppRunner):
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
if dataset_tool.identity is not None:
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
@@ -320,24 +312,24 @@ class BaseAgentRunner(AppRunner):
|
||||
def save_agent_thought(
|
||||
self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
tool_name: str | None,
|
||||
tool_input: Union[str, dict, None],
|
||||
thought: str | None,
|
||||
observation: Union[str, dict, None],
|
||||
tool_invoke_meta: Union[str, dict, None],
|
||||
answer: str,
|
||||
answer: str | None,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage | None = None,
|
||||
):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
queried_thought = (
|
||||
updated_agent_thought = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
)
|
||||
if not queried_thought:
|
||||
raise ValueError(f"Agent thought {agent_thought.id} not found")
|
||||
agent_thought = queried_thought
|
||||
if not updated_agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
agent_thought = updated_agent_thought
|
||||
|
||||
if thought:
|
||||
agent_thought.thought = thought
|
||||
@@ -349,39 +341,39 @@ class BaseAgentRunner(AppRunner):
|
||||
if isinstance(tool_input, dict):
|
||||
try:
|
||||
tool_input = json.dumps(tool_input, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_input = json.dumps(tool_input)
|
||||
|
||||
agent_thought.tool_input = tool_input
|
||||
updated_agent_thought.tool_input = tool_input
|
||||
|
||||
if observation:
|
||||
if isinstance(observation, dict):
|
||||
try:
|
||||
observation = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
agent_thought.observation = observation
|
||||
updated_agent_thought.observation = observation
|
||||
|
||||
if answer:
|
||||
agent_thought.answer = answer
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
updated_agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
agent_thought.answer_token = llm_usage.completion_tokens
|
||||
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
updated_agent_thought.message_token = llm_usage.prompt_tokens
|
||||
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
updated_agent_thought.answer_token = llm_usage.completion_tokens
|
||||
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
updated_agent_thought.tokens = llm_usage.total_tokens
|
||||
updated_agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||
labels = updated_agent_thought.tool_labels or {}
|
||||
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
@@ -392,42 +384,20 @@ class BaseAgentRunner(AppRunner):
|
||||
else:
|
||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
updated_agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
if tool_invoke_meta is not None:
|
||||
if isinstance(tool_invoke_meta, dict):
|
||||
try:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
||||
|
||||
agent_thought.tool_meta_str = tool_invoke_meta
|
||||
updated_agent_thought.tool_meta_str = tool_invoke_meta
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
queried_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not queried_variables:
|
||||
return
|
||||
|
||||
db_variables = queried_variables
|
||||
|
||||
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize agent history
|
||||
@@ -464,11 +434,11 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_inputs = {tool: {} for tool in tools}
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_responses = dict.fromkeys(tools, agent_thought.observation)
|
||||
|
||||
for tool in tools:
|
||||
@@ -515,7 +485,11 @@ class BaseAgentRunner(AppRunner):
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if not files:
|
||||
return UserPromptMessage(content=message.query)
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
if message.app_model_config:
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
else:
|
||||
file_extra_config = None
|
||||
|
||||
if not file_extra_config:
|
||||
return UserPromptMessage(content=message.query)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
@@ -18,8 +18,8 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
@@ -27,11 +27,11 @@ from models.model import Message
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] | None = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
|
||||
_instruction: str = "" # FIXME this must be str for now
|
||||
_query: str | None = None
|
||||
_prompt_messages_tools: list[PromptMessageTool] = []
|
||||
_historic_prompt_messages: list[PromptMessage]
|
||||
_agent_scratchpad: list[AgentScratchpadUnit]
|
||||
_instruction: str
|
||||
_query: str
|
||||
_prompt_messages_tools: Sequence[PromptMessageTool]
|
||||
|
||||
def run(
|
||||
self,
|
||||
@@ -42,6 +42,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
@@ -54,17 +55,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config.agent
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
|
||||
instruction = app_config.prompt_template.simple_prompt_template or ""
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
@@ -116,14 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
if not isinstance(chunks, Generator):
|
||||
raise ValueError("Expected streaming response from LLM")
|
||||
|
||||
# check llm result
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
@@ -143,25 +139,25 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += chunk
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought += chunk
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += chunk
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
if self._agent_scratchpad is not None:
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
@@ -256,8 +252,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
)
|
||||
if self.variables_pool is not None and self.db_variables_pool is not None:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@@ -275,7 +269,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
tool_instances: Mapping[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
@@ -315,11 +309,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
)
|
||||
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as is not None and self.variables_pool:
|
||||
# FIXME the save_as type is confusing, it should be a string or not
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
|
||||
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
@@ -342,7 +332,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return instruction
|
||||
@@ -379,7 +369,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: Optional[list[PromptMessage]] = None
|
||||
self, current_session_messages: list[PromptMessage] | None = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
@@ -391,8 +381,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
if not isinstance(message.content, str | None):
|
||||
raise NotImplementedError("expected str type")
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
@@ -411,9 +400,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
except:
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if not current_scratchpad:
|
||||
continue
|
||||
if isinstance(message.content, str):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
|
||||
@@ -19,8 +19,8 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
if not self.app_config.agent:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
assert self.app_config.agent
|
||||
assert self.app_config.agent.prompt
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if not prompt_entity:
|
||||
@@ -83,8 +83,10 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
|
||||
provider_type: Literal["builtin", "api", "workflow"]
|
||||
provider_type: ToolProviderType
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
plugin_unique_identifier: str | None = None
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
@@ -66,7 +69,7 @@ class AgentEntity(BaseModel):
|
||||
Agent Entity.
|
||||
"""
|
||||
|
||||
class Strategy(Enum):
|
||||
class Strategy(StrEnum):
|
||||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
@@ -78,5 +81,13 @@ class AgentEntity(BaseModel):
|
||||
model: str
|
||||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: list[AgentToolEntity] | None = None
|
||||
tools: Optional[list[AgentToolEntity]] = None
|
||||
max_iteration: int = 5
|
||||
|
||||
|
||||
class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
Agent Invoke Message.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -46,18 +46,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
@@ -107,7 +109,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if self.stream_tool_call and isinstance(chunks, Generator):
|
||||
if isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
@@ -124,7 +126,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
@@ -140,7 +142,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
|
||||
else:
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
@@ -151,7 +153,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
@@ -183,8 +185,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
usage=result.usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
@@ -243,15 +243,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=self.message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
if self.variables_pool:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as
|
||||
)
|
||||
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
@@ -303,8 +300,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
if self.variables_pool and self.db_variables_pool:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@@ -335,9 +330,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(
|
||||
self, llm_result_chunk: LLMResultChunk
|
||||
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@@ -360,7 +353,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
@@ -383,9 +376,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(
|
||||
self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
|
||||
) -> list[PromptMessage]:
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
|
||||
89
api/core/agent/plugin_entities.py
Normal file
89
api/core/agent/plugin_entities.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.entities.parameter_entities import CommonParameterType
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolProviderIdentity,
|
||||
)
|
||||
|
||||
|
||||
class AgentStrategyProviderIdentity(ToolProviderIdentity):
|
||||
"""
|
||||
Inherits from ToolProviderIdentity, without any additional fields.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentStrategyParameter(PluginParameter):
|
||||
class AgentStrategyParameterType(enum.StrEnum):
|
||||
"""
|
||||
Keep all the types from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class AgentStrategyProviderEntity(BaseModel):
|
||||
identity: AgentStrategyProviderIdentity
|
||||
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
|
||||
|
||||
|
||||
class AgentStrategyIdentity(ToolIdentity):
|
||||
"""
|
||||
Inherits from ToolIdentity, without any additional fields.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentStrategyEntity(BaseModel):
|
||||
identity: AgentStrategyIdentity
|
||||
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
|
||||
return v or []
|
||||
|
||||
|
||||
class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
|
||||
strategies: list[AgentStrategyEntity] = Field(default_factory=list)
|
||||
42
api/core/agent/strategy/base.py
Normal file
42
api/core/agent/strategy/base.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
|
||||
|
||||
class BaseAgentStrategy(ABC):
|
||||
"""
|
||||
Agent Strategy
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent strategy.
|
||||
"""
|
||||
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
"""
|
||||
Get the parameters for the agent strategy.
|
||||
"""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
pass
|
||||
59
api/core/agent/strategy/plugin.py
Normal file
59
api/core/agent/strategy/plugin.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||
from core.agent.strategy.base import BaseAgentStrategy
|
||||
from core.plugin.manager.agent import PluginAgentManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
||||
class PluginAgentStrategy(BaseAgentStrategy):
|
||||
"""
|
||||
Agent Strategy
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
declaration: AgentStrategyEntity
|
||||
|
||||
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
|
||||
self.tenant_id = tenant_id
|
||||
self.declaration = declaration
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Initialize the parameters for the agent strategy.
|
||||
"""
|
||||
for parameter in self.declaration.parameters:
|
||||
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
|
||||
return params
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent strategy.
|
||||
"""
|
||||
manager = PluginAgentManager()
|
||||
|
||||
initialized_params = self.initialize_parameters(params)
|
||||
params = convert_parameters_to_plugin_format(initialized_params)
|
||||
|
||||
yield from manager.invoke(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
agent_provider=self.declaration.identity.provider,
|
||||
agent_strategy=self.declaration.identity.name,
|
||||
agent_params=params,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
Reference in New Issue
Block a user