Initial commit

This commit is contained in:
John Wang
2023-05-15 08:51:32 +08:00
commit db896255d6
744 changed files with 56028 additions and 0 deletions

52
api/core/__init__.py Normal file
View File

@@ -0,0 +1,52 @@
import os
from typing import Optional
import langchain
from flask import Flask
from jieba.analyse import default_tfidf
from langchain import set_handler
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
from llama_index import IndexStructType, QueryMode
from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
from pydantic import BaseModel
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.index.keyword_table.stopwords import STOPWORDS
from core.prompt.prompt_template import OneLineFormatter
from core.vector_store.vector_store import VectorStore
from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
class HostedOpenAICredential(BaseModel):
api_key: str
class HostedLLMCredentials(BaseModel):
openai: Optional[HostedOpenAICredential] = None
hosted_llm_credentials = HostedLLMCredentials()
def init_app(app: Flask):
formatter = OneLineFormatter()
DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
default_tfidf.stop_words = STOPWORDS
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
set_handler(DifyStdOutCallbackHandler())
if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))

View File

@@ -0,0 +1,89 @@
from typing import Optional
from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
from langchain.callbacks import CallbackManager
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.llm.llm_builder import LLMBuilder
class AgentBuilder:
@classmethod
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
dataset_tool_callback_handler: DatasetToolCallbackHandler,
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=agent_loop_gather_callback_handler.model_name,
temperature=0,
max_tokens=1024,
callback_manager=llm_callback_manager
)
tool_callback_manager = CallbackManager([
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
])
for tool in tools:
tool.callback_manager = tool_callback_manager
prompt = cls.build_agent_prompt_template(
tools=tools,
memory=memory,
)
agent_llm_chain = LLMChain(
llm=llm,
prompt=prompt,
)
agent = cls.build_agent(agent_llm_chain=agent_llm_chain, memory=memory)
agent_callback_manager = CallbackManager(
[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
)
agent_chain = AgentExecutor.from_agent_and_tools(
tools=tools,
agent=agent,
memory=memory,
callback_manager=agent_callback_manager,
max_iterations=6,
early_stopping_method="generate",
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
)
return agent_chain
@classmethod
def build_agent_prompt_template(cls, tools, memory: Optional[BaseChatMemory]):
if memory:
prompt = ConversationalAgent.create_prompt(
tools=tools,
)
else:
prompt = ZeroShotAgent.create_prompt(
tools=tools,
)
return prompt
@classmethod
def build_agent(cls, agent_llm_chain: LLMChain, memory: Optional[BaseChatMemory]):
if memory:
agent = ConversationalAgent(
llm_chain=agent_llm_chain
)
else:
agent = ZeroShotAgent(
llm_chain=agent_llm_chain
)
return agent

View File

@@ -0,0 +1,178 @@
import logging
import time
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.model_name = model_name
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
self.current_chain = None
@property
def agent_loops(self) -> List[AgentLoop]:
return self._agent_loops
def clear_agent_loops(self) -> None:
self._agent_loops = []
self._current_loop = None
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
# serialized={'name': 'OpenAI'}
# prompts=['Answer the following questions...\nThought:']
# kwargs={}
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt=prompts[0],
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
# kwargs={}
if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end'
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
self._current_loop.completion = response.generations[0][0].text
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
self._agent_loops = []
self._current_loop = None
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing."""
# kwargs={'color': 'green', 'llm_prefix': 'Thought:', 'observation_prefix': 'Observation: '}
# input_str='action-input'
# serialized={'description': 'A search engine. Useful for when you need to answer questions about current events. Input should be a search query.', 'name': 'Search'}
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
tool = action.tool
tool_input = action.tool_input
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
if self._current_loop and self._current_loop.status == 'llm_end':
self._current_loop.status = 'agent_action'
self._current_loop.thought = thought
self._current_loop.tool_name = tool
self._current_loop.tool_input = tool_input
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
# kwargs={'name': 'Search'}
# llm_prefix='Thought:'
# observation_prefix='Observation: '
# output='53 years'
if self._current_loop and self._current_loop.status == 'agent_action' and output and output != 'None':
self._current_loop.status = 'tool_end'
self._current_loop.tool_output = output
self._current_loop.completed = True
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
self._agent_loops.append(self._current_loop)
self._current_loop = None
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.error(error)
self._agent_loops = []
self._current_loop = None
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
# Final Answer
if self._current_loop and (self._current_loop.status == 'llm_end' or self._current_loop.status == 'agent_action'):
self._current_loop.status = 'agent_finish'
self._current_loop.completed = True
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop)
self._agent_loops.append(self._current_loop)
self._current_loop = None
elif not self._current_loop and self._agent_loops:
self._agent_loops[-1].status = 'agent_finish'

View File

@@ -0,0 +1,117 @@
import logging
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.conversation_message_task import ConversationMessageTask
class DatasetToolCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.queries = []
self.conversation_message_task = conversation_message_task
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return True
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return True
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return False
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
tool_name = serialized.get('name')
dataset_id = tool_name[len("dataset-"):]
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str))
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
# kwargs={'name': 'Search'}
# llm_prefix='Thought:'
# observation_prefix='Observation: '
# output='53 years'
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
pass

View File

@@ -0,0 +1,23 @@
from pydantic import BaseModel
class AgentLoop(BaseModel):
position: int = 1
thought: str = None
tool_name: str = None
tool_input: str = None
tool_output: str = None
prompt: str = None
prompt_tokens: int = None
completion: str = None
completion_tokens: int = None
latency: float = None
status: str = 'llm_started'
completed: bool = False
started_at: float = None
completed_at: float = None

View File

@@ -0,0 +1,16 @@
from pydantic import BaseModel
class ChainResult(BaseModel):
type: str = None
prompt: dict = None
completion: dict = None
status: str = 'chain_started'
completed: bool = False
started_at: float = None
completed_at: float = None
agent_result: dict = None
"""only when type is 'AgentExecutor'"""

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class DatasetQueryObj(BaseModel):
dataset_id: str = None
query: str = None

View File

@@ -0,0 +1,9 @@
from pydantic import BaseModel
class LLMMessage(BaseModel):
prompt: str = ''
prompt_tokens: int = 0
completion: str = ''
completion_tokens: int = 0
latency: float = 0.0

View File

@@ -0,0 +1,38 @@
from llama_index import Response
from extensions.ext_database import db
from models.dataset import DocumentSegment
class IndexToolCallbackHandler:
def __init__(self) -> None:
self._response = None
@property
def response(self) -> Response:
return self._response
def on_tool_end(self, response: Response) -> None:
"""Handle tool end."""
self._response = response
class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
"""Callback handler for dataset tool."""
def __init__(self, dataset_id: str) -> None:
super().__init__()
self.dataset_id = dataset_id
def on_tool_end(self, response: Response) -> None:
"""Handle tool end."""
for node in response.source_nodes:
index_node_id = node.node.doc_id
# add hit count to document segment
db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.index_node_id == index_node_id
).update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)

View File

@@ -0,0 +1,147 @@
import logging
import time
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
class LLMCallbackHandler(BaseCallbackHandler):
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
conversation_message_task: ConversationMessageTask):
self.llm = llm
self.llm_message = LLMMessage()
self.start_at = None
self.conversation_message_task = conversation_message_task
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.start_at = time.perf_counter()
if 'Chat' in serialized['name']:
real_prompts = []
messages = []
for prompt in prompts:
role, content = prompt.split(': ', maxsplit=1)
if role == 'human':
role = 'user'
message = HumanMessage(content=content)
elif role == 'ai':
role = 'assistant'
message = AIMessage(content=content)
else:
message = SystemMessage(content=content)
real_prompt = {
"role": role,
"text": content
}
real_prompts.append(real_prompt)
messages.append(message)
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
else:
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
}]
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.conversation_message_task.save_message(self.llm_message)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.conversation_message_task.append_message_text(token)
self.llm_message.completion += token
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
else:
logging.error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
pass

View File

@@ -0,0 +1,137 @@
import logging
import time
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.entity.chain_result import ChainResult
from core.constant import llm_constant
from core.conversation_message_task import ConversationMessageTask
class MainChainGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self._current_chain_result = None
self._current_chain_message = None
self.conversation_message_task = conversation_message_task
self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler(
llm_constant.agent_model_name,
conversation_message_task
)
def clear_chain_results(self) -> None:
self._current_chain_result = None
self._current_chain_message = None
self.agent_loop_gather_callback_handler.current_chain = None
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return True
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return True
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
if not self._current_chain_result:
self._current_chain_result = ChainResult(
type=serialized['name'],
prompt=inputs,
started_at=time.perf_counter()
)
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
if self._current_chain_result and self._current_chain_result.status == 'chain_started':
self._current_chain_result.status = 'chain_ended'
self._current_chain_result.completion = outputs
self._current_chain_result.completed = True
self._current_chain_result.completed_at = time.perf_counter()
self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result)
self.clear_chain_results()
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
self.clear_chain_results()
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.error(error)
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
pass

View File

@@ -0,0 +1,127 @@
import sys
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult
class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def __init__(self, color: Optional[str] = None) -> None:
"""Initialize callback handler."""
self.color = color
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
print_text("\n[on_llm_start]\n", color='blue')
if 'Chat' in serialized['name']:
for prompt in prompts:
print_text(prompt + "\n", color='blue')
else:
print_text(prompts[0] + "\n", color='blue')
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
print_text("\n[on_llm_end]\nOutput: " + str(response.generations[0][0].text) + "\nllm_output: " + str(
response.llm_output) + "\n", color='blue')
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized["name"]
print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink')
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
print_text("\n[on_chain_error]\nError: " + str(error) + "\n", color='pink')
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing."""
print_text("\n[on_tool_start] " + str(serialized), color='yellow')
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
tool = action.tool
tool_input = action.tool_input
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
print_text("\n[on_tool_end]\n", color='yellow')
if observation_prefix:
print_text(f"\n{observation_prefix}")
print_text(output, color='yellow')
if llm_prefix:
print_text(f"\n{llm_prefix}")
print_text("\n")
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='yellow')
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run when agent ends."""
print_text("\n[on_text] " + text + "\n", color=color if color else self.color, end=end)
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
sys.stdout.write(token)
sys.stdout.flush()

View File

@@ -0,0 +1,34 @@
from typing import Optional
from langchain.callbacks import CallbackManager
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.tool_chain import ToolChain
class ChainBuilder:
@classmethod
def to_tool_chain(cls, tool, **kwargs) -> ToolChain:
return ToolChain(
tool=tool,
input_key=kwargs.get('input_key', 'input'),
output_key=kwargs.get('output_key', 'tool_output'),
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
)
@classmethod
def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[
SensitiveWordAvoidanceChain]:
sensitive_words = tool_config.get("words", "")
if tool_config.get("enabled", False) \
and sensitive_words:
return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","),
canned_response=tool_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]),
**kwargs
)
return None

View File

@@ -0,0 +1,116 @@
from typing import Optional, List
from langchain.callbacks import SharedCallbackManager
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from core.agent.agent_builder import AgentBuilder
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.chain.chain_builder import ChainBuilder
from core.constant import llm_constant
from core.conversation_message_task import ConversationMessageTask
from core.tool.dataset_tool_builder import DatasetToolBuilder
class MainChainBuilder:
@classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask):
first_input_key = "input"
final_output_key = "output"
chains = []
chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
# agent mode
tool_chains, chains_output_key = cls.get_agent_chains(
tenant_id=tenant_id,
agent_mode=agent_mode,
memory=memory,
dataset_tool_callback_handler=DatasetToolCallbackHandler(conversation_message_task),
agent_loop_gather_callback_handler=chain_callback_handler.agent_loop_gather_callback_handler
)
chains += tool_chains
if chains_output_key:
final_output_key = chains_output_key
if len(chains) == 0:
return None
for chain in chains:
# do not add handler into singleton callback manager
if not isinstance(chain.callback_manager, SharedCallbackManager):
chain.callback_manager.add_handler(chain_callback_handler)
# build main chain
overall_chain = SequentialChain(
chains=chains,
input_variables=[first_input_key],
output_variables=[final_output_key],
memory=memory, # only for use the memory prompt input key
)
return overall_chain
@classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
dataset_tool_callback_handler: DatasetToolCallbackHandler,
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
# agent mode
chains = []
if agent_mode and agent_mode.get('enabled'):
tools = agent_mode.get('tools', [])
pre_fixed_chains = []
agent_tools = []
for tool in tools:
tool_type = list(tool.keys())[0]
tool_config = list(tool.values())[0]
if tool_type == 'sensitive-word-avoidance':
chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
if chain:
pre_fixed_chains.append(chain)
elif tool_type == "dataset":
dataset_tool = DatasetToolBuilder.build_dataset_tool(
tenant_id=tenant_id,
dataset_id=tool_config.get("id"),
response_mode='no_synthesizer', # "compact"
callback_handler=dataset_tool_callback_handler
)
if dataset_tool:
agent_tools.append(dataset_tool)
# add pre-fixed chains
chains += pre_fixed_chains
if len(agent_tools) == 1:
# tool to chain
tool_chain = ChainBuilder.to_tool_chain(tool=agent_tools[0], output_key='tool_output')
chains.append(tool_chain)
elif len(agent_tools) > 1:
# build agent config
agent_chain = AgentBuilder.to_agent_chain(
tenant_id=tenant_id,
tools=agent_tools,
memory=memory,
dataset_tool_callback_handler=dataset_tool_callback_handler,
agent_loop_gather_callback_handler=agent_loop_gather_callback_handler
)
chains.append(agent_chain)
final_output_key = cls.get_chains_output_key(chains)
return chains, final_output_key
@classmethod
def get_chains_output_key(cls, chains: List[Chain]):
if len(chains) > 0:
return chains[-1].output_keys[0]
return None

View File

@@ -0,0 +1,42 @@
from typing import List, Dict
from langchain.chains.base import Chain
class SensitiveWordAvoidanceChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
sensitive_words: List[str] = []
canned_response: str = None
@property
def _chain_type(self) -> str:
return "sensitive_word_avoidance_chain"
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def _check_sensitive_word(self, text: str) -> str:
for word in self.sensitive_words:
if word in text:
return self.canned_response
return text
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
text = inputs[self.input_key]
output = self._check_sensitive_word(text)
return {self.output_key: output}

View File

@@ -0,0 +1,42 @@
from typing import List, Dict
from langchain.chains.base import Chain
from langchain.tools import BaseTool
class ToolChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
tool: BaseTool
@property
def _chain_type(self) -> str:
return "tool_chain"
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
input = inputs[self.input_key]
output = self.tool.run(input, self.verbose)
return {self.output_key: output}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output."""
input = inputs[self.input_key]
output = await self.tool.arun(input, self.verbose)
return {self.output_key: output}

