chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -22,18 +22,22 @@ class AdvancedPromptTransform(PromptTransform):
"""
Advanced Prompt Transform for Workflow LLM Node.
"""
def __init__(self, with_variable_tmpl: bool = False) -> None:
self.with_variable_tmpl = with_variable_tmpl
def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
inputs: dict,
query: str,
files: list[FileVar],
context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
def get_prompt(
self,
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
inputs: dict,
query: str,
files: list[FileVar],
context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None,
) -> list[PromptMessage]:
inputs = {key: str(value) for key, value in inputs.items()}
prompt_messages = []
@@ -48,7 +52,7 @@ class AdvancedPromptTransform(PromptTransform):
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config
model_config=model_config,
)
elif model_mode == ModelMode.CHAT:
prompt_messages = self._get_chat_model_prompt_messages(
@@ -60,20 +64,22 @@ class AdvancedPromptTransform(PromptTransform):
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config
model_config=model_config,
)
return prompt_messages
def _get_completion_model_prompt_messages(self,
prompt_template: CompletionModelPromptTemplate,
inputs: dict,
query: Optional[str],
files: list[FileVar],
context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
def _get_completion_model_prompt_messages(
self,
prompt_template: CompletionModelPromptTemplate,
inputs: dict,
query: Optional[str],
files: list[FileVar],
context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> list[PromptMessage]:
"""
Get completion model prompt messages.
"""
@@ -81,7 +87,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt_messages = []
if prompt_template.edition_type == 'basic' or not prompt_template.edition_type:
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
@@ -96,15 +102,13 @@ class AdvancedPromptTransform(PromptTransform):
role_prefix=role_prefix,
prompt_template=prompt_template,
prompt_inputs=prompt_inputs,
model_config=model_config
model_config=model_config,
)
if query:
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
prompt = prompt_template.format(
prompt_inputs
)
prompt = prompt_template.format(prompt_inputs)
else:
prompt = raw_prompt
prompt_inputs = inputs
@@ -122,16 +126,18 @@ class AdvancedPromptTransform(PromptTransform):
return prompt_messages
def _get_chat_model_prompt_messages(self,
prompt_template: list[ChatModelMessage],
inputs: dict,
query: Optional[str],
files: list[FileVar],
context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
def _get_chat_model_prompt_messages(
self,
prompt_template: list[ChatModelMessage],
inputs: dict,
query: Optional[str],
files: list[FileVar],
context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None,
) -> list[PromptMessage]:
"""
Get chat model prompt messages.
"""
@@ -142,22 +148,20 @@ class AdvancedPromptTransform(PromptTransform):
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item.text
if prompt_item.edition_type == 'basic' or not prompt_item.edition_type:
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
prompt = prompt_template.format(
prompt_inputs
)
elif prompt_item.edition_type == 'jinja2':
prompt = prompt_template.format(prompt_inputs)
elif prompt_item.edition_type == "jinja2":
prompt = raw_prompt
prompt_inputs = inputs
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
else:
raise ValueError(f'Invalid edition type: {prompt_item.edition_type}')
raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
if prompt_item.role == PromptMessageRole.USER:
prompt_messages.append(UserPromptMessage(content=prompt))
@@ -168,17 +172,14 @@ class AdvancedPromptTransform(PromptTransform):
if query and query_prompt_template:
prompt_template = PromptTemplateParser(
template=query_prompt_template,
with_variable_tmpl=self.with_variable_tmpl
template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt_inputs['#sys.query#'] = query
prompt_inputs["#sys.query#"] = query
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
query = prompt_template.format(
prompt_inputs
)
query = prompt_template.format(prompt_inputs)
if memory and memory_config:
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
@@ -203,7 +204,7 @@ class AdvancedPromptTransform(PromptTransform):
last_message.content = prompt_message_contents
else:
prompt_message_contents = [TextPromptMessageContent(data='')] # not for query
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
for file in files:
prompt_message_contents.append(file.prompt_message_content)
@@ -220,38 +221,39 @@ class AdvancedPromptTransform(PromptTransform):
return prompt_messages
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
if '#context#' in prompt_template.variable_keys:
if "#context#" in prompt_template.variable_keys:
if context:
prompt_inputs['#context#'] = context
prompt_inputs["#context#"] = context
else:
prompt_inputs['#context#'] = ''
prompt_inputs["#context#"] = ""
return prompt_inputs
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
if '#query#' in prompt_template.variable_keys:
if "#query#" in prompt_template.variable_keys:
if query:
prompt_inputs['#query#'] = query
prompt_inputs["#query#"] = query
else:
prompt_inputs['#query#'] = ''
prompt_inputs["#query#"] = ""
return prompt_inputs
def _set_histories_variable(self, memory: TokenBufferMemory,
memory_config: MemoryConfig,
raw_prompt: str,
role_prefix: MemoryConfig.RolePrefix,
prompt_template: PromptTemplateParser,
prompt_inputs: dict,
model_config: ModelConfigWithCredentialsEntity) -> dict:
if '#histories#' in prompt_template.variable_keys:
def _set_histories_variable(
self,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
raw_prompt: str,
role_prefix: MemoryConfig.RolePrefix,
prompt_template: PromptTemplateParser,
prompt_inputs: dict,
model_config: ModelConfigWithCredentialsEntity,
) -> dict:
if "#histories#" in prompt_template.variable_keys:
if memory:
inputs = {'#histories#': '', **prompt_inputs}
inputs = {"#histories#": "", **prompt_inputs}
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
tmp_human_message = UserPromptMessage(
content=prompt_template.format(prompt_inputs)
)
tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs))
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
@@ -260,10 +262,10 @@ class AdvancedPromptTransform(PromptTransform):
memory_config=memory_config,
max_token_limit=rest_tokens,
human_prefix=role_prefix.user,
ai_prefix=role_prefix.assistant
ai_prefix=role_prefix.assistant,
)
prompt_inputs['#histories#'] = histories
prompt_inputs["#histories#"] = histories
else:
prompt_inputs['#histories#'] = ''
prompt_inputs["#histories#"] = ""
return prompt_inputs

