Feat/assistant app (#2086)

Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: Pascal M <11357019+perzeuss@users.noreply.github.com>
This commit is contained in:
Yeuoly
2024-01-23 19:58:23 +08:00
committed by GitHub
parent 7bbe12b2bd
commit 86286e1ac8
175 changed files with 11619 additions and 1235 deletions

View File

@@ -0,0 +1,77 @@
from typing import Dict, Any
from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolProviderCredentials
from core.tools.provider.tool_provider import ToolProviderController
from core.helper import encrypter
class ToolConfiguration(BaseModel):
tenant_id: str
provider_controller: ToolProviderController
def _deep_copy(self, credentails: Dict[str, str]) -> Dict[str, str]:
"""
deep copy credentials
"""
return {key: value for key, value in credentails.items()}
def encrypt_tool_credentials(self, credentails: Dict[str, str]) -> Dict[str, str]:
"""
encrypt tool credentials with tanent id
return a deep copy of credentials with encrypted values
"""
credentials = self._deep_copy(credentails)
# get fields need to be decrypted
fields = self.provider_controller.get_credentails_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted
return credentials
def mask_tool_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentails_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field_name in credentials:
if len(credentials[field_name]) > 6:
credentials[field_name] = \
credentials[field_name][:2] + \
'*' * (len(credentials[field_name]) - 4) +\
credentials[field_name][-2:]
else:
credentials[field_name] = '*' * len(credentials[field_name])
return credentials
def decrypt_tool_credentials(self, credentials: Dict[str, str]) -> Dict[str, str]:
"""
decrypt tool credentials with tanent id
return a deep copy of credentials with decrypted values
"""
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentails_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field_name in credentials:
try:
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
except:
pass
return credentials

View File

@@ -0,0 +1,21 @@
from pydantic import BaseModel
from enum import Enum
from typing import List
def serialize_base_model_array(l: List[BaseModel]) -> str:
class _BaseModel(BaseModel):
__root__: List[BaseModel]
"""
{"__root__": [BaseModel, BaseModel, ...]}
"""
return _BaseModel(__root__=l).json()
def serialize_base_model_dict(b: dict) -> str:
class _BaseModel(BaseModel):
__root__: dict
"""
{"__root__": {BaseModel}}
"""
return _BaseModel(__root__=b).json()

View File

@@ -0,0 +1,341 @@
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_entities import ToolParamter, ToolParamterOption, ApiProviderSchemaType
from core.tools.entities.common_entities import I18nObject
from core.tools.errors import ToolProviderNotFoundError, ToolNotSupportedError, \
ToolApiSchemaError
from typing import List, Tuple
from yaml import FullLoader, load
from json import loads as json_loads, dumps as json_dumps
from requests import get
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
# set description to extra_info
if 'description' in openapi['info']:
extra_info['description'] = openapi['info']['description']
else:
extra_info['description'] = ''
if len(openapi['servers']) == 0:
raise ToolProviderNotFoundError('No server found in the openapi yaml.')
server_url = openapi['servers'][0]['url']
# list all interfaces
interfaces = []
for path, path_item in openapi['paths'].items():
methods = ['get', 'post', 'put', 'delete', 'patch', 'head', 'options', 'trace']
for method in methods:
if method in path_item:
interfaces.append({
'path': path,
'method': method,
'operation': path_item[method],
})
# get all parameters
bundles = []
for interface in interfaces:
# convert parameters
parameters = []
if 'parameters' in interface['operation']:
for parameter in interface['operation']['parameters']:
parameters.append(ToolParamter(
name=parameter['name'],
label=I18nObject(
en_US=parameter['name'],
zh_Hans=parameter['name']
),
human_description=I18nObject(
en_US=parameter.get('description', ''),
zh_Hans=parameter.get('description', '')
),
type=ToolParamter.ToolParameterType.STRING,
required=parameter.get('required', False),
form=ToolParamter.ToolParameterForm.LLM,
llm_description=parameter.get('description'),
default=parameter['default'] if 'default' in parameter else None,
))
# create tool bundle
# check if there is a request body
if 'requestBody' in interface['operation']:
request_body = interface['operation']['requestBody']
if 'content' in request_body:
for content_type, content in request_body['content'].items():
# if there is a reference, get the reference and overwrite the content
if 'schema' not in content:
content
if '$ref' in content['schema']:
# get the reference
root = openapi
reference = content['schema']['$ref'].split('/')[1:]
for ref in reference:
root = root[ref]
# overwrite the content
interface['operation']['requestBody']['content'][content_type]['schema'] = root
# parse body parameters
if 'schema' in interface['operation']['requestBody']['content'][content_type]:
body_schema = interface['operation']['requestBody']['content'][content_type]['schema']
required = body_schema['required'] if 'required' in body_schema else []
properties = body_schema['properties'] if 'properties' in body_schema else {}
for name, property in properties.items():
parameters.append(ToolParamter(
name=name,
label=I18nObject(
en_US=name,
zh_Hans=name
),
human_description=I18nObject(
en_US=property['description'] if 'description' in property else '',
zh_Hans=property['description'] if 'description' in property else ''
),
type=ToolParamter.ToolParameterType.STRING,
required=name in required,
form=ToolParamter.ToolParameterForm.LLM,
llm_description=property['description'] if 'description' in property else '',
default=property['default'] if 'default' in property else None,
))
# check if parameters is duplicated
parameters_count = {}
for parameter in parameters:
if parameter.name not in parameters_count:
parameters_count[parameter.name] = 0
parameters_count[parameter.name] += 1
for name, count in parameters_count.items():
if count > 1:
warning['duplicated_parameter'] = f'Parameter {name} is duplicated.'
bundles.append(ApiBasedToolBundle(
server_url=server_url + interface['path'],
method=interface['method'],
summary=interface['operation']['summary'] if 'summary' in interface['operation'] else None,
operation_id=interface['operation']['operationId'],
parameters=parameters,
author='',
icon=None,
openapi=interface['operation'],
))
return bundles
@staticmethod
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
"""
parse openapi yaml to tool bundle
:param yaml: the yaml string
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = load(yaml, Loader=FullLoader)
if openapi is None:
raise ToolApiSchemaError('Invalid openapi yaml.')
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_openapi_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
"""
parse openapi yaml to tool bundle
:param yaml: the yaml string
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = json_loads(json)
if openapi is None:
raise ToolApiSchemaError('Invalid openapi json.')
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict:
"""
parse swagger to openapi
:param swagger: the swagger dict
:return: the openapi dict
"""
# convert swagger to openapi
info = swagger.get('info', {
'title': 'Swagger',
'description': 'Swagger',
'version': '1.0.0'
})
servers = swagger.get('servers', [])
if len(servers) == 0:
raise ToolApiSchemaError('No server found in the swagger yaml.')
openapi = {
'openapi': '3.0.0',
'info': {
'title': info.get('title', 'Swagger'),
'description': info.get('description', 'Swagger'),
'version': info.get('version', '1.0.0')
},
'servers': swagger['servers'],
'paths': {},
'components': {
'schemas': {}
}
}
# check paths
if 'paths' not in swagger or len(swagger['paths']) == 0:
raise ToolApiSchemaError('No paths found in the swagger yaml.')
# convert paths
for path, path_item in swagger['paths'].items():
openapi['paths'][path] = {}
for method, operation in path_item.items():
if 'operationId' not in operation:
raise ToolApiSchemaError(f'No operationId found in operation {method} {path}.')
if 'summary' not in operation or len(operation['summary']) == 0:
warning['missing_summary'] = f'No summary found in operation {method} {path}.'
if 'description' not in operation or len(operation['description']) == 0:
warning['missing_description'] = f'No description found in operation {method} {path}.'
openapi['paths'][path][method] = {
'operationId': operation['operationId'],
'summary': operation.get('summary', ''),
'description': operation.get('description', ''),
'parameters': operation.get('parameters', []),
'responses': operation.get('responses', {}),
}
if 'requestBody' in operation:
openapi['paths'][path][method]['requestBody'] = operation['requestBody']
# convert definitions
for name, definition in swagger['definitions'].items():
openapi['components']['schemas'][name] = definition
return openapi
@staticmethod
def parse_swagger_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
"""
parse swagger yaml to tool bundle
:param yaml: the yaml string
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
swagger: dict = load(yaml, Loader=FullLoader)
openapi = ApiBasedToolSchemaParser.parse_swagger_to_openapi(swagger, extra_info=extra_info, warning=warning)
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
"""
parse swagger yaml to tool bundle
:param yaml: the yaml string
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
swagger: dict = json_loads(json)
openapi = ApiBasedToolSchemaParser.parse_swagger_to_openapi(swagger, extra_info=extra_info, warning=warning)
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
"""
parse openapi plugin yaml to tool bundle
:param json: the json string
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
try:
openai_plugin = json_loads(json)
api = openai_plugin['api']
api_url = api['url']
api_type = api['type']
except:
raise ToolProviderNotFoundError('Invalid openai plugin json.')
if api_type != 'openapi':
raise ToolNotSupportedError('Only openapi is supported now.')
# get openapi yaml
response = get(api_url, headers={
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) '
}, timeout=5)
if response.status_code != 200:
raise ToolProviderNotFoundError('cannot get openapi yaml from url.')
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
@staticmethod
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> Tuple[List[ApiBasedToolBundle], str]:
"""
auto parse to tool bundle
:param content: the content
:return: tools bundle, schema_type
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
json_possible = False
content = content.strip()
if content.startswith('{') and content.endswith('}'):
json_possible = True
if json_possible:
try:
return ApiBasedToolSchemaParser.parse_openapi_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \
ApiProviderSchemaType.OPENAPI.value
except:
pass
try:
return ApiBasedToolSchemaParser.parse_swagger_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \
ApiProviderSchemaType.SWAGGER.value
except:
pass
try:
return ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \
ApiProviderSchemaType.OPENAI_PLUGIN.value
except:
pass
else:
try:
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(content, extra_info=extra_info, warning=warning), \
ApiProviderSchemaType.OPENAPI.value
except:
pass
try:
return ApiBasedToolSchemaParser.parse_swagger_yaml_to_tool_bundle(content, extra_info=extra_info, warning=warning), \
ApiProviderSchemaType.SWAGGER.value
except:
pass
raise ToolApiSchemaError('Invalid api schema.')

View File

@@ -0,0 +1,446 @@
import hashlib
import json
import os
import re
import site
import subprocess
import tempfile
import unicodedata
from contextlib import contextmanager
from typing import Type, Any
import requests
from bs4 import BeautifulSoup, NavigableString, Comment, CData
from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool
from newspaper import Article
from pydantic import BaseModel, Field
from regex import regex
from core.chain.llm_chain import LLMChain
from core.data_loader import file_extractor
from core.data_loader.file_extractor import FileExtractor
from core.entities.application_entities import ModelConfigEntity
FULL_TEMPLATE = """
TITLE: {title}
AUTHORS: {authors}
PUBLISH DATE: {publish_date}
TOP_IMAGE_URL: {top_image}
TEXT:
{text}
"""
class WebReaderToolInput(BaseModel):
url: str = Field(..., description="URL of the website to read")
summary: bool = Field(
default=False,
description="When the user's question requires extracting the summarizing content of the webpage, "
"set it to true."
)
cursor: int = Field(
default=0,
description="Start reading from this character."
"Use when the first response was truncated"
"and you want to continue reading the page."
"The value cannot exceed 24000.",
)
class WebReaderTool(BaseTool):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name: str = "web_reader"
args_schema: Type[BaseModel] = WebReaderToolInput
description: str = "use this to read a website. " \
"If you can answer the question based on the information provided, " \
"there is no need to use."
page_contents: str = None
url: str = None
max_chunk_length: int = 4000
summary_chunk_tokens: int = 4000
summary_chunk_overlap: int = 0
summary_separators: list[str] = ["\n\n", "", ".", " ", ""]
continue_reading: bool = True
model_config: ModelConfigEntity
model_parameters: dict[str, Any]
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
try:
if not self.page_contents or self.url != url:
page_contents = get_url(url)
self.page_contents = page_contents
self.url = url
else:
page_contents = self.page_contents
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
if summary:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.summary_chunk_tokens,
chunk_overlap=self.summary_chunk_overlap,
separators=self.summary_separators
)
texts = character_splitter.split_text(page_contents)
docs = [Document(page_content=t) for t in texts]
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
return "No content found."
# only use first 5 docs
if len(docs) > 5:
docs = docs[:5]
chain = self.get_summary_chain()
try:
page_contents = chain.run(docs)
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
else:
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return page_contents
async def _arun(self, url: str) -> str:
raise NotImplementedError
def get_summary_chain(self) -> RefineDocumentsChain:
initial_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.PROMPT,
parameters=self.model_parameters
)
refine_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.REFINE_PROMPT,
parameters=self.model_parameters
)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name="text",
initial_response_name="existing_answer",
callbacks=self.callbacks
)
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor: cursor + max_length]
def get_url(url: str, user_agent: str = None) -> str:
"""Fetch URL and return the contents as a string."""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
if user_agent:
headers["User-Agent"] = user_agent
supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
if head_response.status_code != 200:
return "URL returned status code {}.".format(head_response.status_code)
# check content-type
main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES:
return FileExtractor.load_from_url(url, return_text=True)
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
a = extract_using_readabilipy(response.text)
if not a['plain_text'] or not a['plain_text'].strip():
return get_url_from_newspaper3k(url)
res = FULL_TEMPLATE.format(
title=a['title'],
authors=a['byline'],
publish_date=a['date'],
top_image="",
text=a['plain_text'] if a['plain_text'] else "",
)
return res
def get_url_from_newspaper3k(url: str) -> str:
a = Article(url)
a.download()
a.parse()
res = FULL_TEMPLATE.format(
title=a.title,
authors=a.authors,
publish_date=a.publish_date,
top_image=a.top_image,
text=a.text,
)
return res
def extract_using_readabilipy(html):
with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
f_html.write(html)
f_html.close()
html_path = f_html.name
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
article_json_path = html_path + ".json"
jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
with chdir(jsdir):
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
with open(article_json_path, "r", encoding="utf-8") as json_file:
input_json = json.loads(json_file.read())
# Deleting files after processing
os.unlink(article_json_path)
os.unlink(html_path)
article_json = {
"title": None,
"byline": None,
"date": None,
"content": None,
"plain_content": None,
"plain_text": None
}
# Populate article fields from readability fields where present
if input_json:
if "title" in input_json and input_json["title"]:
article_json["title"] = input_json["title"]
if "byline" in input_json and input_json["byline"]:
article_json["byline"] = input_json["byline"]
if "date" in input_json and input_json["date"]:
article_json["date"] = input_json["date"]
if "content" in input_json and input_json["content"]:
article_json["content"] = input_json["content"]
article_json["plain_content"] = plain_content(article_json["content"], False, False)
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
if "textContent" in input_json and input_json["textContent"]:
article_json["plain_text"] = input_json["textContent"]
article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
return article_json
def find_module_path(module_name):
for package_path in site.getsitepackages():
potential_path = os.path.join(package_path, module_name)
if os.path.exists(potential_path):
return potential_path
return None
@contextmanager
def chdir(path):
"""Change directory in context and return to original on exit"""
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
original_path = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(original_path)
def extract_text_blocks_as_plain_text(paragraph_html):
# Load article as DOM
soup = BeautifulSoup(paragraph_html, 'html.parser')
# Select all lists
list_elements = soup.find_all(['ul', 'ol'])
# Prefix text in all list items with "* " and make lists paragraphs
for list_element in list_elements:
plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
list_element.string = plain_items
list_element.name = "p"
# Select all text blocks
text_blocks = [s.parent for s in soup.find_all(string=True)]
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
# Drop empty paragraphs
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
return text_blocks
def plain_text_leaf_node(element):
# Extract all text, stripped of any child HTML elements and normalise it
plain_text = normalise_text(element.get_text())
if plain_text != "" and element.name == "li":
plain_text = "* {}, ".format(plain_text)
if plain_text == "":
plain_text = None
if "data-node-index" in element.attrs:
plain = {"node_index": element["data-node-index"], "text": plain_text}
else:
plain = {"text": plain_text}
return plain
def plain_content(readability_content, content_digests, node_indexes):
# Load article as DOM
soup = BeautifulSoup(readability_content, 'html.parser')
# Make all elements plain
elements = plain_elements(soup.contents, content_digests, node_indexes)
if node_indexes:
# Add node index attributes to nodes
elements = [add_node_indexes(element) for element in elements]
# Replace article contents with plain elements
soup.contents = elements
return str(soup)
def plain_elements(elements, content_digests, node_indexes):
# Get plain content versions of all elements
elements = [plain_element(element, content_digests, node_indexes)
for element in elements]
if content_digests:
# Add content digest attribute to nodes
elements = [add_content_digest(element) for element in elements]
return elements
def plain_element(element, content_digests, node_indexes):
# For lists, we make each item plain text
if is_leaf(element):
# For leaf node elements, extract the text content, discarding any HTML tags
# 1. Get element contents as text
plain_text = element.get_text()
# 2. Normalise the extracted text string to a canonical representation
plain_text = normalise_text(plain_text)
# 3. Update element content to be plain text
element.string = plain_text
elif is_text(element):
if is_non_printing(element):
# The simplified HTML may have come from Readability.js so might
# have non-printing text (e.g. Comment or CData). In this case, we
# keep the structure, but ensure that the string is empty.
element = type(element)("")
else:
plain_text = element.string
plain_text = normalise_text(plain_text)
element = type(element)(plain_text)
else:
# If not a leaf node or leaf type call recursively on child nodes, replacing
element.contents = plain_elements(element.contents, content_digests, node_indexes)
return element
def add_node_indexes(element, node_index="0"):
# Can't add attributes to string types
if is_text(element):
return element
# Add index to current element
element["data-node-index"] = node_index
# Add index to child elements
for local_idx, child in enumerate(
[c for c in element.contents if not is_text(c)], start=1):
# Can't add attributes to leaf string types
child_index = "{stem}.{local}".format(
stem=node_index, local=local_idx)
add_node_indexes(child, node_index=child_index)
return element
def normalise_text(text):
"""Normalise unicode and whitespace."""
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
text = strip_control_characters(text)
text = normalise_unicode(text)
text = normalise_whitespace(text)
return text
def strip_control_characters(text):
"""Strip out unicode control characters which might break the parsing."""
# Unicode control characters
# [Cc]: Other, Control [includes new lines]
# [Cf]: Other, Format
# [Cn]: Other, Not Assigned
# [Co]: Other, Private Use
# [Cs]: Other, Surrogate
control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
retained_chars = ['\t', '\n', '\r', '\f']
# Remove non-printing control characters
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
def normalise_unicode(text):
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form = "NFKC"
text = unicodedata.normalize(normal_form, text)
return text
def normalise_whitespace(text):
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
text = regex.sub(r"\s+", " ", text)
# Remove leading and trailing whitespace
text = text.strip()
return text
def is_leaf(element):
return (element.name in ['p', 'li'])
def is_text(element):
return isinstance(element, NavigableString)
def is_non_printing(element):
return any(isinstance(element, _e) for _e in [Comment, CData])
def add_content_digest(element):
if not is_text(element):
element["data-content-digest"] = content_digest(element)
return element
def content_digest(element):
if is_text(element):
# Hash
trimmed_string = element.string.strip()
if trimmed_string == "":
digest = ""
else:
digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
else:
contents = element.contents
num_contents = len(contents)
if num_contents == 0:
# No hash when no child elements exist
digest = ""
elif num_contents == 1:
# If single child, use digest of child
digest = content_digest(contents[0])
else:
# Build content digest from the "non-empty" digests of child nodes
digest = hashlib.sha256()
child_digests = list(
filter(lambda x: x != "", [content_digest(content) for content in contents]))
for child in child_digests:
digest.update(child.encode('utf-8'))
digest = digest.hexdigest()
return digest