mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-08 10:26:50 +08:00
feat: Integrate WaterCrawl.dev as a new knowledge base provider (#16396)
Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
committed by
GitHub
parent
0afad94378
commit
f54905e685
@@ -14,7 +14,12 @@ class WebsiteCrawlApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"provider", type=str, choices=["firecrawl", "jinareader"], required=True, nullable=True, location="json"
|
||||
"provider",
|
||||
type=str,
|
||||
choices=["firecrawl", "watercrawl", "jinareader"],
|
||||
required=True,
|
||||
nullable=True,
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||
@@ -34,7 +39,9 @@ class WebsiteCrawlStatusApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, job_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, choices=["firecrawl", "jinareader"], required=True, location="args")
|
||||
parser.add_argument(
|
||||
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# get crawl status
|
||||
try:
|
||||
|
||||
@@ -26,6 +26,7 @@ from core.rag.extractor.unstructured.unstructured_msg_extractor import Unstructu
|
||||
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
|
||||
from core.rag.extractor.watercrawl.extractor import WaterCrawlWebExtractor
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
@@ -183,6 +184,15 @@ class ExtractProcessor:
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.website_info.provider == "watercrawl":
|
||||
extractor = WaterCrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
tenant_id=extract_setting.website_info.tenant_id,
|
||||
mode=extract_setting.website_info.mode,
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.website_info.provider == "jinareader":
|
||||
extractor = JinaReaderWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
|
||||
161
api/core/rag/extractor/watercrawl/client.py
Normal file
161
api/core/rag/extractor/watercrawl/client.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from requests import Response
|
||||
|
||||
|
||||
class BaseAPIClient:
|
||||
def __init__(self, api_key, base_url):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.session = self.init_session()
|
||||
|
||||
def init_session(self):
|
||||
session = requests.Session()
|
||||
session.headers.update({"X-API-Key": self.api_key})
|
||||
session.headers.update({"Content-Type": "application/json"})
|
||||
session.headers.update({"Accept": "application/json"})
|
||||
session.headers.update({"User-Agent": "WaterCrawl-Plugin"})
|
||||
session.headers.update({"Accept-Language": "en-US"})
|
||||
return session
|
||||
|
||||
def _get(self, endpoint: str, query_params: dict | None = None, **kwargs):
|
||||
return self.session.get(urljoin(self.base_url, endpoint), params=query_params, **kwargs)
|
||||
|
||||
def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
return self.session.post(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
|
||||
|
||||
def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
return self.session.put(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
|
||||
|
||||
def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs):
|
||||
return self.session.delete(urljoin(self.base_url, endpoint), params=query_params, **kwargs)
|
||||
|
||||
def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
return self.session.patch(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
|
||||
|
||||
|
||||
class WaterCrawlAPIClient(BaseAPIClient):
|
||||
def __init__(self, api_key, base_url: str | None = "https://app.watercrawl.dev/"):
|
||||
super().__init__(api_key, base_url)
|
||||
|
||||
def process_eventstream(self, response: Response, download: bool = False) -> Generator:
|
||||
for line in response.iter_lines():
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
data = json.loads(line)
|
||||
if data["type"] == "result" and download:
|
||||
data["data"] = self.download_result(data["data"])
|
||||
yield data
|
||||
|
||||
def process_response(self, response: Response) -> dict | bytes | list | None | Generator:
|
||||
response.raise_for_status()
|
||||
if response.status_code == 204:
|
||||
return None
|
||||
if response.headers.get("Content-Type") == "application/json":
|
||||
return response.json() or {}
|
||||
|
||||
if response.headers.get("Content-Type") == "application/octet-stream":
|
||||
return response.content
|
||||
|
||||
if response.headers.get("Content-Type") == "text/event-stream":
|
||||
return self.process_eventstream(response)
|
||||
|
||||
raise Exception(f"Unknown response type: {response.headers.get('Content-Type')}")
|
||||
|
||||
def get_crawl_requests_list(self, page: int | None = None, page_size: int | None = None):
|
||||
query_params = {"page": page or 1, "page_size": page_size or 10}
|
||||
return self.process_response(
|
||||
self._get(
|
||||
"/api/v1/core/crawl-requests/",
|
||||
query_params=query_params,
|
||||
)
|
||||
)
|
||||
|
||||
def get_crawl_request(self, item_id: str):
|
||||
return self.process_response(
|
||||
self._get(
|
||||
f"/api/v1/core/crawl-requests/{item_id}/",
|
||||
)
|
||||
)
|
||||
|
||||
def create_crawl_request(
|
||||
self,
|
||||
url: Union[list, str] | None = None,
|
||||
spider_options: dict | None = None,
|
||||
page_options: dict | None = None,
|
||||
plugin_options: dict | None = None,
|
||||
):
|
||||
data = {
|
||||
# 'urls': url if isinstance(url, list) else [url],
|
||||
"url": url,
|
||||
"options": {
|
||||
"spider_options": spider_options or {},
|
||||
"page_options": page_options or {},
|
||||
"plugin_options": plugin_options or {},
|
||||
},
|
||||
}
|
||||
return self.process_response(
|
||||
self._post(
|
||||
"/api/v1/core/crawl-requests/",
|
||||
data=data,
|
||||
)
|
||||
)
|
||||
|
||||
def stop_crawl_request(self, item_id: str):
|
||||
return self.process_response(
|
||||
self._delete(
|
||||
f"/api/v1/core/crawl-requests/{item_id}/",
|
||||
)
|
||||
)
|
||||
|
||||
def download_crawl_request(self, item_id: str):
|
||||
return self.process_response(
|
||||
self._get(
|
||||
f"/api/v1/core/crawl-requests/{item_id}/download/",
|
||||
)
|
||||
)
|
||||
|
||||
def monitor_crawl_request(self, item_id: str, prefetched=False) -> Generator:
|
||||
query_params = {"prefetched": str(prefetched).lower()}
|
||||
generator = self.process_response(
|
||||
self._get(f"/api/v1/core/crawl-requests/{item_id}/status/", stream=True, query_params=query_params),
|
||||
)
|
||||
if not isinstance(generator, Generator):
|
||||
raise ValueError("Generator expected")
|
||||
yield from generator
|
||||
|
||||
def get_crawl_request_results(
|
||||
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict | None = None
|
||||
):
|
||||
query_params = query_params or {}
|
||||
query_params.update({"page": page or 1, "page_size": page_size or 25})
|
||||
return self.process_response(
|
||||
self._get(f"/api/v1/core/crawl-requests/{item_id}/results/", query_params=query_params)
|
||||
)
|
||||
|
||||
def scrape_url(
|
||||
self,
|
||||
url: str,
|
||||
page_options: dict | None = None,
|
||||
plugin_options: dict | None = None,
|
||||
sync: bool = True,
|
||||
prefetched: bool = True,
|
||||
):
|
||||
response_result = self.create_crawl_request(url=url, page_options=page_options, plugin_options=plugin_options)
|
||||
if not sync:
|
||||
return response_result
|
||||
|
||||
for event_data in self.monitor_crawl_request(response_result["uuid"], prefetched):
|
||||
if event_data["type"] == "result":
|
||||
return event_data["data"]
|
||||
|
||||
def download_result(self, result_object: dict):
|
||||
response = requests.get(result_object["result"])
|
||||
response.raise_for_status()
|
||||
result_object["result"] = response.json()
|
||||
return result_object
|
||||
57
api/core/rag/extractor/watercrawl/extractor.py
Normal file
57
api/core/rag/extractor/watercrawl/extractor.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from services.website_service import WebsiteService
|
||||
|
||||
|
||||
class WaterCrawlWebExtractor(BaseExtractor):
|
||||
"""
|
||||
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||
|
||||
|
||||
Args:
|
||||
url: The URL to scrape.
|
||||
api_key: The API key for WaterCrawl.
|
||||
base_url: The base URL for the Firecrawl API. Defaults to 'https://app.firecrawl.dev'.
|
||||
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
|
||||
only_main_content: Only return the main content of the page excluding headers, navs, footers, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.mode = mode
|
||||
self.only_main_content = only_main_content
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Extract content from the URL."""
|
||||
documents = []
|
||||
if self.mode == "crawl":
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "watercrawl", self._url, self.tenant_id)
|
||||
if crawl_data is None:
|
||||
return []
|
||||
document = Document(
|
||||
page_content=crawl_data.get("markdown", ""),
|
||||
metadata={
|
||||
"source_url": crawl_data.get("source_url"),
|
||||
"description": crawl_data.get("description"),
|
||||
"title": crawl_data.get("title"),
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
elif self.mode == "scrape":
|
||||
scrape_data = WebsiteService.get_scrape_url_data(
|
||||
"watercrawl", self._url, self.tenant_id, self.only_main_content
|
||||
)
|
||||
|
||||
document = Document(
|
||||
page_content=scrape_data.get("markdown", ""),
|
||||
metadata={
|
||||
"source_url": scrape_data.get("source_url"),
|
||||
"description": scrape_data.get("description"),
|
||||
"title": scrape_data.get("title"),
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
return documents
|
||||
117
api/core/rag/extractor/watercrawl/provider.py
Normal file
117
api/core/rag/extractor/watercrawl/provider.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient
|
||||
|
||||
|
||||
class WaterCrawlProvider:
|
||||
def __init__(self, api_key, base_url: str | None = None):
|
||||
self.client = WaterCrawlAPIClient(api_key, base_url)
|
||||
|
||||
def crawl_url(self, url, options: dict | Any = None) -> dict:
|
||||
options = options or {}
|
||||
spider_options = {
|
||||
"max_depth": 1,
|
||||
"page_limit": 1,
|
||||
"allowed_domains": [],
|
||||
"exclude_paths": [],
|
||||
"include_paths": [],
|
||||
}
|
||||
if options.get("crawl_sub_pages", True):
|
||||
spider_options["page_limit"] = options.get("limit", 1)
|
||||
spider_options["max_depth"] = options.get("depth", 1)
|
||||
spider_options["include_paths"] = options.get("includes", "").split(",") if options.get("includes") else []
|
||||
spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else []
|
||||
|
||||
wait_time = options.get("wait_time", 1000)
|
||||
page_options = {
|
||||
"exclude_tags": options.get("exclude_tags", "").split(",") if options.get("exclude_tags") else [],
|
||||
"include_tags": options.get("include_tags", "").split(",") if options.get("include_tags") else [],
|
||||
"wait_time": max(1000, wait_time), # minimum wait time is 1 second
|
||||
"include_html": False,
|
||||
"only_main_content": options.get("only_main_content", True),
|
||||
"include_links": False,
|
||||
"timeout": 15000,
|
||||
"accept_cookies_selector": "#cookies-accept",
|
||||
"locale": "en-US",
|
||||
"actions": [],
|
||||
}
|
||||
result = self.client.create_crawl_request(url=url, spider_options=spider_options, page_options=page_options)
|
||||
|
||||
return {"status": "active", "job_id": result.get("uuid")}
|
||||
|
||||
def get_crawl_status(self, crawl_request_id) -> dict:
|
||||
response = self.client.get_crawl_request(crawl_request_id)
|
||||
data = []
|
||||
if response["status"] in ["new", "running"]:
|
||||
status = "active"
|
||||
else:
|
||||
status = "completed"
|
||||
data = list(self._get_results(crawl_request_id))
|
||||
|
||||
time_str = response.get("duration")
|
||||
time_consuming: float = 0
|
||||
if time_str:
|
||||
time_obj = datetime.strptime(time_str, "%H:%M:%S.%f")
|
||||
time_consuming = (
|
||||
time_obj.hour * 3600 + time_obj.minute * 60 + time_obj.second + time_obj.microsecond / 1_000_000
|
||||
)
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"job_id": response.get("uuid"),
|
||||
"total": response.get("options", {}).get("spider_options", {}).get("page_limit", 1),
|
||||
"current": response.get("number_of_documents", 0),
|
||||
"data": data,
|
||||
"time_consuming": time_consuming,
|
||||
}
|
||||
|
||||
def get_crawl_url_data(self, job_id, url) -> dict | None:
|
||||
if not job_id:
|
||||
return self.scrape_url(url)
|
||||
|
||||
for result in self._get_results(
|
||||
job_id,
|
||||
{
|
||||
# filter by url
|
||||
"url": url
|
||||
},
|
||||
):
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def scrape_url(self, url: str) -> dict:
|
||||
response = self.client.scrape_url(url=url, sync=True, prefetched=True)
|
||||
return self._structure_data(response)
|
||||
|
||||
def _structure_data(self, result_object: dict) -> dict:
|
||||
if isinstance(result_object.get("result", {}), str):
|
||||
raise ValueError("Invalid result object. Expected a dictionary.")
|
||||
|
||||
metadata = result_object.get("result", {}).get("metadata", {})
|
||||
return {
|
||||
"title": metadata.get("og:title") or metadata.get("title"),
|
||||
"description": metadata.get("description"),
|
||||
"source_url": result_object.get("url"),
|
||||
"markdown": result_object.get("result", {}).get("markdown"),
|
||||
}
|
||||
|
||||
def _get_results(self, crawl_request_id: str, query_params: dict | None = None) -> Generator[dict, None, None]:
|
||||
page = 0
|
||||
page_size = 100
|
||||
|
||||
query_params = query_params or {}
|
||||
query_params.update({"prefetched": "true"})
|
||||
while True:
|
||||
page += 1
|
||||
response = self.client.get_crawl_request_results(crawl_request_id, page, page_size, query_params)
|
||||
if not response["results"]:
|
||||
break
|
||||
|
||||
for result in response["results"]:
|
||||
yield self._structure_data(result)
|
||||
|
||||
if response["next"] is None:
|
||||
break
|
||||
@@ -17,6 +17,10 @@ class ApiKeyAuthFactory:
|
||||
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||
|
||||
return FirecrawlAuth
|
||||
case AuthType.WATERCRAWL:
|
||||
from services.auth.watercrawl.watercrawl import WatercrawlAuth
|
||||
|
||||
return WatercrawlAuth
|
||||
case AuthType.JINA:
|
||||
from services.auth.jina.jina import JinaAuth
|
||||
|
||||
|
||||
@@ -3,4 +3,5 @@ from enum import StrEnum
|
||||
|
||||
class AuthType(StrEnum):
|
||||
FIRECRAWL = "firecrawl"
|
||||
WATERCRAWL = "watercrawl"
|
||||
JINA = "jinareader"
|
||||
|
||||
0
api/services/auth/watercrawl/__init__.py
Normal file
0
api/services/auth/watercrawl/__init__.py
Normal file
44
api/services/auth/watercrawl/watercrawl.py
Normal file
44
api/services/auth/watercrawl/watercrawl.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class WatercrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "x-api-key":
|
||||
raise ValueError("Invalid auth type, WaterCrawl auth type must be x-api-key")
|
||||
self.api_key = credentials.get("config", {}).get("api_key", None)
|
||||
self.base_url = credentials.get("config", {}).get("base_url", "https://app.watercrawl.dev")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
url = urljoin(self.base_url, "/api/v1/core/crawl-requests/")
|
||||
response = self._get_request(url, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
|
||||
|
||||
def _get_request(self, url, headers):
|
||||
return requests.get(url, headers=headers)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
||||
@@ -7,6 +7,7 @@ from flask_login import current_user # type: ignore
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
@@ -59,6 +60,13 @@ class WebsiteService:
|
||||
time = str(datetime.datetime.now().timestamp())
|
||||
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||
return {"status": "active", "job_id": job_id}
|
||||
elif provider == "watercrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options)
|
||||
|
||||
elif provider == "jinareader":
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
@@ -116,6 +124,14 @@ class WebsiteService:
|
||||
time_consuming = abs(end_time - float(start_time))
|
||||
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
|
||||
redis_client.delete(website_crawl_time_cache_key)
|
||||
elif provider == "watercrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
crawl_status_data = WaterCrawlProvider(
|
||||
api_key, credentials.get("config").get("base_url", None)
|
||||
).get_crawl_status(job_id)
|
||||
elif provider == "jinareader":
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
@@ -180,6 +196,11 @@ class WebsiteService:
|
||||
if item.get("source_url") == url:
|
||||
return dict(item)
|
||||
return None
|
||||
elif provider == "watercrawl":
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).get_crawl_url_data(
|
||||
job_id, url
|
||||
)
|
||||
elif provider == "jinareader":
|
||||
if not job_id:
|
||||
response = requests.get(
|
||||
@@ -223,5 +244,8 @@ class WebsiteService:
|
||||
params = {"onlyMainContent": only_main_content}
|
||||
result = firecrawl_app.scrape_url(url, params)
|
||||
return result
|
||||
elif provider == "watercrawl":
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).scrape_url(url)
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
Reference in New Issue
Block a user