View File

@@ -17,12 +17,14 @@ class AgentHistoryPromptTransform(PromptTransform):
"""
History Prompt Transform for Agent App
"""
def __init__(self,
model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage],
history_messages: list[PromptMessage],
memory: Optional[TokenBufferMemory] = None,
):
def __init__(
self,
model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage],
history_messages: list[PromptMessage],
memory: Optional[TokenBufferMemory] = None,
):
self.model_config = model_config
self.prompt_messages = prompt_messages
self.history_messages = history_messages
@@ -45,9 +47,7 @@ class AgentHistoryPromptTransform(PromptTransform):
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model,
self.memory.model_instance.credentials,
self.history_messages
self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
)
if curr_message_tokens <= max_token_limit:
return self.history_messages
@@ -63,9 +63,7 @@ class AgentHistoryPromptTransform(PromptTransform):
# a message is start with UserPromptMessage
if isinstance(prompt_message, UserPromptMessage):
curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model,
self.memory.model_instance.credentials,
prompt_messages
self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
)
# if current message token is overflow, drop all the prompts in current message and break
if curr_message_tokens > max_token_limit:

View File

@@ -9,27 +9,31 @@ class ChatModelMessage(BaseModel):
"""
Chat Message.
"""
text: str
role: PromptMessageRole
edition_type: Optional[Literal['basic', 'jinja2']] = None
edition_type: Optional[Literal["basic", "jinja2"]] = None
class CompletionModelPromptTemplate(BaseModel):
"""
Completion Model Prompt Template.
"""
text: str
edition_type: Optional[Literal['basic', 'jinja2']] = None
edition_type: Optional[Literal["basic", "jinja2"]] = None
class MemoryConfig(BaseModel):
"""
Memory Config.
"""
class RolePrefix(BaseModel):
"""
Role Prefix.
"""
user: str
assistant: str
@@ -37,6 +41,7 @@ class MemoryConfig(BaseModel):
"""
Window Config.
"""
enabled: bool
size: Optional[int] = None

View File

@@ -7,39 +7,18 @@ CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"prompt": {
"text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: "
},
"conversation_histories_role": {
"user_prefix": "Human",
"assistant_prefix": "Assistant"
}
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
},
"stop": ["Human:"]
"stop": ["Human:"],
}
CHAT_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "system",
"text": "{{#pre_prompt#}}"
}]
}
}
CHAT_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}}
COMPLETION_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "user",
"text": "{{#pre_prompt#}}"
}]
}
}
COMPLETION_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}}
COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}"
}
},
"stop": ["Human:"]
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
"stop": ["Human:"],
}
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
@@ -47,37 +26,20 @@ BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"prompt": {
"text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}"
},
"conversation_histories_role": {
"user_prefix": "用户",
"assistant_prefix": "助手"
}
"conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
},
"stop": ["用户:"]
"stop": ["用户:"],
}
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "system",
"text": "{{#pre_prompt#}}"
}]
}
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}
}
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "user",
"text": "{{#pre_prompt#}}"
}]
}
"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}
}
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}"
}
},
"stop": ["用户:"]
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
"stop": ["用户:"],
}

