mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 19:06:51 +08:00
Initial commit
This commit is contained in:
@@ -0,0 +1,77 @@
|
||||
from typing import Any, List, Dict, Union
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
|
||||
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
||||
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
conversation: Conversation
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: Union[StreamableChatOpenAI | StreamableOpenAI]
|
||||
memory_key: str = "chat_history"
|
||||
max_token_limit: int = 2000
|
||||
message_limit: int = 10
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
"""String buffer of memory."""
|
||||
# fetch limited messages desc, and return reversed
|
||||
messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
Message.answer_tokens > 0
|
||||
).order_by(Message.created_at.desc()).limit(self.message_limit).all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
|
||||
chat_messages: List[BaseMessage] = []
|
||||
for message in messages:
|
||||
chat_messages.append(HumanMessage(content=message.query))
|
||||
chat_messages.append(AIMessage(content=message.answer))
|
||||
|
||||
if not chat_messages:
|
||||
return chat_messages
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit and chat_messages:
|
||||
pruned_memory.append(chat_messages.pop(0))
|
||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
||||
|
||||
return chat_messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
buffer: Any = self.buffer
|
||||
if self.return_messages:
|
||||
final_buffer: Any = buffer
|
||||
else:
|
||||
final_buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed"""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
pass
|
||||
@@ -0,0 +1,36 @@
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel
|
||||
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
|
||||
|
||||
class ReadOnlyConversationTokenDBStringBufferSharedMemory(BaseChatMemory):
|
||||
memory: ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Return memory variables."""
|
||||
return self.memory.memory_variables
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Load memory variables from memory."""
|
||||
buffer: Any = self.memory.buffer
|
||||
|
||||
final_buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.memory.human_prefix,
|
||||
ai_prefix=self.memory.ai_prefix,
|
||||
)
|
||||
|
||||
return {self.memory.memory_key: final_buffer}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed"""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
pass
|
||||
Reference in New Issue
Block a user