feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -14,7 +14,7 @@ class NotionInfo(BaseModel):
notion_workspace_id: str
notion_obj_id: str
notion_page_type: str
document: Document = None
document: Optional[Document] = None
tenant_id: str
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -1,7 +1,7 @@
"""Abstract interface for document loader implementations."""
import os
from typing import Optional
from typing import Optional, cast
import pandas as pd
from openpyxl import load_workbook
@@ -47,7 +47,7 @@ class ExcelExtractor(BaseExtractor):
for col_index, (k, v) in enumerate(row.items()):
if pd.notna(v):
cell = sheet.cell(
row=index + 2, column=col_index + 1
row=cast(int, index) + 2, column=col_index + 1
) # +2 to account for header and 1-based index
if cell.hyperlink:
value = f"[{v}]({cell.hyperlink.target})"
@@ -60,8 +60,8 @@ class ExcelExtractor(BaseExtractor):
elif file_extension == ".xls":
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
for sheet_name in excel_file.sheet_names:
df = excel_file.parse(sheet_name=sheet_name)
for excel_sheet_name in excel_file.sheet_names:
df = excel_file.parse(sheet_name=excel_sheet_name)
df.dropna(how="all", inplace=True)
for _, row in df.iterrows():

View File

@@ -10,6 +10,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.excel_extractor import ExcelExtractor
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
from core.rag.extractor.html_extractor import HtmlExtractor
from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor
@@ -66,9 +67,13 @@ class ExtractProcessor:
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
match = re.search(r"\.(\w+)$", filename)
if match:
suffix = "." + match.group(1)
else:
suffix = ""
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
if return_text:
@@ -89,15 +94,20 @@ class ExtractProcessor:
if extract_setting.datasource_type == DatasourceType.FILE.value:
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
assert extract_setting.upload_file is not None, "upload_file is required"
upload_file: UploadFile = extract_setting.upload_file
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
etl_type = dify_config.ETL_TYPE
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
assert unstructured_api_url is not None, "unstructured_api_url is required"
assert unstructured_api_key is not None, "unstructured_api_key is required"
extractor: Optional[BaseExtractor] = None
if etl_type == "Unstructured":
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
@@ -156,6 +166,7 @@ class ExtractProcessor:
extractor = TextExtractor(file_path, autodetect_encoding=True)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
assert extract_setting.notion_info is not None, "notion_info is required"
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_obj_id=extract_setting.notion_info.notion_obj_id,
@@ -165,6 +176,7 @@ class ExtractProcessor:
)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
assert extract_setting.website_info is not None, "website_info is required"
if extract_setting.website_info.provider == "firecrawl":
extractor = FirecrawlWebExtractor(
url=extract_setting.website_info.url,

View File

@@ -1,5 +1,6 @@
import json
import time
from typing import cast
import requests
@@ -20,9 +21,9 @@ class FirecrawlApp:
json_data.update(params)
response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data)
if response.status_code == 200:
response = response.json()
if response["success"] == True:
data = response["data"]
response_data = response.json()
if response_data["success"] == True:
data = response_data["data"]
return {
"title": data.get("metadata").get("title"),
"description": data.get("metadata").get("description"),
@@ -30,7 +31,7 @@ class FirecrawlApp:
"markdown": data.get("markdown"),
}
else:
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}')
elif response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
@@ -46,9 +47,11 @@ class FirecrawlApp:
response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers)
if response.status_code == 200:
job_id = response.json().get("jobId")
return job_id
return cast(str, job_id)
else:
self._handle_error(response, "start crawl job")
# FIXME: unreachable code for mypy
return "" # unreachable
def check_crawl_status(self, job_id) -> dict:
headers = self._prepare_headers()
@@ -64,9 +67,9 @@ class FirecrawlApp:
for item in data:
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
url_data = {
"title": item.get("metadata").get("title"),
"description": item.get("metadata").get("description"),
"source_url": item.get("metadata").get("sourceURL"),
"title": item.get("metadata", {}).get("title"),
"description": item.get("metadata", {}).get("description"),
"source_url": item.get("metadata", {}).get("sourceURL"),
"markdown": item.get("markdown"),
}
url_data_list.append(url_data)
@@ -92,6 +95,8 @@ class FirecrawlApp:
else:
self._handle_error(response, "check crawl status")
# FIXME: unreachable code for mypy
return {} # unreachable
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}

View File