View File

@@ -9,75 +9,78 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
class PromptTransform:
def _append_chat_histories(self, memory: TokenBufferMemory,
memory_config: MemoryConfig,
prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
def _append_chat_histories(
self,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity,
) -> list[PromptMessage]:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
prompt_messages.extend(histories)
return prompt_messages
def _calculate_rest_token(self, prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity) -> int:
def _calculate_rest_token(
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _get_history_messages_from_memory(self, memory: TokenBufferMemory,
memory_config: MemoryConfig,
max_token_limit: int,
human_prefix: Optional[str] = None,
ai_prefix: Optional[str] = None) -> str:
def _get_history_messages_from_memory(
self,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
max_token_limit: int,
human_prefix: Optional[str] = None,
ai_prefix: Optional[str] = None,
) -> str:
"""Get memory messages."""
kwargs = {
"max_token_limit": max_token_limit
}
kwargs = {"max_token_limit": max_token_limit}
if human_prefix:
kwargs['human_prefix'] = human_prefix
kwargs["human_prefix"] = human_prefix
if ai_prefix:
kwargs['ai_prefix'] = ai_prefix
kwargs["ai_prefix"] = ai_prefix
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
kwargs['message_limit'] = memory_config.window.size
kwargs["message_limit"] = memory_config.window.size
return memory.get_history_prompt_text(
**kwargs
)
return memory.get_history_prompt_text(**kwargs)
def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory,
memory_config: MemoryConfig,
max_token_limit: int) -> list[PromptMessage]:
def _get_history_messages_list_from_memory(
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
) -> list[PromptMessage]:
"""Get memory messages."""
return memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=memory_config.window.size
if (memory_config.window.enabled
and memory_config.window.size is not None
and memory_config.window.size > 0)
else None
if (
memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0
)
else None,
)

View File

