mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 19:06:51 +08:00
Initial commit
This commit is contained in:
52
api/core/__init__.py
Normal file
52
api/core/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import langchain
|
||||
from flask import Flask
|
||||
from jieba.analyse import default_tfidf
|
||||
from langchain import set_handler
|
||||
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
|
||||
from llama_index import IndexStructType, QueryMode
|
||||
from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
|
||||
from core.index.keyword_table.stopwords import STOPWORDS
|
||||
from core.prompt.prompt_template import OneLineFormatter
|
||||
from core.vector_store.vector_store import VectorStore
|
||||
from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
|
||||
|
||||
|
||||
class HostedOpenAICredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedLLMCredentials(BaseModel):
|
||||
openai: Optional[HostedOpenAICredential] = None
|
||||
|
||||
|
||||
hosted_llm_credentials = HostedLLMCredentials()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
formatter = OneLineFormatter()
|
||||
DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
|
||||
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
|
||||
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
|
||||
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
|
||||
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
|
||||
}
|
||||
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
|
||||
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
|
||||
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
|
||||
}
|
||||
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
langchain.verbose = True
|
||||
set_handler(DifyStdOutCallbackHandler())
|
||||
|
||||
if app.config.get("OPENAI_API_KEY"):
|
||||
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
|
||||
89
api/core/agent/agent_builder.py
Normal file
89
api/core/agent/agent_builder.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
|
||||
|
||||
class AgentBuilder:
|
||||
@classmethod
|
||||
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
|
||||
dataset_tool_callback_handler: DatasetToolCallbackHandler,
|
||||
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
|
||||
llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=agent_loop_gather_callback_handler.model_name,
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
callback_manager=llm_callback_manager
|
||||
)
|
||||
|
||||
tool_callback_manager = CallbackManager([
|
||||
agent_loop_gather_callback_handler,
|
||||
dataset_tool_callback_handler,
|
||||
DifyStdOutCallbackHandler()
|
||||
])
|
||||
|
||||
for tool in tools:
|
||||
tool.callback_manager = tool_callback_manager
|
||||
|
||||
prompt = cls.build_agent_prompt_template(
|
||||
tools=tools,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
agent_llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
agent = cls.build_agent(agent_llm_chain=agent_llm_chain, memory=memory)
|
||||
|
||||
agent_callback_manager = CallbackManager(
|
||||
[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
agent_chain = AgentExecutor.from_agent_and_tools(
|
||||
tools=tools,
|
||||
agent=agent,
|
||||
memory=memory,
|
||||
callback_manager=agent_callback_manager,
|
||||
max_iterations=6,
|
||||
early_stopping_method="generate",
|
||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||
)
|
||||
|
||||
return agent_chain
|
||||
|
||||
@classmethod
|
||||
def build_agent_prompt_template(cls, tools, memory: Optional[BaseChatMemory]):
|
||||
if memory:
|
||||
prompt = ConversationalAgent.create_prompt(
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def build_agent(cls, agent_llm_chain: LLMChain, memory: Optional[BaseChatMemory]):
|
||||
if memory:
|
||||
agent = ConversationalAgent(
|
||||
llm_chain=agent_llm_chain
|
||||
)
|
||||
else:
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=agent_llm_chain
|
||||
)
|
||||
|
||||
return agent
|
||||
178
api/core/callback_handler/agent_loop_gather_callback_handler.py
Normal file
178
api/core/callback_handler/agent_loop_gather_callback_handler.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.model_name = model_name
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
self.current_chain = None
|
||||
|
||||
@property
|
||||
def agent_loops(self) -> List[AgentLoop]:
|
||||
return self._agent_loops
|
||||
|
||||
def clear_agent_loops(self) -> None:
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
# serialized={'name': 'OpenAI'}
|
||||
# prompts=['Answer the following questions...\nThought:']
|
||||
# kwargs={}
|
||||
if not self._current_loop:
|
||||
# Agent start with a LLM query
|
||||
self._current_loop = AgentLoop(
|
||||
position=len(self._agent_loops) + 1,
|
||||
prompt=prompts[0],
|
||||
status='llm_started',
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
# kwargs={}
|
||||
if self._current_loop and self._current_loop.status == 'llm_started':
|
||||
self._current_loop.status = 'llm_end'
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
self._current_loop.completion = response.generations[0][0].text
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
# kwargs={'color': 'green', 'llm_prefix': 'Thought:', 'observation_prefix': 'Observation: '}
|
||||
# input_str='action-input'
|
||||
# serialized={'description': 'A search engine. Useful for when you need to answer questions about current events. Input should be a search query.', 'name': 'Search'}
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
|
||||
if self._current_loop and self._current_loop.status == 'llm_end':
|
||||
self._current_loop.status = 'agent_action'
|
||||
self._current_loop.thought = thought
|
||||
self._current_loop.tool_name = tool
|
||||
self._current_loop.tool_input = tool_input
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
# kwargs={'name': 'Search'}
|
||||
# llm_prefix='Thought:'
|
||||
# observation_prefix='Observation: '
|
||||
# output='53 years'
|
||||
|
||||
if self._current_loop and self._current_loop.status == 'agent_action' and output and output != 'None':
|
||||
self._current_loop.status = 'tool_end'
|
||||
self._current_loop.tool_output = output
|
||||
self._current_loop.completed = True
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.error(error)
|
||||
self._agent_loops = []
|
||||
self._current_loop = None
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
# Final Answer
|
||||
if self._current_loop and (self._current_loop.status == 'llm_end' or self._current_loop.status == 'agent_action'):
|
||||
self._current_loop.status = 'agent_finish'
|
||||
self._current_loop.completed = True
|
||||
self._current_loop.completed_at = time.perf_counter()
|
||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||
|
||||
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
|
||||
|
||||
self._agent_loops.append(self._current_loop)
|
||||
self._current_loop = None
|
||||
elif not self._current_loop and self._agent_loops:
|
||||
self._agent_loops[-1].status = 'agent_finish'
|
||||
117
api/core/callback_handler/dataset_tool_callback_handler.py
Normal file
117
api/core/callback_handler/dataset_tool_callback_handler.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.queries = []
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_name = serialized.get('name')
|
||||
dataset_id = tool_name[len("dataset-"):]
|
||||
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str))
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# kwargs={'name': 'Search'}
|
||||
# llm_prefix='Thought:'
|
||||
# observation_prefix='Observation: '
|
||||
# output='53 years'
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
pass
|
||||
23
api/core/callback_handler/entity/agent_loop.py
Normal file
23
api/core/callback_handler/entity/agent_loop.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentLoop(BaseModel):
|
||||
position: int = 1
|
||||
|
||||
thought: str = None
|
||||
tool_name: str = None
|
||||
tool_input: str = None
|
||||
tool_output: str = None
|
||||
|
||||
prompt: str = None
|
||||
prompt_tokens: int = None
|
||||
completion: str = None
|
||||
completion_tokens: int = None
|
||||
|
||||
latency: float = None
|
||||
|
||||
status: str = 'llm_started'
|
||||
completed: bool = False
|
||||
|
||||
started_at: float = None
|
||||
completed_at: float = None
|
||||
16
api/core/callback_handler/entity/chain_result.py
Normal file
16
api/core/callback_handler/entity/chain_result.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChainResult(BaseModel):
|
||||
type: str = None
|
||||
prompt: dict = None
|
||||
completion: dict = None
|
||||
|
||||
status: str = 'chain_started'
|
||||
completed: bool = False
|
||||
|
||||
started_at: float = None
|
||||
completed_at: float = None
|
||||
|
||||
agent_result: dict = None
|
||||
"""only when type is 'AgentExecutor'"""
|
||||
6
api/core/callback_handler/entity/dataset_query.py
Normal file
6
api/core/callback_handler/entity/dataset_query.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DatasetQueryObj(BaseModel):
|
||||
dataset_id: str = None
|
||||
query: str = None
|
||||
9
api/core/callback_handler/entity/llm_message.py
Normal file
9
api/core/callback_handler/entity/llm_message.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
prompt: str = ''
|
||||
prompt_tokens: int = 0
|
||||
completion: str = ''
|
||||
completion_tokens: int = 0
|
||||
latency: float = 0.0
|
||||
38
api/core/callback_handler/index_tool_callback_handler.py
Normal file
38
api/core/callback_handler/index_tool_callback_handler.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from llama_index import Response
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
|
||||
class IndexToolCallbackHandler:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._response = None
|
||||
|
||||
@property
|
||||
def response(self) -> Response:
|
||||
return self._response
|
||||
|
||||
def on_tool_end(self, response: Response) -> None:
|
||||
"""Handle tool end."""
|
||||
self._response = response
|
||||
|
||||
|
||||
class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
|
||||
"""Callback handler for dataset tool."""
|
||||
|
||||
def __init__(self, dataset_id: str) -> None:
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
|
||||
def on_tool_end(self, response: Response) -> None:
|
||||
"""Handle tool end."""
|
||||
for node in response.source_nodes:
|
||||
index_node_id = node.node.doc_id
|
||||
|
||||
# add hit count to document segment
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset_id,
|
||||
DocumentSegment.index_node_id == index_node_id
|
||||
).update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
147
api/core/callback_handler/llm_callback_handler.py
Normal file
147
api/core/callback_handler/llm_callback_handler.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
|
||||
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
|
||||
|
||||
class LLMCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
self.llm = llm
|
||||
self.llm_message = LLMMessage()
|
||||
self.start_at = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
if 'Chat' in serialized['name']:
|
||||
real_prompts = []
|
||||
messages = []
|
||||
for prompt in prompts:
|
||||
role, content = prompt.split(': ', maxsplit=1)
|
||||
if role == 'human':
|
||||
role = 'user'
|
||||
message = HumanMessage(content=content)
|
||||
elif role == 'ai':
|
||||
role = 'assistant'
|
||||
message = AIMessage(content=content)
|
||||
else:
|
||||
message = SystemMessage(content=content)
|
||||
|
||||
real_prompt = {
|
||||
"role": role,
|
||||
"text": content
|
||||
}
|
||||
real_prompts.append(real_prompt)
|
||||
messages.append(message)
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
|
||||
else:
|
||||
self.llm_message.prompt = [{
|
||||
"role": 'user',
|
||||
"text": prompts[0]
|
||||
}]
|
||||
|
||||
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
end_at = time.perf_counter()
|
||||
self.llm_message.latency = end_at - self.start_at
|
||||
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
self.conversation_message_task.append_message_text(token)
|
||||
self.llm_message.completion += token
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
if isinstance(error, ConversationTaskStoppedException):
|
||||
if self.conversation_message_task.streaming:
|
||||
end_at = time.perf_counter()
|
||||
self.llm_message.latency = end_at - self.start_at
|
||||
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
||||
else:
|
||||
logging.error(error)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
137
api/core/callback_handler/main_chain_gather_callback_handler.py
Normal file
137
api/core/callback_handler/main_chain_gather_callback_handler.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.constant import llm_constant
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
|
||||
|
||||
class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler(
|
||||
llm_constant.agent_model_name,
|
||||
conversation_message_task
|
||||
)
|
||||
|
||||
def clear_chain_results(self) -> None:
|
||||
self._current_chain_result = None
|
||||
self._current_chain_message = None
|
||||
self.agent_loop_gather_callback_handler.current_chain = None
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return True
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
if not self._current_chain_result:
|
||||
self._current_chain_result = ChainResult(
|
||||
type=serialized['name'],
|
||||
prompt=inputs,
|
||||
started_at=time.perf_counter()
|
||||
)
|
||||
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
||||
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
if self._current_chain_result and self._current_chain_result.status == 'chain_started':
|
||||
self._current_chain_result.status = 'chain_ended'
|
||||
self._current_chain_result.completion = outputs
|
||||
self._current_chain_result.completed = True
|
||||
self._current_chain_result.completed_at = time.perf_counter()
|
||||
|
||||
self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result)
|
||||
|
||||
self.clear_chain_results()
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
self.clear_chain_results()
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logging.error(error)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
logging.error(error)
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run on additional input from chains and agents."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
pass
|
||||
127
api/core/callback_handler/std_out_callback_handler.py
Normal file
127
api/core/callback_handler/std_out_callback_handler.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def __init__(self, color: Optional[str] = None) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.color = color
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
print_text("\n[on_llm_start]\n", color='blue')
|
||||
|
||||
if 'Chat' in serialized['name']:
|
||||
for prompt in prompts:
|
||||
print_text(prompt + "\n", color='blue')
|
||||
else:
|
||||
print_text(prompts[0] + "\n", color='blue')
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_llm_end]\nOutput: " + str(response.generations[0][0].text) + "\nllm_output: " + str(
|
||||
response.llm_output) + "\n", color='blue')
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
class_name = serialized["name"]
|
||||
print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink')
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_chain_error]\nError: " + str(error) + "\n", color='pink')
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] " + str(serialized), color='yellow')
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
|
||||
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
|
||||
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
print_text("\n[on_tool_end]\n", color='yellow')
|
||||
if observation_prefix:
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(output, color='yellow')
|
||||
if llm_prefix:
|
||||
print_text(f"\n{llm_prefix}")
|
||||
print_text("\n")
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='yellow')
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run when agent ends."""
|
||||
print_text("\n[on_text] " + text + "\n", color=color if color else self.color, end=end)
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
|
||||
|
||||
|
||||
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
|
||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
34
api/core/chain/chain_builder.py
Normal file
34
api/core/chain/chain_builder.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.callbacks import CallbackManager
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||
from core.chain.tool_chain import ToolChain
|
||||
|
||||
|
||||
class ChainBuilder:
|
||||
@classmethod
|
||||
def to_tool_chain(cls, tool, **kwargs) -> ToolChain:
|
||||
return ToolChain(
|
||||
tool=tool,
|
||||
input_key=kwargs.get('input_key', 'input'),
|
||||
output_key=kwargs.get('output_key', 'tool_output'),
|
||||
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[
|
||||
SensitiveWordAvoidanceChain]:
|
||||
sensitive_words = tool_config.get("words", "")
|
||||
if tool_config.get("enabled", False) \
|
||||
and sensitive_words:
|
||||
return SensitiveWordAvoidanceChain(
|
||||
sensitive_words=sensitive_words.split(","),
|
||||
canned_response=tool_config.get("canned_response", ''),
|
||||
output_key="sensitive_word_avoidance_output",
|
||||
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return None
|
||||
116
api/core/chain/main_chain_builder.py
Normal file
116
api/core/chain/main_chain_builder.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from langchain.callbacks import SharedCallbackManager
|
||||
from langchain.chains import SequentialChain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
from core.agent.agent_builder import AgentBuilder
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.chain.chain_builder import ChainBuilder
|
||||
from core.constant import llm_constant
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.tool.dataset_tool_builder import DatasetToolBuilder
|
||||
|
||||
|
||||
class MainChainBuilder:
|
||||
@classmethod
|
||||
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
||||
conversation_message_task: ConversationMessageTask):
|
||||
first_input_key = "input"
|
||||
final_output_key = "output"
|
||||
|
||||
chains = []
|
||||
|
||||
chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
|
||||
# agent mode
|
||||
tool_chains, chains_output_key = cls.get_agent_chains(
|
||||
tenant_id=tenant_id,
|
||||
agent_mode=agent_mode,
|
||||
memory=memory,
|
||||
dataset_tool_callback_handler=DatasetToolCallbackHandler(conversation_message_task),
|
||||
agent_loop_gather_callback_handler=chain_callback_handler.agent_loop_gather_callback_handler
|
||||
)
|
||||
chains += tool_chains
|
||||
|
||||
if chains_output_key:
|
||||
final_output_key = chains_output_key
|
||||
|
||||
if len(chains) == 0:
|
||||
return None
|
||||
|
||||
for chain in chains:
|
||||
# do not add handler into singleton callback manager
|
||||
if not isinstance(chain.callback_manager, SharedCallbackManager):
|
||||
chain.callback_manager.add_handler(chain_callback_handler)
|
||||
|
||||
# build main chain
|
||||
overall_chain = SequentialChain(
|
||||
chains=chains,
|
||||
input_variables=[first_input_key],
|
||||
output_variables=[final_output_key],
|
||||
memory=memory, # only for use the memory prompt input key
|
||||
)
|
||||
|
||||
return overall_chain
|
||||
|
||||
@classmethod
|
||||
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
||||
dataset_tool_callback_handler: DatasetToolCallbackHandler,
|
||||
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
|
||||
# agent mode
|
||||
chains = []
|
||||
if agent_mode and agent_mode.get('enabled'):
|
||||
tools = agent_mode.get('tools', [])
|
||||
|
||||
pre_fixed_chains = []
|
||||
agent_tools = []
|
||||
for tool in tools:
|
||||
tool_type = list(tool.keys())[0]
|
||||
tool_config = list(tool.values())[0]
|
||||
if tool_type == 'sensitive-word-avoidance':
|
||||
chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
|
||||
if chain:
|
||||
pre_fixed_chains.append(chain)
|
||||
elif tool_type == "dataset":
|
||||
dataset_tool = DatasetToolBuilder.build_dataset_tool(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=tool_config.get("id"),
|
||||
response_mode='no_synthesizer', # "compact"
|
||||
callback_handler=dataset_tool_callback_handler
|
||||
)
|
||||
|
||||
if dataset_tool:
|
||||
agent_tools.append(dataset_tool)
|
||||
|
||||
# add pre-fixed chains
|
||||
chains += pre_fixed_chains
|
||||
|
||||
if len(agent_tools) == 1:
|
||||
# tool to chain
|
||||
tool_chain = ChainBuilder.to_tool_chain(tool=agent_tools[0], output_key='tool_output')
|
||||
chains.append(tool_chain)
|
||||
elif len(agent_tools) > 1:
|
||||
# build agent config
|
||||
agent_chain = AgentBuilder.to_agent_chain(
|
||||
tenant_id=tenant_id,
|
||||
tools=agent_tools,
|
||||
memory=memory,
|
||||
dataset_tool_callback_handler=dataset_tool_callback_handler,
|
||||
agent_loop_gather_callback_handler=agent_loop_gather_callback_handler
|
||||
)
|
||||
|
||||
chains.append(agent_chain)
|
||||
|
||||
final_output_key = cls.get_chains_output_key(chains)
|
||||
|
||||
return chains, final_output_key
|
||||
|
||||
@classmethod
|
||||
def get_chains_output_key(cls, chains: List[Chain]):
|
||||
if len(chains) > 0:
|
||||
return chains[-1].output_keys[0]
|
||||
return None
|
||||
42
api/core/chain/sensitive_word_avoidance_chain.py
Normal file
42
api/core/chain/sensitive_word_avoidance_chain.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceChain(Chain):
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
sensitive_words: List[str] = []
|
||||
canned_response: str = None
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "sensitive_word_avoidance_chain"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _check_sensitive_word(self, text: str) -> str:
|
||||
for word in self.sensitive_words:
|
||||
if word in text:
|
||||
return self.canned_response
|
||||
return text
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
text = inputs[self.input_key]
|
||||
output = self._check_sensitive_word(text)
|
||||
return {self.output_key: output}
|
||||
42
api/core/chain/tool_chain.py
Normal file
42
api/core/chain/tool_chain.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class ToolChain(Chain):
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
tool: BaseTool
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "tool_chain"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
input = inputs[self.input_key]
|
||||
output = self.tool.run(input, self.verbose)
|
||||
return {self.output_key: output}
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
input = inputs[self.input_key]
|
||||
output = await self.tool.arun(input, self.verbose)
|
||||
return {self.output_key: output}
|
||||
326
api/core/completion.py
Normal file
326
api/core/completion.py
Normal file
@@ -0,0 +1,326 @@
|
||||
from typing import Optional, List, Union
|
||||
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms import BaseLLM
|
||||
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
|
||||
from core.constant import llm_constant
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
|
||||
DifyStdOutCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.llm.error import LLMBadRequestError
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.chain.main_chain_builder import MainChainBuilder
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBStringBufferSharedMemory
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import OutLinePromptTemplate
|
||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
from models.model import App, AppModelConfig, Account, Conversation, Message
|
||||
|
||||
|
||||
class Completion:
|
||||
@classmethod
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
cls.validate_query_tokens(app.tenant_id, app_model_config, query)
|
||||
|
||||
memory = None
|
||||
if conversation:
|
||||
# get memory of conversation (read-only)
|
||||
memory = cls.get_memory_from_conversation(
|
||||
tenant_id=app.tenant_id,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation
|
||||
)
|
||||
|
||||
inputs = conversation.inputs
|
||||
|
||||
conversation_message_task = ConversationMessageTask(
|
||||
task_id=task_id,
|
||||
app=app,
|
||||
app_model_config=app_model_config,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
is_override=is_override,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
# build main chain include agent
|
||||
main_chain = MainChainBuilder.to_langchain_components(
|
||||
tenant_id=app.tenant_id,
|
||||
agent_mode=app_model_config.agent_mode_dict,
|
||||
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
|
||||
conversation_message_task=conversation_message_task
|
||||
)
|
||||
|
||||
chain_output = ''
|
||||
if main_chain:
|
||||
chain_output = main_chain.run(query)
|
||||
|
||||
# run the final llm
|
||||
try:
|
||||
cls.run_final_llm(
|
||||
tenant_id=app.tenant_id,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=chain_output,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
streaming=streaming
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
chain_output: str,
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
|
||||
final_llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=tenant_id,
|
||||
model=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
# get llm prompt
|
||||
prompt = cls.get_main_llm_prompt(
|
||||
mode=mode,
|
||||
llm=final_llm,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
chain_output=chain_output,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=final_llm,
|
||||
prompt=prompt,
|
||||
mode=mode
|
||||
)
|
||||
|
||||
response = final_llm.generate([prompt])
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str],
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Union[str | List[BaseMessage]]:
|
||||
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
|
||||
if mode == 'completion':
|
||||
prompt_template = OutLinePromptTemplate.from_template(
|
||||
template=("Use the following pieces of [CONTEXT] to answer the question at the end. "
|
||||
"If you don't know the answer, "
|
||||
"just say that you don't know, don't try to make up an answer. \n"
|
||||
"```\n"
|
||||
"[CONTEXT]\n"
|
||||
"{context}\n"
|
||||
"```\n" if chain_output else "")
|
||||
+ (pre_prompt + "\n" if pre_prompt else "")
|
||||
+ "{query}\n"
|
||||
)
|
||||
|
||||
if chain_output:
|
||||
inputs['context'] = chain_output
|
||||
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
prompt_content = prompt_template.format(
|
||||
query=query,
|
||||
**prompt_inputs
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
# use chat llm as completion model
|
||||
return [HumanMessage(content=prompt_content)]
|
||||
else:
|
||||
return prompt_content
|
||||
else:
|
||||
messages: List[BaseMessage] = []
|
||||
|
||||
system_message = None
|
||||
if pre_prompt:
|
||||
# append pre prompt as system message
|
||||
system_message = PromptBuilder.to_system_message(pre_prompt, inputs)
|
||||
|
||||
if chain_output:
|
||||
# append context as system message, currently only use simple stuff prompt
|
||||
context_message = PromptBuilder.to_system_message(
|
||||
"""Use the following pieces of [CONTEXT] to answer the users question.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
```
|
||||
[CONTEXT]
|
||||
{context}
|
||||
```""",
|
||||
{'context': chain_output}
|
||||
)
|
||||
|
||||
if not system_message:
|
||||
system_message = context_message
|
||||
else:
|
||||
system_message.content = context_message.content + "\n\n" + system_message.content
|
||||
|
||||
if system_message:
|
||||
messages.append(system_message)
|
||||
|
||||
human_inputs = {
|
||||
"query": query
|
||||
}
|
||||
|
||||
# construct main prompt
|
||||
human_message = PromptBuilder.to_human_message(
|
||||
prompt_content="{query}",
|
||||
inputs=human_inputs
|
||||
)
|
||||
|
||||
if memory:
|
||||
# append chat histories
|
||||
tmp_messages = messages.copy() + [human_message]
|
||||
curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
|
||||
rest_tokens = llm_constant.max_context_token_length[
|
||||
memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
messages += history_messages
|
||||
|
||||
messages.append(human_message)
|
||||
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager:
|
||||
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
||||
if streaming:
|
||||
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
|
||||
else:
|
||||
callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
|
||||
|
||||
return CallbackManager(callback_handlers)
|
||||
|
||||
@classmethod
|
||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||
max_token_limit: int) -> \
|
||||
List[BaseMessage]:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory_key = memory.memory_variables[0]
|
||||
external_context = memory.load_memory_variables({})
|
||||
return external_context[memory_key]
|
||||
|
||||
@classmethod
|
||||
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
|
||||
conversation: Conversation,
|
||||
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
|
||||
# only for calc token in memory
|
||||
memory_llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=tenant_id,
|
||||
model=app_model_config.model_dict
|
||||
)
|
||||
|
||||
# use llm config from conversation
|
||||
memory = ReadOnlyConversationTokenDBBufferSharedMemory(
|
||||
conversation=conversation,
|
||||
llm=memory_llm,
|
||||
max_token_limit=kwargs.get("max_token_limit", 2048),
|
||||
memory_key=kwargs.get("memory_key", "chat_history"),
|
||||
return_messages=kwargs.get("return_messages", True),
|
||||
input_key=kwargs.get("input_key", "input"),
|
||||
output_key=kwargs.get("output_key", "output"),
|
||||
message_limit=kwargs.get("message_limit", 10),
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
@classmethod
|
||||
def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
|
||||
llm = LLMBuilder.to_llm_from_model(
|
||||
tenant_id=tenant_id,
|
||||
model=app_model_config.model_dict
|
||||
)
|
||||
|
||||
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
|
||||
max_tokens = llm.max_tokens
|
||||
|
||||
if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
|
||||
raise LLMBadRequestError("Query is too long")
|
||||
|
||||
@classmethod
|
||||
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
prompt: Union[str, List[BaseMessage]], mode: str):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
|
||||
max_tokens = final_llm.max_tokens
|
||||
|
||||
if mode == 'completion' and isinstance(final_llm, BaseLLM):
|
||||
prompt_tokens = final_llm.get_num_tokens(prompt)
|
||||
else:
|
||||
prompt_tokens = final_llm.get_messages_tokens(prompt)
|
||||
|
||||
if prompt_tokens + max_tokens > model_limited_tokens:
|
||||
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
||||
final_llm.max_tokens = max_tokens
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
||||
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=app.tenant_id,
|
||||
model_name='gpt-3.5-turbo',
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
# get llm prompt
|
||||
original_prompt = cls.get_main_llm_prompt(
|
||||
mode="completion",
|
||||
llm=llm,
|
||||
pre_prompt=pre_prompt,
|
||||
query=message.query,
|
||||
inputs=message.inputs,
|
||||
chain_output=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
original_completion = message.answer.strip()
|
||||
|
||||
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [HumanMessage(content=prompt)]
|
||||
|
||||
conversation_message_task = ConversationMessageTask(
|
||||
task_id=task_id,
|
||||
app=app,
|
||||
app_model_config=app_model_config,
|
||||
user=user,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
is_override=True if message.override_model_configs else False,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
final_llm=llm,
|
||||
prompt=prompt,
|
||||
mode='completion'
|
||||
)
|
||||
|
||||
llm.generate([prompt])
|
||||
84
api/core/constant/llm_constant.py
Normal file
84
api/core/constant/llm_constant.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from _decimal import Decimal
|
||||
|
||||
models = {
|
||||
'gpt-4': 'openai', # 8,192 tokens
|
||||
'gpt-4-32k': 'openai', # 32,768 tokens
|
||||
'gpt-3.5-turbo': 'openai', # 4,096 tokens
|
||||
'text-davinci-003': 'openai', # 4,097 tokens
|
||||
'text-davinci-002': 'openai', # 4,097 tokens
|
||||
'text-curie-001': 'openai', # 2,049 tokens
|
||||
'text-babbage-001': 'openai', # 2,049 tokens
|
||||
'text-ada-001': 'openai', # 2,049 tokens
|
||||
'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions
|
||||
}
|
||||
|
||||
max_context_token_length = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'text-davinci-003': 4097,
|
||||
'text-davinci-002': 4097,
|
||||
'text-curie-001': 2049,
|
||||
'text-babbage-001': 2049,
|
||||
'text-ada-001': 2049,
|
||||
'text-embedding-ada-002': 8191
|
||||
}
|
||||
|
||||
models_by_mode = {
|
||||
'chat': [
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
],
|
||||
'completion': [
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'text-davinci-003', # 4,097 tokens
|
||||
'text-davinci-002' # 4,097 tokens
|
||||
'text-curie-001', # 2,049 tokens
|
||||
'text-babbage-001', # 2,049 tokens
|
||||
'text-ada-001' # 2,049 tokens
|
||||
],
|
||||
'embedding': [
|
||||
'text-embedding-ada-002' # 8191 tokens, 1536 dimensions
|
||||
]
|
||||
}
|
||||
|
||||
model_currency = 'USD'
|
||||
|
||||
model_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': Decimal('0.03'),
|
||||
'completion': Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': Decimal('0.06'),
|
||||
'completion': Decimal('0.12')
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'prompt': Decimal('0.002'),
|
||||
'completion': Decimal('0.002')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': Decimal('0.02'),
|
||||
'completion': Decimal('0.02')
|
||||
},
|
||||
'text-curie-001': {
|
||||
'prompt': Decimal('0.002'),
|
||||
'completion': Decimal('0.002')
|
||||
},
|
||||
'text-babbage-001': {
|
||||
'prompt': Decimal('0.0005'),
|
||||
'completion': Decimal('0.0005')
|
||||
},
|
||||
'text-ada-001': {
|
||||
'prompt': Decimal('0.0004'),
|
||||
'completion': Decimal('0.0004')
|
||||
},
|
||||
'text-embedding-ada-002': {
|
||||
'usage': Decimal('0.0004'),
|
||||
}
|
||||
}
|
||||
|
||||
agent_model_name = 'text-davinci-003'
|
||||
388
api/core/conversation_message_task.py
Normal file
388
api/core/conversation_message_task.py
Normal file
@@ -0,0 +1,388 @@
|
||||
import decimal
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
from gunicorn.config import User
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.constant import llm_constant
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import OutLinePromptTemplate
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DatasetQuery
|
||||
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
|
||||
from models.provider import ProviderType, Provider
|
||||
|
||||
|
||||
class ConversationMessageTask:
|
||||
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
||||
inputs: dict, query: str, streaming: bool,
|
||||
conversation: Optional[Conversation] = None, is_override: bool = False):
|
||||
self.task_id = task_id
|
||||
|
||||
self.app = app
|
||||
self.tenant_id = app.tenant_id
|
||||
self.app_model_config = app_model_config
|
||||
self.is_override = is_override
|
||||
|
||||
self.user = user
|
||||
self.inputs = inputs
|
||||
self.query = query
|
||||
self.streaming = streaming
|
||||
|
||||
self.conversation = conversation
|
||||
self.is_new_conversation = False
|
||||
|
||||
self.message = None
|
||||
|
||||
self.model_dict = self.app_model_config.model_dict
|
||||
self.model_name = self.model_dict.get('name')
|
||||
self.mode = app.mode
|
||||
|
||||
self.init()
|
||||
|
||||
self._pub_handler = PubHandler(
|
||||
user=self.user,
|
||||
task_id=self.task_id,
|
||||
message=self.message,
|
||||
conversation=self.conversation,
|
||||
chain_pub=False, # disabled currently
|
||||
agent_thought_pub=False # disabled currently
|
||||
)
|
||||
|
||||
def init(self):
|
||||
override_model_configs = None
|
||||
if self.is_override:
|
||||
override_model_configs = {
|
||||
"model": self.app_model_config.model_dict,
|
||||
"pre_prompt": self.app_model_config.pre_prompt,
|
||||
"agent_mode": self.app_model_config.agent_mode_dict,
|
||||
"opening_statement": self.app_model_config.opening_statement,
|
||||
"suggested_questions": self.app_model_config.suggested_questions_list,
|
||||
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
|
||||
"more_like_this": self.app_model_config.more_like_this_dict,
|
||||
"user_input_form": self.app_model_config.user_input_form_list,
|
||||
}
|
||||
|
||||
introduction = ''
|
||||
system_instruction = ''
|
||||
system_instruction_tokens = 0
|
||||
if self.mode == 'chat':
|
||||
introduction = self.app_model_config.opening_statement
|
||||
if introduction:
|
||||
prompt_template = OutLinePromptTemplate.from_template(template=PromptBuilder.process_template(introduction))
|
||||
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
|
||||
introduction = prompt_template.format(**prompt_inputs)
|
||||
|
||||
if self.app_model_config.pre_prompt:
|
||||
pre_prompt = PromptBuilder.process_template(self.app_model_config.pre_prompt)
|
||||
system_message = PromptBuilder.to_system_message(pre_prompt, self.inputs)
|
||||
system_instruction = system_message.content
|
||||
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
|
||||
system_instruction_tokens = llm.get_messages_tokens([system_message])
|
||||
|
||||
if not self.conversation:
|
||||
self.is_new_conversation = True
|
||||
self.conversation = Conversation(
|
||||
app_id=self.app_model_config.app_id,
|
||||
app_model_config_id=self.app_model_config.id,
|
||||
model_provider=self.model_dict.get('provider'),
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=self.mode,
|
||||
name='',
|
||||
inputs=self.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction=system_instruction,
|
||||
system_instruction_tokens=system_instruction_tokens,
|
||||
status='normal',
|
||||
from_source=('console' if isinstance(self.user, Account) else 'api'),
|
||||
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
|
||||
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
|
||||
)
|
||||
|
||||
db.session.add(self.conversation)
|
||||
db.session.flush()
|
||||
|
||||
self.message = Message(
|
||||
app_id=self.app_model_config.app_id,
|
||||
model_provider=self.model_dict.get('provider'),
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=self.conversation.id,
|
||||
inputs=self.inputs,
|
||||
query=self.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency=llm_constant.model_currency,
|
||||
from_source=('console' if isinstance(self.user, Account) else 'api'),
|
||||
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
|
||||
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
|
||||
agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
|
||||
)
|
||||
|
||||
db.session.add(self.message)
|
||||
db.session.flush()
|
||||
|
||||
def append_message_text(self, text: str):
|
||||
self._pub_handler.pub_text(text)
|
||||
|
||||
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
||||
model_name = self.app_model_config.model_dict.get('name')
|
||||
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
message_unit_price = llm_constant.model_prices[model_name]['prompt']
|
||||
answer_unit_price = llm_constant.model_prices[model_name]['completion']
|
||||
|
||||
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
|
||||
|
||||
self.message.message = llm_message.prompt
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.answer = llm_message.completion.strip() if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
self.message.provider_response_latency = llm_message.latency
|
||||
self.message.total_price = total_price
|
||||
|
||||
self.update_provider_quota()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
message_was_created.send(
|
||||
self.message,
|
||||
conversation=self.conversation,
|
||||
is_first_message=self.is_new_conversation
|
||||
)
|
||||
|
||||
if not by_stopped:
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
def update_provider_quota(self):
|
||||
llm_provider_service = LLMProviderService(
|
||||
tenant_id=self.app.tenant_id,
|
||||
provider_name=self.message.model_provider,
|
||||
)
|
||||
|
||||
provider = llm_provider_service.get_provider_db_record()
|
||||
if provider and provider.provider_type == ProviderType.SYSTEM.value:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.app.tenant_id,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + 1})
|
||||
|
||||
def init_chain(self, chain_result: ChainResult):
|
||||
message_chain = MessageChain(
|
||||
message_id=self.message.id,
|
||||
type=chain_result.type,
|
||||
input=json.dumps(chain_result.prompt),
|
||||
output=''
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.flush()
|
||||
|
||||
return message_chain
|
||||
|
||||
def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
|
||||
message_chain.output = json.dumps(chain_result.completion)
|
||||
|
||||
self._pub_handler.pub_chain(message_chain)
|
||||
|
||||
def on_agent_end(self, message_chain: MessageChain, agent_model_name: str,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
|
||||
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_total_price = self.calc_total_price(
|
||||
loop_message_tokens,
|
||||
agent_message_unit_price,
|
||||
loop_answer_tokens,
|
||||
agent_answer_unit_price
|
||||
)
|
||||
|
||||
message_agent_loop = MessageAgentThought(
|
||||
message_id=self.message.id,
|
||||
message_chain_id=message_chain.id,
|
||||
position=agent_loop.position,
|
||||
thought=agent_loop.thought,
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
observation=agent_loop.tool_output,
|
||||
tool_process_data='', # currently not support
|
||||
message=agent_loop.prompt,
|
||||
message_token=loop_message_tokens,
|
||||
message_unit_price=agent_message_unit_price,
|
||||
answer=agent_loop.completion,
|
||||
answer_token=loop_answer_tokens,
|
||||
answer_unit_price=agent_answer_unit_price,
|
||||
latency=agent_loop.latency,
|
||||
tokens=agent_loop.prompt_tokens + agent_loop.completion_tokens,
|
||||
total_price=loop_total_price,
|
||||
currency=llm_constant.model_currency,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(message_agent_loop)
|
||||
db.session.flush()
|
||||
|
||||
self._pub_handler.pub_agent_thought(message_agent_loop)
|
||||
|
||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_query_obj.dataset_id,
|
||||
content=dataset_query_obj.query,
|
||||
source='app',
|
||||
source_app_id=self.app.id,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
|
||||
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
|
||||
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
|
||||
class PubHandler:
|
||||
def __init__(self, user: Union[Account | User], task_id: str,
|
||||
message: Message, conversation: Conversation,
|
||||
chain_pub: bool = False, agent_thought_pub: bool = False):
|
||||
self._channel = PubHandler.generate_channel_name(user, task_id)
|
||||
self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
|
||||
|
||||
self._task_id = task_id
|
||||
self._message = message
|
||||
self._conversation = conversation
|
||||
self._chain_pub = chain_pub
|
||||
self._agent_thought_pub = agent_thought_pub
|
||||
|
||||
@classmethod
|
||||
def generate_channel_name(cls, user: Union[Account | User], task_id: str):
|
||||
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
|
||||
return "generate_result:{}-{}".format(user_str, task_id)
|
||||
|
||||
@classmethod
|
||||
def generate_stopped_cache_key(cls, user: Union[Account | User], task_id: str):
|
||||
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
|
||||
return "generate_result_stopped:{}-{}".format(user_str, task_id)
|
||||
|
||||
def pub_text(self, text: str):
|
||||
content = {
|
||||
'event': 'message',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'text': text,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_chain(self, message_chain: MessageChain):
|
||||
if self._chain_pub:
|
||||
content = {
|
||||
'event': 'chain',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'chain_id': message_chain.id,
|
||||
'type': message_chain.type,
|
||||
'input': json.loads(message_chain.input),
|
||||
'output': json.loads(message_chain.output),
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
|
||||
if self._agent_thought_pub:
|
||||
content = {
|
||||
'event': 'agent_thought',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'chain_id': message_agent_thought.message_chain_id,
|
||||
'agent_thought_id': message_agent_thought.id,
|
||||
'position': message_agent_thought.position,
|
||||
'thought': message_agent_thought.thought,
|
||||
'tool': message_agent_thought.tool,
|
||||
'tool_input': message_agent_thought.tool_input,
|
||||
'observation': message_agent_thought.observation,
|
||||
'answer': message_agent_thought.answer,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
|
||||
def pub_end(self):
|
||||
content = {
|
||||
'event': 'end',
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
@classmethod
|
||||
def pub_error(cls, user: Union[Account | User], task_id: str, e):
|
||||
content = {
|
||||
'error': type(e).__name__,
|
||||
'description': e.description if getattr(e, 'description', None) is not None else str(e)
|
||||
}
|
||||
|
||||
channel = cls.generate_channel_name(user, task_id)
|
||||
redis_client.publish(channel, json.dumps(content))
|
||||
|
||||
def _is_stopped(self):
|
||||
return redis_client.get(self._stopped_cache_key) is not None
|
||||
|
||||
@classmethod
|
||||
def stop(cls, user: Union[Account | User], task_id: str):
|
||||
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
|
||||
class ConversationTaskStoppedException(Exception):
|
||||
pass
|
||||
190
api/core/docstore/dataset_docstore.py
Normal file
190
api/core/docstore/dataset_docstore.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
import tiktoken
|
||||
from llama_index.data_structs import Node
|
||||
from llama_index.docstore.types import BaseDocumentStore
|
||||
from llama_index.docstore.utils import json_to_doc
|
||||
from llama_index.schema import BaseDocument
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class DatesetDocumentStore(BaseDocumentStore):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
embedding_model_name: str,
|
||||
document_id: Optional[str] = None,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._user_id = user_id
|
||||
self._embedding_model_name = embedding_model_name
|
||||
self._document_id = document_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatesetDocumentStore":
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {
|
||||
"dataset_id": self._dataset.id,
|
||||
}
|
||||
|
||||
@property
|
||||
def dateset_id(self) -> Any:
|
||||
return self._dataset.id
|
||||
|
||||
@property
|
||||
def user_id(self) -> Any:
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def embedding_model_name(self) -> Any:
|
||||
return self._embedding_model_name
|
||||
|
||||
@property
|
||||
def docs(self) -> Dict[str, BaseDocument]:
|
||||
document_segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id
|
||||
).all()
|
||||
|
||||
output = {}
|
||||
for document_segment in document_segments:
|
||||
doc_id = document_segment.index_node_id
|
||||
result = self.segment_to_dict(document_segment)
|
||||
output[doc_id] = json_to_doc(result)
|
||||
|
||||
return output
|
||||
|
||||
def add_documents(
|
||||
self, docs: Sequence[BaseDocument], allow_update: bool = True
|
||||
) -> None:
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document == self._document_id
|
||||
).scalar()
|
||||
|
||||
if max_position is None:
|
||||
max_position = 0
|
||||
|
||||
for doc in docs:
|
||||
if doc.is_doc_id_none:
|
||||
raise ValueError("doc_id not set")
|
||||
|
||||
if not isinstance(doc, Node):
|
||||
raise ValueError("doc must be a Node")
|
||||
|
||||
segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
|
||||
|
||||
# NOTE: doc could already exist in the store, but we overwrite it
|
||||
if not allow_update and segment_document:
|
||||
raise ValueError(
|
||||
f"doc_id {doc.get_doc_id()} already exists. "
|
||||
"Set allow_update to True to overwrite."
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text())
|
||||
|
||||
if not segment_document:
|
||||
max_position += 1
|
||||
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
document_id=self._document_id,
|
||||
index_node_id=doc.get_doc_id(),
|
||||
index_node_hash=doc.get_doc_hash(),
|
||||
position=max_position,
|
||||
content=doc.get_text(),
|
||||
word_count=len(doc.get_text()),
|
||||
tokens=tokens,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(segment_document)
|
||||
else:
|
||||
segment_document.content = doc.get_text()
|
||||
segment_document.index_node_hash = doc.get_doc_hash()
|
||||
segment_document.word_count = len(doc.get_text())
|
||||
segment_document.tokens = tokens
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def document_exists(self, doc_id: str) -> bool:
|
||||
"""Check if document exists."""
|
||||
result = self.get_document_segment(doc_id)
|
||||
return result is not None
|
||||
|
||||
def get_document(
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
) -> Optional[BaseDocument]:
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
if raise_error:
|
||||
raise ValueError(f"doc_id {doc_id} not found.")
|
||||
else:
|
||||
return None
|
||||
|
||||
result = self.segment_to_dict(document_segment)
|
||||
return json_to_doc(result)
|
||||
|
||||
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
if raise_error:
|
||||
raise ValueError(f"doc_id {doc_id} not found.")
|
||||
else:
|
||||
return None
|
||||
|
||||
db.session.delete(document_segment)
|
||||
db.session.commit()
|
||||
|
||||
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
|
||||
"""Set the hash for a given doc_id."""
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
return None
|
||||
|
||||
document_segment.index_node_hash = doc_hash
|
||||
db.session.commit()
|
||||
|
||||
def get_document_hash(self, doc_id: str) -> Optional[str]:
|
||||
"""Get the stored hash for a document, if it exists."""
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
return None
|
||||
|
||||
return document_segment.index_node_hash
|
||||
|
||||
def update_docstore(self, other: "BaseDocumentStore") -> None:
|
||||
"""Update docstore.
|
||||
|
||||
Args:
|
||||
other (BaseDocumentStore): docstore to update from
|
||||
|
||||
"""
|
||||
self.add_documents(list(other.docs.values()))
|
||||
|
||||
def get_document_segment(self, doc_id: str) -> DocumentSegment:
|
||||
document_segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id,
|
||||
DocumentSegment.index_node_id == doc_id
|
||||
).first()
|
||||
|
||||
return document_segment
|
||||
|
||||
def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
|
||||
return {
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"text": segment.content,
|
||||
"__type__": Node.get_type()
|
||||
}
|
||||
51
api/core/docstore/empty_docstore.py
Normal file
51
api/core/docstore/empty_docstore.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
from llama_index.docstore.types import BaseDocumentStore
|
||||
from llama_index.schema import BaseDocument
|
||||
|
||||
|
||||
class EmptyDocumentStore(BaseDocumentStore):
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
|
||||
return cls()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def docs(self) -> Dict[str, BaseDocument]:
|
||||
return {}
|
||||
|
||||
def add_documents(
|
||||
self, docs: Sequence[BaseDocument], allow_update: bool = True
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def document_exists(self, doc_id: str) -> bool:
|
||||
"""Check if document exists."""
|
||||
return False
|
||||
|
||||
def get_document(
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
) -> Optional[BaseDocument]:
|
||||
return None
|
||||
|
||||
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
|
||||
pass
|
||||
|
||||
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
|
||||
"""Set the hash for a given doc_id."""
|
||||
pass
|
||||
|
||||
def get_document_hash(self, doc_id: str) -> Optional[str]:
|
||||
"""Get the stored hash for a document, if it exists."""
|
||||
return None
|
||||
|
||||
def update_docstore(self, other: "BaseDocumentStore") -> None:
|
||||
"""Update docstore.
|
||||
|
||||
Args:
|
||||
other (BaseDocumentStore): docstore to update from
|
||||
|
||||
"""
|
||||
self.add_documents(list(other.docs.values()))
|
||||
176
api/core/embedding/openai_embedding.py
Normal file
176
api/core/embedding/openai_embedding.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from typing import Optional, Any, List
|
||||
|
||||
import openai
|
||||
from llama_index.embeddings.base import BaseEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
|
||||
_TEXT_MODE_MODEL_DICT
|
||||
from tenacity import wait_random_exponential, retry, stop_after_attempt
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
def get_embedding(
|
||||
text: str,
|
||||
engine: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None,
|
||||
) -> List[float]:
|
||||
"""Get embedding.
|
||||
|
||||
NOTE: Copied from OpenAI's embedding utils:
|
||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
||||
|
||||
Copied here to avoid importing unnecessary dependencies
|
||||
like matplotlib, plotly, scipy, sklearn.
|
||||
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"]
|
||||
|
||||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]:
|
||||
"""Asynchronously get embedding.
|
||||
|
||||
NOTE: Copied from OpenAI's embedding utils:
|
||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
||||
|
||||
Copied here to avoid importing unnecessary dependencies
|
||||
like matplotlib, plotly, scipy, sklearn.
|
||||
|
||||
"""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
|
||||
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][
|
||||
"embedding"
|
||||
]
|
||||
|
||||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
def get_embeddings(
|
||||
list_of_text: List[str],
|
||||
engine: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
"""Get embeddings.
|
||||
|
||||
NOTE: Copied from OpenAI's embedding utils:
|
||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
||||
|
||||
Copied here to avoid importing unnecessary dependencies
|
||||
like matplotlib, plotly, scipy, sklearn.
|
||||
|
||||
"""
|
||||
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
||||
|
||||
# replace newlines, which can negatively affect performance.
|
||||
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
||||
|
||||
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data
|
||||
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
||||
return [d["embedding"] for d in data]
|
||||
|
||||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
async def aget_embeddings(
|
||||
list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
"""Asynchronously get embeddings.
|
||||
|
||||
NOTE: Copied from OpenAI's embedding utils:
|
||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
||||
|
||||
Copied here to avoid importing unnecessary dependencies
|
||||
like matplotlib, plotly, scipy, sklearn.
|
||||
|
||||
"""
|
||||
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
||||
|
||||
# replace newlines, which can negatively affect performance.
|
||||
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
||||
|
||||
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data
|
||||
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
||||
return [d["embedding"] for d in data]
|
||||
|
||||
|
||||
class OpenAIEmbedding(BaseEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
|
||||
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
|
||||
deployment_name: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
super().__init__(**kwargs)
|
||||
self.mode = OpenAIEmbeddingMode(mode)
|
||||
self.model = OpenAIEmbeddingModelType(model)
|
||||
self.deployment_name = deployment_name
|
||||
self.openai_api_key = openai_api_key
|
||||
|
||||
@handle_llm_exceptions
|
||||
def _get_query_embedding(self, query: str) -> List[float]:
|
||||
"""Get query embedding."""
|
||||
if self.deployment_name is not None:
|
||||
engine = self.deployment_name
|
||||
else:
|
||||
key = (self.mode, self.model)
|
||||
if key not in _QUERY_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _QUERY_MODE_MODEL_DICT[key]
|
||||
return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key)
|
||||
|
||||
def _get_text_embedding(self, text: str) -> List[float]:
|
||||
"""Get text embedding."""
|
||||
if self.deployment_name is not None:
|
||||
engine = self.deployment_name
|
||||
else:
|
||||
key = (self.mode, self.model)
|
||||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
|
||||
|
||||
async def _aget_text_embedding(self, text: str) -> List[float]:
|
||||
"""Asynchronously get text embedding."""
|
||||
if self.deployment_name is not None:
|
||||
engine = self.deployment_name
|
||||
else:
|
||||
key = (self.mode, self.model)
|
||||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
|
||||
|
||||
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get text embeddings.
|
||||
|
||||
By default, this is a wrapper around _get_text_embedding.
|
||||
Can be overriden for batch queries.
|
||||
|
||||
"""
|
||||
if self.deployment_name is not None:
|
||||
engine = self.deployment_name
|
||||
else:
|
||||
key = (self.mode, self.model)
|
||||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
|
||||
return embeddings
|
||||
|
||||
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronously get text embeddings."""
|
||||
if self.deployment_name is not None:
|
||||
engine = self.deployment_name
|
||||
else:
|
||||
key = (self.mode, self.model)
|
||||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
|
||||
return embeddings
|
||||
120
api/core/generator/llm_generator.py
Normal file
120
api/core/generator/llm_generator.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import logging
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
|
||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.prompt.prompt_template import OutLinePromptTemplate
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
|
||||
|
||||
|
||||
# gpt-3.5-turbo works not well
|
||||
generate_base_model = 'text-davinci-003'
|
||||
|
||||
|
||||
class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_conversation_name(cls, tenant_id: str, query, answer):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
prompt = prompt.format(query=query, answer=answer)
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
max_tokens=50
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [HumanMessage(content=prompt)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_conversation_summary(cls, tenant_id: str, messages):
|
||||
max_tokens = 200
|
||||
|
||||
prompt = CONVERSATION_SUMMARY_PROMPT
|
||||
prompt_with_empty_context = prompt.format(context='')
|
||||
prompt_tokens = TokenCalculator.get_num_tokens(generate_base_model, prompt_with_empty_context)
|
||||
rest_tokens = llm_constant.max_context_token_length[generate_base_model] - prompt_tokens - max_tokens
|
||||
|
||||
context = ''
|
||||
for message in messages:
|
||||
if not message.answer:
|
||||
continue
|
||||
|
||||
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
|
||||
if rest_tokens - TokenCalculator.get_num_tokens(generate_base_model, context + message_qa_text) > 0:
|
||||
context += message_qa_text
|
||||
|
||||
prompt = prompt.format(context=context)
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [HumanMessage(content=prompt)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
|
||||
prompt = INTRODUCTION_GENERATE_PROMPT
|
||||
prompt = prompt.format(prompt=pre_prompt)
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
prompt = [HumanMessage(content=prompt)]
|
||||
|
||||
response = llm.generate([prompt])
|
||||
answer = response.generations[0][0].text
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
prompt = OutLinePromptTemplate(
|
||||
template="{histories}\n{format_instructions}\nquestions:\n",
|
||||
input_variables=["histories"],
|
||||
partial_variables={"format_instructions": format_instructions}
|
||||
)
|
||||
|
||||
_input = prompt.format_prompt(histories=histories)
|
||||
|
||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=generate_base_model,
|
||||
temperature=0,
|
||||
max_tokens=256
|
||||
)
|
||||
|
||||
if isinstance(llm, BaseChatModel):
|
||||
query = [HumanMessage(content=_input.to_string())]
|
||||
else:
|
||||
query = _input.to_string()
|
||||
|
||||
try:
|
||||
output = llm(query)
|
||||
questions = output_parser.parse(output)
|
||||
except Exception:
|
||||
logging.exception("Error generating suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
return questions
|
||||
45
api/core/index/index_builder.py
Normal file
45
api/core/index/index_builder.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from langchain.callbacks import CallbackManager
|
||||
from llama_index import ServiceContext, PromptHelper, LLMPredictor
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
@classmethod
|
||||
def get_default_service_context(cls, tenant_id: str) -> ServiceContext:
|
||||
# set number of output tokens
|
||||
num_output = 512
|
||||
|
||||
# only for verbose
|
||||
callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
|
||||
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='text-davinci-003',
|
||||
temperature=0,
|
||||
max_tokens=num_output,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
|
||||
llm_predictor = LLMPredictor(llm=llm)
|
||||
|
||||
# These parameters here will affect the logic of segmenting the final synthesized response.
|
||||
# The number of refinement iterations in the synthesis process depends
|
||||
# on whether the length of the segmented output exceeds the max_input_size.
|
||||
prompt_helper = PromptHelper(
|
||||
max_input_size=3500,
|
||||
num_output=num_output,
|
||||
max_chunk_overlap=20
|
||||
)
|
||||
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
return ServiceContext.from_defaults(
|
||||
llm_predictor=llm_predictor,
|
||||
prompt_helper=prompt_helper,
|
||||
embed_model=OpenAIEmbedding(**model_credentials),
|
||||
)
|
||||
159
api/core/index/keyword_table/jieba_keyword_table.py
Normal file
159
api/core/index/keyword_table/jieba_keyword_table.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import re
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Set,
|
||||
Optional
|
||||
)
|
||||
|
||||
import jieba.analyse
|
||||
|
||||
from core.index.keyword_table.stopwords import STOPWORDS
|
||||
from llama_index.indices.query.base import IS
|
||||
from llama_index import QueryMode
|
||||
from llama_index.indices.base import QueryMap
|
||||
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
|
||||
from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery
|
||||
from llama_index.docstore import BaseDocumentStore
|
||||
from llama_index.indices.postprocessor.node import (
|
||||
BaseNodePostprocessor,
|
||||
)
|
||||
from llama_index.indices.response.response_builder import ResponseMode
|
||||
from llama_index.indices.service_context import ServiceContext
|
||||
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
|
||||
from llama_index.prompts.prompts import (
|
||||
QuestionAnswerPrompt,
|
||||
RefinePrompt,
|
||||
SimpleInputPrompt,
|
||||
)
|
||||
|
||||
from core.index.query.synthesizer import EnhanceResponseSynthesizer
|
||||
|
||||
|
||||
def jieba_extract_keywords(
|
||||
text_chunk: str,
|
||||
max_keywords: Optional[int] = None,
|
||||
expand_with_subtokens: bool = True,
|
||||
) -> Set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text_chunk,
|
||||
topK=max_keywords,
|
||||
)
|
||||
|
||||
if expand_with_subtokens:
|
||||
return set(expand_tokens_with_subtokens(keywords))
|
||||
else:
|
||||
return set(keywords)
|
||||
|
||||
|
||||
def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]:
|
||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||
results = set()
|
||||
for token in tokens:
|
||||
results.add(token)
|
||||
sub_tokens = re.findall(r"\w+", token)
|
||||
if len(sub_tokens) > 1:
|
||||
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex):
|
||||
"""GPT JIEBA Keyword Table Index.
|
||||
|
||||
This index uses a JIEBA keyword extractor to extract keywords from the text.
|
||||
|
||||
"""
|
||||
|
||||
def _extract_keywords(self, text: str) -> Set[str]:
|
||||
"""Extract keywords from text."""
|
||||
return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk)
|
||||
|
||||
@classmethod
|
||||
def get_query_map(self) -> QueryMap:
|
||||
"""Get query map."""
|
||||
super_map = super().get_query_map()
|
||||
super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery
|
||||
return super_map
|
||||
|
||||
def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
|
||||
"""Delete a document."""
|
||||
# get set of ids that correspond to node
|
||||
node_idxs_to_delete = {doc_id}
|
||||
|
||||
# delete node_idxs from keyword to node idxs mapping
|
||||
keywords_to_delete = set()
|
||||
for keyword, node_idxs in self._index_struct.table.items():
|
||||
if node_idxs_to_delete.intersection(node_idxs):
|
||||
self._index_struct.table[keyword] = node_idxs.difference(
|
||||
node_idxs_to_delete
|
||||
)
|
||||
if not self._index_struct.table[keyword]:
|
||||
keywords_to_delete.add(keyword)
|
||||
|
||||
for keyword in keywords_to_delete:
|
||||
del self._index_struct.table[keyword]
|
||||
|
||||
|
||||
class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery):
|
||||
"""GPT Keyword Table Index JIEBA Query.
|
||||
|
||||
Extracts keywords using JIEBA keyword extractor.
|
||||
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
response = index.query("<query_str>", mode="jieba")
|
||||
|
||||
See BaseGPTKeywordTableQuery for arguments.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_args(
|
||||
cls,
|
||||
index_struct: IS,
|
||||
service_context: ServiceContext,
|
||||
docstore: Optional[BaseDocumentStore] = None,
|
||||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
||||
verbose: bool = False,
|
||||
# response synthesizer args
|
||||
response_mode: ResponseMode = ResponseMode.DEFAULT,
|
||||
text_qa_template: Optional[QuestionAnswerPrompt] = None,
|
||||
refine_template: Optional[RefinePrompt] = None,
|
||||
simple_template: Optional[SimpleInputPrompt] = None,
|
||||
response_kwargs: Optional[Dict] = None,
|
||||
use_async: bool = False,
|
||||
streaming: bool = False,
|
||||
optimizer: Optional[BaseTokenUsageOptimizer] = None,
|
||||
# class-specific args
|
||||
**kwargs: Any,
|
||||
) -> "BaseGPTIndexQuery":
|
||||
response_synthesizer = EnhanceResponseSynthesizer.from_args(
|
||||
service_context=service_context,
|
||||
text_qa_template=text_qa_template,
|
||||
refine_template=refine_template,
|
||||
simple_template=simple_template,
|
||||
response_mode=response_mode,
|
||||
response_kwargs=response_kwargs,
|
||||
use_async=use_async,
|
||||
streaming=streaming,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
return cls(
|
||||
index_struct=index_struct,
|
||||
service_context=service_context,
|
||||
response_synthesizer=response_synthesizer,
|
||||
docstore=docstore,
|
||||
node_postprocessors=node_postprocessors,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_keywords(self, query_str: str) -> List[str]:
|
||||
"""Extract keywords."""
|
||||
return list(
|
||||
jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
|
||||
)
|
||||
90
api/core/index/keyword_table/stopwords.py
Normal file
90
api/core/index/keyword_table/stopwords.py
Normal file
@@ -0,0 +1,90 @@
|
||||
STOPWORDS = {
|
||||
"during", "when", "but", "then", "further", "isn", "mustn't", "until", "own", "i", "couldn", "y", "only", "you've",
|
||||
"ours", "who", "where", "ourselves", "has", "to", "was", "didn't", "themselves", "if", "against", "through", "her",
|
||||
"an", "your", "can", "those", "didn", "about", "aren't", "shan't", "be", "not", "these", "again", "so", "t",
|
||||
"theirs", "weren", "won't", "won", "itself", "just", "same", "while", "why", "doesn", "aren", "him", "haven",
|
||||
"for", "you'll", "that", "we", "am", "d", "by", "having", "wasn't", "than", "weren't", "out", "from", "now",
|
||||
"their", "too", "hadn", "o", "needn", "most", "it", "under", "needn't", "any", "some", "few", "ll", "hers", "which",
|
||||
"m", "you're", "off", "other", "had", "she", "you'd", "do", "you", "does", "s", "will", "each", "wouldn't", "hasn't",
|
||||
"such", "more", "whom", "she's", "my", "yours", "yourself", "of", "on", "very", "hadn't", "with", "yourselves",
|
||||
"been", "ma", "them", "mightn't", "shan", "mustn", "they", "what", "both", "that'll", "how", "is", "he", "because",
|
||||
"down", "haven't", "are", "no", "it's", "our", "being", "the", "or", "above", "myself", "once", "don't", "doesn't",
|
||||
"as", "nor", "here", "herself", "hasn", "mightn", "have", "its", "all", "were", "ain", "this", "at", "after",
|
||||
"over", "shouldn't", "into", "before", "don", "wouldn", "re", "couldn't", "wasn", "in", "should", "there",
|
||||
"himself", "isn't", "should've", "doing", "ve", "shouldn", "a", "did", "and", "his", "between", "me", "up", "below",
|
||||
"人民", "末##末", "啊", "阿", "哎", "哎呀", "哎哟", "唉", "俺", "俺们", "按", "按照", "吧", "吧哒", "把", "罢了", "被", "本",
|
||||
"本着", "比", "比方", "比如", "鄙人", "彼", "彼此", "边", "别", "别的", "别说", "并", "并且", "不比", "不成", "不单", "不但",
|
||||
"不独", "不管", "不光", "不过", "不仅", "不拘", "不论", "不怕", "不然", "不如", "不特", "不惟", "不问", "不只", "朝", "朝着",
|
||||
"趁", "趁着", "乘", "冲", "除", "除此之外", "除非", "除了", "此", "此间", "此外", "从", "从而", "打", "待", "但", "但是", "当",
|
||||
"当着", "到", "得", "的", "的话", "等", "等等", "地", "第", "叮咚", "对", "对于", "多", "多少", "而", "而况", "而且", "而是",
|
||||
"而外", "而言", "而已", "尔后", "反过来", "反过来说", "反之", "非但", "非徒", "否则", "嘎", "嘎登", "该", "赶", "个", "各",
|
||||
"各个", "各位", "各种", "各自", "给", "根据", "跟", "故", "故此", "固然", "关于", "管", "归", "果然", "果真", "过", "哈",
|
||||
"哈哈", "呵", "和", "何", "何处", "何况", "何时", "嘿", "哼", "哼唷", "呼哧", "乎", "哗", "还是", "还有", "换句话说", "换言之",
|
||||
"或", "或是", "或者", "极了", "及", "及其", "及至", "即", "即便", "即或", "即令", "即若", "即使", "几", "几时", "己", "既",
|
||||
"既然", "既是", "继而", "加之", "假如", "假若", "假使", "鉴于", "将", "较", "较之", "叫", "接着", "结果", "借", "紧接着",
|
||||
"进而", "尽", "尽管", "经", "经过", "就", "就是", "就是说", "据", "具体地说", "具体说来", "开始", "开外", "靠", "咳", "可",
|
||||
"可见", "可是", "可以", "况且", "啦", "来", "来着", "离", "例如", "哩", "连", "连同", "两者", "了", "临", "另", "另外",
|
||||
"另一方面", "论", "嘛", "吗", "慢说", "漫说", "冒", "么", "每", "每当", "们", "莫若", "某", "某个", "某些", "拿", "哪",
|
||||
"哪边", "哪儿", "哪个", "哪里", "哪年", "哪怕", "哪天", "哪些", "哪样", "那", "那边", "那儿", "那个", "那会儿", "那里", "那么",
|
||||
"那么些", "那么样", "那时", "那些", "那样", "乃", "乃至", "呢", "能", "你", "你们", "您", "宁", "宁可", "宁肯", "宁愿", "哦",
|
||||
"呕", "啪达", "旁人", "呸", "凭", "凭借", "其", "其次", "其二", "其他", "其它", "其一", "其余", "其中", "起", "起见", "岂但",
|
||||
"恰恰相反", "前后", "前者", "且", "然而", "然后", "然则", "让", "人家", "任", "任何", "任凭", "如", "如此", "如果", "如何",
|
||||
"如其", "如若", "如上所述", "若", "若非", "若是", "啥", "上下", "尚且", "设若", "设使", "甚而", "甚么", "甚至", "省得", "时候",
|
||||
"什么", "什么样", "使得", "是", "是的", "首先", "谁", "谁知", "顺", "顺着", "似的", "虽", "虽然", "虽说", "虽则", "随", "随着",
|
||||
"所", "所以", "他", "他们", "他人", "它", "它们", "她", "她们", "倘", "倘或", "倘然", "倘若", "倘使", "腾", "替", "通过", "同",
|
||||
"同时", "哇", "万一", "往", "望", "为", "为何", "为了", "为什么", "为着", "喂", "嗡嗡", "我", "我们", "呜", "呜呼", "乌乎",
|
||||
"无论", "无宁", "毋宁", "嘻", "吓", "相对而言", "像", "向", "向着", "嘘", "呀", "焉", "沿", "沿着", "要", "要不", "要不然",
|
||||
"要不是", "要么", "要是", "也", "也罢", "也好", "一", "一般", "一旦", "一方面", "一来", "一切", "一样", "一则", "依", "依照",
|
||||
"矣", "以", "以便", "以及", "以免", "以至", "以至于", "以致", "抑或", "因", "因此", "因而", "因为", "哟", "用", "由",
|
||||
"由此可见", "由于", "有", "有的", "有关", "有些", "又", "于", "于是", "于是乎", "与", "与此同时", "与否", "与其", "越是",
|
||||
"云云", "哉", "再说", "再者", "在", "在下", "咱", "咱们", "则", "怎", "怎么", "怎么办", "怎么样", "怎样", "咋", "照", "照着",
|
||||
"者", "这", "这边", "这儿", "这个", "这会儿", "这就是说", "这里", "这么", "这么点儿", "这么些", "这么样", "这时", "这些", "这样",
|
||||
"正如", "吱", "之", "之类", "之所以", "之一", "只是", "只限", "只要", "只有", "至", "至于", "诸位", "着", "着呢", "自", "自从",
|
||||
"自个儿", "自各儿", "自己", "自家", "自身", "综上所述", "总的来看", "总的来说", "总的说来", "总而言之", "总之", "纵", "纵令",
|
||||
"纵然", "纵使", "遵照", "作为", "兮", "呃", "呗", "咚", "咦", "喏", "啐", "喔唷", "嗬", "嗯", "嗳", "~", "!", ".", ":",
|
||||
"\"", "'", "(", ")", "*", "A", "白", "社会主义", "--", "..", ">>", " [", " ]", "", "<", ">", "/", "\\", "|", "-", "_",
|
||||
"+", "=", "&", "^", "%", "#", "@", "`", ";", "$", "(", ")", "——", "—", "¥", "·", "...", "‘", "’", "〉", "〈", "…",
|
||||
" ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "二",
|
||||
"三", "四", "五", "六", "七", "八", "九", "零", ">", "<", "@", "#", "$", "%", "︿", "&", "*", "+", "~", "|", "[",
|
||||
"]", "{", "}", "啊哈", "啊呀", "啊哟", "挨次", "挨个", "挨家挨户", "挨门挨户", "挨门逐户", "挨着", "按理", "按期", "按时",
|
||||
"按说", "暗地里", "暗中", "暗自", "昂然", "八成", "白白", "半", "梆", "保管", "保险", "饱", "背地里", "背靠背", "倍感", "倍加",
|
||||
"本人", "本身", "甭", "比起", "比如说", "比照", "毕竟", "必", "必定", "必将", "必须", "便", "别人", "并非", "并肩", "并没",
|
||||
"并没有", "并排", "并无", "勃然", "不", "不必", "不常", "不大", "不但...而且", "不得", "不得不", "不得了", "不得已", "不迭",
|
||||
"不定", "不对", "不妨", "不管怎样", "不会", "不仅...而且", "不仅仅", "不仅仅是", "不经意", "不可开交", "不可抗拒", "不力", "不了",
|
||||
"不料", "不满", "不免", "不能不", "不起", "不巧", "不然的话", "不日", "不少", "不胜", "不时", "不是", "不同", "不能", "不要",
|
||||
"不外", "不外乎", "不下", "不限", "不消", "不已", "不亦乐乎", "不由得", "不再", "不择手段", "不怎么", "不曾", "不知不觉", "不止",
|
||||
"不止一次", "不至于", "才", "才能", "策略地", "差不多", "差一点", "常", "常常", "常言道", "常言说", "常言说得好", "长此下去",
|
||||
"长话短说", "长期以来", "长线", "敞开儿", "彻夜", "陈年", "趁便", "趁机", "趁热", "趁势", "趁早", "成年", "成年累月", "成心",
|
||||
"乘机", "乘胜", "乘势", "乘隙", "乘虚", "诚然", "迟早", "充分", "充其极", "充其量", "抽冷子", "臭", "初", "出", "出来", "出去",
|
||||
"除此", "除此而外", "除此以外", "除开", "除去", "除却", "除外", "处处", "川流不息", "传", "传说", "传闻", "串行", "纯", "纯粹",
|
||||
"此后", "此中", "次第", "匆匆", "从不", "从此", "从此以后", "从古到今", "从古至今", "从今以后", "从宽", "从来", "从轻", "从速",
|
||||
"从头", "从未", "从无到有", "从小", "从新", "从严", "从优", "从早到晚", "从中", "从重", "凑巧", "粗", "存心", "达旦", "打从",
|
||||
"打开天窗说亮话", "大", "大不了", "大大", "大抵", "大都", "大多", "大凡", "大概", "大家", "大举", "大略", "大面儿上", "大事",
|
||||
"大体", "大体上", "大约", "大张旗鼓", "大致", "呆呆地", "带", "殆", "待到", "单", "单纯", "单单", "但愿", "弹指之间", "当场",
|
||||
"当儿", "当即", "当口儿", "当然", "当庭", "当头", "当下", "当真", "当中", "倒不如", "倒不如说", "倒是", "到处", "到底", "到了儿",
|
||||
"到目前为止", "到头", "到头来", "得起", "得天独厚", "的确", "等到", "叮当", "顶多", "定", "动不动", "动辄", "陡然", "都", "独",
|
||||
"独自", "断然", "顿时", "多次", "多多", "多多少少", "多多益善", "多亏", "多年来", "多年前", "而后", "而论", "而又", "尔等",
|
||||
"二话不说", "二话没说", "反倒", "反倒是", "反而", "反手", "反之亦然", "反之则", "方", "方才", "方能", "放量", "非常", "非得",
|
||||
"分期", "分期分批", "分头", "奋勇", "愤然", "风雨无阻", "逢", "弗", "甫", "嘎嘎", "该当", "概", "赶快", "赶早不赶晚", "敢",
|
||||
"敢情", "敢于", "刚", "刚才", "刚好", "刚巧", "高低", "格外", "隔日", "隔夜", "个人", "各式", "更", "更加", "更进一步", "更为",
|
||||
"公然", "共", "共总", "够瞧的", "姑且", "古来", "故而", "故意", "固", "怪", "怪不得", "惯常", "光", "光是", "归根到底",
|
||||
"归根结底", "过于", "毫不", "毫无", "毫无保留地", "毫无例外", "好在", "何必", "何尝", "何妨", "何苦", "何乐而不为", "何须",
|
||||
"何止", "很", "很多", "很少", "轰然", "后来", "呼啦", "忽地", "忽然", "互", "互相", "哗啦", "话说", "还", "恍然", "会", "豁然",
|
||||
"活", "伙同", "或多或少", "或许", "基本", "基本上", "基于", "极", "极大", "极度", "极端", "极力", "极其", "极为", "急匆匆",
|
||||
"即将", "即刻", "即是说", "几度", "几番", "几乎", "几经", "既...又", "继之", "加上", "加以", "间或", "简而言之", "简言之",
|
||||
"简直", "见", "将才", "将近", "将要", "交口", "较比", "较为", "接连不断", "接下来", "皆可", "截然", "截至", "藉以", "借此",
|
||||
"借以", "届时", "仅", "仅仅", "谨", "进来", "进去", "近", "近几年来", "近来", "近年来", "尽管如此", "尽可能", "尽快", "尽量",
|
||||
"尽然", "尽如人意", "尽心竭力", "尽心尽力", "尽早", "精光", "经常", "竟", "竟然", "究竟", "就此", "就地", "就算", "居然", "局外",
|
||||
"举凡", "据称", "据此", "据实", "据说", "据我所知", "据悉", "具体来说", "决不", "决非", "绝", "绝不", "绝顶", "绝对", "绝非",
|
||||
"均", "喀", "看", "看来", "看起来", "看上去", "看样子", "可好", "可能", "恐怕", "快", "快要", "来不及", "来得及", "来讲",
|
||||
"来看", "拦腰", "牢牢", "老", "老大", "老老实实", "老是", "累次", "累年", "理当", "理该", "理应", "历", "立", "立地", "立刻",
|
||||
"立马", "立时", "联袂", "连连", "连日", "连日来", "连声", "连袂", "临到", "另方面", "另行", "另一个", "路经", "屡", "屡次",
|
||||
"屡次三番", "屡屡", "缕缕", "率尔", "率然", "略", "略加", "略微", "略为", "论说", "马上", "蛮", "满", "没", "没有", "每逢",
|
||||
"每每", "每时每刻", "猛然", "猛然间", "莫", "莫不", "莫非", "莫如", "默默地", "默然", "呐", "那末", "奈", "难道", "难得", "难怪",
|
||||
"难说", "内", "年复一年", "凝神", "偶而", "偶尔", "怕", "砰", "碰巧", "譬如", "偏偏", "乒", "平素", "颇", "迫于", "扑通",
|
||||
"其后", "其实", "奇", "齐", "起初", "起来", "起首", "起头", "起先", "岂", "岂非", "岂止", "迄", "恰逢", "恰好", "恰恰", "恰巧",
|
||||
"恰如", "恰似", "千", "千万", "千万千万", "切", "切不可", "切莫", "切切", "切勿", "窃", "亲口", "亲身", "亲手", "亲眼", "亲自",
|
||||
"顷", "顷刻", "顷刻间", "顷刻之间", "请勿", "穷年累月", "取道", "去", "权时", "全都", "全力", "全年", "全然", "全身心", "然",
|
||||
"人人", "仍", "仍旧", "仍然", "日复一日", "日见", "日渐", "日益", "日臻", "如常", "如此等等", "如次", "如今", "如期", "如前所述",
|
||||
"如上", "如下", "汝", "三番两次", "三番五次", "三天两头", "瑟瑟", "沙沙", "上", "上来", "上去", "一个", "月", "日", "\n"
|
||||
}
|
||||
135
api/core/index/keyword_table_index.py
Normal file
135
api/core/index/keyword_table_index.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding
|
||||
from llama_index.data_structs import KeywordTable, Node
|
||||
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
|
||||
from llama_index.indices.registry import load_index_struct_from_dict
|
||||
|
||||
from core.docstore.dataset_docstore import DatesetDocumentStore
|
||||
from core.docstore.empty_docstore import EmptyDocumentStore
|
||||
from core.index.index_builder import IndexBuilder
|
||||
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
|
||||
|
||||
|
||||
class KeywordTableIndex:
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self._dataset = dataset
|
||||
|
||||
def add_nodes(self, nodes: List[Node]):
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_name='fake'
|
||||
)
|
||||
|
||||
service_context = ServiceContext.from_defaults(
|
||||
llm_predictor=LLMPredictor(llm=llm),
|
||||
embed_model=OpenAIEmbedding()
|
||||
)
|
||||
|
||||
dataset_keyword_table = self.get_keyword_table()
|
||||
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
|
||||
index_struct = KeywordTable()
|
||||
else:
|
||||
index_struct_dict = dataset_keyword_table.keyword_table_dict
|
||||
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
|
||||
|
||||
# create index
|
||||
index = GPTJIEBAKeywordTableIndex(
|
||||
index_struct=index_struct,
|
||||
docstore=EmptyDocumentStore(),
|
||||
service_context=service_context
|
||||
)
|
||||
|
||||
for node in nodes:
|
||||
keywords = index._extract_keywords(node.get_text())
|
||||
self.update_segment_keywords(node.doc_id, list(keywords))
|
||||
index._index_struct.add_node(list(keywords), node)
|
||||
|
||||
index_struct_dict = index.index_struct.to_dict()
|
||||
|
||||
if not dataset_keyword_table:
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self._dataset.id,
|
||||
keyword_table=json.dumps(index_struct_dict)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
else:
|
||||
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def del_nodes(self, node_ids: List[str]):
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_name='fake'
|
||||
)
|
||||
|
||||
service_context = ServiceContext.from_defaults(
|
||||
llm_predictor=LLMPredictor(llm=llm),
|
||||
embed_model=OpenAIEmbedding()
|
||||
)
|
||||
|
||||
dataset_keyword_table = self.get_keyword_table()
|
||||
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
|
||||
return
|
||||
else:
|
||||
index_struct_dict = dataset_keyword_table.keyword_table_dict
|
||||
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
|
||||
|
||||
# create index
|
||||
index = GPTJIEBAKeywordTableIndex(
|
||||
index_struct=index_struct,
|
||||
docstore=EmptyDocumentStore(),
|
||||
service_context=service_context
|
||||
)
|
||||
|
||||
for node_id in node_ids:
|
||||
index.delete(node_id)
|
||||
|
||||
index_struct_dict = index.index_struct.to_dict()
|
||||
|
||||
if not dataset_keyword_table:
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self._dataset.id,
|
||||
keyword_table=json.dumps(index_struct_dict)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
else:
|
||||
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@property
|
||||
def query_index(self) -> Optional[BaseGPTKeywordTableIndex]:
|
||||
docstore = DatesetDocumentStore(
|
||||
dataset=self._dataset,
|
||||
user_id=self._dataset.created_by,
|
||||
embedding_model_name="text-embedding-ada-002"
|
||||
)
|
||||
|
||||
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
dataset_keyword_table = self.get_keyword_table()
|
||||
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
|
||||
return None
|
||||
|
||||
index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict)
|
||||
|
||||
return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context)
|
||||
|
||||
def get_keyword_table(self):
|
||||
dataset_keyword_table = self._dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
return dataset_keyword_table
|
||||
return None
|
||||
|
||||
def update_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.commit()
|
||||
79
api/core/index/query/synthesizer.py
Normal file
79
api/core/index/query/synthesizer.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Optional, Sequence,
|
||||
)
|
||||
|
||||
from llama_index.indices.response.response_synthesis import ResponseSynthesizer
|
||||
from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
|
||||
from llama_index.indices.service_context import ServiceContext
|
||||
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
|
||||
from llama_index.prompts.prompts import (
|
||||
QuestionAnswerPrompt,
|
||||
RefinePrompt,
|
||||
SimpleInputPrompt,
|
||||
)
|
||||
from llama_index.types import RESPONSE_TEXT_TYPE
|
||||
|
||||
|
||||
class EnhanceResponseSynthesizer(ResponseSynthesizer):
|
||||
@classmethod
|
||||
def from_args(
|
||||
cls,
|
||||
service_context: ServiceContext,
|
||||
streaming: bool = False,
|
||||
use_async: bool = False,
|
||||
text_qa_template: Optional[QuestionAnswerPrompt] = None,
|
||||
refine_template: Optional[RefinePrompt] = None,
|
||||
simple_template: Optional[SimpleInputPrompt] = None,
|
||||
response_mode: ResponseMode = ResponseMode.DEFAULT,
|
||||
response_kwargs: Optional[Dict] = None,
|
||||
optimizer: Optional[BaseTokenUsageOptimizer] = None,
|
||||
) -> "ResponseSynthesizer":
|
||||
response_builder: Optional[BaseResponseBuilder] = None
|
||||
if response_mode != ResponseMode.NO_TEXT:
|
||||
if response_mode == 'no_synthesizer':
|
||||
response_builder = NoSynthesizer(
|
||||
service_context=service_context,
|
||||
simple_template=simple_template,
|
||||
streaming=streaming,
|
||||
)
|
||||
else:
|
||||
response_builder = get_response_builder(
|
||||
service_context,
|
||||
text_qa_template,
|
||||
refine_template,
|
||||
simple_template,
|
||||
response_mode,
|
||||
use_async=use_async,
|
||||
streaming=streaming,
|
||||
)
|
||||
return cls(response_builder, response_mode, response_kwargs, optimizer)
|
||||
|
||||
|
||||
class NoSynthesizer(BaseResponseBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
service_context: ServiceContext,
|
||||
simple_template: Optional[SimpleInputPrompt] = None,
|
||||
streaming: bool = False,
|
||||
) -> None:
|
||||
super().__init__(service_context, streaming)
|
||||
|
||||
async def aget_response(
|
||||
self,
|
||||
query_str: str,
|
||||
text_chunks: Sequence[str],
|
||||
prev_response: Optional[str] = None,
|
||||
**response_kwargs: Any,
|
||||
) -> RESPONSE_TEXT_TYPE:
|
||||
return "\n".join(text_chunks)
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
query_str: str,
|
||||
text_chunks: Sequence[str],
|
||||
prev_response: Optional[str] = None,
|
||||
**response_kwargs: Any,
|
||||
) -> RESPONSE_TEXT_TYPE:
|
||||
return "\n".join(text_chunks)
|
||||
22
api/core/index/readers/html_parser.py
Normal file
22
api/core/index/readers/html_parser.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from llama_index.readers.file.base_parser import BaseParser
|
||||
|
||||
|
||||
class HTMLParser(BaseParser):
|
||||
"""HTML parser."""
|
||||
|
||||
def _init_parser(self) -> Dict:
|
||||
"""Init parser."""
|
||||
return {}
|
||||
|
||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
||||
"""Parse file."""
|
||||
with open(file, "rb") as fp:
|
||||
soup = BeautifulSoup(fp, 'html.parser')
|
||||
text = soup.get_text()
|
||||
text = text.strip() if text else ''
|
||||
|
||||
return text
|
||||
56
api/core/index/readers/pdf_parser.py
Normal file
56
api/core/index/readers/pdf_parser.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from flask import current_app
|
||||
from llama_index.readers.file.base_parser import BaseParser
|
||||
from pypdf import PdfReader
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class PDFParser(BaseParser):
|
||||
"""PDF parser."""
|
||||
|
||||
def _init_parser(self) -> Dict:
|
||||
"""Init parser."""
|
||||
return {}
|
||||
|
||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
||||
"""Parse file."""
|
||||
if not current_app.config.get('PDF_PREVIEW', True):
|
||||
return ''
|
||||
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']:
|
||||
upload_file: UploadFile = self._parser_config['upload_file']
|
||||
if upload_file.hash:
|
||||
plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext'
|
||||
try:
|
||||
text = storage.load(plaintext_file_key).decode('utf-8')
|
||||
plaintext_file_exists = True
|
||||
return text
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
text_list = []
|
||||
with open(file, "rb") as fp:
|
||||
# Create a PDF object
|
||||
pdf = PdfReader(fp)
|
||||
|
||||
# Get the number of pages in the PDF document
|
||||
num_pages = len(pdf.pages)
|
||||
|
||||
# Iterate over every page
|
||||
for page in range(num_pages):
|
||||
# Extract the text from the page
|
||||
page_text = pdf.pages[page].extract_text()
|
||||
text_list.append(page_text)
|
||||
text = "\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
|
||||
return text
|
||||
136
api/core/index/vector_index.py
Normal file
136
api/core/index/vector_index.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_index.data_structs import Node
|
||||
from requests import ReadTimeout
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from tenacity import retry, stop_after_attempt, retry_if_exception_type
|
||||
|
||||
from core.index.index_builder import IndexBuilder
|
||||
from core.vector_store.base import BaseGPTVectorStoreIndex
|
||||
from extensions.ext_vector_store import vector_store
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Embedding
|
||||
|
||||
|
||||
class VectorIndex:
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self._dataset = dataset
|
||||
|
||||
def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
|
||||
if not self._dataset.index_struct_dict:
|
||||
index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
|
||||
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
|
||||
db.session.commit()
|
||||
|
||||
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
index = vector_store.get_index(
|
||||
service_context=service_context,
|
||||
index_struct=self._dataset.index_struct_dict
|
||||
)
|
||||
|
||||
if duplicate_check:
|
||||
nodes = self._filter_duplicate_nodes(index, nodes)
|
||||
|
||||
embedding_queue_nodes = []
|
||||
embedded_nodes = []
|
||||
for node in nodes:
|
||||
node_hash = node.doc_hash
|
||||
|
||||
# if node hash in cached embedding tables, use cached embedding
|
||||
embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
|
||||
if embedding:
|
||||
node.embedding = embedding.get_embedding()
|
||||
embedded_nodes.append(node)
|
||||
else:
|
||||
embedding_queue_nodes.append(node)
|
||||
|
||||
if embedding_queue_nodes:
|
||||
embedding_results = index._get_node_embedding_results(
|
||||
embedding_queue_nodes,
|
||||
set(),
|
||||
)
|
||||
|
||||
# pre embed nodes for cached embedding
|
||||
for embedding_result in embedding_results:
|
||||
node = embedding_result.node
|
||||
node.embedding = embedding_result.embedding
|
||||
|
||||
try:
|
||||
embedding = Embedding(hash=node.doc_hash)
|
||||
embedding.set_embedding(node.embedding)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
continue
|
||||
except:
|
||||
logging.exception('Failed to add embedding to db')
|
||||
continue
|
||||
|
||||
embedded_nodes.append(node)
|
||||
|
||||
self.index_insert_nodes(index, embedded_nodes)
|
||||
|
||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
||||
def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
|
||||
index.insert_nodes(nodes)
|
||||
|
||||
def del_nodes(self, node_ids: List[str]):
|
||||
if not self._dataset.index_struct_dict:
|
||||
return
|
||||
|
||||
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
index = vector_store.get_index(
|
||||
service_context=service_context,
|
||||
index_struct=self._dataset.index_struct_dict
|
||||
)
|
||||
|
||||
for node_id in node_ids:
|
||||
self.index_delete_node(index, node_id)
|
||||
|
||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
||||
def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
|
||||
index.delete_node(node_id)
|
||||
|
||||
def del_doc(self, doc_id: str):
|
||||
if not self._dataset.index_struct_dict:
|
||||
return
|
||||
|
||||
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
index = vector_store.get_index(
|
||||
service_context=service_context,
|
||||
index_struct=self._dataset.index_struct_dict
|
||||
)
|
||||
|
||||
self.index_delete_doc(index, doc_id)
|
||||
|
||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
||||
def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
|
||||
index.delete(doc_id)
|
||||
|
||||
@property
|
||||
def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
|
||||
if not self._dataset.index_struct_dict:
|
||||
return None
|
||||
|
||||
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
return vector_store.get_index(
|
||||
service_context=service_context,
|
||||
index_struct=self._dataset.index_struct_dict
|
||||
)
|
||||
|
||||
def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
|
||||
for node in nodes:
|
||||
node_id = node.doc_id
|
||||
exists_duplicate_node = index.exists_by_node_id(node_id)
|
||||
if exists_duplicate_node:
|
||||
nodes.remove(node)
|
||||
|
||||
return nodes
|
||||
467
api/core/indexing_runner.py
Normal file
467
api/core/indexing_runner.py
Normal file
@@ -0,0 +1,467 @@
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from llama_index import SimpleDirectoryReader
|
||||
from llama_index.data_structs import Node
|
||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
||||
from llama_index.node_parser import SimpleNodeParser, NodeParser
|
||||
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
|
||||
from llama_index.readers.file.markdown_parser import MarkdownParser
|
||||
|
||||
from core.docstore.dataset_docstore import DatesetDocumentStore
|
||||
from core.index.keyword_table_index import KeywordTableIndex
|
||||
from core.index.readers.html_parser import HTMLParser
|
||||
from core.index.readers.pdf_parser import PDFParser
|
||||
from core.index.vector_index import VectorIndex
|
||||
from core.llm.token_calculator import TokenCalculator
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class IndexingRunner:
|
||||
|
||||
def __init__(self, embedding_model_name: str = "text-embedding-ada-002"):
|
||||
self.storage = storage
|
||||
self.embedding_model_name = embedding_model_name
|
||||
|
||||
def run(self, document: Document):
|
||||
"""Run the indexing process."""
|
||||
# get dataset
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=document.dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(document)
|
||||
|
||||
# get the process rule
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
|
||||
first()
|
||||
|
||||
# get node parser for splitting
|
||||
node_parser = self._get_node_parser(processing_rule)
|
||||
|
||||
# split to nodes
|
||||
nodes = self._step_split(
|
||||
text_docs=text_docs,
|
||||
node_parser=node_parser,
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
# build index
|
||||
self._build_index(
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
nodes=nodes
|
||||
)
|
||||
|
||||
def run_in_splitting_status(self, document: Document):
|
||||
"""Run the indexing process when the index_status is splitting."""
|
||||
# get dataset
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=document.dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# get exist document_segment list and delete
|
||||
document_segments = DocumentSegment.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id
|
||||
).all()
|
||||
db.session.delete(document_segments)
|
||||
db.session.commit()
|
||||
# load file
|
||||
text_docs = self._load_data(document)
|
||||
|
||||
# get the process rule
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
|
||||
first()
|
||||
|
||||
# get node parser for splitting
|
||||
node_parser = self._get_node_parser(processing_rule)
|
||||
|
||||
# split to nodes
|
||||
nodes = self._step_split(
|
||||
text_docs=text_docs,
|
||||
node_parser=node_parser,
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
# build index
|
||||
self._build_index(
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
nodes=nodes
|
||||
)
|
||||
|
||||
def run_in_indexing_status(self, document: Document):
|
||||
"""Run the indexing process when the index_status is indexing."""
|
||||
# get dataset
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=document.dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# get exist document_segment list and delete
|
||||
document_segments = DocumentSegment.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id
|
||||
).all()
|
||||
nodes = []
|
||||
if document_segments:
|
||||
for document_segment in document_segments:
|
||||
# transform segment to node
|
||||
if document_segment.status != "completed":
|
||||
relationships = {
|
||||
DocumentRelationship.SOURCE: document_segment.document_id,
|
||||
}
|
||||
|
||||
previous_segment = document_segment.previous_segment
|
||||
if previous_segment:
|
||||
relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
|
||||
|
||||
next_segment = document_segment.next_segment
|
||||
if next_segment:
|
||||
relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
|
||||
node = Node(
|
||||
doc_id=document_segment.index_node_id,
|
||||
doc_hash=document_segment.index_node_hash,
|
||||
text=document_segment.content,
|
||||
extra_info=None,
|
||||
node_info=None,
|
||||
relationships=relationships
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
# build index
|
||||
self._build_index(
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
nodes=nodes
|
||||
)
|
||||
|
||||
def indexing_estimate(self, file_detail: UploadFile, tmp_processing_rule: dict) -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
# load data from file
|
||||
text_docs = self._load_data_from_file(file_detail)
|
||||
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"],
|
||||
rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
|
||||
# get node parser for splitting
|
||||
node_parser = self._get_node_parser(processing_rule)
|
||||
|
||||
# split to nodes
|
||||
nodes = self._split_to_nodes(
|
||||
text_docs=text_docs,
|
||||
node_parser=node_parser,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
for node in nodes:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(node.get_text())
|
||||
|
||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
|
||||
|
||||
return {
|
||||
"total_segments": len(nodes),
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
|
||||
"currency": TokenCalculator.get_currency(self.embedding_model_name),
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def _load_data(self, document: Document) -> List[Document]:
|
||||
# load file
|
||||
if document.data_source_type != "upload_file":
|
||||
return []
|
||||
|
||||
data_source_info = document.data_source_info_dict
|
||||
if not data_source_info or 'upload_file_id' not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
|
||||
file_detail = db.session.query(UploadFile). \
|
||||
filter(UploadFile.id == data_source_info['upload_file_id']). \
|
||||
one_or_none()
|
||||
|
||||
text_docs = self._load_data_from_file(file_detail)
|
||||
|
||||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=document.id,
|
||||
after_indexing_status="splitting",
|
||||
extra_update_params={
|
||||
Document.file_id: file_detail.id,
|
||||
Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
|
||||
Document.parsing_completed_at: datetime.datetime.utcnow()
|
||||
}
|
||||
)
|
||||
|
||||
# replace doc id to document model id
|
||||
for text_doc in text_docs:
|
||||
# remove invalid symbol
|
||||
text_doc.text = self.filter_string(text_doc.get_text())
|
||||
text_doc.doc_id = document.id
|
||||
|
||||
return text_docs
|
||||
|
||||
def filter_string(self, text):
|
||||
pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
|
||||
return pattern.sub('', text)
|
||||
|
||||
def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
self.storage.download(upload_file.key, filepath)
|
||||
|
||||
file_extractor = DEFAULT_FILE_EXTRACTOR.copy()
|
||||
file_extractor[".markdown"] = MarkdownParser()
|
||||
file_extractor[".html"] = HTMLParser()
|
||||
file_extractor[".htm"] = HTMLParser()
|
||||
file_extractor[".pdf"] = PDFParser({'upload_file': upload_file})
|
||||
|
||||
loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor)
|
||||
text_docs = loader.load_data()
|
||||
|
||||
return text_docs
|
||||
|
||||
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
|
||||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
"""
|
||||
if processing_rule.mode == "custom":
|
||||
# The user-defined segmentation rule
|
||||
rules = json.loads(processing_rule.rules)
|
||||
segmentation = rules["segmentation"]
|
||||
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
|
||||
raise ValueError("Custom segment length should be between 50 and 1000.")
|
||||
|
||||
separator = segmentation["separator"]
|
||||
if not separator:
|
||||
separators = ["\n\n", "。", ".", " ", ""]
|
||||
else:
|
||||
separator = separator.replace('\\n', '\n')
|
||||
separators = [separator, ""]
|
||||
|
||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=segmentation["max_tokens"],
|
||||
chunk_overlap=0,
|
||||
separators=separators
|
||||
)
|
||||
else:
|
||||
# Automatic segmentation
|
||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
||||
chunk_overlap=0,
|
||||
separators=["\n\n", "。", ".", " ", ""]
|
||||
)
|
||||
|
||||
return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True)
|
||||
|
||||
def _step_split(self, text_docs: List[Document], node_parser: NodeParser,
|
||||
dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]:
|
||||
"""
|
||||
Split the text documents into nodes and save them to the document segment.
|
||||
"""
|
||||
nodes = self._split_to_nodes(
|
||||
text_docs=text_docs,
|
||||
node_parser=node_parser,
|
||||
processing_rule=processing_rule
|
||||
)
|
||||
|
||||
# save node to document segment
|
||||
doc_store = DatesetDocumentStore(
|
||||
dataset=dataset,
|
||||
user_id=document.created_by,
|
||||
embedding_model_name=self.embedding_model_name,
|
||||
document_id=document.id
|
||||
)
|
||||
|
||||
doc_store.add_documents(nodes)
|
||||
|
||||
# update document status to indexing
|
||||
cur_time = datetime.datetime.utcnow()
|
||||
self._update_document_index_status(
|
||||
document_id=document.id,
|
||||
after_indexing_status="indexing",
|
||||
extra_update_params={
|
||||
Document.cleaning_completed_at: cur_time,
|
||||
Document.splitting_completed_at: cur_time,
|
||||
}
|
||||
)
|
||||
|
||||
# update segment status to indexing
|
||||
self._update_segments_by_document(
|
||||
document_id=document.id,
|
||||
update_params={
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.indexing_at: datetime.datetime.utcnow()
|
||||
}
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser,
|
||||
processing_rule: DatasetProcessRule) -> List[Node]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
all_nodes = []
|
||||
for text_doc in text_docs:
|
||||
# document clean
|
||||
document_text = self._document_clean(text_doc.get_text(), processing_rule)
|
||||
text_doc.text = document_text
|
||||
|
||||
# parse document to nodes
|
||||
nodes = node_parser.get_nodes_from_documents([text_doc])
|
||||
|
||||
all_nodes.extend(nodes)
|
||||
|
||||
return all_nodes
|
||||
|
||||
def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
|
||||
"""
|
||||
Clean the document text according to the processing rules.
|
||||
"""
|
||||
if processing_rule.mode == "automatic":
|
||||
rules = DatasetProcessRule.AUTOMATIC_RULES
|
||||
else:
|
||||
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
|
||||
|
||||
if 'pre_processing_rules' in rules:
|
||||
pre_processing_rules = rules["pre_processing_rules"]
|
||||
for pre_processing_rule in pre_processing_rules:
|
||||
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
|
||||
# Remove extra spaces
|
||||
pattern = r'\n{3,}'
|
||||
text = re.sub(pattern, '\n\n', text)
|
||||
pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
|
||||
text = re.sub(pattern, ' ', text)
|
||||
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
|
||||
# Remove email
|
||||
pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
|
||||
text = re.sub(pattern, '', text)
|
||||
|
||||
# Remove URL
|
||||
pattern = r'https?://[^\s]+'
|
||||
text = re.sub(pattern, '', text)
|
||||
|
||||
return text
|
||||
|
||||
def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None:
|
||||
"""
|
||||
Build the index for the document.
|
||||
"""
|
||||
vector_index = VectorIndex(dataset=dataset)
|
||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
||||
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
chunk_size = 100
|
||||
for i in range(0, len(nodes), chunk_size):
|
||||
# check document is paused
|
||||
self._check_document_paused_status(document.id)
|
||||
chunk_nodes = nodes[i:i + chunk_size]
|
||||
|
||||
tokens += sum(
|
||||
TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes
|
||||
)
|
||||
|
||||
# save vector index
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector_index.add_nodes(chunk_nodes)
|
||||
|
||||
# save keyword index
|
||||
keyword_table_index.add_nodes(chunk_nodes)
|
||||
|
||||
node_ids = [node.doc_id for node in chunk_nodes]
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
DocumentSegment.status == "indexing"
|
||||
).update({
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.completed_at: datetime.datetime.utcnow()
|
||||
})
|
||||
|
||||
db.session.commit()
|
||||
|
||||
indexing_end_at = time.perf_counter()
|
||||
|
||||
# update document status to completed
|
||||
self._update_document_index_status(
|
||||
document_id=document.id,
|
||||
after_indexing_status="completed",
|
||||
extra_update_params={
|
||||
Document.tokens: tokens,
|
||||
Document.completed_at: datetime.datetime.utcnow(),
|
||||
Document.indexing_latency: indexing_end_at - indexing_start_at,
|
||||
}
|
||||
)
|
||||
|
||||
def _check_document_paused_status(self, document_id: str):
|
||||
indexing_cache_key = 'document_{}_is_paused'.format(document_id)
|
||||
result = redis_client.get(indexing_cache_key)
|
||||
if result:
|
||||
raise DocumentIsPausedException()
|
||||
|
||||
def _update_document_index_status(self, document_id: str, after_indexing_status: str,
|
||||
extra_update_params: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Update the document indexing status.
|
||||
"""
|
||||
count = Document.query.filter_by(id=document_id, is_paused=True).count()
|
||||
if count > 0:
|
||||
raise DocumentIsPausedException()
|
||||
|
||||
update_params = {
|
||||
Document.indexing_status: after_indexing_status
|
||||
}
|
||||
|
||||
if extra_update_params:
|
||||
update_params.update(extra_update_params)
|
||||
|
||||
Document.query.filter_by(id=document_id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def _update_segments_by_document(self, document_id: str, update_params: dict) -> None:
|
||||
"""
|
||||
Update the document segment by document id.
|
||||
"""
|
||||
DocumentSegment.query.filter_by(document_id=document_id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
pass
|
||||
55
api/core/llm/error.py
Normal file
55
api/core/llm/error.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
"""Base class for all LLM exceptions."""
|
||||
description: Optional[str] = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
self.description = description
|
||||
|
||||
|
||||
class LLMBadRequestError(LLMError):
|
||||
"""Raised when the LLM returns bad request."""
|
||||
description = "Bad Request"
|
||||
|
||||
|
||||
class LLMAPIConnectionError(LLMError):
|
||||
"""Raised when the LLM returns API connection error."""
|
||||
description = "API Connection Error"
|
||||
|
||||
|
||||
class LLMAPIUnavailableError(LLMError):
|
||||
"""Raised when the LLM returns API unavailable error."""
|
||||
description = "API Unavailable Error"
|
||||
|
||||
|
||||
class LLMRateLimitError(LLMError):
|
||||
"""Raised when the LLM returns rate limit error."""
|
||||
description = "Rate Limit Error"
|
||||
|
||||
|
||||
class LLMAuthorizationError(LLMError):
|
||||
"""Raised when the LLM returns authorization error."""
|
||||
description = "Authorization Error"
|
||||
|
||||
|
||||
class ProviderTokenNotInitError(Exception):
|
||||
"""
|
||||
Custom exception raised when the provider token is not initialized.
|
||||
"""
|
||||
description = "Provider Token Not Init"
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
"""
|
||||
Custom exception raised when the quota for a provider has been exceeded.
|
||||
"""
|
||||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
"""
|
||||
description = "Model Currently Not Support"
|
||||
51
api/core/llm/error_handle_wraps.py
Normal file
51
api/core/llm/error_handle_wraps.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
import openai
|
||||
|
||||
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_llm_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except openai.error.InvalidRequestError as e:
|
||||
logging.exception("Invalid request to OpenAI API.")
|
||||
raise LLMBadRequestError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to OpenAI API.")
|
||||
raise LLMAPIConnectionError(str(e))
|
||||
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
|
||||
logging.exception("OpenAI service unavailable.")
|
||||
raise LLMAPIUnavailableError(str(e))
|
||||
except openai.error.RateLimitError as e:
|
||||
raise LLMRateLimitError(str(e))
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(str(e))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def handle_llm_exceptions_async(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except openai.error.InvalidRequestError as e:
|
||||
logging.exception("Invalid request to OpenAI API.")
|
||||
raise LLMBadRequestError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to OpenAI API.")
|
||||
raise LLMAPIConnectionError(str(e))
|
||||
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
|
||||
logging.exception("OpenAI service unavailable.")
|
||||
raise LLMAPIUnavailableError(str(e))
|
||||
except openai.error.RateLimitError as e:
|
||||
raise LLMRateLimitError(str(e))
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(str(e))
|
||||
|
||||
return wrapper
|
||||
103
api/core/llm/llm_builder.py
Normal file
103
api/core/llm/llm_builder.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import Union, Optional
|
||||
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
|
||||
|
||||
class LLMBuilder:
|
||||
"""
|
||||
This class handles the following logic:
|
||||
1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
|
||||
2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
|
||||
OPENAI_API_TYPE=azure
|
||||
OPENAI_API_VERSION=2022-12-01
|
||||
OPENAI_API_BASE=https://your-resource-name.openai.azure.com
|
||||
OPENAI_API_KEY=<your Azure OpenAI API key>
|
||||
3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
|
||||
4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
|
||||
5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
|
||||
6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
|
||||
7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
|
||||
8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
|
||||
if model_name == 'fake':
|
||||
return FakeListLLM(responses=[])
|
||||
|
||||
mode = cls.get_mode_by_model(model_name)
|
||||
if mode == 'chat':
|
||||
# llm_cls = StreamableAzureChatOpenAI
|
||||
llm_cls = StreamableChatOpenAI
|
||||
elif mode == 'completion':
|
||||
llm_cls = StreamableOpenAI
|
||||
else:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, model_name)
|
||||
|
||||
return llm_cls(
|
||||
model_name=model_name,
|
||||
temperature=kwargs.get('temperature', 0),
|
||||
max_tokens=kwargs.get('max_tokens', 256),
|
||||
top_p=kwargs.get('top_p', 1),
|
||||
frequency_penalty=kwargs.get('frequency_penalty', 0),
|
||||
presence_penalty=kwargs.get('presence_penalty', 0),
|
||||
callback_manager=kwargs.get('callback_manager', None),
|
||||
streaming=kwargs.get('streaming', False),
|
||||
# request_timeout=None
|
||||
**model_credentials
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
||||
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
model_name = model.get("name")
|
||||
completion_params = model.get("completion_params", {})
|
||||
|
||||
return cls.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=model_name,
|
||||
temperature=completion_params.get('temperature', 0),
|
||||
max_tokens=completion_params.get('max_tokens', 256),
|
||||
top_p=completion_params.get('top_p', 0),
|
||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
||||
presence_penalty=completion_params.get('presence_penalty', 0.1),
|
||||
streaming=streaming,
|
||||
callback_manager=callback_manager
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mode_by_model(cls, model_name: str) -> str:
|
||||
if not model_name:
|
||||
raise ValueError(f"empty model name is not supported.")
|
||||
|
||||
if model_name in llm_constant.models_by_mode['chat']:
|
||||
return "chat"
|
||||
elif model_name in llm_constant.models_by_mode['completion']:
|
||||
return "completion"
|
||||
else:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
@classmethod
|
||||
def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict:
|
||||
"""
|
||||
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
|
||||
Raises an exception if the model_name is not found or if the provider is not found.
|
||||
"""
|
||||
if not model_name:
|
||||
raise Exception('model name not found')
|
||||
|
||||
if model_name not in llm_constant.models:
|
||||
raise Exception('model {} not found'.format(model_name))
|
||||
|
||||
model_provider = llm_constant.models[model_name]
|
||||
|
||||
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
|
||||
return provider_service.get_credentials(model_name)
|
||||
15
api/core/llm/moderation.py
Normal file
15
api/core/llm/moderation.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import openai
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class Moderation:
|
||||
|
||||
def __init__(self, provider: str, api_key: str):
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
|
||||
if self.provider == ProviderName.OPENAI.value:
|
||||
self.client = openai.Moderation
|
||||
|
||||
def moderate(self, text):
|
||||
return self.client.create(input=text, api_key=self.api_key)
|
||||
23
api/core/llm/provider/anthropic_provider.py
Normal file
23
api/core/llm/provider/anthropic_provider.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
# todo
|
||||
return []
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
|
||||
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
|
||||
"""
|
||||
return {
|
||||
'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.ANTHROPIC
|
||||
105
api/core/llm/provider/azure_provider.py
Normal file
105
api/core/llm/provider/azure_provider.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class AzureProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
url = "{}/openai/deployments?api-version={}".format(
|
||||
credentials.get('openai_api_base'),
|
||||
credentials.get('openai_api_version')
|
||||
)
|
||||
|
||||
headers = {
|
||||
"api-key": credentials.get('openai_api_key'),
|
||||
"content-type": "application/json; charset=utf-8"
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return [{
|
||||
'id': deployment['id'],
|
||||
'name': '{} ({})'.format(deployment['id'], deployment['model'])
|
||||
} for deployment in result['data'] if deployment['status'] == 'succeeded']
|
||||
else:
|
||||
# TODO: optimize in future
|
||||
raise Exception('Failed to get deployments from Azure OpenAI. Status code: {}'.format(response.status_code))
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary.
|
||||
"""
|
||||
encrypted_config = self.get_provider_api_key(model_id=model_id)
|
||||
config = json.loads(encrypted_config)
|
||||
config['openai_api_type'] = 'azure'
|
||||
config['deployment_name'] = model_id
|
||||
return config
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.AZURE_OPENAI
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
config = json.loads(config)
|
||||
except:
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': 'https://foo.microsoft.com/bar',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
if obfuscated:
|
||||
if not config.get('openai_api_key'):
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': 'https://foo.microsoft.com/bar',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
|
||||
return config
|
||||
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
# TODO: change to dict when implemented
|
||||
return lambda value: value
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
# TODO: implement
|
||||
pass
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
Returns the encrypted token.
|
||||
"""
|
||||
return json.dumps({
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': '2023-03-15-preview',
|
||||
'openai_api_base': config['openai_api_base'],
|
||||
'openai_api_key': self.encrypt_token(config['openai_api_key'])
|
||||
})
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
"""
|
||||
Returns the decrypted token.
|
||||
"""
|
||||
config = json.loads(token)
|
||||
config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
|
||||
return config
|
||||
124
api/core/llm/provider/base.py
Normal file
124
api/core/llm/provider/base.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.account import Tenant
|
||||
from models.provider import Provider, ProviderType, ProviderName
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str:
|
||||
"""
|
||||
Returns the decrypted API key for the given tenant_id and provider_name.
|
||||
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
||||
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
|
||||
"""
|
||||
provider = self.get_provider(prefer_custom)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
|
||||
|
||||
if model_id and model_id == 'gpt-4':
|
||||
raise ModelCurrentlyNotSupportError()
|
||||
|
||||
if quota_used >= quota_limit:
|
||||
raise QuotaExceededError()
|
||||
|
||||
return self.get_hosted_credentials()
|
||||
else:
|
||||
return self.get_decrypted_token(provider.encrypted_config)
|
||||
|
||||
def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
"""
|
||||
providers = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.get_provider_name().value
|
||||
).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
|
||||
|
||||
custom_provider = None
|
||||
system_provider = None
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value:
|
||||
custom_provider = provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value:
|
||||
system_provider = provider
|
||||
|
||||
if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config:
|
||||
return custom_provider
|
||||
elif system_provider and system_provider.is_valid:
|
||||
return system_provider
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_hosted_credentials(self) -> str:
|
||||
if self.get_provider_name() != ProviderName.OPENAI:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||
raise ProviderTokenNotInitError()
|
||||
|
||||
return hosted_llm_credentials.openai.api_key
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key()
|
||||
except:
|
||||
config = 'THIS-IS-A-MOCK-TOKEN'
|
||||
|
||||
if obfuscated:
|
||||
return self.obfuscated_token(config)
|
||||
|
||||
return config
|
||||
|
||||
def obfuscated_token(self, token: str):
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
|
||||
def get_token_type(self):
|
||||
return str
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
return self.encrypt_token(config)
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
return self.decrypt_token(token)
|
||||
|
||||
def encrypt_token(self, token):
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
def decrypt_token(self, token):
|
||||
return rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def config_validate(self, config: str):
|
||||
raise NotImplementedError
|
||||
2
api/core/llm/provider/errors.py
Normal file
2
api/core/llm/provider/errors.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class ValidateFailedError(Exception):
|
||||
description = "Provider Validate failed"
|
||||
22
api/core/llm/provider/huggingface_provider.py
Normal file
22
api/core/llm/provider/huggingface_provider.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class HuggingfaceProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
# todo
|
||||
return []
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
|
||||
"""
|
||||
return {
|
||||
'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.HUGGINGFACEHUB
|
||||
53
api/core/llm/provider/llm_provider_service.py
Normal file
53
api/core/llm/provider/llm_provider_service.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.llm.provider.anthropic_provider import AnthropicProvider
|
||||
from core.llm.provider.azure_provider import AzureProvider
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.huggingface_provider import HuggingfaceProvider
|
||||
from core.llm.provider.openai_provider import OpenAIProvider
|
||||
from models.provider import Provider
|
||||
|
||||
|
||||
class LLMProviderService:
|
||||
|
||||
def __init__(self, tenant_id: str, provider_name: str):
|
||||
self.provider = self.init_provider(tenant_id, provider_name)
|
||||
|
||||
def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
|
||||
if provider_name == 'openai':
|
||||
return OpenAIProvider(tenant_id)
|
||||
elif provider_name == 'azure_openai':
|
||||
return AzureProvider(tenant_id)
|
||||
elif provider_name == 'anthropic':
|
||||
return AnthropicProvider(tenant_id)
|
||||
elif provider_name == 'huggingface':
|
||||
return HuggingfaceProvider(tenant_id)
|
||||
else:
|
||||
raise Exception('provider {} not found'.format(provider_name))
|
||||
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
return self.provider.get_models(model_id)
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
return self.provider.get_credentials(model_id)
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
||||
return self.provider.get_provider_configs(obfuscated)
|
||||
|
||||
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
|
||||
return self.provider.get_provider(prefer_custom)
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
|
||||
:param config:
|
||||
:raises: ValidateFailedError
|
||||
"""
|
||||
return self.provider.config_validate(config)
|
||||
|
||||
def get_token_type(self):
|
||||
return self.provider.get_token_type()
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
return self.provider.get_encrypted_token(config)
|
||||
44
api/core/llm/provider/openai_provider.py
Normal file
44
api/core/llm/provider/openai_provider.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import openai
|
||||
from openai.error import AuthenticationError, OpenAIError
|
||||
|
||||
from core.llm.moderation import Moderation
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
response = openai.Model.list(**credentials)
|
||||
|
||||
return [{
|
||||
'id': model['id'],
|
||||
'name': model['id'],
|
||||
} for model in response['data']]
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the credentials for the given tenant_id and provider_name.
|
||||
"""
|
||||
return {
|
||||
'openai_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.OPENAI
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
try:
|
||||
Moderation(self.get_provider_name().value, config).moderate('test')
|
||||
except (AuthenticationError, OpenAIError) as ex:
|
||||
raise ValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('OpenAI config validation failed')
|
||||
raise ex
|
||||
89
api/core/llm/streamable_azure_chat_open_ai.py
Normal file
89
api/core/llm/streamable_azure_chat_open_ai.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import requests
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from typing import Optional, List
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: The messages to count the tokens of.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the messages.
|
||||
"""
|
||||
tokens_per_message = 5
|
||||
tokens_per_request = 3
|
||||
|
||||
message_tokens = tokens_per_request
|
||||
message_strs = ''
|
||||
for message in messages:
|
||||
message_strs += message.content
|
||||
message_tokens += tokens_per_message
|
||||
|
||||
# calc once
|
||||
message_tokens += self.get_num_tokens(message_strs)
|
||||
|
||||
return message_tokens
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
||||
verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop)
|
||||
86
api/core/llm/streamable_chat_open_ai.py
Normal file
86
api/core/llm/streamable_chat_open_ai.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from typing import Optional, List
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableChatOpenAI(ChatOpenAI):
|
||||
|
||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: The messages to count the tokens of.
|
||||
|
||||
Returns:
|
||||
The number of tokens in the messages.
|
||||
"""
|
||||
tokens_per_message = 5
|
||||
tokens_per_request = 3
|
||||
|
||||
message_tokens = tokens_per_request
|
||||
message_strs = ''
|
||||
for message in messages:
|
||||
message_strs += message.content
|
||||
message_tokens += tokens_per_message
|
||||
|
||||
# calc once
|
||||
message_tokens += self.get_num_tokens(message_strs)
|
||||
|
||||
return message_tokens
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop)
|
||||
20
api/core/llm/streamable_open_ai.py
Normal file
20
api/core/llm/streamable_open_ai.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List
|
||||
from langchain import OpenAI
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
class StreamableOpenAI(OpenAI):
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop)
|
||||
41
api/core/llm/token_calculator.py
Normal file
41
api/core/llm/token_calculator.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import decimal
|
||||
from typing import Optional
|
||||
|
||||
import tiktoken
|
||||
|
||||
from core.constant import llm_constant
|
||||
|
||||
|
||||
class TokenCalculator:
|
||||
@classmethod
|
||||
def get_num_tokens(cls, model_name: str, text: str):
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
@classmethod
|
||||
def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal:
|
||||
if model_name in llm_constant.models_by_mode['embedding']:
|
||||
unit_price = llm_constant.model_prices[model_name]['usage']
|
||||
elif text_type == 'prompt':
|
||||
unit_price = llm_constant.model_prices[model_name]['prompt']
|
||||
elif text_type == 'completion':
|
||||
unit_price = llm_constant.model_prices[model_name]['completion']
|
||||
else:
|
||||
raise Exception('Invalid text type')
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
@classmethod
|
||||
def get_currency(cls, model_name: str):
|
||||
return llm_constant.model_currency
|
||||
@@ -0,0 +1,77 @@
|
||||
from typing import Any, List, Dict, Union
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
|
||||
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
||||
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
conversation: Conversation
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: Union[StreamableChatOpenAI | StreamableOpenAI]
|
||||
memory_key: str = "chat_history"
|
||||
max_token_limit: int = 2000
|
||||
message_limit: int = 10
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
"""String buffer of memory."""
|
||||
# fetch limited messages desc, and return reversed
|
||||
messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
Message.answer_tokens > 0
|
||||
).order_by(Message.created_at.desc()).limit(self.message_limit).all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
|
||||
chat_messages: List[BaseMessage] = []
|
||||
for message in messages:
|
||||
chat_messages.append(HumanMessage(content=message.query))
|
||||
chat_messages.append(AIMessage(content=message.answer))
|
||||
|
||||
if not chat_messages:
|
||||
return chat_messages
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit and chat_messages:
|
||||
pruned_memory.append(chat_messages.pop(0))
|
||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
||||
|
||||
return chat_messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
buffer: Any = self.buffer
|
||||
if self.return_messages:
|
||||
final_buffer: Any = buffer
|
||||
else:
|
||||
final_buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed"""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
pass
|
||||
@@ -0,0 +1,36 @@
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel
|
||||
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
|
||||
|
||||
class ReadOnlyConversationTokenDBStringBufferSharedMemory(BaseChatMemory):
|
||||
memory: ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Return memory variables."""
|
||||
return self.memory.memory_variables
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Load memory variables from memory."""
|
||||
buffer: Any = self.memory.buffer
|
||||
|
||||
final_buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.memory.human_prefix,
|
||||
ai_prefix=self.memory.ai_prefix,
|
||||
)
|
||||
|
||||
return {self.memory.memory_key: final_buffer}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed"""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
pass
|
||||
@@ -0,0 +1,16 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain.schema import BaseOutputParser
|
||||
from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
def parse(self, text: str) -> Any:
|
||||
json_string = text.strip()
|
||||
json_obj = json.loads(json_string)
|
||||
return json_obj
|
||||
37
api/core/prompt/prompt_builder.py
Normal file
37
api/core/prompt/prompt_builder.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import re
|
||||
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from core.prompt.prompt_template import OutLinePromptTemplate
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
@classmethod
|
||||
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
||||
prompt_template = OutLinePromptTemplate.from_template(prompt_content)
|
||||
system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
|
||||
prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
|
||||
system_message = system_prompt_template.format(**prompt_inputs)
|
||||
return system_message
|
||||
|
||||
@classmethod
|
||||
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
||||
prompt_template = OutLinePromptTemplate.from_template(prompt_content)
|
||||
ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
|
||||
prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
|
||||
ai_message = ai_prompt_template.format(**prompt_inputs)
|
||||
return ai_message
|
||||
|
||||
@classmethod
|
||||
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
||||
prompt_template = OutLinePromptTemplate.from_template(prompt_content)
|
||||
human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
|
||||
human_message = human_prompt_template.format(**inputs)
|
||||
return human_message
|
||||
|
||||
@classmethod
|
||||
def process_template(cls, template: str):
|
||||
processed_template = re.sub(r'\{(.+?)\}', r'\1', template)
|
||||
processed_template = re.sub(r'\{\{(.+?)\}\}', r'{\1}', processed_template)
|
||||
return processed_template
|
||||
37
api/core/prompt/prompt_template.py
Normal file
37
api/core/prompt/prompt_template.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.formatting import StrictFormatter
|
||||
|
||||
|
||||
class OutLinePromptTemplate(PromptTemplate):
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
|
||||
"""Load a prompt template from a template."""
|
||||
input_variables = {
|
||||
v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
|
||||
}
|
||||
return cls(
|
||||
input_variables=list(sorted(input_variables)), template=template, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class OneLineFormatter(StrictFormatter):
|
||||
def parse(self, format_string):
|
||||
last_end = 0
|
||||
results = []
|
||||
for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
|
||||
field_name = match.group(1)
|
||||
start, end = match.span()
|
||||
|
||||
literal_text = format_string[last_end:start]
|
||||
last_end = end
|
||||
|
||||
results.append((literal_text, field_name, '', None))
|
||||
|
||||
remaining_literal_text = format_string[last_end:]
|
||||
if remaining_literal_text:
|
||||
results.append((remaining_literal_text, None, None, None))
|
||||
|
||||
return results
|
||||
63
api/core/prompt/prompts.py
Normal file
63
api/core/prompt/prompts.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from llama_index import QueryKeywordExtractPrompt
|
||||
|
||||
CONVERSATION_TITLE_PROMPT = (
|
||||
"Human:{query}\n-----\n"
|
||||
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
|
||||
"If the human said is conducted in Chinese, you should return a Chinese title.\n"
|
||||
"If the human said is conducted in English, you should return an English title.\n"
|
||||
"title:"
|
||||
)
|
||||
|
||||
CONVERSATION_SUMMARY_PROMPT = (
|
||||
"Please generate a short summary of the following conversation.\n"
|
||||
"If the conversation communicating in Chinese, you should return a Chinese summary.\n"
|
||||
"If the conversation communicating in English, you should return an English summary.\n"
|
||||
"[Conversation Start]\n"
|
||||
"{context}\n"
|
||||
"[Conversation End]\n\n"
|
||||
"summary:"
|
||||
)
|
||||
|
||||
INTRODUCTION_GENERATE_PROMPT = (
|
||||
"I am designing a product for users to interact with an AI through dialogue. "
|
||||
"The Prompt given to the AI before the conversation is:\n\n"
|
||||
"```\n{prompt}\n```\n\n"
|
||||
"Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
|
||||
"Do not reveal the developer's motivation or deep logic behind the Prompt, "
|
||||
"but focus on building a relationship with the user:\n"
|
||||
)
|
||||
|
||||
MORE_LIKE_THIS_GENERATE_PROMPT = (
|
||||
"-----\n"
|
||||
"{original_completion}\n"
|
||||
"-----\n\n"
|
||||
"Please use the above content as a sample for generating the result, "
|
||||
"and include key information points related to the original sample in the result. "
|
||||
"Try to rephrase this information in different ways and predict according to the rules below.\n\n"
|
||||
"-----\n"
|
||||
"{prompt}\n"
|
||||
)
|
||||
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
"The output must be in JSON format following the specified schema:\n"
|
||||
"[\"question1\",\"question2\",\"question3\"]\n"
|
||||
)
|
||||
|
||||
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
|
||||
"A question is provided below. Given the question, extract up to {max_keywords} "
|
||||
"keywords from the text. Focus on extracting the keywords that we can use "
|
||||
"to best lookup answers to the question. Avoid stopwords."
|
||||
"I am not sure which language the following question is in. "
|
||||
"If the user asked the question in Chinese, please return the keywords in Chinese. "
|
||||
"If the user asked the question in English, please return the keywords in English.\n"
|
||||
"---------------------\n"
|
||||
"{question}\n"
|
||||
"---------------------\n"
|
||||
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
|
||||
)
|
||||
|
||||
QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
|
||||
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
|
||||
)
|
||||
83
api/core/tool/dataset_tool_builder.py
Normal file
83
api/core/tool/dataset_tool_builder.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.callbacks import CallbackManager
|
||||
from llama_index.langchain_helpers.agents import IndexToolConfig
|
||||
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.index.keyword_table_index import KeywordTableIndex
|
||||
from core.index.vector_index import VectorIndex
|
||||
from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
|
||||
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class DatasetToolBuilder:
|
||||
@classmethod
|
||||
def build_dataset_tool(cls, tenant_id: str, dataset_id: str,
|
||||
response_mode: str = "no_synthesizer",
|
||||
callback_handler: Optional[DatasetToolCallbackHandler] = None):
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
index = KeywordTableIndex(dataset=dataset).query_index
|
||||
|
||||
if not index:
|
||||
return None
|
||||
|
||||
query_kwargs = {
|
||||
"mode": "default",
|
||||
"response_mode": response_mode,
|
||||
"query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE,
|
||||
"max_keywords_per_query": 5,
|
||||
# If num_chunks_per_query is too large,
|
||||
# it will slow down the synthesis process due to multiple iterations of refinement.
|
||||
"num_chunks_per_query": 2
|
||||
}
|
||||
else:
|
||||
index = VectorIndex(dataset=dataset).query_index
|
||||
|
||||
if not index:
|
||||
return None
|
||||
|
||||
query_kwargs = {
|
||||
"mode": "default",
|
||||
"response_mode": response_mode,
|
||||
# If top_k is too large,
|
||||
# it will slow down the synthesis process due to multiple iterations of refinement.
|
||||
"similarity_top_k": 2
|
||||
}
|
||||
|
||||
# fulfill description when it is empty
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
index_tool_config = IndexToolConfig(
|
||||
index=index,
|
||||
name=f"dataset-{dataset_id}",
|
||||
description=description,
|
||||
index_query_kwargs=query_kwargs,
|
||||
tool_kwargs={
|
||||
"callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()])
|
||||
},
|
||||
# tool_kwargs={"return_direct": True},
|
||||
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
|
||||
)
|
||||
|
||||
index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset_id)
|
||||
|
||||
return EnhanceLlamaIndexTool.from_tool_config(
|
||||
tool_config=index_tool_config,
|
||||
callback_handler=index_callback_handler
|
||||
)
|
||||
43
api/core/tool/llama_index_tool.py
Normal file
43
api/core/tool/llama_index_tool.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Dict
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from llama_index.indices.base import BaseGPTIndex
|
||||
from llama_index.langchain_helpers.agents import IndexToolConfig
|
||||
from pydantic import Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
|
||||
|
||||
|
||||
class EnhanceLlamaIndexTool(BaseTool):
|
||||
"""Tool for querying a LlamaIndex."""
|
||||
|
||||
# NOTE: name/description still needs to be set
|
||||
index: BaseGPTIndex
|
||||
query_kwargs: Dict = Field(default_factory=dict)
|
||||
return_sources: bool = False
|
||||
callback_handler: IndexToolCallbackHandler
|
||||
|
||||
@classmethod
|
||||
def from_tool_config(cls, tool_config: IndexToolConfig,
|
||||
callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
|
||||
"""Create a tool from a tool config."""
|
||||
return_sources = tool_config.tool_kwargs.pop("return_sources", False)
|
||||
return cls(
|
||||
index=tool_config.index,
|
||||
callback_handler=callback_handler,
|
||||
name=tool_config.name,
|
||||
description=tool_config.description,
|
||||
return_sources=return_sources,
|
||||
query_kwargs=tool_config.index_query_kwargs,
|
||||
**tool_config.tool_kwargs,
|
||||
)
|
||||
|
||||
def _run(self, tool_input: str) -> str:
|
||||
response = self.index.query(tool_input, **self.query_kwargs)
|
||||
self.callback_handler.on_tool_end(response)
|
||||
return str(response)
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
response = await self.index.aquery(tool_input, **self.query_kwargs)
|
||||
self.callback_handler.on_tool_end(response)
|
||||
return str(response)
|
||||
34
api/core/vector_store/base.py
Normal file
34
api/core/vector_store/base.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from llama_index import ServiceContext, GPTVectorStoreIndex
|
||||
from llama_index.data_structs import Node
|
||||
from llama_index.vector_stores.types import VectorStore
|
||||
|
||||
|
||||
class BaseVectorStoreClient(ABC):
|
||||
@abstractmethod
|
||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_index_config(self, index_id: str) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseGPTVectorStoreIndex(GPTVectorStoreIndex):
|
||||
def delete_node(self, node_id: str):
|
||||
self._vector_store.delete_node(node_id)
|
||||
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
return self._vector_store.exists_by_node_id(node_id)
|
||||
|
||||
|
||||
class EnhanceVectorStore(ABC):
|
||||
@abstractmethod
|
||||
def delete_node(self, node_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
pass
|
||||
147
api/core/vector_store/qdrant_vector_store_client.py
Normal file
147
api/core/vector_store/qdrant_vector_store_client.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import os
|
||||
from typing import cast, List
|
||||
|
||||
from llama_index.data_structs import Node
|
||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
||||
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
|
||||
from qdrant_client.http.models import Payload, Filter
|
||||
|
||||
import qdrant_client
|
||||
from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex
|
||||
from llama_index.data_structs.data_structs_v2 import QdrantIndexDict
|
||||
from llama_index.vector_stores import QdrantVectorStore
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
|
||||
|
||||
|
||||
class QdrantVectorStoreClient(BaseVectorStoreClient):
|
||||
|
||||
def __init__(self, url: str, api_key: str, root_path: str):
|
||||
self._client = self.init_from_config(url, api_key, root_path)
|
||||
|
||||
@classmethod
|
||||
def init_from_config(cls, url: str, api_key: str, root_path: str):
|
||||
if url and url.startswith('path:'):
|
||||
path = url.replace('path:', '')
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(root_path, path)
|
||||
|
||||
return qdrant_client.QdrantClient(
|
||||
path=path
|
||||
)
|
||||
else:
|
||||
return qdrant_client.QdrantClient(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
||||
index_struct = QdrantIndexDict()
|
||||
|
||||
if self._client is None:
|
||||
raise Exception("Vector client is not initialized.")
|
||||
|
||||
# {"collection_name": "Gpt_index_xxx"}
|
||||
collection_name = config.get('collection_name')
|
||||
if not collection_name:
|
||||
raise Exception("collection_name cannot be None.")
|
||||
|
||||
return GPTQdrantEnhanceIndex(
|
||||
service_context=service_context,
|
||||
index_struct=index_struct,
|
||||
vector_store=QdrantEnhanceVectorStore(
|
||||
client=self._client,
|
||||
collection_name=collection_name
|
||||
)
|
||||
)
|
||||
|
||||
def to_index_config(self, index_id: str) -> dict:
|
||||
return {"collection_name": index_id}
|
||||
|
||||
|
||||
class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex):
|
||||
pass
|
||||
|
||||
|
||||
class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore):
|
||||
def delete_node(self, node_id: str):
|
||||
"""
|
||||
Delete node from the index.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=rest.Filter(
|
||||
must=[
|
||||
rest.FieldCondition(
|
||||
key="id", match=rest.MatchValue(value=node_id)
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
"""
|
||||
Get node from the index by node id.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
self._reload_if_needed()
|
||||
|
||||
response = self._client.retrieve(
|
||||
collection_name=self._collection_name,
|
||||
ids=[node_id]
|
||||
)
|
||||
|
||||
return len(response) > 0
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: VectorStoreQuery,
|
||||
) -> VectorStoreQueryResult:
|
||||
"""Query index for top k most similar nodes.
|
||||
|
||||
Args:
|
||||
query (VectorStoreQuery): query
|
||||
"""
|
||||
query_embedding = cast(List[float], query.query_embedding)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
response = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_embedding,
|
||||
limit=cast(int, query.similarity_top_k),
|
||||
query_filter=cast(Filter, self._build_query_filter(query)),
|
||||
with_vectors=True
|
||||
)
|
||||
|
||||
nodes = []
|
||||
similarities = []
|
||||
ids = []
|
||||
for point in response:
|
||||
payload = cast(Payload, point.payload)
|
||||
node = Node(
|
||||
doc_id=str(point.id),
|
||||
text=payload.get("text"),
|
||||
embedding=point.vector,
|
||||
extra_info=payload.get("extra_info"),
|
||||
relationships={
|
||||
DocumentRelationship.SOURCE: payload.get("doc_id", "None"),
|
||||
},
|
||||
)
|
||||
nodes.append(node)
|
||||
similarities.append(point.score)
|
||||
ids.append(str(point.id))
|
||||
|
||||
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client._client, QdrantLocal):
|
||||
self._client._client._load()
|
||||
61
api/core/vector_store/vector_store.py
Normal file
61
api/core/vector_store/vector_store.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from flask import Flask
|
||||
from llama_index import ServiceContext, GPTVectorStoreIndex
|
||||
from requests import ReadTimeout
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
||||
|
||||
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
|
||||
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
|
||||
|
||||
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
|
||||
|
||||
|
||||
class VectorStore:
|
||||
|
||||
def __init__(self):
|
||||
self._vector_store = None
|
||||
self._client = None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
if not app.config['VECTOR_STORE']:
|
||||
return
|
||||
|
||||
self._vector_store = app.config['VECTOR_STORE']
|
||||
if self._vector_store not in SUPPORTED_VECTOR_STORES:
|
||||
raise ValueError(f"Vector store {self._vector_store} is not supported.")
|
||||
|
||||
if self._vector_store == 'weaviate':
|
||||
self._client = WeaviateVectorStoreClient(
|
||||
endpoint=app.config['WEAVIATE_ENDPOINT'],
|
||||
api_key=app.config['WEAVIATE_API_KEY'],
|
||||
grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED']
|
||||
)
|
||||
elif self._vector_store == 'qdrant':
|
||||
self._client = QdrantVectorStoreClient(
|
||||
url=app.config['QDRANT_URL'],
|
||||
api_key=app.config['QDRANT_API_KEY'],
|
||||
root_path=app.root_path
|
||||
)
|
||||
|
||||
app.extensions['vector_store'] = self
|
||||
|
||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
||||
def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex:
|
||||
vector_store_config: dict = index_struct.get('vector_store')
|
||||
index = self.get_client().get_index(
|
||||
service_context=service_context,
|
||||
config=vector_store_config
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
def to_index_struct(self, index_id: str) -> dict:
|
||||
return {
|
||||
"type": self._vector_store,
|
||||
"vector_store": self.get_client().to_index_config(index_id)
|
||||
}
|
||||
|
||||
def get_client(self):
|
||||
if not self._client:
|
||||
raise Exception("Vector store client is not initialized.")
|
||||
|
||||
return self._client
|
||||
66
api/core/vector_store/vector_store_index_query.py
Normal file
66
api/core/vector_store/vector_store_index_query.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from llama_index.indices.query.base import IS
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional
|
||||
)
|
||||
|
||||
from llama_index.docstore import BaseDocumentStore
|
||||
from llama_index.indices.postprocessor.node import (
|
||||
BaseNodePostprocessor,
|
||||
)
|
||||
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
|
||||
from llama_index.indices.response.response_builder import ResponseMode
|
||||
from llama_index.indices.service_context import ServiceContext
|
||||
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
|
||||
from llama_index.prompts.prompts import (
|
||||
QuestionAnswerPrompt,
|
||||
RefinePrompt,
|
||||
SimpleInputPrompt,
|
||||
)
|
||||
|
||||
from core.index.query.synthesizer import EnhanceResponseSynthesizer
|
||||
|
||||
|
||||
class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery):
|
||||
@classmethod
|
||||
def from_args(
|
||||
cls,
|
||||
index_struct: IS,
|
||||
service_context: ServiceContext,
|
||||
docstore: Optional[BaseDocumentStore] = None,
|
||||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
||||
verbose: bool = False,
|
||||
# response synthesizer args
|
||||
response_mode: ResponseMode = ResponseMode.DEFAULT,
|
||||
text_qa_template: Optional[QuestionAnswerPrompt] = None,
|
||||
refine_template: Optional[RefinePrompt] = None,
|
||||
simple_template: Optional[SimpleInputPrompt] = None,
|
||||
response_kwargs: Optional[Dict] = None,
|
||||
use_async: bool = False,
|
||||
streaming: bool = False,
|
||||
optimizer: Optional[BaseTokenUsageOptimizer] = None,
|
||||
# class-specific args
|
||||
**kwargs: Any,
|
||||
) -> "BaseGPTIndexQuery":
|
||||
response_synthesizer = EnhanceResponseSynthesizer.from_args(
|
||||
service_context=service_context,
|
||||
text_qa_template=text_qa_template,
|
||||
refine_template=refine_template,
|
||||
simple_template=simple_template,
|
||||
response_mode=response_mode,
|
||||
response_kwargs=response_kwargs,
|
||||
use_async=use_async,
|
||||
streaming=streaming,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
return cls(
|
||||
index_struct=index_struct,
|
||||
service_context=service_context,
|
||||
response_synthesizer=response_synthesizer,
|
||||
docstore=docstore,
|
||||
node_postprocessors=node_postprocessors,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
258
api/core/vector_store/weaviate_vector_store_client.py
Normal file
258
api/core/vector_store/weaviate_vector_store_client.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import json
|
||||
import weaviate
|
||||
from dataclasses import field
|
||||
from typing import List, Any, Dict, Optional
|
||||
|
||||
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
|
||||
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
|
||||
from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
|
||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
||||
from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
|
||||
from llama_index.vector_stores import WeaviateVectorStore
|
||||
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
|
||||
from llama_index.readers.weaviate.utils import (
|
||||
parse_get_response,
|
||||
validate_client,
|
||||
)
|
||||
|
||||
|
||||
class WeaviateVectorStoreClient(BaseVectorStoreClient):
|
||||
|
||||
def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool):
|
||||
self._client = self.init_from_config(endpoint, api_key, grpc_enabled)
|
||||
|
||||
def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool):
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
|
||||
|
||||
weaviate.connect.connection.has_grpc = grpc_enabled
|
||||
|
||||
return weaviate.Client(
|
||||
url=endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 15),
|
||||
startup_period=None
|
||||
)
|
||||
|
||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
||||
index_struct = WeaviateIndexDict()
|
||||
|
||||
if self._client is None:
|
||||
raise Exception("Vector client is not initialized.")
|
||||
|
||||
# {"class_prefix": "Gpt_index_xxx"}
|
||||
class_prefix = config.get('class_prefix')
|
||||
if not class_prefix:
|
||||
raise Exception("class_prefix cannot be None.")
|
||||
|
||||
return GPTWeaviateEnhanceIndex(
|
||||
service_context=service_context,
|
||||
index_struct=index_struct,
|
||||
vector_store=WeaviateWithSimilaritiesVectorStore(
|
||||
weaviate_client=self._client,
|
||||
class_prefix=class_prefix
|
||||
)
|
||||
)
|
||||
|
||||
def to_index_config(self, index_id: str) -> dict:
|
||||
return {"class_prefix": index_id}
|
||||
|
||||
|
||||
class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
|
||||
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
|
||||
"""Query index for top k most similar nodes."""
|
||||
nodes = self.weaviate_query(
|
||||
self._client,
|
||||
self._class_prefix,
|
||||
query,
|
||||
)
|
||||
nodes = nodes[: query.similarity_top_k]
|
||||
node_idxs = [str(i) for i in range(len(nodes))]
|
||||
|
||||
similarities = []
|
||||
for node in nodes:
|
||||
similarities.append(node.extra_info['similarity'])
|
||||
del node.extra_info['similarity']
|
||||
|
||||
return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)
|
||||
|
||||
def weaviate_query(
|
||||
self,
|
||||
client: Any,
|
||||
class_prefix: str,
|
||||
query_spec: VectorStoreQuery,
|
||||
) -> List[Node]:
|
||||
"""Convert to LlamaIndex list."""
|
||||
validate_client(client)
|
||||
|
||||
class_name = _class_name(class_prefix)
|
||||
prop_names = [p["name"] for p in NODE_SCHEMA]
|
||||
vector = query_spec.query_embedding
|
||||
|
||||
# build query
|
||||
query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
|
||||
if query_spec.mode == VectorStoreQueryMode.DEFAULT:
|
||||
_logger.debug("Using vector search")
|
||||
if vector is not None:
|
||||
query = query.with_near_vector(
|
||||
{
|
||||
"vector": vector,
|
||||
}
|
||||
)
|
||||
elif query_spec.mode == VectorStoreQueryMode.HYBRID:
|
||||
_logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
|
||||
query = query.with_hybrid(
|
||||
query=query_spec.query_str,
|
||||
alpha=query_spec.alpha,
|
||||
vector=vector,
|
||||
)
|
||||
query = query.with_limit(query_spec.similarity_top_k)
|
||||
_logger.debug(f"Using limit of {query_spec.similarity_top_k}")
|
||||
|
||||
# execute query
|
||||
query_result = query.do()
|
||||
|
||||
# parse results
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
results = [self._to_node(entry) for entry in entries]
|
||||
return results
|
||||
|
||||
def _to_node(self, entry: Dict) -> Node:
|
||||
"""Convert to Node."""
|
||||
extra_info_str = entry["extra_info"]
|
||||
if extra_info_str == "":
|
||||
extra_info = None
|
||||
else:
|
||||
extra_info = json.loads(extra_info_str)
|
||||
|
||||
if 'certainty' in entry['_additional']:
|
||||
if extra_info:
|
||||
extra_info['similarity'] = entry['_additional']['certainty']
|
||||
else:
|
||||
extra_info = {'similarity': entry['_additional']['certainty']}
|
||||
|
||||
node_info_str = entry["node_info"]
|
||||
if node_info_str == "":
|
||||
node_info = None
|
||||
else:
|
||||
node_info = json.loads(node_info_str)
|
||||
|
||||
relationships_str = entry["relationships"]
|
||||
relationships: Dict[DocumentRelationship, str]
|
||||
if relationships_str == "":
|
||||
relationships = field(default_factory=dict)
|
||||
else:
|
||||
relationships = {
|
||||
DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
|
||||
}
|
||||
|
||||
return Node(
|
||||
text=entry["text"],
|
||||
doc_id=entry["doc_id"],
|
||||
embedding=entry["_additional"]["vector"],
|
||||
extra_info=extra_info,
|
||||
node_info=node_info,
|
||||
relationships=relationships,
|
||||
)
|
||||
|
||||
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
|
||||
"""Delete a document.
|
||||
|
||||
Args:
|
||||
doc_id (str): document id
|
||||
|
||||
"""
|
||||
delete_document(self._client, doc_id, self._class_prefix)
|
||||
|
||||
def delete_node(self, node_id: str):
|
||||
"""
|
||||
Delete node from the index.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
delete_node(self._client, node_id, self._class_prefix)
|
||||
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
"""
|
||||
Get node from the index by node id.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
entry = get_by_node_id(self._client, node_id, self._class_prefix)
|
||||
return True if entry else False
|
||||
|
||||
|
||||
class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
|
||||
pass
|
||||
|
||||
|
||||
def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
|
||||
"""Delete entry."""
|
||||
validate_client(client)
|
||||
# make sure that each entry
|
||||
class_name = _class_name(class_prefix)
|
||||
where_filter = {
|
||||
"path": ["ref_doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueString": ref_doc_id,
|
||||
}
|
||||
query = (
|
||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
||||
)
|
||||
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
for entry in entries:
|
||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
||||
|
||||
while len(entries) > 0:
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
for entry in entries:
|
||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
||||
|
||||
|
||||
def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
|
||||
"""Delete entry."""
|
||||
validate_client(client)
|
||||
# make sure that each entry
|
||||
class_name = _class_name(class_prefix)
|
||||
where_filter = {
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueString": node_id,
|
||||
}
|
||||
query = (
|
||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
||||
)
|
||||
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
for entry in entries:
|
||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
||||
|
||||
|
||||
def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
|
||||
"""Delete entry."""
|
||||
validate_client(client)
|
||||
# make sure that each entry
|
||||
class_name = _class_name(class_prefix)
|
||||
where_filter = {
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueString": node_id,
|
||||
}
|
||||
query = (
|
||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
||||
)
|
||||
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
if len(entries) == 0:
|
||||
return None
|
||||
|
||||
return entries[0]
|
||||
Reference in New Issue
Block a user