mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-09 19:06:51 +08:00
Feat/firecrawl data source (#5232)
Co-authored-by: Nicolas <nicolascamara29@gmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
@@ -339,7 +339,7 @@ class IndexingRunner:
|
||||
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
|
||||
-> list[Document]:
|
||||
# load file
|
||||
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
|
||||
if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
|
||||
return []
|
||||
|
||||
data_source_info = dataset_document.data_source_info_dict
|
||||
@@ -375,6 +375,23 @@ class IndexingRunner:
|
||||
document_model=dataset_document.doc_form
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||
elif dataset_document.data_source_type == 'website_crawl':
|
||||
if (not data_source_info or 'provider' not in data_source_info
|
||||
or 'url' not in data_source_info or 'job_id' not in data_source_info):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
website_info={
|
||||
"provider": data_source_info['provider'],
|
||||
"job_id": data_source_info['job_id'],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info['url'],
|
||||
"mode": data_source_info['mode'],
|
||||
"only_main_content": data_source_info['only_main_content']
|
||||
},
|
||||
document_model=dataset_document.doc_form
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
|
||||
@@ -124,7 +124,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
default=float(credentials.get('presence_penalty', 0)),
|
||||
min=-2,
|
||||
max=2
|
||||
)
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(cred_with_endpoint.get('input_price', 0)),
|
||||
|
||||
@@ -4,3 +4,4 @@ from enum import Enum
|
||||
class DatasourceType(Enum):
|
||||
FILE = "upload_file"
|
||||
NOTION = "notion_import"
|
||||
WEBSITE = "website_crawl"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from models.dataset import Document
|
||||
@@ -19,14 +21,33 @@ class NotionInfo(BaseModel):
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class WebsiteInfo(BaseModel):
|
||||
"""
|
||||
website import info.
|
||||
"""
|
||||
provider: str
|
||||
job_id: str
|
||||
url: str
|
||||
mode: str
|
||||
tenant_id: str
|
||||
only_main_content: bool = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class ExtractSetting(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
datasource_type: str
|
||||
upload_file: UploadFile = None
|
||||
notion_info: NotionInfo = None
|
||||
document_model: str = None
|
||||
upload_file: Optional[UploadFile]
|
||||
notion_info: Optional[NotionInfo]
|
||||
website_info: Optional[WebsiteInfo]
|
||||
document_model: Optional[str]
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
|
||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
@@ -154,5 +155,17 @@ class ExtractProcessor:
|
||||
tenant_id=extract_setting.notion_info.tenant_id,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
|
||||
if extract_setting.website_info.provider == 'firecrawl':
|
||||
extractor = FirecrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
tenant_id=extract_setting.website_info.tenant_id,
|
||||
mode=extract_setting.website_info.mode,
|
||||
only_main_content=extract_setting.website_info.only_main_content
|
||||
)
|
||||
return extractor.extract()
|
||||
else:
|
||||
raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")
|
||||
|
||||
132
api/core/rag/extractor/firecrawl/firecrawl_app.py
Normal file
132
api/core/rag/extractor/firecrawl/firecrawl_app.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class FirecrawlApp:
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or 'https://api.firecrawl.dev'
|
||||
if self.api_key is None and self.base_url == 'https://api.firecrawl.dev':
|
||||
raise ValueError('No API key provided')
|
||||
|
||||
def scrape_url(self, url, params=None) -> dict:
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
json_data = {'url': url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = requests.post(
|
||||
f'{self.base_url}/v0/scrape',
|
||||
headers=headers,
|
||||
json=json_data
|
||||
)
|
||||
if response.status_code == 200:
|
||||
response = response.json()
|
||||
if response['success'] == True:
|
||||
data = response['data']
|
||||
return {
|
||||
'title': data.get('metadata').get('title'),
|
||||
'description': data.get('metadata').get('description'),
|
||||
'source_url': data.get('metadata').get('sourceURL'),
|
||||
'markdown': data.get('markdown')
|
||||
}
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
|
||||
|
||||
elif response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}')
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
start_time = time.time()
|
||||
headers = self._prepare_headers()
|
||||
json_data = {'url': url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers)
|
||||
if response.status_code == 200:
|
||||
job_id = response.json().get('jobId')
|
||||
return job_id
|
||||
else:
|
||||
self._handle_error(response, 'start crawl job')
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers)
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get('status') == 'completed':
|
||||
total = crawl_status_response.get('total', 0)
|
||||
if total == 0:
|
||||
raise Exception('Failed to check crawl status. Error: No page found')
|
||||
data = crawl_status_response.get('data', [])
|
||||
url_data_list = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and 'metadata' in item and 'markdown' in item:
|
||||
url_data = {
|
||||
'title': item.get('metadata').get('title'),
|
||||
'description': item.get('metadata').get('description'),
|
||||
'source_url': item.get('metadata').get('sourceURL'),
|
||||
'markdown': item.get('markdown')
|
||||
}
|
||||
url_data_list.append(url_data)
|
||||
if url_data_list:
|
||||
file_key = 'website_files/' + job_id + '.txt'
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(url_data_list).encode('utf-8'))
|
||||
return {
|
||||
'status': 'completed',
|
||||
'total': crawl_status_response.get('total'),
|
||||
'current': crawl_status_response.get('current'),
|
||||
'data': url_data_list
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
'status': crawl_status_response.get('status'),
|
||||
'total': crawl_status_response.get('total'),
|
||||
'current': crawl_status_response.get('current'),
|
||||
'data': []
|
||||
}
|
||||
|
||||
else:
|
||||
self._handle_error(response, 'check crawl status')
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
|
||||
for attempt in range(retries):
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2 ** attempt))
|
||||
else:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _get_request(self, url, headers, retries=3, backoff_factor=0.5):
|
||||
for attempt in range(retries):
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2 ** attempt))
|
||||
else:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}')
|
||||
|
||||
|
||||
60
api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py
Normal file
60
api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from services.website_service import WebsiteService
|
||||
|
||||
|
||||
class FirecrawlWebExtractor(BaseExtractor):
|
||||
"""
|
||||
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||
|
||||
|
||||
Args:
|
||||
url: The URL to scrape.
|
||||
api_key: The API key for Firecrawl.
|
||||
base_url: The base URL for the Firecrawl API. Defaults to 'https://api.firecrawl.dev'.
|
||||
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
mode: str = 'crawl',
|
||||
only_main_content: bool = False
|
||||
):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.mode = mode
|
||||
self.only_main_content = only_main_content
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Extract content from the URL."""
|
||||
documents = []
|
||||
if self.mode == 'crawl':
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id)
|
||||
if crawl_data is None:
|
||||
return []
|
||||
document = Document(page_content=crawl_data.get('markdown', ''),
|
||||
metadata={
|
||||
'source_url': crawl_data.get('source_url'),
|
||||
'description': crawl_data.get('description'),
|
||||
'title': crawl_data.get('title')
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
elif self.mode == 'scrape':
|
||||
scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id,
|
||||
self.only_main_content)
|
||||
|
||||
document = Document(page_content=scrape_data.get('markdown', ''),
|
||||
metadata={
|
||||
'source_url': scrape_data.get('source_url'),
|
||||
'description': scrape_data.get('description'),
|
||||
'title': scrape_data.get('title')
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
return documents
|
||||
@@ -9,7 +9,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document as DocumentModel
|
||||
from models.source import DataSourceBinding
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -345,12 +345,12 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user