mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-16 06:16:53 +08:00
166
api/core/rag/extractor/blod/blod.py
Normal file
166
api/core/rag/extractor/blod/blod.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Schema for Blobs and Blob Loaders.
|
||||
|
||||
The goal is to facilitate decoupling of content loading from content parsing code.
|
||||
|
||||
In addition, content loading code should provide a lazy loading interface by default.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import mimetypes
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Iterable, Mapping
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
PathLike = Union[str, PurePath]
|
||||
|
||||
|
||||
class Blob(BaseModel):
|
||||
"""A blob is used to represent raw data by either reference or value.
|
||||
|
||||
Provides an interface to materialize the blob in different representations, and
|
||||
help to decouple the development of data loaders from the downstream parsing of
|
||||
the raw data.
|
||||
|
||||
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
|
||||
"""
|
||||
|
||||
data: Union[bytes, str, None] # Raw data
|
||||
mimetype: Optional[str] = None # Not to be confused with a file extension
|
||||
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
|
||||
# Location where the original content was found
|
||||
# Represent location on the local file system
|
||||
# Useful for situations where downstream code assumes it must work with file paths
|
||||
# rather than in-memory content.
|
||||
path: Optional[PathLike] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
frozen = True
|
||||
|
||||
@property
|
||||
def source(self) -> Optional[str]:
|
||||
"""The source location of the blob as string if known otherwise none."""
|
||||
return str(self.path) if self.path else None
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Verify that either data or path is provided."""
|
||||
if "data" not in values and "path" not in values:
|
||||
raise ValueError("Either data or path must be provided")
|
||||
return values
|
||||
|
||||
def as_string(self) -> str:
|
||||
"""Read data as a string."""
|
||||
if self.data is None and self.path:
|
||||
with open(str(self.path), encoding=self.encoding) as f:
|
||||
return f.read()
|
||||
elif isinstance(self.data, bytes):
|
||||
return self.data.decode(self.encoding)
|
||||
elif isinstance(self.data, str):
|
||||
return self.data
|
||||
else:
|
||||
raise ValueError(f"Unable to get string for blob {self}")
|
||||
|
||||
def as_bytes(self) -> bytes:
|
||||
"""Read data as bytes."""
|
||||
if isinstance(self.data, bytes):
|
||||
return self.data
|
||||
elif isinstance(self.data, str):
|
||||
return self.data.encode(self.encoding)
|
||||
elif self.data is None and self.path:
|
||||
with open(str(self.path), "rb") as f:
|
||||
return f.read()
|
||||
else:
|
||||
raise ValueError(f"Unable to get bytes for blob {self}")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
|
||||
"""Read data as a byte stream."""
|
||||
if isinstance(self.data, bytes):
|
||||
yield BytesIO(self.data)
|
||||
elif self.data is None and self.path:
|
||||
with open(str(self.path), "rb") as f:
|
||||
yield f
|
||||
else:
|
||||
raise NotImplementedError(f"Unable to convert blob {self}")
|
||||
|
||||
@classmethod
|
||||
def from_path(
|
||||
cls,
|
||||
path: PathLike,
|
||||
*,
|
||||
encoding: str = "utf-8",
|
||||
mime_type: Optional[str] = None,
|
||||
guess_type: bool = True,
|
||||
) -> Blob:
|
||||
"""Load the blob from a path like object.
|
||||
|
||||
Args:
|
||||
path: path like object to file to be read
|
||||
encoding: Encoding to use if decoding the bytes into a string
|
||||
mime_type: if provided, will be set as the mime-type of the data
|
||||
guess_type: If True, the mimetype will be guessed from the file extension,
|
||||
if a mime-type was not provided
|
||||
|
||||
Returns:
|
||||
Blob instance
|
||||
"""
|
||||
if mime_type is None and guess_type:
|
||||
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None
|
||||
else:
|
||||
_mimetype = mime_type
|
||||
# We do not load the data immediately, instead we treat the blob as a
|
||||
# reference to the underlying data.
|
||||
return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path)
|
||||
|
||||
@classmethod
|
||||
def from_data(
|
||||
cls,
|
||||
data: Union[str, bytes],
|
||||
*,
|
||||
encoding: str = "utf-8",
|
||||
mime_type: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
) -> Blob:
|
||||
"""Initialize the blob from in-memory data.
|
||||
|
||||
Args:
|
||||
data: the in-memory data associated with the blob
|
||||
encoding: Encoding to use if decoding the bytes into a string
|
||||
mime_type: if provided, will be set as the mime-type of the data
|
||||
path: if provided, will be set as the source from which the data came
|
||||
|
||||
Returns:
|
||||
Blob instance
|
||||
"""
|
||||
return cls(data=data, mimetype=mime_type, encoding=encoding, path=path)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Define the blob representation."""
|
||||
str_repr = f"Blob {id(self)}"
|
||||
if self.source:
|
||||
str_repr += f" {self.source}"
|
||||
return str_repr
|
||||
|
||||
|
||||
class BlobLoader(ABC):
|
||||
"""Abstract interface for blob loaders implementation.
|
||||
|
||||
Implementer should be able to load raw content from a datasource system according
|
||||
to some criteria and return the raw content lazily as a stream of blobs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def yield_blobs(
|
||||
self,
|
||||
) -> Iterable[Blob]:
|
||||
"""A lazy loader for raw data represented by LangChain's Blob object.
|
||||
|
||||
Returns:
|
||||
A generator over blobs
|
||||
"""
|
||||
71
api/core/rag/extractor/csv_extractor.py
Normal file
71
api/core/rag/extractor/csv_extractor.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
import csv
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class CSVExtractor(BaseExtractor):
|
||||
"""Load CSV files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
self.source_column = source_column
|
||||
self.csv_args = csv_args or {}
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load data into document objects."""
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_filze_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
|
||||
return docs
|
||||
|
||||
def _read_from_file(self, csvfile) -> list[Document]:
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
||||
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
|
||||
try:
|
||||
source = (
|
||||
row[self.source_column]
|
||||
if self.source_column is not None
|
||||
else ''
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Source column '{self.source_column}' not found in CSV file."
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
6
api/core/rag/extractor/entity/datasource_type.py
Normal file
6
api/core/rag/extractor/entity/datasource_type.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DatasourceType(Enum):
|
||||
FILE = "upload_file"
|
||||
NOTION = "notion_import"
|
||||
36
api/core/rag/extractor/entity/extract_setting.py
Normal file
36
api/core/rag/extractor/entity/extract_setting.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.dataset import Document
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class NotionInfo(BaseModel):
|
||||
"""
|
||||
Notion import info.
|
||||
"""
|
||||
notion_workspace_id: str
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
document: Document = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class ExtractSetting(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
datasource_type: str
|
||||
upload_file: UploadFile = None
|
||||
notion_info: NotionInfo = None
|
||||
document_model: str = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
50
api/core/rag/extractor/excel_extractor.py
Normal file
50
api/core/rag/extractor/excel_extractor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from typing import Optional
|
||||
|
||||
from openpyxl.reader.excel import load_workbook
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class ExcelExtractor(BaseExtractor):
|
||||
"""Load Excel files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load from file path."""
|
||||
data = []
|
||||
keys = []
|
||||
wb = load_workbook(filename=self._file_path, read_only=True)
|
||||
# loop over all sheets
|
||||
for sheet in wb:
|
||||
if 'A1:A1' == sheet.calculate_dimension():
|
||||
sheet.reset_dimensions()
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
if all(v is None for v in row):
|
||||
continue
|
||||
if keys == []:
|
||||
keys = list(map(str, row))
|
||||
else:
|
||||
row_dict = dict(zip(keys, list(map(str, row))))
|
||||
row_dict = {k: v for k, v in row_dict.items() if v}
|
||||
item = ''.join(f'{k}:{v};' for k, v in row_dict.items())
|
||||
document = Document(page_content=item, metadata={'source': self._file_path})
|
||||
data.append(document)
|
||||
|
||||
return data
|
||||
139
api/core/rag/extractor/extract_processor.py
Normal file
139
api/core/rag/extractor/extract_processor.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from core.rag.extractor.pdf_extractor import PdfExtractor
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class ExtractProcessor:
|
||||
@classmethod
|
||||
def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \
|
||||
-> Union[list[Document], str]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=upload_file,
|
||||
document_model='text_model'
|
||||
)
|
||||
if return_text:
|
||||
delimiter = '\n'
|
||||
return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)])
|
||||
else:
|
||||
return cls.extract(extract_setting, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
document_model='text_model'
|
||||
)
|
||||
if return_text:
|
||||
delimiter = '\n'
|
||||
return delimiter.join([document.page_content for document in cls.extract(
|
||||
extract_setting=extract_setting, file_path=file_path)])
|
||||
else:
|
||||
return cls.extract(extract_setting=extract_setting, file_path=file_path)
|
||||
|
||||
@classmethod
|
||||
def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False,
|
||||
file_path: str = None) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE.value:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if not file_path:
|
||||
upload_file: UploadFile = extract_setting.upload_file
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx':
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \
|
||||
else MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.csv':
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.pptx':
|
||||
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \
|
||||
else TextExtractor(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
if file_extension == '.xlsx':
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
extractor = WordExtractor(file_path)
|
||||
elif file_extension == '.csv':
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
extractor = TextExtractor(file_path, autodetect_encoding=True)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
|
||||
notion_obj_id=extract_setting.notion_info.notion_obj_id,
|
||||
notion_page_type=extract_setting.notion_info.notion_page_type,
|
||||
document_model=extract_setting.notion_info.document
|
||||
)
|
||||
return extractor.extract()
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")
|
||||
12
api/core/rag/extractor/extractor_base.py
Normal file
12
api/core/rag/extractor/extractor_base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseExtractor(ABC):
|
||||
"""Interface for extract files.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self):
|
||||
raise NotImplementedError
|
||||
|
||||
46
api/core/rag/extractor/helpers.py
Normal file
46
api/core/rag/extractor/helpers.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Document loader helpers."""
|
||||
|
||||
import concurrent.futures
|
||||
from typing import NamedTuple, Optional, cast
|
||||
|
||||
|
||||
class FileEncoding(NamedTuple):
|
||||
"""A file encoding as the NamedTuple."""
|
||||
|
||||
encoding: Optional[str]
|
||||
"""The encoding of the file."""
|
||||
confidence: float
|
||||
"""The confidence of the encoding."""
|
||||
language: Optional[str]
|
||||
"""The language of the file."""
|
||||
|
||||
|
||||
def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]:
|
||||
"""Try to detect the file encoding.
|
||||
|
||||
Returns a list of `FileEncoding` tuples with the detected encodings ordered
|
||||
by confidence.
|
||||
|
||||
Args:
|
||||
file_path: The path to the file to detect the encoding for.
|
||||
timeout: The timeout in seconds for the encoding detection.
|
||||
"""
|
||||
import chardet
|
||||
|
||||
def read_and_detect(file_path: str) -> list[dict]:
|
||||
with open(file_path, "rb") as f:
|
||||
rawdata = f.read()
|
||||
return cast(list[dict], chardet.detect_all(rawdata))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(read_and_detect, file_path)
|
||||
try:
|
||||
encodings = future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Timeout reached while detecting encoding for {file_path}"
|
||||
)
|
||||
|
||||
if all(encoding["encoding"] is None for encoding in encodings):
|
||||
raise RuntimeError(f"Could not detect encoding for {file_path}")
|
||||
return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]
|
||||
71
api/core/rag/extractor/html_extractor.py
Normal file
71
api/core/rag/extractor/html_extractor.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class HtmlExtractor(BaseExtractor):
|
||||
"""Load html files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
self.source_column = source_column
|
||||
self.csv_args = csv_args or {}
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load data into document objects."""
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
|
||||
return docs
|
||||
|
||||
def _read_from_file(self, csvfile) -> list[Document]:
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
||||
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
|
||||
try:
|
||||
source = (
|
||||
row[self.source_column]
|
||||
if self.source_column is not None
|
||||
else ''
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Source column '{self.source_column}' not found in CSV file."
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
122
api/core/rag/extractor/markdown_extractor.py
Normal file
122
api/core/rag/extractor/markdown_extractor.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
import re
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class MarkdownExtractor(BaseExtractor):
|
||||
"""Load Markdown files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
remove_hyperlinks: bool = True,
|
||||
remove_images: bool = True,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._remove_hyperlinks = remove_hyperlinks
|
||||
self._remove_images = remove_images
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load from file path."""
|
||||
tups = self.parse_tups(self._file_path)
|
||||
documents = []
|
||||
for header, value in tups:
|
||||
value = value.strip()
|
||||
if header is None:
|
||||
documents.append(Document(page_content=value))
|
||||
else:
|
||||
documents.append(Document(page_content=f"\n\n{header}\n{value}"))
|
||||
|
||||
return documents
|
||||
|
||||
def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]:
|
||||
"""Convert a markdown file to a dictionary.
|
||||
|
||||
The keys are the headers and the values are the text under each header.
|
||||
|
||||
"""
|
||||
markdown_tups: list[tuple[Optional[str], str]] = []
|
||||
lines = markdown_text.split("\n")
|
||||
|
||||
current_header = None
|
||||
current_text = ""
|
||||
|
||||
for line in lines:
|
||||
header_match = re.match(r"^#+\s", line)
|
||||
if header_match:
|
||||
if current_header is not None:
|
||||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
current_header = line
|
||||
current_text = ""
|
||||
else:
|
||||
current_text += line + "\n"
|
||||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
if current_header is not None:
|
||||
# pass linting, assert keys are defined
|
||||
markdown_tups = [
|
||||
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
|
||||
for key, value in markdown_tups
|
||||
]
|
||||
else:
|
||||
markdown_tups = [
|
||||
(key, re.sub("\n", "", value)) for key, value in markdown_tups
|
||||
]
|
||||
|
||||
return markdown_tups
|
||||
|
||||
def remove_images(self, content: str) -> str:
|
||||
"""Get a dictionary of a markdown file from its path."""
|
||||
pattern = r"!{1}\[\[(.*)\]\]"
|
||||
content = re.sub(pattern, "", content)
|
||||
return content
|
||||
|
||||
def remove_hyperlinks(self, content: str) -> str:
|
||||
"""Get a dictionary of a markdown file from its path."""
|
||||
pattern = r"\[(.*?)\]\((.*?)\)"
|
||||
content = re.sub(pattern, r"\1", content)
|
||||
return content
|
||||
|
||||
def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]:
|
||||
"""Parse file into tuples."""
|
||||
content = ""
|
||||
try:
|
||||
with open(filepath, encoding=self._encoding) as f:
|
||||
content = f.read()
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(filepath)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(filepath, encoding=encoding.encoding) as f:
|
||||
content = f.read()
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {filepath}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading {filepath}") from e
|
||||
|
||||
if self._remove_hyperlinks:
|
||||
content = self.remove_hyperlinks(content)
|
||||
|
||||
if self._remove_images:
|
||||
content = self.remove_images(content)
|
||||
|
||||
return self.markdown_to_tups(content)
|
||||
366
api/core/rag/extractor/notion_extractor.py
Normal file
366
api/core/rag/extractor/notion_extractor.py
Normal file
@@ -0,0 +1,366 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document as DocumentModel
|
||||
from models.source import DataSourceBinding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
|
||||
SEARCH_URL = "https://api.notion.com/v1/search"
|
||||
|
||||
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
||||
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
||||
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
|
||||
|
||||
|
||||
class NotionExtractor(BaseExtractor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
notion_workspace_id: str,
|
||||
notion_obj_id: str,
|
||||
notion_page_type: str,
|
||||
document_model: Optional[DocumentModel] = None,
|
||||
notion_access_token: Optional[str] = None
|
||||
):
|
||||
self._notion_access_token = None
|
||||
self._document_model = document_model
|
||||
self._notion_workspace_id = notion_workspace_id
|
||||
self._notion_obj_id = notion_obj_id
|
||||
self._notion_page_type = notion_page_type
|
||||
if notion_access_token:
|
||||
self._notion_access_token = notion_access_token
|
||||
else:
|
||||
self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
|
||||
self._notion_workspace_id)
|
||||
if not self._notion_access_token:
|
||||
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
|
||||
if integration_token is None:
|
||||
raise ValueError(
|
||||
"Must specify `integration_token` or set environment "
|
||||
"variable `NOTION_INTEGRATION_TOKEN`."
|
||||
)
|
||||
|
||||
self._notion_access_token = integration_token
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
self.update_last_edited_time(
|
||||
self._document_model
|
||||
)
|
||||
|
||||
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
|
||||
|
||||
return text_docs
|
||||
|
||||
def _load_data_as_documents(
|
||||
self, notion_obj_id: str, notion_page_type: str
|
||||
) -> list[Document]:
|
||||
docs = []
|
||||
if notion_page_type == 'database':
|
||||
# get all the pages in the database
|
||||
page_text_documents = self._get_notion_database_data(notion_obj_id)
|
||||
docs.extend(page_text_documents)
|
||||
elif notion_page_type == 'page':
|
||||
page_text_list = self._get_notion_block_data(notion_obj_id)
|
||||
for page_text in page_text_list:
|
||||
docs.append(Document(page_content=page_text))
|
||||
else:
|
||||
raise ValueError("notion page type not supported")
|
||||
|
||||
return docs
|
||||
|
||||
def _get_notion_database_data(
|
||||
self, database_id: str, query_dict: dict[str, Any] = {}
|
||||
) -> list[Document]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
res = requests.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict,
|
||||
)
|
||||
|
||||
data = res.json()
|
||||
|
||||
database_content_list = []
|
||||
if 'results' not in data or data["results"] is None:
|
||||
return []
|
||||
for result in data["results"]:
|
||||
properties = result['properties']
|
||||
data = {}
|
||||
for property_name, property_value in properties.items():
|
||||
type = property_value['type']
|
||||
if type == 'multi_select':
|
||||
value = []
|
||||
multi_select_list = property_value[type]
|
||||
for multi_select in multi_select_list:
|
||||
value.append(multi_select['name'])
|
||||
elif type == 'rich_text' or type == 'title':
|
||||
if len(property_value[type]) > 0:
|
||||
value = property_value[type][0]['plain_text']
|
||||
else:
|
||||
value = ''
|
||||
elif type == 'select' or type == 'status':
|
||||
if property_value[type]:
|
||||
value = property_value[type]['name']
|
||||
else:
|
||||
value = ''
|
||||
else:
|
||||
value = property_value[type]
|
||||
data[property_name] = value
|
||||
row_dict = {k: v for k, v in data.items() if v}
|
||||
row_content = ''
|
||||
for key, value in row_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value_dict = {k: v for k, v in value.items() if v}
|
||||
value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items())
|
||||
row_content = row_content + f'{key}:{value_content}\n'
|
||||
else:
|
||||
row_content = row_content + f'{key}:{value}\n'
|
||||
document = Document(page_content=row_content)
|
||||
database_content_list.append(document)
|
||||
|
||||
return database_content_list
|
||||
|
||||
def _get_notion_block_data(self, page_id: str) -> list[str]:
|
||||
result_lines_arr = []
|
||||
cur_block_id = page_id
|
||||
while True:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
block_url,
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict
|
||||
)
|
||||
data = res.json()
|
||||
# current block's heading
|
||||
heading = ''
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
cur_result_text_arr = []
|
||||
if result_type == 'table':
|
||||
result_block_id = result["id"]
|
||||
text = self._read_table_rows(result_block_id)
|
||||
text += "\n\n"
|
||||
result_lines_arr.append(text)
|
||||
else:
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
# skip if doesn't have text object
|
||||
if "text" in rich_text:
|
||||
text = rich_text["text"]["content"]
|
||||
cur_result_text_arr.append(text)
|
||||
if result_type in HEADING_TYPE:
|
||||
heading = text
|
||||
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
block_type = result["type"]
|
||||
if has_children and block_type != 'child_page':
|
||||
children_text = self._read_block(
|
||||
result_block_id, num_tabs=1
|
||||
)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
cur_result_text += "\n\n"
|
||||
if result_type in HEADING_TYPE:
|
||||
result_lines_arr.append(cur_result_text)
|
||||
else:
|
||||
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
else:
|
||||
cur_block_id = data["next_cursor"]
|
||||
return result_lines_arr
|
||||
|
||||
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
|
||||
"""Read a block."""
|
||||
result_lines_arr = []
|
||||
cur_block_id = block_id
|
||||
while True:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
block_url,
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict
|
||||
)
|
||||
data = res.json()
|
||||
if 'results' not in data or data["results"] is None:
|
||||
break
|
||||
heading = ''
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
cur_result_text_arr = []
|
||||
if result_type == 'table':
|
||||
result_block_id = result["id"]
|
||||
text = self._read_table_rows(result_block_id)
|
||||
result_lines_arr.append(text)
|
||||
else:
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
# skip if doesn't have text object
|
||||
if "text" in rich_text:
|
||||
text = rich_text["text"]["content"]
|
||||
prefix = "\t" * num_tabs
|
||||
cur_result_text_arr.append(prefix + text)
|
||||
if result_type in HEADING_TYPE:
|
||||
heading = text
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
block_type = result["type"]
|
||||
if has_children and block_type != 'child_page':
|
||||
children_text = self._read_block(
|
||||
result_block_id, num_tabs=num_tabs + 1
|
||||
)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
if result_type in HEADING_TYPE:
|
||||
result_lines_arr.append(cur_result_text)
|
||||
else:
|
||||
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
else:
|
||||
cur_block_id = data["next_cursor"]
|
||||
|
||||
result_lines = "\n".join(result_lines_arr)
|
||||
return result_lines
|
||||
|
||||
def _read_table_rows(self, block_id: str) -> str:
|
||||
"""Read table rows."""
|
||||
done = False
|
||||
result_lines_arr = []
|
||||
cur_block_id = block_id
|
||||
while not done:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
block_url,
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict
|
||||
)
|
||||
data = res.json()
|
||||
# get table headers text
|
||||
table_header_cell_texts = []
|
||||
tabel_header_cells = data["results"][0]['table_row']['cells']
|
||||
for tabel_header_cell in tabel_header_cells:
|
||||
if tabel_header_cell:
|
||||
for table_header_cell_text in tabel_header_cell:
|
||||
text = table_header_cell_text["text"]["content"]
|
||||
table_header_cell_texts.append(text)
|
||||
# get table columns text and format
|
||||
results = data["results"]
|
||||
for i in range(len(results) - 1):
|
||||
column_texts = []
|
||||
tabel_column_cells = data["results"][i + 1]['table_row']['cells']
|
||||
for j in range(len(tabel_column_cells)):
|
||||
if tabel_column_cells[j]:
|
||||
for table_column_cell_text in tabel_column_cells[j]:
|
||||
column_text = table_column_cell_text["text"]["content"]
|
||||
column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
|
||||
|
||||
cur_result_text = "\n".join(column_texts)
|
||||
result_lines_arr.append(cur_result_text)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
done = True
|
||||
break
|
||||
else:
|
||||
cur_block_id = data["next_cursor"]
|
||||
|
||||
result_lines = "\n".join(result_lines_arr)
|
||||
return result_lines
|
||||
|
||||
def update_last_edited_time(self, document_model: DocumentModel):
|
||||
if not document_model:
|
||||
return
|
||||
|
||||
last_edited_time = self.get_notion_last_edited_time()
|
||||
data_source_info = document_model.data_source_info_dict
|
||||
data_source_info['last_edited_time'] = last_edited_time
|
||||
update_params = {
|
||||
DocumentModel.data_source_info: json.dumps(data_source_info)
|
||||
}
|
||||
|
||||
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def get_notion_last_edited_time(self) -> str:
|
||||
obj_id = self._notion_obj_id
|
||||
page_type = self._notion_page_type
|
||||
if page_type == 'database':
|
||||
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
|
||||
else:
|
||||
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
|
||||
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
retrieve_page_url,
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict
|
||||
)
|
||||
|
||||
data = res.json()
|
||||
return data["last_edited_time"]
|
||||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(f'No notion data source binding found for tenant {tenant_id} '
|
||||
f'and notion workspace {notion_workspace_id}')
|
||||
|
||||
return data_source_binding.access_token
|
||||
72
api/core/rag/extractor/pdf_extractor.py
Normal file
72
api/core/rag/extractor/pdf_extractor.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.blod.blod import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class PdfExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
file_cache_key: Optional[str] = None
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
text = storage.load(self._file_cache_key).decode('utf-8')
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
documents = list(self.load())
|
||||
text_list = []
|
||||
for document in documents:
|
||||
text_list.append(document.page_content)
|
||||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
|
||||
return documents
|
||||
|
||||
def load(
|
||||
self,
|
||||
) -> Iterator[Document]:
|
||||
"""Lazy load given path as pages."""
|
||||
blob = Blob.from_path(self._file_path)
|
||||
yield from self.parse(blob)
|
||||
|
||||
def parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
import pypdfium2
|
||||
|
||||
with blob.as_bytes_io() as file_path:
|
||||
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
||||
try:
|
||||
for page_number, page in enumerate(pdf_reader):
|
||||
text_page = page.get_textpage()
|
||||
content = text_page.get_text_range()
|
||||
text_page.close()
|
||||
page.close()
|
||||
metadata = {"source": blob.source, "page": page_number}
|
||||
yield Document(page_content=content, metadata=metadata)
|
||||
finally:
|
||||
pdf_reader.close()
|
||||
50
api/core/rag/extractor/text_extractor.py
Normal file
50
api/core/rag/extractor/text_extractor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class TextExtractor(BaseExtractor):
|
||||
"""Load text files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load from file path."""
|
||||
text = ""
|
||||
try:
|
||||
with open(self._file_path, encoding=self._encoding) as f:
|
||||
text = f.read()
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, encoding=encoding.encoding) as f:
|
||||
text = f.read()
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
|
||||
metadata = {"source": self._file_path}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredWordExtractor(BaseExtractor):
|
||||
"""Loader that uses unstructured to load word documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.__version__ import __version__ as __unstructured_version__
|
||||
from unstructured.file_utils.filetype import FileType, detect_filetype
|
||||
|
||||
unstructured_version = tuple(
|
||||
[int(x) for x in __unstructured_version__.split(".")]
|
||||
)
|
||||
# check the file extension
|
||||
try:
|
||||
import magic # noqa: F401
|
||||
|
||||
is_doc = detect_filetype(self._file_path) == FileType.DOC
|
||||
except ImportError:
|
||||
_, extension = os.path.splitext(str(self._file_path))
|
||||
is_doc = extension == ".doc"
|
||||
|
||||
if is_doc and unstructured_version < (0, 4, 11):
|
||||
raise ValueError(
|
||||
f"You are on unstructured version {__unstructured_version__}. "
|
||||
"Partitioning .doc files is only supported in unstructured>=0.4.11. "
|
||||
"Please upgrade the unstructured package and try again."
|
||||
)
|
||||
|
||||
if is_doc:
|
||||
from unstructured.partition.doc import partition_doc
|
||||
|
||||
elements = partition_doc(filename=self._file_path)
|
||||
else:
|
||||
from unstructured.partition.docx import partition_docx
|
||||
|
||||
elements = partition_docx(filename=self._file_path)
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
return documents
|
||||
@@ -0,0 +1,51 @@
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredEmailExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
elements = partition_email(filename=self._file_path, api_url=self._api_url)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
for element in elements:
|
||||
element_text = element.text.strip()
|
||||
|
||||
padding_needed = 4 - len(element_text) % 4
|
||||
element_text += '=' * padding_needed
|
||||
|
||||
element_decode = base64.b64decode(element_text)
|
||||
soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser')
|
||||
element.text = soup.get_text()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
return documents
|
||||
@@ -0,0 +1,47 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMarkdownExtractor(BaseExtractor):
|
||||
"""Load md files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
|
||||
remove_hyperlinks: Whether to remove hyperlinks from the text.
|
||||
|
||||
remove_images: Whether to remove images from the text.
|
||||
|
||||
encoding: File encoding to use. If `None`, the file will be loaded
|
||||
with the default system encoding.
|
||||
|
||||
autodetect_encoding: Whether to try to autodetect the file encoding
|
||||
if the specified encoding fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.md import partition_md
|
||||
|
||||
elements = partition_md(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@@ -0,0 +1,37 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredMsgExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
|
||||
text_by_page = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
text = element.text
|
||||
if page in text_by_page:
|
||||
text_by_page[page] += "\n" + text
|
||||
else:
|
||||
text_by_page[page] = text
|
||||
|
||||
combined_texts = list(text_by_page.values())
|
||||
documents = []
|
||||
for combined_text in combined_texts:
|
||||
text = combined_text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
return documents
|
||||
@@ -0,0 +1,45 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTXExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
|
||||
text_by_page = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
text = element.text
|
||||
if page in text_by_page:
|
||||
text_by_page[page] += "\n" + text
|
||||
else:
|
||||
text_by_page[page] = text
|
||||
|
||||
combined_texts = list(text_by_page.values())
|
||||
documents = []
|
||||
for combined_text in combined_texts:
|
||||
text = combined_text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@@ -0,0 +1,37 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredTextExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.text import partition_text
|
||||
|
||||
elements = partition_text(filename=self._file_path, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@@ -0,0 +1,37 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredXmlExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.xml import partition_xml
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
62
api/core/rag/extractor/word_extractor.py
Normal file
62
api/core/rag/extractor/word_extractor.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
import os
|
||||
import tempfile
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class WordExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
"""Initialize with file path."""
|
||||
self.file_path = file_path
|
||||
if "~" in self.file_path:
|
||||
self.file_path = os.path.expanduser(self.file_path)
|
||||
|
||||
# If the file is a web path, download it to a temporary file, and use that
|
||||
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
|
||||
r = requests.get(self.file_path)
|
||||
|
||||
if r.status_code != 200:
|
||||
raise ValueError(
|
||||
"Check the url of your file; returned status code %s"
|
||||
% r.status_code
|
||||
)
|
||||
|
||||
self.web_path = self.file_path
|
||||
self.temp_file = tempfile.NamedTemporaryFile()
|
||||
self.temp_file.write(r.content)
|
||||
self.file_path = self.temp_file.name
|
||||
elif not os.path.isfile(self.file_path):
|
||||
raise ValueError("File path %s is not a valid file or url" % self.file_path)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, "temp_file"):
|
||||
self.temp_file.close()
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
import docx2txt
|
||||
|
||||
return [
|
||||
Document(
|
||||
page_content=docx2txt.process(self.file_path),
|
||||
metadata={"source": self.file_path},
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_url(url: str) -> bool:
|
||||
"""Check if the url is valid."""
|
||||
parsed = urlparse(url)
|
||||
return bool(parsed.netloc) and bool(parsed.scheme)
|
||||
Reference in New Issue
Block a user