326
api/core/completion.py Normal file
View File

@@ -0,0 +1,326 @@
from typing import Optional, List, Union
from langchain.callbacks import CallbackManager
from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
ReadOnlyConversationTokenDBStringBufferSharedMemory
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import OutLinePromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.model import App, AppModelConfig, Account, Conversation, Message
class Completion:
@classmethod
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
"""
errors: ProviderTokenNotInitError
"""
cls.validate_query_tokens(app.tenant_id, app_model_config, query)
memory = None
if conversation:
# get memory of conversation (read-only)
memory = cls.get_memory_from_conversation(
tenant_id=app.tenant_id,
app_model_config=app_model_config,
conversation=conversation
)
inputs = conversation.inputs
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
app_model_config=app_model_config,
user=user,
conversation=conversation,
is_override=is_override,
inputs=inputs,
query=query,
streaming=streaming
)
# build main chain include agent
main_chain = MainChainBuilder.to_langchain_components(
tenant_id=app.tenant_id,
agent_mode=app_model_config.agent_mode_dict,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task
)
chain_output = ''
if main_chain:
chain_output = main_chain.run(query)
# run the final llm
try:
cls.run_final_llm(
tenant_id=app.tenant_id,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
chain_output=chain_output,
conversation_message_task=conversation_message_task,
memory=memory,
streaming=streaming
)
except ConversationTaskStoppedException:
return
@classmethod
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
chain_output: str,
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
final_llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict,
streaming=streaming
)
# get llm prompt
prompt = cls.get_main_llm_prompt(
mode=mode,
llm=final_llm,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
chain_output=chain_output,
memory=memory
)
final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=final_llm,
prompt=prompt,
mode=mode
)
response = final_llm.generate([prompt])
return response
@classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Union[str | List[BaseMessage]]:
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
if mode == 'completion':
prompt_template = OutLinePromptTemplate.from_template(
template=("Use the following pieces of [CONTEXT] to answer the question at the end. "
"If you don't know the answer, "
"just say that you don't know, don't try to make up an answer. \n"
"```\n"
"[CONTEXT]\n"
"{context}\n"
"```\n" if chain_output else "")
+ (pre_prompt + "\n" if pre_prompt else "")
+ "{query}\n"
)
if chain_output:
inputs['context'] = chain_output
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_content = prompt_template.format(
query=query,
**prompt_inputs
)
if isinstance(llm, BaseChatModel):
# use chat llm as completion model
return [HumanMessage(content=prompt_content)]
else:
return prompt_content
else:
messages: List[BaseMessage] = []
system_message = None
if pre_prompt:
# append pre prompt as system message
system_message = PromptBuilder.to_system_message(pre_prompt, inputs)
if chain_output:
# append context as system message, currently only use simple stuff prompt
context_message = PromptBuilder.to_system_message(
"""Use the following pieces of [CONTEXT] to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
```
[CONTEXT]
{context}
```""",
{'context': chain_output}
)
if not system_message:
system_message = context_message
else:
system_message.content = context_message.content + "\n\n" + system_message.content
if system_message:
messages.append(system_message)
human_inputs = {
"query": query
}
# construct main prompt
human_message = PromptBuilder.to_human_message(
prompt_content="{query}",
inputs=human_inputs
)
if memory:
# append chat histories
tmp_messages = messages.copy() + [human_message]
curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
rest_tokens = llm_constant.max_context_token_length[
memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
messages += history_messages
messages.append(human_message)
return messages
@classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
else:
callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
return CallbackManager(callback_handlers)
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> \
List[BaseMessage]:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]
@classmethod
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
conversation: Conversation,
**kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
# only for calc token in memory
memory_llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict
)
# use llm config from conversation
memory = ReadOnlyConversationTokenDBBufferSharedMemory(
conversation=conversation,
llm=memory_llm,
max_token_limit=kwargs.get("max_token_limit", 2048),
memory_key=kwargs.get("memory_key", "chat_history"),
return_messages=kwargs.get("return_messages", True),
input_key=kwargs.get("input_key", "input"),
output_key=kwargs.get("output_key", "output"),
message_limit=kwargs.get("message_limit", 10),
)
return memory
@classmethod
def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict
)
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
max_tokens = llm.max_tokens
if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
raise LLMBadRequestError("Query is too long")
@classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
prompt: Union[str, List[BaseMessage]], mode: str):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
max_tokens = final_llm.max_tokens
if mode == 'completion' and isinstance(final_llm, BaseLLM):
prompt_tokens = final_llm.get_num_tokens(prompt)
else:
prompt_tokens = final_llm.get_messages_tokens(prompt)
if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
final_llm.max_tokens = max_tokens
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=app.tenant_id,
model_name='gpt-3.5-turbo',
streaming=streaming
)
# get llm prompt
original_prompt = cls.get_main_llm_prompt(
mode="completion",
llm=llm,
pre_prompt=pre_prompt,
query=message.query,
inputs=message.inputs,
chain_output=None,
memory=None
)
original_completion = message.answer.strip()
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
app_model_config=app_model_config,
user=user,
inputs=message.inputs,
query=message.query,
is_override=True if message.override_model_configs else False,
streaming=streaming
)
llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=llm,
prompt=prompt,
mode='completion'
)
llm.generate([prompt])

View File

@@ -0,0 +1,84 @@
from _decimal import Decimal
models = {
'gpt-4': 'openai', # 8,192 tokens
'gpt-4-32k': 'openai', # 32,768 tokens
'gpt-3.5-turbo': 'openai', # 4,096 tokens
'text-davinci-003': 'openai', # 4,097 tokens
'text-davinci-002': 'openai', # 4,097 tokens
'text-curie-001': 'openai', # 2,049 tokens
'text-babbage-001': 'openai', # 2,049 tokens
'text-ada-001': 'openai', # 2,049 tokens
'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions
}
max_context_token_length = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'text-davinci-003': 4097,
'text-davinci-002': 4097,
'text-curie-001': 2049,
'text-babbage-001': 2049,
'text-ada-001': 2049,
'text-embedding-ada-002': 8191
}
models_by_mode = {
'chat': [
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
],
'completion': [
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'text-davinci-003', # 4,097 tokens
'text-davinci-002' # 4,097 tokens
'text-curie-001', # 2,049 tokens
'text-babbage-001', # 2,049 tokens
'text-ada-001' # 2,049 tokens
],
'embedding': [
'text-embedding-ada-002' # 8191 tokens, 1536 dimensions
]
}
model_currency = 'USD'
model_prices = {
'gpt-4': {
'prompt': Decimal('0.03'),
'completion': Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': Decimal('0.06'),
'completion': Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': Decimal('0.002'),
'completion': Decimal('0.002')
},
'text-davinci-003': {
'prompt': Decimal('0.02'),
'completion': Decimal('0.02')
},
'text-curie-001': {
'prompt': Decimal('0.002'),
'completion': Decimal('0.002')
},
'text-babbage-001': {
'prompt': Decimal('0.0005'),
'completion': Decimal('0.0005')
},
'text-ada-001': {
'prompt': Decimal('0.0004'),
'completion': Decimal('0.0004')
},
'text-embedding-ada-002': {
'usage': Decimal('0.0004'),
}
}
agent_model_name = 'text-davinci-003'

View File