@@ -1,6 +1,6 @@
"""Abstract interface for document loader implementations."""
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
@@ -23,6 +23,7 @@ class HtmlExtractor(BaseExtractor):
return [Document(page_content=self._load_as_text())]
def _load_as_text(self) -> str:
text: str = ""
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, "html.parser")
text = soup.get_text()

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Optional
from typing import Any, Optional, cast
import requests
@@ -78,6 +78,7 @@ class NotionExtractor(BaseExtractor):
def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]:
"""Get all the pages from a Notion database."""
assert self._notion_access_token is not None, "Notion access token is required"
res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id),
headers={
@@ -96,6 +97,7 @@ class NotionExtractor(BaseExtractor):
for result in data["results"]:
properties = result["properties"]
data = {}
value: Any
for property_name, property_value in properties.items():
type = property_value["type"]
if type == "multi_select":
@@ -130,6 +132,7 @@ class NotionExtractor(BaseExtractor):
return [Document(page_content="\n".join(database_content))]
def _get_notion_block_data(self, page_id: str) -> list[str]:
assert self._notion_access_token is not None, "Notion access token is required"
result_lines_arr = []
start_cursor = None
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id)
@@ -184,6 +187,7 @@ class NotionExtractor(BaseExtractor):
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
"""Read a block."""
assert self._notion_access_token is not None, "Notion access token is required"
result_lines_arr = []
start_cursor = None
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id)
@@ -242,6 +246,7 @@ class NotionExtractor(BaseExtractor):
def _read_table_rows(self, block_id: str) -> str:
"""Read table rows."""
assert self._notion_access_token is not None, "Notion access token is required"
done = False
result_lines_arr = []
start_cursor = None
@@ -296,7 +301,7 @@ class NotionExtractor(BaseExtractor):
result_lines = "\n".join(result_lines_arr)
return result_lines
def update_last_edited_time(self, document_model: DocumentModel):
def update_last_edited_time(self, document_model: Optional[DocumentModel]):
if not document_model:
return
@@ -309,6 +314,7 @@ class NotionExtractor(BaseExtractor):
db.session.commit()
def get_notion_last_edited_time(self) -> str:
assert self._notion_access_token is not None, "Notion access token is required"
obj_id = self._notion_obj_id
page_type = self._notion_page_type
if page_type == "database":
@@ -330,7 +336,7 @@ class NotionExtractor(BaseExtractor):
)
data = res.json()
return data["last_edited_time"]
return cast(str, data["last_edited_time"])
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
@@ -349,4 +355,4 @@ class NotionExtractor(BaseExtractor):
f"and notion workspace {notion_workspace_id}"
)
return data_source_binding.access_token
return cast(str, data_source_binding.access_token)

View File

@@ -1,7 +1,7 @@
"""Abstract interface for document loader implementations."""
from collections.abc import Iterator
from typing import Optional
from typing import Optional, cast
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
@@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
plaintext_file_exists = False
if self._file_cache_key:
try:
text = storage.load(self._file_cache_key).decode("utf-8")
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
@@ -53,7 +53,7 @@ class PdfExtractor(BaseExtractor):
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
import pypdfium2
import pypdfium2 # type: ignore
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)

View File

@@ -1,7 +1,7 @@
import base64
import logging
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@@ -30,6 +30,9 @@ class UnstructuredEpubExtractor(BaseExtractor):
if self._api_url:
from unstructured.partition.api import partition_via_api
if self._api_key is None:
raise ValueError("api_key is required")
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
from unstructured.partition.epub import partition_epub

View File

@@ -27,9 +27,11 @@ class UnstructuredPPTExtractor(BaseExtractor):
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
raise NotImplementedError("Unstructured API Url is not configured")
text_by_page = {}
text_by_page: dict[int, str] = {}
for element in elements:
page = element.metadata.page_number
if page is None:
continue
text = element.text
if page in text_by_page:
text_by_page[page] += "\n" + text

View File

@@ -29,14 +29,15 @@ class UnstructuredPPTXExtractor(BaseExtractor):
from unstructured.partition.pptx import partition_pptx
elements = partition_pptx(filename=self._file_path)
text_by_page = {}
text_by_page: dict[int, str] = {}
for element in elements:
page = element.metadata.page_number
text = element.text
if page in text_by_page:
text_by_page[page] += "\n" + text
else:
text_by_page[page] = text
if page is not None:
if page in text_by_page:
text_by_page[page] += "\n" + text
else:
text_by_page[page] = text
combined_texts = list(text_by_page.values())
documents = []

View File

@@ -89,6 +89,8 @@ class WordExtractor(BaseExtractor):
response = ssrf_proxy.get(url)
if response.status_code == 200:
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
if image_ext is None:
continue
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
mime_type, _ = mimetypes.guess_type(file_key)
@@ -97,6 +99,8 @@ class WordExtractor(BaseExtractor):
continue
else:
image_ext = rel.target_ref.split(".")[-1]
if image_ext is None:
continue
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
@@ -226,6 +230,8 @@ class WordExtractor(BaseExtractor):
if x_child is None:
continue
if x.tag.endswith("instrText"):
if x.text is None:
continue
for i in url_pattern.findall(x.text):
hyperlinks_url = str(i)
except Exception as e: