mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 19:06:51 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -14,7 +14,7 @@ class NotionInfo(BaseModel):
|
||||
notion_workspace_id: str
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
document: Document = None
|
||||
document: Optional[Document] = None
|
||||
tenant_id: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook
|
||||
@@ -47,7 +47,7 @@ class ExcelExtractor(BaseExtractor):
|
||||
for col_index, (k, v) in enumerate(row.items()):
|
||||
if pd.notna(v):
|
||||
cell = sheet.cell(
|
||||
row=index + 2, column=col_index + 1
|
||||
row=cast(int, index) + 2, column=col_index + 1
|
||||
) # +2 to account for header and 1-based index
|
||||
if cell.hyperlink:
|
||||
value = f"[{v}]({cell.hyperlink.target})"
|
||||
@@ -60,8 +60,8 @@ class ExcelExtractor(BaseExtractor):
|
||||
|
||||
elif file_extension == ".xls":
|
||||
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
|
||||
for sheet_name in excel_file.sheet_names:
|
||||
df = excel_file.parse(sheet_name=sheet_name)
|
||||
for excel_sheet_name in excel_file.sheet_names:
|
||||
df = excel_file.parse(sheet_name=excel_sheet_name)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
|
||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||
from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor
|
||||
@@ -66,9 +67,13 @@ class ExtractProcessor:
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
|
||||
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
match = re.search(r"\.(\w+)$", filename)
|
||||
if match:
|
||||
suffix = "." + match.group(1)
|
||||
else:
|
||||
suffix = ""
|
||||
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
||||
if return_text:
|
||||
@@ -89,15 +94,20 @@ class ExtractProcessor:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE.value:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if not file_path:
|
||||
assert extract_setting.upload_file is not None, "upload_file is required"
|
||||
upload_file: UploadFile = extract_setting.upload_file
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
|
||||
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
|
||||
assert unstructured_api_url is not None, "unstructured_api_url is required"
|
||||
assert unstructured_api_key is not None, "unstructured_api_key is required"
|
||||
extractor: Optional[BaseExtractor] = None
|
||||
if etl_type == "Unstructured":
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
@@ -156,6 +166,7 @@ class ExtractProcessor:
|
||||
extractor = TextExtractor(file_path, autodetect_encoding=True)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
|
||||
assert extract_setting.notion_info is not None, "notion_info is required"
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
|
||||
notion_obj_id=extract_setting.notion_info.notion_obj_id,
|
||||
@@ -165,6 +176,7 @@ class ExtractProcessor:
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
|
||||
assert extract_setting.website_info is not None, "website_info is required"
|
||||
if extract_setting.website_info.provider == "firecrawl":
|
||||
extractor = FirecrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
|
||||
@@ -20,9 +21,9 @@ class FirecrawlApp:
|
||||
json_data.update(params)
|
||||
response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data)
|
||||
if response.status_code == 200:
|
||||
response = response.json()
|
||||
if response["success"] == True:
|
||||
data = response["data"]
|
||||
response_data = response.json()
|
||||
if response_data["success"] == True:
|
||||
data = response_data["data"]
|
||||
return {
|
||||
"title": data.get("metadata").get("title"),
|
||||
"description": data.get("metadata").get("description"),
|
||||
@@ -30,7 +31,7 @@ class FirecrawlApp:
|
||||
"markdown": data.get("markdown"),
|
||||
}
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
|
||||
raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}')
|
||||
|
||||
elif response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
@@ -46,9 +47,11 @@ class FirecrawlApp:
|
||||
response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
job_id = response.json().get("jobId")
|
||||
return job_id
|
||||
return cast(str, job_id)
|
||||
else:
|
||||
self._handle_error(response, "start crawl job")
|
||||
# FIXME: unreachable code for mypy
|
||||
return "" # unreachable
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict:
|
||||
headers = self._prepare_headers()
|
||||
@@ -64,9 +67,9 @@ class FirecrawlApp:
|
||||
for item in data:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = {
|
||||
"title": item.get("metadata").get("title"),
|
||||
"description": item.get("metadata").get("description"),
|
||||
"source_url": item.get("metadata").get("sourceURL"),
|
||||
"title": item.get("metadata", {}).get("title"),
|
||||
"description": item.get("metadata", {}).get("description"),
|
||||
"source_url": item.get("metadata", {}).get("sourceURL"),
|
||||
"markdown": item.get("markdown"),
|
||||
}
|
||||
url_data_list.append(url_data)
|
||||
@@ -92,6 +95,8 @@ class FirecrawlApp:
|
||||
|
||||
else:
|
||||
self._handle_error(response, "check crawl status")
|
||||
# FIXME: unreachable code for mypy
|
||||
return {} # unreachable
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
@@ -23,6 +23,7 @@ class HtmlExtractor(BaseExtractor):
|
||||
return [Document(page_content=self._load_as_text())]
|
||||
|
||||
def _load_as_text(self) -> str:
|
||||
text: str = ""
|
||||
with open(self._file_path, "rb") as fp:
|
||||
soup = BeautifulSoup(fp, "html.parser")
|
||||
text = soup.get_text()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import requests
|
||||
|
||||
@@ -78,6 +78,7 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
assert self._notion_access_token is not None, "Notion access token is required"
|
||||
res = requests.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
headers={
|
||||
@@ -96,6 +97,7 @@ class NotionExtractor(BaseExtractor):
|
||||
for result in data["results"]:
|
||||
properties = result["properties"]
|
||||
data = {}
|
||||
value: Any
|
||||
for property_name, property_value in properties.items():
|
||||
type = property_value["type"]
|
||||
if type == "multi_select":
|
||||
@@ -130,6 +132,7 @@ class NotionExtractor(BaseExtractor):
|
||||
return [Document(page_content="\n".join(database_content))]
|
||||
|
||||
def _get_notion_block_data(self, page_id: str) -> list[str]:
|
||||
assert self._notion_access_token is not None, "Notion access token is required"
|
||||
result_lines_arr = []
|
||||
start_cursor = None
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id)
|
||||
@@ -184,6 +187,7 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
|
||||
"""Read a block."""
|
||||
assert self._notion_access_token is not None, "Notion access token is required"
|
||||
result_lines_arr = []
|
||||
start_cursor = None
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id)
|
||||
@@ -242,6 +246,7 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
def _read_table_rows(self, block_id: str) -> str:
|
||||
"""Read table rows."""
|
||||
assert self._notion_access_token is not None, "Notion access token is required"
|
||||
done = False
|
||||
result_lines_arr = []
|
||||
start_cursor = None
|
||||
@@ -296,7 +301,7 @@ class NotionExtractor(BaseExtractor):
|
||||
result_lines = "\n".join(result_lines_arr)
|
||||
return result_lines
|
||||
|
||||
def update_last_edited_time(self, document_model: DocumentModel):
|
||||
def update_last_edited_time(self, document_model: Optional[DocumentModel]):
|
||||
if not document_model:
|
||||
return
|
||||
|
||||
@@ -309,6 +314,7 @@ class NotionExtractor(BaseExtractor):
|
||||
db.session.commit()
|
||||
|
||||
def get_notion_last_edited_time(self) -> str:
|
||||
assert self._notion_access_token is not None, "Notion access token is required"
|
||||
obj_id = self._notion_obj_id
|
||||
page_type = self._notion_page_type
|
||||
if page_type == "database":
|
||||
@@ -330,7 +336,7 @@ class NotionExtractor(BaseExtractor):
|
||||
)
|
||||
|
||||
data = res.json()
|
||||
return data["last_edited_time"]
|
||||
return cast(str, data["last_edited_time"])
|
||||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
@@ -349,4 +355,4 @@ class NotionExtractor(BaseExtractor):
|
||||
f"and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return data_source_binding.access_token
|
||||
return cast(str, data_source_binding.access_token)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.rag.extractor.blob.blob import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
text = storage.load(self._file_cache_key).decode("utf-8")
|
||||
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
@@ -53,7 +53,7 @@ class PdfExtractor(BaseExtractor):
|
||||
|
||||
def parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
import pypdfium2
|
||||
import pypdfium2 # type: ignore
|
||||
|
||||
with blob.as_bytes_io() as file_path:
|
||||
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
@@ -30,6 +30,9 @@ class UnstructuredEpubExtractor(BaseExtractor):
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
if self._api_key is None:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.epub import partition_epub
|
||||
|
||||
@@ -27,9 +27,11 @@ class UnstructuredPPTExtractor(BaseExtractor):
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
raise NotImplementedError("Unstructured API Url is not configured")
|
||||
text_by_page = {}
|
||||
text_by_page: dict[int, str] = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
if page is None:
|
||||
continue
|
||||
text = element.text
|
||||
if page in text_by_page:
|
||||
text_by_page[page] += "\n" + text
|
||||
|
||||
@@ -29,14 +29,15 @@ class UnstructuredPPTXExtractor(BaseExtractor):
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
elements = partition_pptx(filename=self._file_path)
|
||||
text_by_page = {}
|
||||
text_by_page: dict[int, str] = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
text = element.text
|
||||
if page in text_by_page:
|
||||
text_by_page[page] += "\n" + text
|
||||
else:
|
||||
text_by_page[page] = text
|
||||
if page is not None:
|
||||
if page in text_by_page:
|
||||
text_by_page[page] += "\n" + text
|
||||
else:
|
||||
text_by_page[page] = text
|
||||
|
||||
combined_texts = list(text_by_page.values())
|
||||
documents = []
|
||||
|
||||
@@ -89,6 +89,8 @@ class WordExtractor(BaseExtractor):
|
||||
response = ssrf_proxy.get(url)
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
if image_ext is None:
|
||||
continue
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
@@ -97,6 +99,8 @@ class WordExtractor(BaseExtractor):
|
||||
continue
|
||||
else:
|
||||
image_ext = rel.target_ref.split(".")[-1]
|
||||
if image_ext is None:
|
||||
continue
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
@@ -226,6 +230,8 @@ class WordExtractor(BaseExtractor):
|
||||
if x_child is None:
|
||||
continue
|
||||
if x.tag.endswith("instrText"):
|
||||
if x.text is None:
|
||||
continue
|
||||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user