@@ -0,0 +1,388 @@
import decimal
import json
from typing import Optional, Union
from gunicorn.config import User
from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage
from core.callback_handler.entity.chain_result import ChainResult
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
from core.llm.provider.llm_provider_service import LLMProviderService
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import OutLinePromptTemplate
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
from models.provider import ProviderType, Provider
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool,
conversation: Optional[Conversation] = None, is_override: bool = False):
self.task_id = task_id
self.app = app
self.tenant_id = app.tenant_id
self.app_model_config = app_model_config
self.is_override = is_override
self.user = user
self.inputs = inputs
self.query = query
self.streaming = streaming
self.conversation = conversation
self.is_new_conversation = False
self.message = None
self.model_dict = self.app_model_config.model_dict
self.model_name = self.model_dict.get('name')
self.mode = app.mode
self.init()
self._pub_handler = PubHandler(
user=self.user,
task_id=self.task_id,
message=self.message,
conversation=self.conversation,
chain_pub=False, # disabled currently
agent_thought_pub=False # disabled currently
)
def init(self):
override_model_configs = None
if self.is_override:
override_model_configs = {
"model": self.app_model_config.model_dict,
"pre_prompt": self.app_model_config.pre_prompt,
"agent_mode": self.app_model_config.agent_mode_dict,
"opening_statement": self.app_model_config.opening_statement,
"suggested_questions": self.app_model_config.suggested_questions_list,
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
"more_like_this": self.app_model_config.more_like_this_dict,
"user_input_form": self.app_model_config.user_input_form_list,
}
introduction = ''
system_instruction = ''
system_instruction_tokens = 0
if self.mode == 'chat':
introduction = self.app_model_config.opening_statement
if introduction:
prompt_template = OutLinePromptTemplate.from_template(template=PromptBuilder.process_template(introduction))
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
introduction = prompt_template.format(**prompt_inputs)
if self.app_model_config.pre_prompt:
pre_prompt = PromptBuilder.process_template(self.app_model_config.pre_prompt)
system_message = PromptBuilder.to_system_message(pre_prompt, self.inputs)
system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
system_instruction_tokens = llm.get_messages_tokens([system_message])
if not self.conversation:
self.is_new_conversation = True
self.conversation = Conversation(
app_id=self.app_model_config.app_id,
app_model_config_id=self.app_model_config.id,
model_provider=self.model_dict.get('provider'),
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=self.mode,
name='',
inputs=self.inputs,
introduction=introduction,
system_instruction=system_instruction,
system_instruction_tokens=system_instruction_tokens,
status='normal',
from_source=('console' if isinstance(self.user, Account) else 'api'),
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
)
db.session.add(self.conversation)
db.session.flush()
self.message = Message(
app_id=self.app_model_config.app_id,
model_provider=self.model_dict.get('provider'),
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=self.conversation.id,
inputs=self.inputs,
query=self.query,
message="",
message_tokens=0,
message_unit_price=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
provider_response_latency=0,
total_price=0,
currency=llm_constant.model_currency,
from_source=('console' if isinstance(self.user, Account) else 'api'),
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
)
db.session.add(self.message)
db.session.flush()
def append_message_text(self, text: str):
self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
model_name = self.app_model_config.model_dict.get('name')
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = llm_constant.model_prices[model_name]['prompt']
answer_unit_price = llm_constant.model_prices[model_name]['completion']
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.answer = llm_message.completion.strip() if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
self.message.provider_response_latency = llm_message.latency
self.message.total_price = total_price
self.update_provider_quota()
db.session.commit()
message_was_created.send(
self.message,
conversation=self.conversation,
is_first_message=self.is_new_conversation
)
if not by_stopped:
self._pub_handler.pub_end()
def update_provider_quota(self):
llm_provider_service = LLMProviderService(
tenant_id=self.app.tenant_id,
provider_name=self.message.model_provider,
)
provider = llm_provider_service.get_provider_db_record()
if provider and provider.provider_type == ProviderType.SYSTEM.value:
db.session.query(Provider).filter(
Provider.tenant_id == self.app.tenant_id,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + 1})
def init_chain(self, chain_result: ChainResult):
message_chain = MessageChain(
message_id=self.message.id,
type=chain_result.type,
input=json.dumps(chain_result.prompt),
output=''
)
db.session.add(message_chain)
db.session.flush()
return message_chain
def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
message_chain.output = json.dumps(chain_result.completion)
self._pub_handler.pub_chain(message_chain)
def on_agent_end(self, message_chain: MessageChain, agent_model_name: str,
agent_loop: AgentLoop):
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_total_price = self.calc_total_price(
loop_message_tokens,
agent_message_unit_price,
loop_answer_tokens,
agent_answer_unit_price
)
message_agent_loop = MessageAgentThought(
message_id=self.message.id,
message_chain_id=message_chain.id,
position=agent_loop.position,
thought=agent_loop.thought,
tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input,
observation=agent_loop.tool_output,
tool_process_data='', # currently not support
message=agent_loop.prompt,
message_token=loop_message_tokens,
message_unit_price=agent_message_unit_price,
answer=agent_loop.completion,
answer_token=loop_answer_tokens,
answer_unit_price=agent_answer_unit_price,
latency=agent_loop.latency,
tokens=agent_loop.prompt_tokens + agent_loop.completion_tokens,
total_price=loop_total_price,
currency=llm_constant.model_currency,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_agent_loop)
db.session.flush()
self._pub_handler.pub_agent_thought(message_agent_loop)
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
dataset_query = DatasetQuery(
dataset_id=dataset_query_obj.dataset_id,
content=dataset_query_obj.query,
source='app',
source_app_id=self.app.id,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(dataset_query)
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
class PubHandler:
def __init__(self, user: Union[Account | User], task_id: str,
message: Message, conversation: Conversation,
chain_pub: bool = False, agent_thought_pub: bool = False):
self._channel = PubHandler.generate_channel_name(user, task_id)
self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
self._task_id = task_id
self._message = message
self._conversation = conversation
self._chain_pub = chain_pub
self._agent_thought_pub = agent_thought_pub
@classmethod
def generate_channel_name(cls, user: Union[Account | User], task_id: str):
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
return "generate_result:{}-{}".format(user_str, task_id)
@classmethod
def generate_stopped_cache_key(cls, user: Union[Account | User], task_id: str):
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
return "generate_result_stopped:{}-{}".format(user_str, task_id)
def pub_text(self, text: str):
content = {
'event': 'message',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'text': text,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_chain(self, message_chain: MessageChain):
if self._chain_pub:
content = {
'event': 'chain',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'chain_id': message_chain.id,
'type': message_chain.type,
'input': json.loads(message_chain.input),
'output': json.loads(message_chain.output),
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
if self._agent_thought_pub:
content = {
'event': 'agent_thought',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'chain_id': message_agent_thought.message_chain_id,
'agent_thought_id': message_agent_thought.id,
'position': message_agent_thought.position,
'thought': message_agent_thought.thought,
'tool': message_agent_thought.tool,
'tool_input': message_agent_thought.tool_input,
'observation': message_agent_thought.observation,
'answer': message_agent_thought.answer,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_end(self):
content = {
'event': 'end',
}
redis_client.publish(self._channel, json.dumps(content))
@classmethod
def pub_error(cls, user: Union[Account | User], task_id: str, e):
content = {
'error': type(e).__name__,
'description': e.description if getattr(e, 'description', None) is not None else str(e)
}
channel = cls.generate_channel_name(user, task_id)
redis_client.publish(channel, json.dumps(content))
def _is_stopped(self):
return redis_client.get(self._stopped_cache_key) is not None
@classmethod
def stop(cls, user: Union[Account | User], task_id: str):
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
redis_client.setex(stopped_cache_key, 600, 1)
class ConversationTaskStoppedException(Exception):
pass

View File

@@ -0,0 +1,190 @@
from typing import Any, Dict, Optional, Sequence
import tiktoken
from llama_index.data_structs import Node
from llama_index.docstore.types import BaseDocumentStore
from llama_index.docstore.utils import json_to_doc
from llama_index.schema import BaseDocument
from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
class DatesetDocumentStore(BaseDocumentStore):
def __init__(
self,
dataset: Dataset,
user_id: str,
embedding_model_name: str,
document_id: Optional[str] = None,
):
self._dataset = dataset
self._user_id = user_id
self._embedding_model_name = embedding_model_name
self._document_id = document_id
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatesetDocumentStore":
return cls(**config_dict)
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dict."""
return {
"dataset_id": self._dataset.id,
}
@property
def dateset_id(self) -> Any:
return self._dataset.id
@property
def user_id(self) -> Any:
return self._user_id
@property
def embedding_model_name(self) -> Any:
return self._embedding_model_name
@property
def docs(self) -> Dict[str, BaseDocument]:
document_segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id
).all()
output = {}
for document_segment in document_segments:
doc_id = document_segment.index_node_id
result = self.segment_to_dict(document_segment)
output[doc_id] = json_to_doc(result)
return output
def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True
) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document == self._document_id
).scalar()
if max_position is None:
max_position = 0
for doc in docs:
if doc.is_doc_id_none:
raise ValueError("doc_id not set")
if not isinstance(doc, Node):
raise ValueError("doc must be a Node")
segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
# NOTE: doc could already exist in the store, but we overwrite it
if not allow_update and segment_document:
raise ValueError(
f"doc_id {doc.get_doc_id()} already exists. "
"Set allow_update to True to overwrite."
)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text())
if not segment_document:
max_position += 1
segment_document = DocumentSegment(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
index_node_id=doc.get_doc_id(),
index_node_hash=doc.get_doc_hash(),
position=max_position,
content=doc.get_text(),
word_count=len(doc.get_text()),
tokens=tokens,
created_by=self._user_id,
)
db.session.add(segment_document)
else:
segment_document.content = doc.get_text()
segment_document.index_node_hash = doc.get_doc_hash()
segment_document.word_count = len(doc.get_text())
segment_document.tokens = tokens
db.session.commit()
def document_exists(self, doc_id: str) -> bool:
"""Check if document exists."""
result = self.get_document_segment(doc_id)
return result is not None
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]:
document_segment = self.get_document_segment(doc_id)
if document_segment is None:
if raise_error:
raise ValueError(f"doc_id {doc_id} not found.")
else:
return None
result = self.segment_to_dict(document_segment)
return json_to_doc(result)
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
document_segment = self.get_document_segment(doc_id)
if document_segment is None:
if raise_error:
raise ValueError(f"doc_id {doc_id} not found.")
else:
return None
db.session.delete(document_segment)
db.session.commit()
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
"""Set the hash for a given doc_id."""
document_segment = self.get_document_segment(doc_id)
if document_segment is None:
return None
document_segment.index_node_hash = doc_hash
db.session.commit()
def get_document_hash(self, doc_id: str) -> Optional[str]:
"""Get the stored hash for a document, if it exists."""
document_segment = self.get_document_segment(doc_id)
if document_segment is None:
return None
return document_segment.index_node_hash
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))
def get_document_segment(self, doc_id: str) -> DocumentSegment:
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id,
DocumentSegment.index_node_id == doc_id
).first()
return document_segment
def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
return {
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"text": segment.content,
"__type__": Node.get_type()
}

View File

@@ -0,0 +1,51 @@
from typing import Any, Dict, Optional, Sequence
from llama_index.docstore.types import BaseDocumentStore
from llama_index.schema import BaseDocument
class EmptyDocumentStore(BaseDocumentStore):
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
return cls()
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dict."""
return {}
@property
def docs(self) -> Dict[str, BaseDocument]:
return {}
def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True
) -> None:
pass
def document_exists(self, doc_id: str) -> bool:
"""Check if document exists."""
return False
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]:
return None
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
pass
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
"""Set the hash for a given doc_id."""
pass
def get_document_hash(self, doc_id: str) -> Optional[str]:
"""Get the stored hash for a document, if it exists."""
return None
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))

View File

@@ -0,0 +1,176 @@
from typing import Optional, Any, List
import openai
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
_TEXT_MODE_MODEL_DICT
from tenacity import wait_random_exponential, retry, stop_after_attempt
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding(
text: str,
engine: Optional[str] = None,
openai_api_key: Optional[str] = None,
) -> List[float]:
"""Get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]:
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][
"embedding"
]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
list_of_text: List[str],
engine: Optional[str] = None,
openai_api_key: Optional[str] = None
) -> List[List[float]]:
"""Get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings(
list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None
) -> List[List[float]]:
"""Asynchronously get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
class OpenAIEmbedding(BaseEmbedding):
def __init__(
self,
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
deployment_name: Optional[str] = None,
openai_api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(**kwargs)
self.mode = OpenAIEmbeddingMode(mode)
self.model = OpenAIEmbeddingModelType(model)
self.deployment_name = deployment_name
self.openai_api_key = openai_api_key
@handle_llm_exceptions
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _QUERY_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _QUERY_MODE_MODEL_DICT[key]
return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings.
By default, this is a wrapper around _get_text_embedding.
Can be overriden for batch queries.
"""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
return embeddings
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
return embeddings

View File

@@ -0,0 +1,120 @@
import logging
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_open_ai import StreamableOpenAI
from core.llm.token_calculator import TokenCalculator
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.prompt.prompt_template import OutLinePromptTemplate
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
# gpt-3.5-turbo works not well
generate_base_model = 'text-davinci-003'
class LLMGenerator:
@classmethod
def generate_conversation_name(cls, tenant_id: str, query, answer):
prompt = CONVERSATION_TITLE_PROMPT
prompt = prompt.format(query=query, answer=answer)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
max_tokens=50
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip()
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = TokenCalculator.get_num_tokens(generate_base_model, prompt_with_empty_context)
rest_tokens = llm_constant.max_context_token_length[generate_base_model] - prompt_tokens - max_tokens
context = ''
for message in messages:
if not message.answer:
continue
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
if rest_tokens - TokenCalculator.get_num_tokens(generate_base_model, context + message_qa_text) > 0:
context += message_qa_text
prompt = prompt.format(context=context)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
max_tokens=max_tokens
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip()
@classmethod
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip()
@classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
prompt = OutLinePromptTemplate(
template="{histories}\n{format_instructions}\nquestions:\n",
input_variables=["histories"],
partial_variables={"format_instructions": format_instructions}
)
_input = prompt.format_prompt(histories=histories)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
temperature=0,
max_tokens=256
)
if isinstance(llm, BaseChatModel):
query = [HumanMessage(content=_input.to_string())]
else:
query = _input.to_string()
try:
output = llm(query)
questions = output_parser.parse(output)
except Exception:
logging.exception("Error generating suggested questions after answer")
questions = []
return questions

View File

