feat: advanced prompt backend (#1301)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Garfield Dai
2023-10-12 23:13:10 +08:00
committed by GitHub
parent 2d1cb076c6
commit 42a5b3ec17
61 changed files with 767 additions and 581 deletions

View File

@@ -1,79 +1,39 @@
import re
from typing import Any
from jinja2 import Environment, meta
from langchain import PromptTemplate
from langchain.formatting import StrictFormatter
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{1,29}|#histories#|#query#|#context#)\}\}")
class JinjaPromptTemplate(PromptTemplate):
template_format: str = "jinja2"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
class PromptTemplateParser:
"""
Rules:
1. Template variables must be enclosed in `{{}}`.
2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters,
and can only start with letters and underscores.
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
4. In addition to the above, 3 types of special template variable Keys are accepted:
`{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
"""
def __init__(self, template: str):
self.template = template
self.variable_keys = self.extract()
def extract(self) -> list:
# Regular expression to match the template rules
return re.findall(REGEX, self.template)
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
if remove_template_variables:
return PromptTemplateParser.remove_template_variables(value)
return value
return re.sub(REGEX, replacer, self.template)
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
env = Environment()
template = template.replace("{{}}", "{}")
ast = env.parse(template)
input_variables = meta.find_undeclared_variables(ast)
if "partial_variables" in kwargs:
partial_variables = kwargs["partial_variables"]
input_variables = {
var for var in input_variables if var not in partial_variables
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
class OutLinePromptTemplate(PromptTemplate):
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
input_variables = {
v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
return OneLineFormatter().format(self.template, **kwargs)
class OneLineFormatter(StrictFormatter):
def parse(self, format_string):
last_end = 0
results = []
for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
field_name = match.group(1)
start, end = match.span()
literal_text = format_string[last_end:start]
last_end = end
results.append((literal_text, field_name, '', None))
remaining_literal_text = format_string[last_end:]
if remaining_literal_text:
results.append((remaining_literal_text, None, None, None))
return results
def remove_template_variables(cls, text: str):
return re.sub(REGEX, r'{\1}', text)