mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 18:23:07 +08:00
Feat/assistant app (#2086)
Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: Pascal M <11357019+perzeuss@users.noreply.github.com>
This commit is contained in:
77
api/core/tools/utils/configration.py
Normal file
77
api/core/tools/utils/configration.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentials
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.helper import encrypter
|
||||
|
||||
class ToolConfiguration(BaseModel):
|
||||
tenant_id: str
|
||||
provider_controller: ToolProviderController
|
||||
|
||||
def _deep_copy(self, credentails: Dict[str, str]) -> Dict[str, str]:
|
||||
"""
|
||||
deep copy credentials
|
||||
"""
|
||||
return {key: value for key, value in credentails.items()}
|
||||
|
||||
def encrypt_tool_credentials(self, credentails: Dict[str, str]) -> Dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tanent id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
credentials = self._deep_copy(credentails)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentails_schema()
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
|
||||
credentials[field_name] = encrypted
|
||||
|
||||
return credentials
|
||||
|
||||
def mask_tool_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
credentials = self._deep_copy(credentials)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentails_schema()
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
if len(credentials[field_name]) > 6:
|
||||
credentials[field_name] = \
|
||||
credentials[field_name][:2] + \
|
||||
'*' * (len(credentials[field_name]) - 4) +\
|
||||
credentials[field_name][-2:]
|
||||
else:
|
||||
credentials[field_name] = '*' * len(credentials[field_name])
|
||||
|
||||
return credentials
|
||||
|
||||
def decrypt_tool_credentials(self, credentials: Dict[str, str]) -> Dict[str, str]:
|
||||
"""
|
||||
decrypt tool credentials with tanent id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
credentials = self._deep_copy(credentials)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentails_schema()
|
||||
for field_name, field in fields.items():
|
||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
try:
|
||||
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
|
||||
except:
|
||||
pass
|
||||
|
||||
return credentials
|
||||
21
api/core/tools/utils/encoder.py
Normal file
21
api/core/tools/utils/encoder.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
def serialize_base_model_array(l: List[BaseModel]) -> str:
|
||||
class _BaseModel(BaseModel):
|
||||
__root__: List[BaseModel]
|
||||
|
||||
"""
|
||||
{"__root__": [BaseModel, BaseModel, ...]}
|
||||
"""
|
||||
return _BaseModel(__root__=l).json()
|
||||
|
||||
def serialize_base_model_dict(b: dict) -> str:
|
||||
class _BaseModel(BaseModel):
|
||||
__root__: dict
|
||||
|
||||
"""
|
||||
{"__root__": {BaseModel}}
|
||||
"""
|
||||
return _BaseModel(__root__=b).json()
|
||||
341
api/core/tools/utils/parser.py
Normal file
341
api/core/tools/utils/parser.py
Normal file
@@ -0,0 +1,341 @@
|
||||
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_entities import ToolParamter, ToolParamterOption, ApiProviderSchemaType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.errors import ToolProviderNotFoundError, ToolNotSupportedError, \
|
||||
ToolApiSchemaError
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
from yaml import FullLoader, load
|
||||
from json import loads as json_loads, dumps as json_dumps
|
||||
from requests import get
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
if 'description' in openapi['info']:
|
||||
extra_info['description'] = openapi['info']['description']
|
||||
else:
|
||||
extra_info['description'] = ''
|
||||
|
||||
if len(openapi['servers']) == 0:
|
||||
raise ToolProviderNotFoundError('No server found in the openapi yaml.')
|
||||
|
||||
server_url = openapi['servers'][0]['url']
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi['paths'].items():
|
||||
methods = ['get', 'post', 'put', 'delete', 'patch', 'head', 'options', 'trace']
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append({
|
||||
'path': path,
|
||||
'method': method,
|
||||
'operation': path_item[method],
|
||||
})
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if 'parameters' in interface['operation']:
|
||||
for parameter in interface['operation']['parameters']:
|
||||
parameters.append(ToolParamter(
|
||||
name=parameter['name'],
|
||||
label=I18nObject(
|
||||
en_US=parameter['name'],
|
||||
zh_Hans=parameter['name']
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get('description', ''),
|
||||
zh_Hans=parameter.get('description', '')
|
||||
),
|
||||
type=ToolParamter.ToolParameterType.STRING,
|
||||
required=parameter.get('required', False),
|
||||
form=ToolParamter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get('description'),
|
||||
default=parameter['default'] if 'default' in parameter else None,
|
||||
))
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if 'requestBody' in interface['operation']:
|
||||
request_body = interface['operation']['requestBody']
|
||||
if 'content' in request_body:
|
||||
for content_type, content in request_body['content'].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if 'schema' not in content:
|
||||
content
|
||||
|
||||
if '$ref' in content['schema']:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content['schema']['$ref'].split('/')[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface['operation']['requestBody']['content'][content_type]['schema'] = root
|
||||
# parse body parameters
|
||||
if 'schema' in interface['operation']['requestBody']['content'][content_type]:
|
||||
body_schema = interface['operation']['requestBody']['content'][content_type]['schema']
|
||||
required = body_schema['required'] if 'required' in body_schema else []
|
||||
properties = body_schema['properties'] if 'properties' in body_schema else {}
|
||||
for name, property in properties.items():
|
||||
parameters.append(ToolParamter(
|
||||
name=name,
|
||||
label=I18nObject(
|
||||
en_US=name,
|
||||
zh_Hans=name
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=property['description'] if 'description' in property else '',
|
||||
zh_Hans=property['description'] if 'description' in property else ''
|
||||
),
|
||||
type=ToolParamter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParamter.ToolParameterForm.LLM,
|
||||
llm_description=property['description'] if 'description' in property else '',
|
||||
default=property['default'] if 'default' in property else None,
|
||||
))
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
for parameter in parameters:
|
||||
if parameter.name not in parameters_count:
|
||||
parameters_count[parameter.name] = 0
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning['duplicated_parameter'] = f'Parameter {name} is duplicated.'
|
||||
|
||||
bundles.append(ApiBasedToolBundle(
|
||||
server_url=server_url + interface['path'],
|
||||
method=interface['method'],
|
||||
summary=interface['operation']['summary'] if 'summary' in interface['operation'] else None,
|
||||
operation_id=interface['operation']['operationId'],
|
||||
parameters=parameters,
|
||||
author='',
|
||||
icon=None,
|
||||
openapi=interface['operation'],
|
||||
))
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = load(yaml, Loader=FullLoader)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError('Invalid openapi yaml.')
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = json_loads(json)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError('Invalid openapi json.')
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict:
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get('info', {
|
||||
'title': 'Swagger',
|
||||
'description': 'Swagger',
|
||||
'version': '1.0.0'
|
||||
})
|
||||
|
||||
servers = swagger.get('servers', [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError('No server found in the swagger yaml.')
|
||||
|
||||
openapi = {
|
||||
'openapi': '3.0.0',
|
||||
'info': {
|
||||
'title': info.get('title', 'Swagger'),
|
||||
'description': info.get('description', 'Swagger'),
|
||||
'version': info.get('version', '1.0.0')
|
||||
},
|
||||
'servers': swagger['servers'],
|
||||
'paths': {},
|
||||
'components': {
|
||||
'schemas': {}
|
||||
}
|
||||
}
|
||||
|
||||
# check paths
|
||||
if 'paths' not in swagger or len(swagger['paths']) == 0:
|
||||
raise ToolApiSchemaError('No paths found in the swagger yaml.')
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger['paths'].items():
|
||||
openapi['paths'][path] = {}
|
||||
for method, operation in path_item.items():
|
||||
if 'operationId' not in operation:
|
||||
raise ToolApiSchemaError(f'No operationId found in operation {method} {path}.')
|
||||
|
||||
if 'summary' not in operation or len(operation['summary']) == 0:
|
||||
warning['missing_summary'] = f'No summary found in operation {method} {path}.'
|
||||
|
||||
if 'description' not in operation or len(operation['description']) == 0:
|
||||
warning['missing_description'] = f'No description found in operation {method} {path}.'
|
||||
|
||||
openapi['paths'][path][method] = {
|
||||
'operationId': operation['operationId'],
|
||||
'summary': operation.get('summary', ''),
|
||||
'description': operation.get('description', ''),
|
||||
'parameters': operation.get('parameters', []),
|
||||
'responses': operation.get('responses', {}),
|
||||
}
|
||||
|
||||
if 'requestBody' in operation:
|
||||
openapi['paths'][path][method]['requestBody'] = operation['requestBody']
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger['definitions'].items():
|
||||
openapi['components']['schemas'][name] = definition
|
||||
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
|
||||
"""
|
||||
parse swagger yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
swagger: dict = load(yaml, Loader=FullLoader)
|
||||
|
||||
openapi = ApiBasedToolSchemaParser.parse_swagger_to_openapi(swagger, extra_info=extra_info, warning=warning)
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
|
||||
"""
|
||||
parse swagger yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
swagger: dict = json_loads(json)
|
||||
|
||||
openapi = ApiBasedToolSchemaParser.parse_swagger_to_openapi(swagger, extra_info=extra_info, warning=warning)
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin['api']
|
||||
api_url = api['url']
|
||||
api_type = api['type']
|
||||
except:
|
||||
raise ToolProviderNotFoundError('Invalid openai plugin json.')
|
||||
|
||||
if api_type != 'openapi':
|
||||
raise ToolNotSupportedError('Only openapi is supported now.')
|
||||
|
||||
# get openapi yaml
|
||||
response = get(api_url, headers={
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) '
|
||||
}, timeout=5)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError('cannot get openapi yaml from url.')
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> Tuple[List[ApiBasedToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
json_possible = False
|
||||
content = content.strip()
|
||||
|
||||
if content.startswith('{') and content.endswith('}'):
|
||||
json_possible = True
|
||||
|
||||
if json_possible:
|
||||
try:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \
|
||||
ApiProviderSchemaType.OPENAPI.value
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ApiBasedToolSchemaParser.parse_swagger_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \
|
||||
ApiProviderSchemaType.SWAGGER.value
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
return ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \
|
||||
ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(content, extra_info=extra_info, warning=warning), \
|
||||
ApiProviderSchemaType.OPENAPI.value
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ApiBasedToolSchemaParser.parse_swagger_yaml_to_tool_bundle(content, extra_info=extra_info, warning=warning), \
|
||||
ApiProviderSchemaType.SWAGGER.value
|
||||
except:
|
||||
pass
|
||||
|
||||
raise ToolApiSchemaError('Invalid api schema.')
|
||||
446
api/core/tools/utils/web_reader_tool.py
Normal file
446
api/core/tools/utils/web_reader_tool.py
Normal file
@@ -0,0 +1,446 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import site
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from typing import Type, Any
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup, NavigableString, Comment, CData
|
||||
from langchain.chains import RefineDocumentsChain
|
||||
from langchain.chains.summarize import refine_prompts
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.tools.base import BaseTool
|
||||
from newspaper import Article
|
||||
from pydantic import BaseModel, Field
|
||||
from regex import regex
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.data_loader import file_extractor
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
|
||||
FULL_TEMPLATE = """
|
||||
TITLE: {title}
|
||||
AUTHORS: {authors}
|
||||
PUBLISH DATE: {publish_date}
|
||||
TOP_IMAGE_URL: {top_image}
|
||||
TEXT:
|
||||
|
||||
{text}
|
||||
"""
|
||||
|
||||
|
||||
class WebReaderToolInput(BaseModel):
|
||||
url: str = Field(..., description="URL of the website to read")
|
||||
summary: bool = Field(
|
||||
default=False,
|
||||
description="When the user's question requires extracting the summarizing content of the webpage, "
|
||||
"set it to true."
|
||||
)
|
||||
cursor: int = Field(
|
||||
default=0,
|
||||
description="Start reading from this character."
|
||||
"Use when the first response was truncated"
|
||||
"and you want to continue reading the page."
|
||||
"The value cannot exceed 24000.",
|
||||
)
|
||||
|
||||
|
||||
class WebReaderTool(BaseTool):
|
||||
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
|
||||
|
||||
name: str = "web_reader"
|
||||
args_schema: Type[BaseModel] = WebReaderToolInput
|
||||
description: str = "use this to read a website. " \
|
||||
"If you can answer the question based on the information provided, " \
|
||||
"there is no need to use."
|
||||
page_contents: str = None
|
||||
url: str = None
|
||||
max_chunk_length: int = 4000
|
||||
summary_chunk_tokens: int = 4000
|
||||
summary_chunk_overlap: int = 0
|
||||
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
|
||||
continue_reading: bool = True
|
||||
model_config: ModelConfigEntity
|
||||
model_parameters: dict[str, Any]
|
||||
|
||||
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
|
||||
try:
|
||||
if not self.page_contents or self.url != url:
|
||||
page_contents = get_url(url)
|
||||
self.page_contents = page_contents
|
||||
self.url = url
|
||||
else:
|
||||
page_contents = self.page_contents
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
|
||||
if summary:
|
||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=self.summary_chunk_tokens,
|
||||
chunk_overlap=self.summary_chunk_overlap,
|
||||
separators=self.summary_separators
|
||||
)
|
||||
|
||||
texts = character_splitter.split_text(page_contents)
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
|
||||
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
|
||||
return "No content found."
|
||||
|
||||
# only use first 5 docs
|
||||
if len(docs) > 5:
|
||||
docs = docs[:5]
|
||||
|
||||
chain = self.get_summary_chain()
|
||||
try:
|
||||
page_contents = chain.run(docs)
|
||||
except Exception as e:
|
||||
return f'Read this website failed, caused by: {str(e)}.'
|
||||
else:
|
||||
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
|
||||
|
||||
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
|
||||
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
|
||||
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
|
||||
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
|
||||
|
||||
return page_contents
|
||||
|
||||
async def _arun(self, url: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_summary_chain(self) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(
|
||||
model_config=self.model_config,
|
||||
prompt=refine_prompts.PROMPT,
|
||||
parameters=self.model_parameters
|
||||
)
|
||||
refine_chain = LLMChain(
|
||||
model_config=self.model_config,
|
||||
prompt=refine_prompts.REFINE_PROMPT,
|
||||
parameters=self.model_parameters
|
||||
)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name="text",
|
||||
initial_response_name="existing_answer",
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
return text[cursor: cursor + max_length]
|
||||
|
||||
|
||||
def get_url(url: str, user_agent: str = None) -> str:
|
||||
"""Fetch URL and return the contents as a string."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
if user_agent:
|
||||
headers["User-Agent"] = user_agent
|
||||
|
||||
supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
|
||||
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
|
||||
|
||||
if head_response.status_code != 200:
|
||||
return "URL returned status code {}.".format(head_response.status_code)
|
||||
|
||||
# check content-type
|
||||
main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
|
||||
if main_content_type not in supported_content_types:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return FileExtractor.load_from_url(url, return_text=True)
|
||||
|
||||
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
|
||||
a = extract_using_readabilipy(response.text)
|
||||
|
||||
if not a['plain_text'] or not a['plain_text'].strip():
|
||||
return get_url_from_newspaper3k(url)
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=a['title'],
|
||||
authors=a['byline'],
|
||||
publish_date=a['date'],
|
||||
top_image="",
|
||||
text=a['plain_text'] if a['plain_text'] else "",
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_url_from_newspaper3k(url: str) -> str:
|
||||
|
||||
a = Article(url)
|
||||
a.download()
|
||||
a.parse()
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=a.title,
|
||||
authors=a.authors,
|
||||
publish_date=a.publish_date,
|
||||
top_image=a.top_image,
|
||||
text=a.text,
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def extract_using_readabilipy(html):
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
|
||||
f_html.write(html)
|
||||
f_html.close()
|
||||
html_path = f_html.name
|
||||
|
||||
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
|
||||
article_json_path = html_path + ".json"
|
||||
jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
|
||||
with chdir(jsdir):
|
||||
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
||||
|
||||
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
|
||||
with open(article_json_path, "r", encoding="utf-8") as json_file:
|
||||
input_json = json.loads(json_file.read())
|
||||
|
||||
# Deleting files after processing
|
||||
os.unlink(article_json_path)
|
||||
os.unlink(html_path)
|
||||
|
||||
article_json = {
|
||||
"title": None,
|
||||
"byline": None,
|
||||
"date": None,
|
||||
"content": None,
|
||||
"plain_content": None,
|
||||
"plain_text": None
|
||||
}
|
||||
# Populate article fields from readability fields where present
|
||||
if input_json:
|
||||
if "title" in input_json and input_json["title"]:
|
||||
article_json["title"] = input_json["title"]
|
||||
if "byline" in input_json and input_json["byline"]:
|
||||
article_json["byline"] = input_json["byline"]
|
||||
if "date" in input_json and input_json["date"]:
|
||||
article_json["date"] = input_json["date"]
|
||||
if "content" in input_json and input_json["content"]:
|
||||
article_json["content"] = input_json["content"]
|
||||
article_json["plain_content"] = plain_content(article_json["content"], False, False)
|
||||
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
|
||||
if "textContent" in input_json and input_json["textContent"]:
|
||||
article_json["plain_text"] = input_json["textContent"]
|
||||
article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
|
||||
|
||||
return article_json
|
||||
|
||||
|
||||
def find_module_path(module_name):
|
||||
for package_path in site.getsitepackages():
|
||||
potential_path = os.path.join(package_path, module_name)
|
||||
if os.path.exists(potential_path):
|
||||
return potential_path
|
||||
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def chdir(path):
|
||||
"""Change directory in context and return to original on exit"""
|
||||
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
|
||||
original_path = os.getcwd()
|
||||
os.chdir(path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(original_path)
|
||||
|
||||
|
||||
def extract_text_blocks_as_plain_text(paragraph_html):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(paragraph_html, 'html.parser')
|
||||
# Select all lists
|
||||
list_elements = soup.find_all(['ul', 'ol'])
|
||||
# Prefix text in all list items with "* " and make lists paragraphs
|
||||
for list_element in list_elements:
|
||||
plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
|
||||
list_element.string = plain_items
|
||||
list_element.name = "p"
|
||||
# Select all text blocks
|
||||
text_blocks = [s.parent for s in soup.find_all(string=True)]
|
||||
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
|
||||
# Drop empty paragraphs
|
||||
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
|
||||
return text_blocks
|
||||
|
||||
|
||||
def plain_text_leaf_node(element):
|
||||
# Extract all text, stripped of any child HTML elements and normalise it
|
||||
plain_text = normalise_text(element.get_text())
|
||||
if plain_text != "" and element.name == "li":
|
||||
plain_text = "* {}, ".format(plain_text)
|
||||
if plain_text == "":
|
||||
plain_text = None
|
||||
if "data-node-index" in element.attrs:
|
||||
plain = {"node_index": element["data-node-index"], "text": plain_text}
|
||||
else:
|
||||
plain = {"text": plain_text}
|
||||
return plain
|
||||
|
||||
|
||||
def plain_content(readability_content, content_digests, node_indexes):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(readability_content, 'html.parser')
|
||||
# Make all elements plain
|
||||
elements = plain_elements(soup.contents, content_digests, node_indexes)
|
||||
if node_indexes:
|
||||
# Add node index attributes to nodes
|
||||
elements = [add_node_indexes(element) for element in elements]
|
||||
# Replace article contents with plain elements
|
||||
soup.contents = elements
|
||||
return str(soup)
|
||||
|
||||
|
||||
def plain_elements(elements, content_digests, node_indexes):
|
||||
# Get plain content versions of all elements
|
||||
elements = [plain_element(element, content_digests, node_indexes)
|
||||
for element in elements]
|
||||
if content_digests:
|
||||
# Add content digest attribute to nodes
|
||||
elements = [add_content_digest(element) for element in elements]
|
||||
return elements
|
||||
|
||||
|
||||
def plain_element(element, content_digests, node_indexes):
|
||||
# For lists, we make each item plain text
|
||||
if is_leaf(element):
|
||||
# For leaf node elements, extract the text content, discarding any HTML tags
|
||||
# 1. Get element contents as text
|
||||
plain_text = element.get_text()
|
||||
# 2. Normalise the extracted text string to a canonical representation
|
||||
plain_text = normalise_text(plain_text)
|
||||
# 3. Update element content to be plain text
|
||||
element.string = plain_text
|
||||
elif is_text(element):
|
||||
if is_non_printing(element):
|
||||
# The simplified HTML may have come from Readability.js so might
|
||||
# have non-printing text (e.g. Comment or CData). In this case, we
|
||||
# keep the structure, but ensure that the string is empty.
|
||||
element = type(element)("")
|
||||
else:
|
||||
plain_text = element.string
|
||||
plain_text = normalise_text(plain_text)
|
||||
element = type(element)(plain_text)
|
||||
else:
|
||||
# If not a leaf node or leaf type call recursively on child nodes, replacing
|
||||
element.contents = plain_elements(element.contents, content_digests, node_indexes)
|
||||
return element
|
||||
|
||||
|
||||
def add_node_indexes(element, node_index="0"):
|
||||
# Can't add attributes to string types
|
||||
if is_text(element):
|
||||
return element
|
||||
# Add index to current element
|
||||
element["data-node-index"] = node_index
|
||||
# Add index to child elements
|
||||
for local_idx, child in enumerate(
|
||||
[c for c in element.contents if not is_text(c)], start=1):
|
||||
# Can't add attributes to leaf string types
|
||||
child_index = "{stem}.{local}".format(
|
||||
stem=node_index, local=local_idx)
|
||||
add_node_indexes(child, node_index=child_index)
|
||||
return element
|
||||
|
||||
|
||||
def normalise_text(text):
|
||||
"""Normalise unicode and whitespace."""
|
||||
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
|
||||
text = strip_control_characters(text)
|
||||
text = normalise_unicode(text)
|
||||
text = normalise_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def strip_control_characters(text):
|
||||
"""Strip out unicode control characters which might break the parsing."""
|
||||
# Unicode control characters
|
||||
# [Cc]: Other, Control [includes new lines]
|
||||
# [Cf]: Other, Format
|
||||
# [Cn]: Other, Not Assigned
|
||||
# [Co]: Other, Private Use
|
||||
# [Cs]: Other, Surrogate
|
||||
control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
|
||||
retained_chars = ['\t', '\n', '\r', '\f']
|
||||
|
||||
# Remove non-printing control characters
|
||||
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
|
||||
|
||||
|
||||
def normalise_unicode(text):
|
||||
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||
normal_form = "NFKC"
|
||||
text = unicodedata.normalize(normal_form, text)
|
||||
return text
|
||||
|
||||
|
||||
def normalise_whitespace(text):
|
||||
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
|
||||
text = regex.sub(r"\s+", " ", text)
|
||||
# Remove leading and trailing whitespace
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
def is_leaf(element):
|
||||
return (element.name in ['p', 'li'])
|
||||
|
||||
|
||||
def is_text(element):
|
||||
return isinstance(element, NavigableString)
|
||||
|
||||
|
||||
def is_non_printing(element):
|
||||
return any(isinstance(element, _e) for _e in [Comment, CData])
|
||||
|
||||
|
||||
def add_content_digest(element):
|
||||
if not is_text(element):
|
||||
element["data-content-digest"] = content_digest(element)
|
||||
return element
|
||||
|
||||
|
||||
def content_digest(element):
|
||||
if is_text(element):
|
||||
# Hash
|
||||
trimmed_string = element.string.strip()
|
||||
if trimmed_string == "":
|
||||
digest = ""
|
||||
else:
|
||||
digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
|
||||
else:
|
||||
contents = element.contents
|
||||
num_contents = len(contents)
|
||||
if num_contents == 0:
|
||||
# No hash when no child elements exist
|
||||
digest = ""
|
||||
elif num_contents == 1:
|
||||
# If single child, use digest of child
|
||||
digest = content_digest(contents[0])
|
||||
else:
|
||||
# Build content digest from the "non-empty" digests of child nodes
|
||||
digest = hashlib.sha256()
|
||||
child_digests = list(
|
||||
filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
||||
for child in child_digests:
|
||||
digest.update(child.encode('utf-8'))
|
||||
digest = digest.hexdigest()
|
||||
return digest
|
||||
Reference in New Issue
Block a user