@@ -0,0 +1,45 @@
from langchain.callbacks import CallbackManager
from llama_index import ServiceContext, PromptHelper, LLMPredictor
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.embedding.openai_embedding import OpenAIEmbedding
from core.llm.llm_builder import LLMBuilder
class IndexBuilder:
@classmethod
def get_default_service_context(cls, tenant_id: str) -> ServiceContext:
# set number of output tokens
num_output = 512
# only for verbose
callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='text-davinci-003',
temperature=0,
max_tokens=num_output,
callback_manager=callback_manager,
)
llm_predictor = LLMPredictor(llm=llm)
# These parameters here will affect the logic of segmenting the final synthesized response.
# The number of refinement iterations in the synthesis process depends
# on whether the length of the segmented output exceeds the max_input_size.
prompt_helper = PromptHelper(
max_input_size=3500,
num_output=num_output,
max_chunk_overlap=20
)
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=tenant_id,
model_name='text-embedding-ada-002'
)
return ServiceContext.from_defaults(
llm_predictor=llm_predictor,
prompt_helper=prompt_helper,
embed_model=OpenAIEmbedding(**model_credentials),
)

View File

@@ -0,0 +1,159 @@
import re
from typing import (
Any,
Dict,
List,
Set,
Optional
)
import jieba.analyse
from core.index.keyword_table.stopwords import STOPWORDS
from llama_index.indices.query.base import IS
from llama_index import QueryMode
from llama_index.indices.base import QueryMap
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery
from llama_index.docstore import BaseDocumentStore
from llama_index.indices.postprocessor.node import (
BaseNodePostprocessor,
)
from llama_index.indices.response.response_builder import ResponseMode
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from core.index.query.synthesizer import EnhanceResponseSynthesizer
def jieba_extract_keywords(
text_chunk: str,
max_keywords: Optional[int] = None,
expand_with_subtokens: bool = True,
) -> Set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text_chunk,
topK=max_keywords,
)
if expand_with_subtokens:
return set(expand_tokens_with_subtokens(keywords))
else:
return set(keywords)
def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex):
"""GPT JIEBA Keyword Table Index.
This index uses a JIEBA keyword extractor to extract keywords from the text.
"""
def _extract_keywords(self, text: str) -> Set[str]:
"""Extract keywords from text."""
return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk)
@classmethod
def get_query_map(self) -> QueryMap:
"""Get query map."""
super_map = super().get_query_map()
super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery
return super_map
def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document."""
# get set of ids that correspond to node
node_idxs_to_delete = {doc_id}
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in self._index_struct.table.items():
if node_idxs_to_delete.intersection(node_idxs):
self._index_struct.table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not self._index_struct.table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del self._index_struct.table[keyword]
class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery):
"""GPT Keyword Table Index JIEBA Query.
Extracts keywords using JIEBA keyword extractor.
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
.. code-block:: python
response = index.query("<query_str>", mode="jieba")
See BaseGPTKeywordTableQuery for arguments.
"""
@classmethod
def from_args(
cls,
index_struct: IS,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_kwargs: Optional[Dict] = None,
use_async: bool = False,
streaming: bool = False,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
# class-specific args
**kwargs: Any,
) -> "BaseGPTIndexQuery":
response_synthesizer = EnhanceResponseSynthesizer.from_args(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
simple_template=simple_template,
response_mode=response_mode,
response_kwargs=response_kwargs,
use_async=use_async,
streaming=streaming,
optimizer=optimizer,
)
return cls(
index_struct=index_struct,
service_context=service_context,
response_synthesizer=response_synthesizer,
docstore=docstore,
node_postprocessors=node_postprocessors,
verbose=verbose,
**kwargs,
)
def _get_keywords(self, query_str: str) -> List[str]:
"""Extract keywords."""
return list(
jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
)

View File

@@ -0,0 +1,90 @@
STOPWORDS = {
"during", "when", "but", "then", "further", "isn", "mustn't", "until", "own", "i", "couldn", "y", "only", "you've",
"ours", "who", "where", "ourselves", "has", "to", "was", "didn't", "themselves", "if", "against", "through", "her",
"an", "your", "can", "those", "didn", "about", "aren't", "shan't", "be", "not", "these", "again", "so", "t",
"theirs", "weren", "won't", "won", "itself", "just", "same", "while", "why", "doesn", "aren", "him", "haven",
"for", "you'll", "that", "we", "am", "d", "by", "having", "wasn't", "than", "weren't", "out", "from", "now",
"their", "too", "hadn", "o", "needn", "most", "it", "under", "needn't", "any", "some", "few", "ll", "hers", "which",
"m", "you're", "off", "other", "had", "she", "you'd", "do", "you", "does", "s", "will", "each", "wouldn't", "hasn't",
"such", "more", "whom", "she's", "my", "yours", "yourself", "of", "on", "very", "hadn't", "with", "yourselves",
"been", "ma", "them", "mightn't", "shan", "mustn", "they", "what", "both", "that'll", "how", "is", "he", "because",
"down", "haven't", "are", "no", "it's", "our", "being", "the", "or", "above", "myself", "once", "don't", "doesn't",
"as", "nor", "here", "herself", "hasn", "mightn", "have", "its", "all", "were", "ain", "this", "at", "after",
"over", "shouldn't", "into", "before", "don", "wouldn", "re", "couldn't", "wasn", "in", "should", "there",
"himself", "isn't", "should've", "doing", "ve", "shouldn", "a", "did", "and", "his", "between", "me", "up", "below",
"人民", "末##末", "", "", "", "哎呀", "哎哟", "", "", "俺们", "", "按照", "", "吧哒", "", "罢了", "", "",
"本着", "", "比方", "比如", "鄙人", "", "彼此", "", "", "别的", "别说", "", "并且", "不比", "不成", "不单", "不但",
"不独", "不管", "不光", "不过", "不仅", "不拘", "不论", "不怕", "不然", "不如", "不特", "不惟", "不问", "不只", "", "朝着",
"", "趁着", "", "", "", "除此之外", "除非", "除了", "", "此间", "此外", "", "从而", "", "", "", "但是", "",
"当着", "", "", "", "的话", "", "等等", "", "", "叮咚", "", "对于", "", "多少", "", "而况", "而且", "而是",
"而外", "而言", "而已", "尔后", "反过来", "反过来说", "反之", "非但", "非徒", "否则", "", "嘎登", "", "", "", "",
"各个", "各位", "各种", "各自", "", "根据", "", "", "故此", "固然", "关于", "", "", "果然", "果真", "", "",
"哈哈", "", "", "", "何处", "何况", "何时", "", "", "哼唷", "呼哧", "", "", "还是", "还有", "换句话说", "换言之",
"", "或是", "或者", "极了", "", "及其", "及至", "", "即便", "即或", "即令", "即若", "即使", "", "几时", "", "",
"既然", "既是", "继而", "加之", "假如", "假若", "假使", "鉴于", "", "", "较之", "", "接着", "结果", "", "紧接着",
"进而", "", "尽管", "", "经过", "", "就是", "就是说", "", "具体地说", "具体说来", "开始", "开外", "", "", "",
"可见", "可是", "可以", "况且", "", "", "来着", "", "例如", "", "", "连同", "两者", "", "", "", "另外",
"另一方面", "", "", "", "慢说", "漫说", "", "", "", "每当", "", "莫若", "", "某个", "某些", "", "",
"哪边", "哪儿", "哪个", "哪里", "哪年", "哪怕", "哪天", "哪些", "哪样", "", "那边", "那儿", "那个", "那会儿", "那里", "那么",
"那么些", "那么样", "那时", "那些", "那样", "", "乃至", "", "", "", "你们", "", "", "宁可", "宁肯", "宁愿", "",
"", "啪达", "旁人", "", "", "凭借", "", "其次", "其二", "其他", "其它", "其一", "其余", "其中", "", "起见", "岂但",
"恰恰相反", "前后", "前者", "", "然而", "然后", "然则", "", "人家", "", "任何", "任凭", "", "如此", "如果", "如何",
"如其", "如若", "如上所述", "", "若非", "若是", "", "上下", "尚且", "设若", "设使", "甚而", "甚么", "甚至", "省得", "时候",
"什么", "什么样", "使得", "", "是的", "首先", "", "谁知", "", "顺着", "似的", "", "虽然", "虽说", "虽则", "", "随着",
"", "所以", "", "他们", "他人", "", "它们", "", "她们", "", "倘或", "倘然", "倘若", "倘使", "", "", "通过", "",
"同时", "", "万一", "", "", "", "为何", "为了", "为什么", "为着", "", "嗡嗡", "", "我们", "", "呜呼", "乌乎",
"无论", "无宁", "毋宁", "", "", "相对而言", "", "", "向着", "", "", "", "沿", "沿着", "", "要不", "要不然",
"要不是", "要么", "要是", "", "也罢", "也好", "", "一般", "一旦", "一方面", "一来", "一切", "一样", "一则", "", "依照",
"", "", "以便", "以及", "以免", "以至", "以至于", "以致", "抑或", "", "因此", "因而", "因为", "", "", "",
"由此可见", "由于", "", "有的", "有关", "有些", "", "", "于是", "于是乎", "", "与此同时", "与否", "与其", "越是",
"云云", "", "再说", "再者", "", "在下", "", "咱们", "", "", "怎么", "怎么办", "怎么样", "怎样", "", "", "照着",
"", "", "这边", "这儿", "这个", "这会儿", "这就是说", "这里", "这么", "这么点儿", "这么些", "这么样", "这时", "这些", "这样",
"正如", "", "", "之类", "之所以", "之一", "只是", "只限", "只要", "只有", "", "至于", "诸位", "", "着呢", "", "自从",
"自个儿", "自各儿", "自己", "自家", "自身", "综上所述", "总的来看", "总的来说", "总的说来", "总而言之", "总之", "", "纵令",
"纵然", "纵使", "遵照", "作为", "", "", "", "", "", "", "", "喔唷", "", "", "", "~", "!", ".", ":",
"\"", "'", "(", ")", "*", "A", "", "社会主义", "--", "..", ">>", " [", " ]", "", "<", ">", "/", "\\", "|", "-", "_",
"+", "=", "&", "^", "%", "#", "@", "`", ";", "$", "", "", "——", "", "", "·", "...", "", "", "", "", "",
" ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "", "", "", "", "", "", "", "︿", "", "", "", "", "", "",
"", "", "", "啊哈", "啊呀", "啊哟", "挨次", "挨个", "挨家挨户", "挨门挨户", "挨门逐户", "挨着", "按理", "按期", "按时",
"按说", "暗地里", "暗中", "暗自", "昂然", "八成", "白白", "", "", "保管", "保险", "", "背地里", "背靠背", "倍感", "倍加",
"本人", "本身", "", "比起", "比如说", "比照", "毕竟", "", "必定", "必将", "必须", "便", "别人", "并非", "并肩", "并没",
"并没有", "并排", "并无", "勃然", "", "不必", "不常", "不大", "不但...而且", "不得", "不得不", "不得了", "不得已", "不迭",
"不定", "不对", "不妨", "不管怎样", "不会", "不仅...而且", "不仅仅", "不仅仅是", "不经意", "不可开交", "不可抗拒", "不力", "不了",
"不料", "不满", "不免", "不能不", "不起", "不巧", "不然的话", "不日", "不少", "不胜", "不时", "不是", "不同", "不能", "不要",
"不外", "不外乎", "不下", "不限", "不消", "不已", "不亦乐乎", "不由得", "不再", "不择手段", "不怎么", "不曾", "不知不觉", "不止",
"不止一次", "不至于", "", "才能", "策略地", "差不多", "差一点", "", "常常", "常言道", "常言说", "常言说得好", "长此下去",
"长话短说", "长期以来", "长线", "敞开儿", "彻夜", "陈年", "趁便", "趁机", "趁热", "趁势", "趁早", "成年", "成年累月", "成心",
"乘机", "乘胜", "乘势", "乘隙", "乘虚", "诚然", "迟早", "充分", "充其极", "充其量", "抽冷子", "", "", "", "出来", "出去",
"除此", "除此而外", "除此以外", "除开", "除去", "除却", "除外", "处处", "川流不息", "", "传说", "传闻", "串行", "", "纯粹",
"此后", "此中", "次第", "匆匆", "从不", "从此", "从此以后", "从古到今", "从古至今", "从今以后", "从宽", "从来", "从轻", "从速",
"从头", "从未", "从无到有", "从小", "从新", "从严", "从优", "从早到晚", "从中", "从重", "凑巧", "", "存心", "达旦", "打从",
"打开天窗说亮话", "", "大不了", "大大", "大抵", "大都", "大多", "大凡", "大概", "大家", "大举", "大略", "大面儿上", "大事",
"大体", "大体上", "大约", "大张旗鼓", "大致", "呆呆地", "", "", "待到", "", "单纯", "单单", "但愿", "弹指之间", "当场",
"当儿", "当即", "当口儿", "当然", "当庭", "当头", "当下", "当真", "当中", "倒不如", "倒不如说", "倒是", "到处", "到底", "到了儿",
"到目前为止", "到头", "到头来", "得起", "得天独厚", "的确", "等到", "叮当", "顶多", "", "动不动", "动辄", "陡然", "", "",
"独自", "断然", "顿时", "多次", "多多", "多多少少", "多多益善", "多亏", "多年来", "多年前", "而后", "而论", "而又", "尔等",
"二话不说", "二话没说", "反倒", "反倒是", "反而", "反手", "反之亦然", "反之则", "", "方才", "方能", "放量", "非常", "非得",
"分期", "分期分批", "分头", "奋勇", "愤然", "风雨无阻", "", "", "", "嘎嘎", "该当", "", "赶快", "赶早不赶晚", "",
"敢情", "敢于", "", "刚才", "刚好", "刚巧", "高低", "格外", "隔日", "隔夜", "个人", "各式", "", "更加", "更进一步", "更为",
"公然", "", "共总", "够瞧的", "姑且", "古来", "故而", "故意", "", "", "怪不得", "惯常", "", "光是", "归根到底",
"归根结底", "过于", "毫不", "毫无", "毫无保留地", "毫无例外", "好在", "何必", "何尝", "何妨", "何苦", "何乐而不为", "何须",
"何止", "", "很多", "很少", "轰然", "后来", "呼啦", "忽地", "忽然", "", "互相", "哗啦", "话说", "", "恍然", "", "豁然",
"", "伙同", "或多或少", "或许", "基本", "基本上", "基于", "", "极大", "极度", "极端", "极力", "极其", "极为", "急匆匆",
"即将", "即刻", "即是说", "几度", "几番", "几乎", "几经", "既...又", "继之", "加上", "加以", "间或", "简而言之", "简言之",
"简直", "", "将才", "将近", "将要", "交口", "较比", "较为", "接连不断", "接下来", "皆可", "截然", "截至", "藉以", "借此",
"借以", "届时", "", "仅仅", "", "进来", "进去", "", "近几年来", "近来", "近年来", "尽管如此", "尽可能", "尽快", "尽量",
"尽然", "尽如人意", "尽心竭力", "尽心尽力", "尽早", "精光", "经常", "", "竟然", "究竟", "就此", "就地", "就算", "居然", "局外",
"举凡", "据称", "据此", "据实", "据说", "据我所知", "据悉", "具体来说", "决不", "决非", "", "绝不", "绝顶", "绝对", "绝非",
"", "", "", "看来", "看起来", "看上去", "看样子", "可好", "可能", "恐怕", "", "快要", "来不及", "来得及", "来讲",
"来看", "拦腰", "牢牢", "", "老大", "老老实实", "老是", "累次", "累年", "理当", "理该", "理应", "", "", "立地", "立刻",
"立马", "立时", "联袂", "连连", "连日", "连日来", "连声", "连袂", "临到", "另方面", "另行", "另一个", "路经", "", "屡次",
"屡次三番", "屡屡", "缕缕", "率尔", "率然", "", "略加", "略微", "略为", "论说", "马上", "", "", "", "没有", "每逢",
"每每", "每时每刻", "猛然", "猛然间", "", "莫不", "莫非", "莫如", "默默地", "默然", "", "那末", "", "难道", "难得", "难怪",
"难说", "", "年复一年", "凝神", "偶而", "偶尔", "", "", "碰巧", "譬如", "偏偏", "", "平素", "", "迫于", "扑通",
"其后", "其实", "", "", "起初", "起来", "起首", "起头", "起先", "", "岂非", "岂止", "", "恰逢", "恰好", "恰恰", "恰巧",
"恰如", "恰似", "", "千万", "千万千万", "", "切不可", "切莫", "切切", "切勿", "", "亲口", "亲身", "亲手", "亲眼", "亲自",
"", "顷刻", "顷刻间", "顷刻之间", "请勿", "穷年累月", "取道", "", "权时", "全都", "全力", "全年", "全然", "全身心", "",
"人人", "", "仍旧", "仍然", "日复一日", "日见", "日渐", "日益", "日臻", "如常", "如此等等", "如次", "如今", "如期", "如前所述",
"如上", "如下", "", "三番两次", "三番五次", "三天两头", "瑟瑟", "沙沙", "", "上来", "上去", "一个", "", "", "\n"
}

View File

@@ -0,0 +1,135 @@
import json
from typing import List, Optional
from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding
from llama_index.data_structs import KeywordTable, Node
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.registry import load_index_struct_from_dict
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.index_builder import IndexBuilder
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
class KeywordTableIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
index_struct = KeywordTable()
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node in nodes:
keywords = index._extract_keywords(node.get_text())
self.update_segment_keywords(node.doc_id, list(keywords))
index._index_struct.add_node(list(keywords), node)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
def del_nodes(self, node_ids: List[str]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node_id in node_ids:
index.delete(node_id)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
@property
def query_index(self) -> Optional[BaseGPTKeywordTableIndex]:
docstore = DatesetDocumentStore(
dataset=self._dataset,
user_id=self._dataset.created_by,
embedding_model_name="text-embedding-ada-002"
)
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return None
index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict)
return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context)
def get_keyword_table(self):
dataset_keyword_table = self._dataset.dataset_keyword_table
if dataset_keyword_table:
return dataset_keyword_table
return None
def update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()

View File

@@ -0,0 +1,79 @@
from typing import (
Any,
Dict,
Optional, Sequence,
)
from llama_index.indices.response.response_synthesis import ResponseSynthesizer
from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from llama_index.types import RESPONSE_TEXT_TYPE
class EnhanceResponseSynthesizer(ResponseSynthesizer):
@classmethod
def from_args(
cls,
service_context: ServiceContext,
streaming: bool = False,
use_async: bool = False,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_mode: ResponseMode = ResponseMode.DEFAULT,
response_kwargs: Optional[Dict] = None,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
) -> "ResponseSynthesizer":
response_builder: Optional[BaseResponseBuilder] = None
if response_mode != ResponseMode.NO_TEXT:
if response_mode == 'no_synthesizer':
response_builder = NoSynthesizer(
service_context=service_context,
simple_template=simple_template,
streaming=streaming,
)
else:
response_builder = get_response_builder(
service_context,
text_qa_template,
refine_template,
simple_template,
response_mode,
use_async=use_async,
streaming=streaming,
)
return cls(response_builder, response_mode, response_kwargs, optimizer)
class NoSynthesizer(BaseResponseBuilder):
def __init__(
self,
service_context: ServiceContext,
simple_template: Optional[SimpleInputPrompt] = None,
streaming: bool = False,
) -> None:
super().__init__(service_context, streaming)
async def aget_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)

View File

@@ -0,0 +1,22 @@
from pathlib import Path
from typing import Dict
from bs4 import BeautifulSoup
from llama_index.readers.file.base_parser import BaseParser
class HTMLParser(BaseParser):
"""HTML parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
with open(file, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text

View File

@@ -0,0 +1,56 @@
from pathlib import Path
from typing import Dict
from flask import current_app
from llama_index.readers.file.base_parser import BaseParser
from pypdf import PdfReader
from extensions.ext_storage import storage
from models.model import UploadFile
class PDFParser(BaseParser):
"""PDF parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
if not current_app.config.get('PDF_PREVIEW', True):
return ''
plaintext_file_key = ''
plaintext_file_exists = False
if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']:
upload_file: UploadFile = self._parser_config['upload_file']
if upload_file.hash:
plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return text
except FileNotFoundError:
pass
text_list = []
with open(file, "rb") as fp:
# Create a PDF object
pdf = PdfReader(fp)
# Get the number of pages in the PDF document
num_pages = len(pdf.pages)
# Iterate over every page
for page in range(num_pages):
# Extract the text from the page
page_text = pdf.pages[page].extract_text()
text_list.append(page_text)
text = "\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return text

View File

@@ -0,0 +1,136 @@
import json
import logging
from typing import List, Optional
from llama_index.data_structs import Node
from requests import ReadTimeout
from sqlalchemy.exc import IntegrityError
from tenacity import retry, stop_after_attempt, retry_if_exception_type
from core.index.index_builder import IndexBuilder
from core.vector_store.base import BaseGPTVectorStoreIndex
from extensions.ext_vector_store import vector_store
from extensions.ext_database import db
from models.dataset import Dataset, Embedding
class VectorIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
if not self._dataset.index_struct_dict:
index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
db.session.commit()
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
if duplicate_check:
nodes = self._filter_duplicate_nodes(index, nodes)
embedding_queue_nodes = []
embedded_nodes = []
for node in nodes:
node_hash = node.doc_hash
# if node hash in cached embedding tables, use cached embedding
embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
if embedding:
node.embedding = embedding.get_embedding()
embedded_nodes.append(node)
else:
embedding_queue_nodes.append(node)
if embedding_queue_nodes:
embedding_results = index._get_node_embedding_results(
embedding_queue_nodes,
set(),
)
# pre embed nodes for cached embedding
for embedding_result in embedding_results:
node = embedding_result.node
node.embedding = embedding_result.embedding
try:
embedding = Embedding(hash=node.doc_hash)
embedding.set_embedding(node.embedding)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
embedded_nodes.append(node)
self.index_insert_nodes(index, embedded_nodes)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
index.insert_nodes(nodes)
def del_nodes(self, node_ids: List[str]):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
for node_id in node_ids:
self.index_delete_node(index, node_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
index.delete_node(node_id)
def del_doc(self, doc_id: str):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
self.index_delete_doc(index, doc_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
index.delete(doc_id)
@property
def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
if not self._dataset.index_struct_dict:
return None
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
return vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
for node in nodes:
node_id = node.doc_id
exists_duplicate_node = index.exists_by_node_id(node_id)
if exists_duplicate_node:
nodes.remove(node)
return nodes

467
api/core/indexing_runner.py Normal file
View File

@@ -0,0 +1,467 @@
import datetime
import json
import re
import tempfile
import time
from pathlib import Path
from typing import Optional, List
from langchain.text_splitter import RecursiveCharacterTextSplitter
from llama_index import SimpleDirectoryReader
from llama_index.data_structs import Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.node_parser import SimpleNodeParser, NodeParser
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
from llama_index.readers.file.markdown_parser import MarkdownParser
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.index.keyword_table_index import KeywordTableIndex
from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser
from core.index.vector_index import VectorIndex
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule
from models.model import UploadFile
class IndexingRunner:
def __init__(self, embedding_model_name: str = "text-embedding-ada-002"):
self.storage = storage
self.embedding_model_name = embedding_model_name
def run(self, document: Document):
"""Run the indexing process."""
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# load file
text_docs = self._load_data(document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
first()
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# split to nodes
nodes = self._step_split(
text_docs=text_docs,
node_parser=node_parser,
dataset=dataset,
document=document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
)
def run_in_splitting_status(self, document: Document):
"""Run the indexing process when the index_status is splitting."""
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=document.id
).all()
db.session.delete(document_segments)
db.session.commit()
# load file
text_docs = self._load_data(document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
first()
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# split to nodes
nodes = self._step_split(
text_docs=text_docs,
node_parser=node_parser,
dataset=dataset,
document=document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
)
def run_in_indexing_status(self, document: Document):
"""Run the indexing process when the index_status is indexing."""
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=document.id
).all()
nodes = []
if document_segments:
for document_segment in document_segments:
# transform segment to node
if document_segment.status != "completed":
relationships = {
DocumentRelationship.SOURCE: document_segment.document_id,
}
previous_segment = document_segment.previous_segment
if previous_segment:
relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
next_segment = document_segment.next_segment
if next_segment:
relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
node = Node(
doc_id=document_segment.index_node_id,
doc_hash=document_segment.index_node_hash,
text=document_segment.content,
extra_info=None,
node_info=None,
relationships=relationships
)
nodes.append(node)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
)
def indexing_estimate(self, file_detail: UploadFile, tmp_processing_rule: dict) -> dict:
"""
Estimate the indexing for the document.
"""
# load data from file
text_docs = self._load_data_from_file(file_detail)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
text_docs=text_docs,
node_parser=node_parser,
processing_rule=processing_rule
)
tokens = 0
preview_texts = []
for node in nodes:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
return {
"total_segments": len(nodes),
"tokens": tokens,
"total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)),
"currency": TokenCalculator.get_currency(self.embedding_model_name),
"preview": preview_texts
}
def _load_data(self, document: Document) -> List[Document]:
# load file
if document.data_source_type != "upload_file":
return []
data_source_info = document.data_source_info_dict
if not data_source_info or 'upload_file_id' not in data_source_info:
raise ValueError("no upload file found")
file_detail = db.session.query(UploadFile). \
filter(UploadFile.id == data_source_info['upload_file_id']). \
one_or_none()
text_docs = self._load_data_from_file(file_detail)
# update document status to splitting
self._update_document_index_status(
document_id=document.id,
after_indexing_status="splitting",
extra_update_params={
Document.file_id: file_detail.id,
Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
Document.parsing_completed_at: datetime.datetime.utcnow()
}
)
# replace doc id to document model id
for text_doc in text_docs:
# remove invalid symbol
text_doc.text = self.filter_string(text_doc.get_text())
text_doc.doc_id = document.id
return text_docs
def filter_string(self, text):
pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
return pattern.sub('', text)
def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
self.storage.download(upload_file.key, filepath)
file_extractor = DEFAULT_FILE_EXTRACTOR.copy()
file_extractor[".markdown"] = MarkdownParser()
file_extractor[".html"] = HTMLParser()
file_extractor[".htm"] = HTMLParser()
file_extractor[".pdf"] = PDFParser({'upload_file': upload_file})
loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor)
text_docs = loader.load_data()
return text_docs
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
"""
Get the NodeParser object according to the processing rule.
"""
if processing_rule.mode == "custom":
# The user-defined segmentation rule
rules = json.loads(processing_rule.rules)
segmentation = rules["segmentation"]
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
raise ValueError("Custom segment length should be between 50 and 1000.")
separator = segmentation["separator"]
if not separator:
separators = ["\n\n", "", ".", " ", ""]
else:
separator = separator.replace('\\n', '\n')
separators = [separator, ""]
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=0,
separators=separators
)
else:
# Automatic segmentation
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=0,
separators=["\n\n", "", ".", " ", ""]
)
return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True)
def _step_split(self, text_docs: List[Document], node_parser: NodeParser,
dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]:
"""
Split the text documents into nodes and save them to the document segment.
"""
nodes = self._split_to_nodes(
text_docs=text_docs,
node_parser=node_parser,
processing_rule=processing_rule
)
# save node to document segment
doc_store = DatesetDocumentStore(
dataset=dataset,
user_id=document.created_by,
embedding_model_name=self.embedding_model_name,
document_id=document.id
)
doc_store.add_documents(nodes)
# update document status to indexing
cur_time = datetime.datetime.utcnow()
self._update_document_index_status(
document_id=document.id,
after_indexing_status="indexing",
extra_update_params={
Document.cleaning_completed_at: cur_time,
Document.splitting_completed_at: cur_time,
}
)
# update segment status to indexing
self._update_segments_by_document(
document_id=document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.utcnow()
}
)
return nodes
def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser,
processing_rule: DatasetProcessRule) -> List[Node]:
"""
Split the text documents into nodes.
"""
all_nodes = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.get_text(), processing_rule)
text_doc.text = document_text
# parse document to nodes
nodes = node_parser.get_nodes_from_documents([text_doc])
all_nodes.extend(nodes)
return all_nodes
def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
"""
Clean the document text according to the processing rules.
"""
if processing_rule.mode == "automatic":
rules = DatasetProcessRule.AUTOMATIC_RULES
else:
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
if 'pre_processing_rules' in rules:
pre_processing_rules = rules["pre_processing_rules"]
for pre_processing_rule in pre_processing_rules:
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
# Remove extra spaces
pattern = r'\n{3,}'
text = re.sub(pattern, '\n\n', text)
pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
text = re.sub(pattern, ' ', text)
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
# Remove email
pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
text = re.sub(pattern, '', text)
# Remove URL
pattern = r'https?://[^\s]+'
text = re.sub(pattern, '', text)
return text
def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None:
"""
Build the index for the document.
"""
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
chunk_size = 100
for i in range(0, len(nodes), chunk_size):
# check document is paused
self._check_document_paused_status(document.id)
chunk_nodes = nodes[i:i + chunk_size]
tokens += sum(
TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes
)
# save vector index
if dataset.indexing_technique == "high_quality":
vector_index.add_nodes(chunk_nodes)
# save keyword index
keyword_table_index.add_nodes(chunk_nodes)
node_ids = [node.doc_id for node in chunk_nodes]
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id,
DocumentSegment.index_node_id.in_(node_ids),
DocumentSegment.status == "indexing"
).update({
DocumentSegment.status: "completed",
DocumentSegment.completed_at: datetime.datetime.utcnow()
})
db.session.commit()
indexing_end_at = time.perf_counter()
# update document status to completed
self._update_document_index_status(
document_id=document.id,
after_indexing_status="completed",
extra_update_params={
Document.tokens: tokens,
Document.completed_at: datetime.datetime.utcnow(),
Document.indexing_latency: indexing_end_at - indexing_start_at,
}
)
def _check_document_paused_status(self, document_id: str):
indexing_cache_key = 'document_{}_is_paused'.format(document_id)
result = redis_client.get(indexing_cache_key)
if result:
raise DocumentIsPausedException()
def _update_document_index_status(self, document_id: str, after_indexing_status: str,
extra_update_params: Optional[dict] = None) -> None:
"""
Update the document indexing status.
"""
count = Document.query.filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedException()
update_params = {
Document.indexing_status: after_indexing_status
}
if extra_update_params:
update_params.update(extra_update_params)
Document.query.filter_by(id=document_id).update(update_params)
db.session.commit()
def _update_segments_by_document(self, document_id: str, update_params: dict) -> None:
"""
Update the document segment by document id.
"""
DocumentSegment.query.filter_by(document_id=document_id).update(update_params)
db.session.commit()
class DocumentIsPausedException(Exception):
pass

55
api/core/llm/error.py Normal file
View File

@@ -0,0 +1,55 @@
from typing import Optional
class LLMError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
self.description = description
class LLMBadRequestError(LLMError):
"""Raised when the LLM returns bad request."""
description = "Bad Request"
class LLMAPIConnectionError(LLMError):
"""Raised when the LLM returns API connection error."""
description = "API Connection Error"
class LLMAPIUnavailableError(LLMError):
"""Raised when the LLM returns API unavailable error."""
description = "API Unavailable Error"
class LLMRateLimitError(LLMError):
"""Raised when the LLM returns rate limit error."""
description = "Rate Limit Error"
class LLMAuthorizationError(LLMError):
"""Raised when the LLM returns authorization error."""
description = "Authorization Error"
class ProviderTokenNotInitError(Exception):
"""
Custom exception raised when the provider token is not initialized.
"""
description = "Provider Token Not Init"
class QuotaExceededError(Exception):
"""
Custom exception raised when the quota for a provider has been exceeded.
"""
description = "Quota Exceeded"
class ModelCurrentlyNotSupportError(Exception):
"""
Custom exception raised when the model not support
"""
description = "Model Currently Not Support"

View File

@@ -0,0 +1,51 @@
import logging
from functools import wraps
import openai
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_llm_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except openai.error.InvalidRequestError as e:
logging.exception("Invalid request to OpenAI API.")
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
return wrapper
def handle_llm_exceptions_async(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except openai.error.InvalidRequestError as e:
logging.exception("Invalid request to OpenAI API.")
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
return wrapper

103
api/core/llm/llm_builder.py Normal file
View File

@@ -0,0 +1,103 @@
from typing import Union, Optional
from langchain.callbacks import CallbackManager
from langchain.llms.fake import FakeListLLM
from core.constant import llm_constant
from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
class LLMBuilder:
"""
This class handles the following logic:
1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
OPENAI_API_TYPE=azure
OPENAI_API_VERSION=2022-12-01
OPENAI_API_BASE=https://your-resource-name.openai.azure.com
OPENAI_API_KEY=<your Azure OpenAI API key>
3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
"""
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
if model_name == 'fake':
return FakeListLLM(responses=[])
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
# llm_cls = StreamableAzureChatOpenAI
llm_cls = StreamableChatOpenAI
elif mode == 'completion':
llm_cls = StreamableOpenAI
else:
raise ValueError(f"model name {model_name} is not supported.")
model_credentials = cls.get_model_credentials(tenant_id, model_name)
return llm_cls(
model_name=model_name,
temperature=kwargs.get('temperature', 0),
max_tokens=kwargs.get('max_tokens', 256),
top_p=kwargs.get('top_p', 1),
frequency_penalty=kwargs.get('frequency_penalty', 0),
presence_penalty=kwargs.get('presence_penalty', 0),
callback_manager=kwargs.get('callback_manager', None),
streaming=kwargs.get('streaming', False),
# request_timeout=None
**model_credentials
)
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name")
completion_params = model.get("completion_params", {})
return cls.to_llm(
tenant_id=tenant_id,
model_name=model_name,
temperature=completion_params.get('temperature', 0),
max_tokens=completion_params.get('max_tokens', 256),
top_p=completion_params.get('top_p', 0),
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming,
callback_manager=callback_manager
)
@classmethod
def get_mode_by_model(cls, model_name: str) -> str:
if not model_name:
raise ValueError(f"empty model name is not supported.")
if model_name in llm_constant.models_by_mode['chat']:
return "chat"
elif model_name in llm_constant.models_by_mode['completion']:
return "completion"
else:
raise ValueError(f"model name {model_name} is not supported.")
@classmethod
def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict:
"""
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Raises an exception if the model_name is not found or if the provider is not found.
"""
if not model_name:
raise Exception('model name not found')
if model_name not in llm_constant.models:
raise Exception('model {} not found'.format(model_name))
model_provider = llm_constant.models[model_name]
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
return provider_service.get_credentials(model_name)

View File

@@ -0,0 +1,15 @@
import openai
from models.provider import ProviderName
class Moderation:
def __init__(self, provider: str, api_key: str):
self.provider = provider
self.api_key = api_key
if self.provider == ProviderName.OPENAI.value:
self.client = openai.Moderation
def moderate(self, text):
return self.client.create(input=text, api_key=self.api_key)

View File

@@ -0,0 +1,23 @@
from typing import Optional
from core.llm.provider.base import BaseProvider
from models.provider import ProviderName
class AnthropicProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
# todo
return []
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
"""
return {
'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self):
return ProviderName.ANTHROPIC

View File

@@ -0,0 +1,105 @@
import json
from typing import Optional, Union
import requests
from core.llm.provider.base import BaseProvider
from models.provider import ProviderName
class AzureProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
url = "{}/openai/deployments?api-version={}".format(
credentials.get('openai_api_base'),
credentials.get('openai_api_version')
)
headers = {
"api-key": credentials.get('openai_api_key'),
"content-type": "application/json; charset=utf-8"
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
result = response.json()
return [{
'id': deployment['id'],
'name': '{} ({})'.format(deployment['id'], deployment['model'])
} for deployment in result['data'] if deployment['status'] == 'succeeded']
else:
# TODO: optimize in future
raise Exception('Failed to get deployments from Azure OpenAI. Status code: {}'.format(response.status_code))
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Azure OpenAI as a dictionary.
"""
encrypted_config = self.get_provider_api_key(model_id=model_id)
config = json.loads(encrypted_config)
config['openai_api_type'] = 'azure'
config['deployment_name'] = model_id
return config
def get_provider_name(self):
return ProviderName.AZURE_OPENAI
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key()
config = json.loads(config)
except:
config = {
'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview',
'openai_api_base': 'https://foo.microsoft.com/bar',
'openai_api_key': ''
}
if obfuscated:
if not config.get('openai_api_key'):
config = {
'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview',
'openai_api_base': 'https://foo.microsoft.com/bar',
'openai_api_key': ''
}
config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
return config
return config
def get_token_type(self):
# TODO: change to dict when implemented
return lambda value: value
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
# TODO: implement
pass
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview',
'openai_api_base': config['openai_api_base'],
'openai_api_key': self.encrypt_token(config['openai_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
return config

View File

@@ -0,0 +1,124 @@
import base64
from abc import ABC, abstractmethod
from typing import Optional, Union
from core import hosted_llm_credentials
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.provider import Provider, ProviderType, ProviderName
class BaseProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str:
"""
Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
"""
provider = self.get_provider(prefer_custom)
if not provider:
raise ProviderTokenNotInitError()
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if model_id and model_id == 'gpt-4':
raise ModelCurrentlyNotSupportError()
if quota_used >= quota_limit:
raise QuotaExceededError()
return self.get_hosted_credentials()
else:
return self.get_decrypted_token(provider.encrypted_config)
def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
providers = db.session.query(Provider).filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.get_provider_name().value
).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
custom_provider = None
system_provider = None
for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value:
custom_provider = provider
elif provider.provider_type == ProviderType.SYSTEM.value:
system_provider = provider
if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config:
return custom_provider
elif system_provider and system_provider.is_valid:
return system_provider
else:
return None
def get_hosted_credentials(self) -> str:
if self.get_provider_name() != ProviderName.OPENAI:
raise ProviderTokenNotInitError()
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
raise ProviderTokenNotInitError()
return hosted_llm_credentials.openai.api_key
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key()
except:
config = 'THIS-IS-A-MOCK-TOKEN'
if obfuscated:
return self.obfuscated_token(config)
return config
def obfuscated_token(self, token: str):
return token[:6] + '*' * (len(token) - 8) + token[-2:]
def get_token_type(self):
return str
def get_encrypted_token(self, config: Union[dict | str]):
return self.encrypt_token(config)
def get_decrypted_token(self, token: str):
return self.decrypt_token(token)
def encrypt_token(self, token):
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token):
return rsa.decrypt(base64.b64decode(token), self.tenant_id)
@abstractmethod
def get_provider_name(self):
raise NotImplementedError
@abstractmethod
def get_credentials(self, model_id: Optional[str] = None) -> dict:
raise NotImplementedError
@abstractmethod
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
raise NotImplementedError
@abstractmethod
def config_validate(self, config: str):
raise NotImplementedError

View File

@@ -0,0 +1,2 @@
class ValidateFailedError(Exception):
description = "Provider Validate failed"

View File

@@ -0,0 +1,22 @@
from typing import Optional
from core.llm.provider.base import BaseProvider
from models.provider import ProviderName
class HuggingfaceProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
# todo
return []
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
"""
return {
'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self):
return ProviderName.HUGGINGFACEHUB

View File

@@ -0,0 +1,53 @@
from typing import Optional, Union
from core.llm.provider.anthropic_provider import AnthropicProvider
from core.llm.provider.azure_provider import AzureProvider
from core.llm.provider.base import BaseProvider
from core.llm.provider.huggingface_provider import HuggingfaceProvider
from core.llm.provider.openai_provider import OpenAIProvider
from models.provider import Provider
class LLMProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self.init_provider(tenant_id, provider_name)
def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
if provider_name == 'openai':
return OpenAIProvider(tenant_id)
elif provider_name == 'azure_openai':
return AzureProvider(tenant_id)
elif provider_name == 'anthropic':
return AnthropicProvider(tenant_id)
elif provider_name == 'huggingface':
return HuggingfaceProvider(tenant_id)
else:
raise Exception('provider {} not found'.format(provider_name))
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
return self.provider.get_models(model_id)
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.provider.get_credentials(model_id)
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated)
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
return self.provider.get_provider(prefer_custom)
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
:param config:
:raises: ValidateFailedError
"""
return self.provider.config_validate(config)
def get_token_type(self):
return self.provider.get_token_type()
def get_encrypted_token(self, config: Union[dict | str]):
return self.provider.get_encrypted_token(config)

View File

@@ -0,0 +1,44 @@
import logging
from typing import Optional, Union
import openai
from openai.error import AuthenticationError, OpenAIError
from core.llm.moderation import Moderation
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName
class OpenAIProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
response = openai.Model.list(**credentials)
return [{
'id': model['id'],
'name': model['id'],
} for model in response['data']]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the credentials for the given tenant_id and provider_name.
"""
return {
'openai_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self):
return ProviderName.OPENAI
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
try:
Moderation(self.get_provider_name().value, config).moderate('test')
except (AuthenticationError, OpenAIError) as ex:
raise ValidateFailedError(str(ex))
except Exception as ex:
logging.exception('OpenAI config validation failed')
raise ex

View File

@@ -0,0 +1,89 @@
import requests
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
class StreamableAzureChatOpenAI(AzureChatOpenAI):
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message = 5
tokens_per_request = 3
message_tokens = tokens_per_request
message_strs = ''
for message in messages:
message_strs += message.content
message_tokens += tokens_per_message
# calc once
message_tokens += self.get_num_tokens(message_strs)
return message_tokens
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
@handle_llm_exceptions
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult:
return super().generate(messages, stop)
@handle_llm_exceptions_async
async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult:
return await super().agenerate(messages, stop)

View File

@@ -0,0 +1,86 @@
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import ChatOpenAI
from typing import Optional, List
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
class StreamableChatOpenAI(ChatOpenAI):
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message = 5
tokens_per_request = 3
message_tokens = tokens_per_request
message_strs = ''
for message in messages:
message_strs += message.content
message_tokens += tokens_per_message
# calc once
message_tokens += self.get_num_tokens(message_strs)
return message_tokens
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
@handle_llm_exceptions
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult:
return super().generate(messages, stop)
@handle_llm_exceptions_async
async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult:
return await super().agenerate(messages, stop)

View File

@@ -0,0 +1,20 @@
from langchain.schema import LLMResult
from typing import Optional, List
from langchain import OpenAI
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
class StreamableOpenAI(OpenAI):
@handle_llm_exceptions
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return super().generate(prompts, stop)
@handle_llm_exceptions_async
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return await super().agenerate(prompts, stop)

View File

@@ -0,0 +1,41 @@
import decimal
from typing import Optional
import tiktoken
from core.constant import llm_constant
class TokenCalculator:
@classmethod
def get_num_tokens(cls, model_name: str, text: str):
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(model_name)
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
@classmethod
def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal:
if model_name in llm_constant.models_by_mode['embedding']:
unit_price = llm_constant.model_prices[model_name]['usage']
elif text_type == 'prompt':
unit_price = llm_constant.model_prices[model_name]['prompt']
elif text_type == 'completion':
unit_price = llm_constant.model_prices[model_name]['completion']
else:
raise Exception('Invalid text type')
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
@classmethod
def get_currency(cls, model_name: str):
return llm_constant.model_currency

View File

@@ -0,0 +1,77 @@
from typing import Any, List, Dict, Union
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from extensions.ext_database import db
from models.model import Conversation, Message
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: Union[StreamableChatOpenAI | StreamableOpenAI]
memory_key: str = "chat_history"
max_token_limit: int = 2000
message_limit: int = 10
@property
def buffer(self) -> List[BaseMessage]:
"""String buffer of memory."""
# fetch limited messages desc, and return reversed
messages = db.session.query(Message).filter(
Message.conversation_id == self.conversation.id,
Message.answer_tokens > 0
).order_by(Message.created_at.desc()).limit(self.message_limit).all()
messages = list(reversed(messages))
chat_messages: List[BaseMessage] = []
for message in messages:
chat_messages.append(HumanMessage(content=message.query))
chat_messages.append(AIMessage(content=message.answer))
if not chat_messages:
return chat_messages
# prune the chat message if it exceeds the max token limit
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
return chat_messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
buffer: Any = self.buffer
if self.return_messages:
final_buffer: Any = buffer
else:
final_buffer = get_buffer_string(
buffer,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
return {self.memory_key: final_buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Nothing should be saved or changed"""
pass
def clear(self) -> None:
"""Nothing to clear, got a memory like a vault."""
pass

View File

@@ -0,0 +1,36 @@
from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
class ReadOnlyConversationTokenDBStringBufferSharedMemory(BaseChatMemory):
memory: ReadOnlyConversationTokenDBBufferSharedMemory
@property
def memory_variables(self) -> List[str]:
"""Return memory variables."""
return self.memory.memory_variables
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Load memory variables from memory."""
buffer: Any = self.memory.buffer
final_buffer = get_buffer_string(
buffer,
human_prefix=self.memory.human_prefix,
ai_prefix=self.memory.ai_prefix,
)
return {self.memory.memory_key: final_buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Nothing should be saved or changed"""
pass
def clear(self) -> None:
"""Nothing to clear, got a memory like a vault."""
pass

View File

@@ -0,0 +1,16 @@
import json
from typing import Any
from langchain.schema import BaseOutputParser
from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
def get_format_instructions(self) -> str:
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
def parse(self, text: str) -> Any:
json_string = text.strip()
json_obj = json.loads(json_string)
return json_obj

View File

@@ -0,0 +1,37 @@
import re
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
from langchain.schema import BaseMessage
from core.prompt.prompt_template import OutLinePromptTemplate
class PromptBuilder:
@classmethod
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = OutLinePromptTemplate.from_template(prompt_content)
system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
system_message = system_prompt_template.format(**prompt_inputs)
return system_message
@classmethod
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = OutLinePromptTemplate.from_template(prompt_content)
ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
ai_message = ai_prompt_template.format(**prompt_inputs)
return ai_message
@classmethod
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = OutLinePromptTemplate.from_template(prompt_content)
human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
human_message = human_prompt_template.format(**inputs)
return human_message
@classmethod
def process_template(cls, template: str):
processed_template = re.sub(r'\{(.+?)\}', r'\1', template)
processed_template = re.sub(r'\{\{(.+?)\}\}', r'{\1}', processed_template)
return processed_template

View File

@@ -0,0 +1,37 @@
import re
from typing import Any
from langchain import PromptTemplate
from langchain.formatting import StrictFormatter
class OutLinePromptTemplate(PromptTemplate):
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
input_variables = {
v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
class OneLineFormatter(StrictFormatter):
def parse(self, format_string):
last_end = 0
results = []
for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
field_name = match.group(1)
start, end = match.span()
literal_text = format_string[last_end:start]
last_end = end
results.append((literal_text, field_name, '', None))
remaining_literal_text = format_string[last_end:]
if remaining_literal_text:
results.append((remaining_literal_text, None, None, None))
return results

View File

@@ -0,0 +1,63 @@
from llama_index import QueryKeywordExtractPrompt
CONVERSATION_TITLE_PROMPT = (
"Human:{query}\n-----\n"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
"If the human said is conducted in Chinese, you should return a Chinese title.\n"
"If the human said is conducted in English, you should return an English title.\n"
"title:"
)
CONVERSATION_SUMMARY_PROMPT = (
"Please generate a short summary of the following conversation.\n"
"If the conversation communicating in Chinese, you should return a Chinese summary.\n"
"If the conversation communicating in English, you should return an English summary.\n"
"[Conversation Start]\n"
"{context}\n"
"[Conversation End]\n\n"
"summary:"
)
INTRODUCTION_GENERATE_PROMPT = (
"I am designing a product for users to interact with an AI through dialogue. "
"The Prompt given to the AI before the conversation is:\n\n"
"```\n{prompt}\n```\n\n"
"Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
"Do not reveal the developer's motivation or deep logic behind the Prompt, "
"but focus on building a relationship with the user:\n"
)
MORE_LIKE_THIS_GENERATE_PROMPT = (
"-----\n"
"{original_completion}\n"
"-----\n\n"
"Please use the above content as a sample for generating the result, "
"and include key information points related to the original sample in the result. "
"Try to rephrase this information in different ways and predict according to the rules below.\n\n"
"-----\n"
"{prompt}\n"
)
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"
"The output must be in JSON format following the specified schema:\n"
"[\"question1\",\"question2\",\"question3\"]\n"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
"A question is provided below. Given the question, extract up to {max_keywords} "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question. Avoid stopwords."
"I am not sure which language the following question is in. "
"If the user asked the question in Chinese, please return the keywords in Chinese. "
"If the user asked the question in English, please return the keywords in English.\n"
"---------------------\n"
"{question}\n"
"---------------------\n"
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
)

View File

@@ -0,0 +1,83 @@
from typing import Optional
from langchain.callbacks import CallbackManager
from llama_index.langchain_helpers.agents import IndexToolConfig
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
from extensions.ext_database import db
from models.dataset import Dataset
class DatasetToolBuilder:
@classmethod
def build_dataset_tool(cls, tenant_id: str, dataset_id: str,
response_mode: str = "no_synthesizer",
callback_handler: Optional[DatasetToolCallbackHandler] = None):
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
return None
if dataset.indexing_technique == "economy":
# use keyword table query
index = KeywordTableIndex(dataset=dataset).query_index
if not index:
return None
query_kwargs = {
"mode": "default",
"response_mode": response_mode,
"query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE,
"max_keywords_per_query": 5,
# If num_chunks_per_query is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"num_chunks_per_query": 2
}
else:
index = VectorIndex(dataset=dataset).query_index
if not index:
return None
query_kwargs = {
"mode": "default",
"response_mode": response_mode,
# If top_k is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"similarity_top_k": 2
}
# fulfill description when it is empty
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
index_tool_config = IndexToolConfig(
index=index,
name=f"dataset-{dataset_id}",
description=description,
index_query_kwargs=query_kwargs,
tool_kwargs={
"callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()])
},
# tool_kwargs={"return_direct": True},
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
)
index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset_id)
return EnhanceLlamaIndexTool.from_tool_config(
tool_config=index_tool_config,
callback_handler=index_callback_handler
)

View File

@@ -0,0 +1,43 @@
from typing import Dict
from langchain.tools import BaseTool
from llama_index.indices.base import BaseGPTIndex
from llama_index.langchain_helpers.agents import IndexToolConfig
from pydantic import Field
from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
class EnhanceLlamaIndexTool(BaseTool):
"""Tool for querying a LlamaIndex."""
# NOTE: name/description still needs to be set
index: BaseGPTIndex
query_kwargs: Dict = Field(default_factory=dict)
return_sources: bool = False
callback_handler: IndexToolCallbackHandler
@classmethod
def from_tool_config(cls, tool_config: IndexToolConfig,
callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
"""Create a tool from a tool config."""
return_sources = tool_config.tool_kwargs.pop("return_sources", False)
return cls(
index=tool_config.index,
callback_handler=callback_handler,
name=tool_config.name,
description=tool_config.description,
return_sources=return_sources,
query_kwargs=tool_config.index_query_kwargs,
**tool_config.tool_kwargs,
)
def _run(self, tool_input: str) -> str:
response = self.index.query(tool_input, **self.query_kwargs)
self.callback_handler.on_tool_end(response)
return str(response)
async def _arun(self, tool_input: str) -> str:
response = await self.index.aquery(tool_input, **self.query_kwargs)
self.callback_handler.on_tool_end(response)
return str(response)

View File

@@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
from typing import Optional
from llama_index import ServiceContext, GPTVectorStoreIndex
from llama_index.data_structs import Node
from llama_index.vector_stores.types import VectorStore
class BaseVectorStoreClient(ABC):
@abstractmethod
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
raise NotImplementedError
@abstractmethod
def to_index_config(self, index_id: str) -> dict:
raise NotImplementedError
class BaseGPTVectorStoreIndex(GPTVectorStoreIndex):
def delete_node(self, node_id: str):
self._vector_store.delete_node(node_id)
def exists_by_node_id(self, node_id: str) -> bool:
return self._vector_store.exists_by_node_id(node_id)
class EnhanceVectorStore(ABC):
@abstractmethod
def delete_node(self, node_id: str):
pass
@abstractmethod
def exists_by_node_id(self, node_id: str) -> bool:
pass

View File

@@ -0,0 +1,147 @@
import os
from typing import cast, List
from llama_index.data_structs import Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
from qdrant_client.http.models import Payload, Filter
import qdrant_client
from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex
from llama_index.data_structs.data_structs_v2 import QdrantIndexDict
from llama_index.vector_stores import QdrantVectorStore
from qdrant_client.local.qdrant_local import QdrantLocal
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
class QdrantVectorStoreClient(BaseVectorStoreClient):
def __init__(self, url: str, api_key: str, root_path: str):
self._client = self.init_from_config(url, api_key, root_path)
@classmethod
def init_from_config(cls, url: str, api_key: str, root_path: str):
if url and url.startswith('path:'):
path = url.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(root_path, path)
return qdrant_client.QdrantClient(
path=path
)
else:
return qdrant_client.QdrantClient(
url=url,
api_key=api_key,
)
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = QdrantIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
# {"collection_name": "Gpt_index_xxx"}
collection_name = config.get('collection_name')
if not collection_name:
raise Exception("collection_name cannot be None.")
return GPTQdrantEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=QdrantEnhanceVectorStore(
client=self._client,
collection_name=collection_name
)
)
def to_index_config(self, index_id: str) -> dict:
return {"collection_name": index_id}
class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex):
pass
class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore):
def delete_node(self, node_id: str):
"""
Delete node from the index.
:param node_id: node id
"""
from qdrant_client.http import models as rest
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=rest.Filter(
must=[
rest.FieldCondition(
key="id", match=rest.MatchValue(value=node_id)
)
]
),
)
def exists_by_node_id(self, node_id: str) -> bool:
"""
Get node from the index by node id.
:param node_id: node id
"""
self._reload_if_needed()
response = self._client.retrieve(
collection_name=self._collection_name,
ids=[node_id]
)
return len(response) > 0
def query(
self,
query: VectorStoreQuery,
) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query (VectorStoreQuery): query
"""
query_embedding = cast(List[float], query.query_embedding)
self._reload_if_needed()
response = self._client.search(
collection_name=self._collection_name,
query_vector=query_embedding,
limit=cast(int, query.similarity_top_k),
query_filter=cast(Filter, self._build_query_filter(query)),
with_vectors=True
)
nodes = []
similarities = []
ids = []
for point in response:
payload = cast(Payload, point.payload)
node = Node(
doc_id=str(point.id),
text=payload.get("text"),
embedding=point.vector,
extra_info=payload.get("extra_info"),
relationships={
DocumentRelationship.SOURCE: payload.get("doc_id", "None"),
},
)
nodes.append(node)
similarities.append(point.score)
ids.append(str(point.id))
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
def _reload_if_needed(self):
if isinstance(self._client._client, QdrantLocal):
self._client._client._load()

View File

@@ -0,0 +1,61 @@
from flask import Flask
from llama_index import ServiceContext, GPTVectorStoreIndex
from requests import ReadTimeout
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
class VectorStore:
def __init__(self):
self._vector_store = None
self._client = None
def init_app(self, app: Flask):
if not app.config['VECTOR_STORE']:
return
self._vector_store = app.config['VECTOR_STORE']
if self._vector_store not in SUPPORTED_VECTOR_STORES:
raise ValueError(f"Vector store {self._vector_store} is not supported.")
if self._vector_store == 'weaviate':
self._client = WeaviateVectorStoreClient(
endpoint=app.config['WEAVIATE_ENDPOINT'],
api_key=app.config['WEAVIATE_API_KEY'],
grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED']
)
elif self._vector_store == 'qdrant':
self._client = QdrantVectorStoreClient(
url=app.config['QDRANT_URL'],
api_key=app.config['QDRANT_API_KEY'],
root_path=app.root_path
)
app.extensions['vector_store'] = self
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex:
vector_store_config: dict = index_struct.get('vector_store')
index = self.get_client().get_index(
service_context=service_context,
config=vector_store_config
)
return index
def to_index_struct(self, index_id: str) -> dict:
return {
"type": self._vector_store,
"vector_store": self.get_client().to_index_config(index_id)
}
def get_client(self):
if not self._client:
raise Exception("Vector store client is not initialized.")
return self._client

View File

@@ -0,0 +1,66 @@
from llama_index.indices.query.base import IS
from typing import (
Any,
Dict,
List,
Optional
)
from llama_index.docstore import BaseDocumentStore
from llama_index.indices.postprocessor.node import (
BaseNodePostprocessor,
)
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
from llama_index.indices.response.response_builder import ResponseMode
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from core.index.query.synthesizer import EnhanceResponseSynthesizer
class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery):
@classmethod
def from_args(
cls,
index_struct: IS,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_kwargs: Optional[Dict] = None,
use_async: bool = False,
streaming: bool = False,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
# class-specific args
**kwargs: Any,
) -> "BaseGPTIndexQuery":
response_synthesizer = EnhanceResponseSynthesizer.from_args(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
simple_template=simple_template,
response_mode=response_mode,
response_kwargs=response_kwargs,
use_async=use_async,
streaming=streaming,
optimizer=optimizer,
)
return cls(
index_struct=index_struct,
service_context=service_context,
response_synthesizer=response_synthesizer,
docstore=docstore,
node_postprocessors=node_postprocessors,
verbose=verbose,
**kwargs,
)

View File

@@ -0,0 +1,258 @@
import json
import weaviate
from dataclasses import field
from typing import List, Any, Dict, Optional
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
from llama_index.vector_stores import WeaviateVectorStore
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
from llama_index.readers.weaviate.utils import (
parse_get_response,
validate_client,
)
class WeaviateVectorStoreClient(BaseVectorStoreClient):
def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool):
self._client = self.init_from_config(endpoint, api_key, grpc_enabled)
def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool):
auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
weaviate.connect.connection.has_grpc = grpc_enabled
return weaviate.Client(
url=endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 15),
startup_period=None
)
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = WeaviateIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
# {"class_prefix": "Gpt_index_xxx"}
class_prefix = config.get('class_prefix')
if not class_prefix:
raise Exception("class_prefix cannot be None.")
return GPTWeaviateEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=WeaviateWithSimilaritiesVectorStore(
weaviate_client=self._client,
class_prefix=class_prefix
)
)
def to_index_config(self, index_id: str) -> dict:
return {"class_prefix": index_id}
class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes."""
nodes = self.weaviate_query(
self._client,
self._class_prefix,
query,
)
nodes = nodes[: query.similarity_top_k]
node_idxs = [str(i) for i in range(len(nodes))]
similarities = []
for node in nodes:
similarities.append(node.extra_info['similarity'])
del node.extra_info['similarity']
return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)
def weaviate_query(
self,
client: Any,
class_prefix: str,
query_spec: VectorStoreQuery,
) -> List[Node]:
"""Convert to LlamaIndex list."""
validate_client(client)
class_name = _class_name(class_prefix)
prop_names = [p["name"] for p in NODE_SCHEMA]
vector = query_spec.query_embedding
# build query
query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
if query_spec.mode == VectorStoreQueryMode.DEFAULT:
_logger.debug("Using vector search")
if vector is not None:
query = query.with_near_vector(
{
"vector": vector,
}
)
elif query_spec.mode == VectorStoreQueryMode.HYBRID:
_logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
query = query.with_hybrid(
query=query_spec.query_str,
alpha=query_spec.alpha,
vector=vector,
)
query = query.with_limit(query_spec.similarity_top_k)
_logger.debug(f"Using limit of {query_spec.similarity_top_k}")
# execute query
query_result = query.do()
# parse results
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
results = [self._to_node(entry) for entry in entries]
return results
def _to_node(self, entry: Dict) -> Node:
"""Convert to Node."""
extra_info_str = entry["extra_info"]
if extra_info_str == "":
extra_info = None
else:
extra_info = json.loads(extra_info_str)
if 'certainty' in entry['_additional']:
if extra_info:
extra_info['similarity'] = entry['_additional']['certainty']
else:
extra_info = {'similarity': entry['_additional']['certainty']}
node_info_str = entry["node_info"]
if node_info_str == "":
node_info = None
else:
node_info = json.loads(node_info_str)
relationships_str = entry["relationships"]
relationships: Dict[DocumentRelationship, str]
if relationships_str == "":
relationships = field(default_factory=dict)
else:
relationships = {
DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
}
return Node(
text=entry["text"],
doc_id=entry["doc_id"],
embedding=entry["_additional"]["vector"],
extra_info=extra_info,
node_info=node_info,
relationships=relationships,
)
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document.
Args:
doc_id (str): document id
"""
delete_document(self._client, doc_id, self._class_prefix)
def delete_node(self, node_id: str):
"""
Delete node from the index.
:param node_id: node id
"""
delete_node(self._client, node_id, self._class_prefix)
def exists_by_node_id(self, node_id: str) -> bool:
"""
Get node from the index by node id.
:param node_id: node id
"""
entry = get_by_node_id(self._client, node_id, self._class_prefix)
return True if entry else False
class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
pass
def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["ref_doc_id"],
"operator": "Equal",
"valueString": ref_doc_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
while len(entries) > 0:
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["doc_id"],
"operator": "Equal",
"valueString": node_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["doc_id"],
"operator": "Equal",
"valueString": node_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
if len(entries) == 0:
return None
return entries[0]