@@ -22,11 +22,11 @@ if TYPE_CHECKING:
class ModelMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'
COMPLETION = "completion"
CHAT = "chat"
@classmethod
def value_of(cls, value: str) -> 'ModelMode':
def value_of(cls, value: str) -> "ModelMode":
"""
Get value of given mode.
@@ -36,7 +36,7 @@ class ModelMode(enum.Enum):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
raise ValueError(f"invalid mode value {value}")
prompt_file_contents = {}
@@ -47,16 +47,17 @@ class SimplePromptTransform(PromptTransform):
Simple Prompt Transform for Chatbot App Basic Mode.
"""
def get_prompt(self,
app_mode: AppMode,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
files: list["FileVar"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> \
tuple[list[PromptMessage], Optional[list[str]]]:
def get_prompt(
self,
app_mode: AppMode,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
files: list["FileVar"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
inputs = {key: str(value) for key, value in inputs.items()}
model_mode = ModelMode.value_of(model_config.mode)
@@ -69,7 +70,7 @@ class SimplePromptTransform(PromptTransform):
files=files,
context=context,
memory=memory,
model_config=model_config
model_config=model_config,
)
else:
prompt_messages, stops = self._get_completion_model_prompt_messages(
@@ -80,19 +81,21 @@ class SimplePromptTransform(PromptTransform):
files=files,
context=context,
memory=memory,
model_config=model_config
model_config=model_config,
)
return prompt_messages, stops
def get_prompt_str_and_rules(self, app_mode: AppMode,
model_config: ModelConfigWithCredentialsEntity,
pre_prompt: str,
inputs: dict,
query: Optional[str] = None,
context: Optional[str] = None,
histories: Optional[str] = None,
) -> tuple[str, dict]:
def get_prompt_str_and_rules(
self,
app_mode: AppMode,
model_config: ModelConfigWithCredentialsEntity,
pre_prompt: str,
inputs: dict,
query: Optional[str] = None,
context: Optional[str] = None,
histories: Optional[str] = None,
) -> tuple[str, dict]:
# get prompt template
prompt_template_config = self.get_prompt_template(
app_mode=app_mode,
@@ -101,74 +104,75 @@ class SimplePromptTransform(PromptTransform):
pre_prompt=pre_prompt,
has_context=context is not None,
query_in_prompt=query is not None,
with_memory_prompt=histories is not None
with_memory_prompt=histories is not None,
)
variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs}
variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}
for v in prompt_template_config['special_variable_keys']:
for v in prompt_template_config["special_variable_keys"]:
# support #context#, #query# and #histories#
if v == '#context#':
variables['#context#'] = context if context else ''
elif v == '#query#':
variables['#query#'] = query if query else ''
elif v == '#histories#':
variables['#histories#'] = histories if histories else ''
if v == "#context#":
variables["#context#"] = context if context else ""
elif v == "#query#":
variables["#query#"] = query if query else ""
elif v == "#histories#":
variables["#histories#"] = histories if histories else ""
prompt_template = prompt_template_config['prompt_template']
prompt_template = prompt_template_config["prompt_template"]
prompt = prompt_template.format(variables)
return prompt, prompt_template_config['prompt_rules']
return prompt, prompt_template_config["prompt_rules"]
def get_prompt_template(self, app_mode: AppMode,
provider: str,
model: str,
pre_prompt: str,
has_context: bool,
query_in_prompt: bool,
with_memory_prompt: bool = False) -> dict:
prompt_rules = self._get_prompt_rule(
app_mode=app_mode,
provider=provider,
model=model
)
def get_prompt_template(
self,
app_mode: AppMode,
provider: str,
model: str,
pre_prompt: str,
has_context: bool,
query_in_prompt: bool,
with_memory_prompt: bool = False,
) -> dict:
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
custom_variable_keys = []
special_variable_keys = []
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt' and has_context:
prompt += prompt_rules['context_prompt']
special_variable_keys.append('#context#')
elif order == 'pre_prompt' and pre_prompt:
prompt += pre_prompt + '\n'
prompt = ""
for order in prompt_rules["system_prompt_orders"]:
if order == "context_prompt" and has_context:
prompt += prompt_rules["context_prompt"]
special_variable_keys.append("#context#")
elif order == "pre_prompt" and pre_prompt:
prompt += pre_prompt + "\n"
pre_prompt_template = PromptTemplateParser(template=pre_prompt)
custom_variable_keys = pre_prompt_template.variable_keys
elif order == 'histories_prompt' and with_memory_prompt:
prompt += prompt_rules['histories_prompt']
special_variable_keys.append('#histories#')
elif order == "histories_prompt" and with_memory_prompt:
prompt += prompt_rules["histories_prompt"]
special_variable_keys.append("#histories#")
if query_in_prompt:
prompt += prompt_rules.get('query_prompt', '{{#query#}}')
special_variable_keys.append('#query#')
prompt += prompt_rules.get("query_prompt", "{{#query#}}")
special_variable_keys.append("#query#")
return {
"prompt_template": PromptTemplateParser(template=prompt),
"custom_variable_keys": custom_variable_keys,
"special_variable_keys": special_variable_keys,
"prompt_rules": prompt_rules
"prompt_rules": prompt_rules,
}
def _get_chat_model_prompt_messages(self, app_mode: AppMode,
pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
files: list["FileVar"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
def _get_chat_model_prompt_messages(
self,
app_mode: AppMode,
pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
files: list["FileVar"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
prompt_messages = []
# get prompt
@@ -178,7 +182,7 @@ class SimplePromptTransform(PromptTransform):
pre_prompt=pre_prompt,
inputs=inputs,
query=None,
context=context
context=context,
)
if prompt and query:
@@ -193,7 +197,7 @@ class SimplePromptTransform(PromptTransform):
)
),
prompt_messages=prompt_messages,
model_config=model_config
model_config=model_config,
)
if query:
@@ -203,15 +207,17 @@ class SimplePromptTransform(PromptTransform):
return prompt_messages, None
def _get_completion_model_prompt_messages(self, app_mode: AppMode,
pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
files: list["FileVar"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
def _get_completion_model_prompt_messages(
self,
app_mode: AppMode,
pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
files: list["FileVar"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
# get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules(
app_mode=app_mode,
@@ -219,13 +225,11 @@ class SimplePromptTransform(PromptTransform):
pre_prompt=pre_prompt,
inputs=inputs,
query=query,
context=context
context=context,
)
if memory:
tmp_human_message = UserPromptMessage(
content=prompt
)
tmp_human_message = UserPromptMessage(content=prompt)
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
histories = self._get_history_messages_from_memory(
@@ -236,8 +240,8 @@ class SimplePromptTransform(PromptTransform):
)
),
max_token_limit=rest_tokens,
human_prefix=prompt_rules.get('human_prefix', 'Human'),
ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant')
human_prefix=prompt_rules.get("human_prefix", "Human"),
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
)
# get prompt
@@ -248,10 +252,10 @@ class SimplePromptTransform(PromptTransform):
inputs=inputs,
query=query,
context=context,
histories=histories
histories=histories,
)
stops = prompt_rules.get('stops')
stops = prompt_rules.get("stops")
if stops is not None and len(stops) == 0:
stops = None
@@ -277,22 +281,18 @@ class SimplePromptTransform(PromptTransform):
:param model: model name
:return:
"""
prompt_file_name = self._prompt_file_name(
app_mode=app_mode,
provider=provider,
model=model
)
prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model)
# Check if the prompt file is already loaded
if prompt_file_name in prompt_file_contents:
return prompt_file_contents[prompt_file_name]
# Get the absolute path of the subdirectory
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates')
json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json')
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json")
# Open the JSON file and read its content
with open(json_file_path, encoding='utf-8') as json_file:
with open(json_file_path, encoding="utf-8") as json_file:
content = json.load(json_file)
# Store the content of the prompt file
@@ -303,21 +303,21 @@ class SimplePromptTransform(PromptTransform):
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
# baichuan
is_baichuan = False
if provider == 'baichuan':
if provider == "baichuan":
is_baichuan = True
else:
baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
if provider in baichuan_supported_providers and 'baichuan' in model.lower():
if provider in baichuan_supported_providers and "baichuan" in model.lower():
is_baichuan = True
if is_baichuan:
if app_mode == AppMode.COMPLETION:
return 'baichuan_completion'
return "baichuan_completion"
else:
return 'baichuan_chat'
return "baichuan_chat"
# common
if app_mode == AppMode.COMPLETION:
return 'common_completion'
return "common_completion"
else:
return 'common_chat'
return "common_chat"

View File

@@ -25,26 +25,29 @@ class PromptMessageUtil:
tool_calls = []
for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER:
role = 'user'
role = "user"
elif prompt_message.role == PromptMessageRole.ASSISTANT:
role = 'assistant'
role = "assistant"
if isinstance(prompt_message, AssistantPromptMessage):
tool_calls = [{
'id': tool_call.id,
'type': 'function',
'function': {
'name': tool_call.function.name,
'arguments': tool_call.function.arguments,
tool_calls = [
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
}
} for tool_call in prompt_message.tool_calls]
for tool_call in prompt_message.tool_calls
]
elif prompt_message.role == PromptMessageRole.SYSTEM:
role = 'system'
role = "system"
elif prompt_message.role == PromptMessageRole.TOOL:
role = 'tool'
role = "tool"
else:
continue
text = ''
text = ""
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
@@ -53,27 +56,25 @@ class PromptMessageUtil:
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append({
"type": 'image',
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
"detail": content.detail.value
})
files.append(
{
"type": "image",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"detail": content.detail.value,
}
)
else:
text = prompt_message.content
prompt = {
"role": role,
"text": text,
"files": files
}
prompt = {"role": role, "text": text, "files": files}
if tool_calls:
prompt['tool_calls'] = tool_calls
prompt["tool_calls"] = tool_calls
prompts.append(prompt)
else:
prompt_message = prompt_messages[0]
text = ''
text = ""
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
@@ -82,21 +83,23 @@ class PromptMessageUtil:
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append({
"type": 'image',
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
"detail": content.detail.value
})
files.append(
{
"type": "image",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"detail": content.detail.value,
}
)
else:
text = prompt_message.content
params = {
"role": 'user',
"role": "user",
"text": text,
}
if files:
params['files'] = files
params["files"] = files
prompts.append(params)

View File

@@ -38,8 +38,8 @@ class PromptTemplateParser:
return value
prompt = re.sub(self.regex, replacer, self.template)
return re.sub(r'<\|.*?\|>', '', prompt)
return re.sub(r"<\|.*?\|>", "", prompt)
@classmethod
def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False):
return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r'{\1}', text)
return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r"{\1}", text)