mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat: regenerate in Chat, agent and Chatflow app (#7661)
This commit is contained in:
@@ -1 +1,2 @@
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
@@ -109,6 +109,7 @@ class ChatMessageApi(Resource):
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -105,8 +105,6 @@ class ChatMessageListApi(Resource):
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
|
||||
|
||||
|
||||
@@ -166,6 +166,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
parser.add_argument("query", type=str, required=True, location="json", default="")
|
||||
parser.add_argument("files", type=list, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@@ -100,6 +100,7 @@ class ChatApi(InstalledAppResource):
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class MessageListApi(InstalledAppResource):
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@@ -54,6 +54,7 @@ class MessageListApi(Resource):
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
|
||||
@@ -96,6 +96,7 @@ class ChatApi(WebApiResource):
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -57,6 +57,7 @@ class MessageListApi(WebApiResource):
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
@@ -89,7 +90,7 @@ class MessageListApi(WebApiResource):
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@@ -32,6 +32,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolRuntimeVariablePool,
|
||||
@@ -441,10 +442,12 @@ class BaseAgentRunner(AppRunner):
|
||||
.filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.order_by(Message.created_at.asc())
|
||||
.order_by(Message.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
messages = list(reversed(extract_thread_messages(messages)))
|
||||
|
||||
for message in messages:
|
||||
if message.id == self.message.id:
|
||||
continue
|
||||
|
||||
@@ -121,6 +121,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@@ -127,6 +127,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@@ -128,6 +128,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@@ -218,6 +218,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
|
||||
@@ -122,6 +122,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
@@ -138,6 +139,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
@@ -149,6 +151,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
query: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowRun
|
||||
@@ -33,8 +34,17 @@ class TokenBufferMemory:
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
query = (
|
||||
db.session.query(Message.id, Message.query, Message.answer, Message.created_at, Message.workflow_run_id)
|
||||
.filter(Message.conversation_id == self.conversation.id, Message.answer != "")
|
||||
db.session.query(
|
||||
Message.id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message.created_at,
|
||||
Message.workflow_run_id,
|
||||
Message.parent_message_id,
|
||||
)
|
||||
.filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
@@ -45,7 +55,12 @@ class TokenBufferMemory:
|
||||
|
||||
messages = query.limit(message_limit).all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
thread_messages.pop(0)
|
||||
messages = list(reversed(thread_messages))
|
||||
|
||||
message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
|
||||
22
api/core/prompt/utils/extract_thread_messages.py
Normal file
22
api/core/prompt/utils/extract_thread_messages.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from constants import UUID_NIL
|
||||
|
||||
|
||||
def extract_thread_messages(messages: list[dict]) -> list[dict]:
|
||||
thread_messages = []
|
||||
next_message = None
|
||||
|
||||
for message in messages:
|
||||
if not message.parent_message_id:
|
||||
# If the message is regenerated and does not have a parent message, it is the start of a new thread
|
||||
thread_messages.append(message)
|
||||
break
|
||||
|
||||
if not next_message:
|
||||
thread_messages.append(message)
|
||||
next_message = message.parent_message_id
|
||||
else:
|
||||
if next_message in {message.id, UUID_NIL}:
|
||||
thread_messages.append(message)
|
||||
next_message = message.parent_message_id
|
||||
|
||||
return thread_messages
|
||||
@@ -75,6 +75,7 @@ message_detail_fields = {
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
}
|
||||
|
||||
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
|
||||
|
||||
@@ -62,6 +62,7 @@ retriever_resource_fields = {
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add parent_message_id to messages
|
||||
|
||||
Revision ID: d57ba9ebb251
|
||||
Revises: 675b5321501b
|
||||
Create Date: 2024-09-11 10:12:45.826265
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd57ba9ebb251'
|
||||
down_revision = '675b5321501b'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('parent_message_id', models.types.StringUUID(), nullable=True))
|
||||
|
||||
# Set parent_message_id for existing messages to uuid_nil() to distinguish them from new messages with actual parent IDs or NULLs
|
||||
op.execute('UPDATE messages SET parent_message_id = uuid_nil() WHERE parent_message_id IS NULL')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.drop_column('parent_message_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -710,6 +710,7 @@ class Message(db.Model):
|
||||
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
parent_message_id = db.Column(StringUUID, nullable=True)
|
||||
provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0"))
|
||||
total_price = db.Column(db.Numeric(10, 7))
|
||||
currency = db.Column(db.String(255), nullable=False)
|
||||
|
||||
@@ -34,6 +34,7 @@ class MessageService:
|
||||
conversation_id: str,
|
||||
first_id: Optional[str],
|
||||
limit: int,
|
||||
order: str = "asc",
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
@@ -91,7 +92,8 @@ class MessageService:
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
if order == "asc":
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
|
||||
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
|
||||
|
||||
class TestMessage:
|
||||
def __init__(self, id, parent_message_id):
|
||||
self.id = id
|
||||
self.parent_message_id = parent_message_id
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
def test_extract_thread_messages_single_message():
|
||||
messages = [TestMessage(str(uuid4()), UUID_NIL)]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 1
|
||||
assert result[0] == messages[0]
|
||||
|
||||
|
||||
def test_extract_thread_messages_linear_thread():
|
||||
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id5, id4),
|
||||
TestMessage(id4, id3),
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, id1),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 5
|
||||
assert [msg["id"] for msg in result] == [id5, id4, id3, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_branched_thread():
|
||||
id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id4, id2),
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, id1),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert [msg["id"] for msg in result] == [id4, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_empty_list():
|
||||
messages = []
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_extract_thread_messages_partially_loaded():
|
||||
id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, id1),
|
||||
TestMessage(id1, id0),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert [msg["id"] for msg in result] == [id3, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_legacy_messages():
|
||||
id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id3, UUID_NIL),
|
||||
TestMessage(id2, UUID_NIL),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert [msg["id"] for msg in result] == [id3, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_mixed_with_legacy_messages():
|
||||
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id5, id4),
|
||||
TestMessage(id4, id2),
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, UUID_NIL),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert [msg["id"] for msg in result] == [id5, id4, id2, id1]
|
||||
Reference in New Issue
Block a user