mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
improve: mordernizing validation by migrating pydantic from 1.x to 2.x (#4592)
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymilvus import MilvusClient, MilvusException, connections
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
@@ -28,7 +28,7 @@ class MilvusConfig(BaseModel):
|
||||
batch_size: int = 100
|
||||
database: str = "default"
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get('host'):
|
||||
raise ValueError("config MILVUS_HOST is required")
|
||||
|
||||
@@ -6,7 +6,7 @@ from uuid import UUID, uuid4
|
||||
from flask import current_app
|
||||
from numpy import ndarray
|
||||
from pgvecto_rs.sqlalchemy import Vector
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Float, String, create_engine, insert, select, text
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
@@ -31,7 +31,7 @@ class PgvectoRSConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config PGVECTO_RS_HOST is required")
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -24,7 +24,7 @@ class PGVectorConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTOR_HOST is required")
|
||||
|
||||
@@ -40,9 +40,9 @@ if TYPE_CHECKING:
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
api_key: Optional[str] = None
|
||||
timeout: float = 20
|
||||
root_path: Optional[str]
|
||||
root_path: Optional[str] = None
|
||||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||
@@ -33,7 +33,7 @@ class RelytConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config RELYT_HOST is required")
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
@@ -27,7 +27,7 @@ class TiDBVectorConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config TIDB_VECTOR_HOST is required")
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Optional
|
||||
import requests
|
||||
import weaviate
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@@ -19,10 +19,10 @@ from models.dataset import Dataset
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
api_key: Optional[str] = None
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode='before')
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
|
||||
@@ -14,7 +14,7 @@ from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
PathLike = Union[str, PurePath]
|
||||
|
||||
@@ -29,7 +29,7 @@ class Blob(BaseModel):
|
||||
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
|
||||
"""
|
||||
|
||||
data: Union[bytes, str, None] # Raw data
|
||||
data: Union[bytes, str, None] = None # Raw data
|
||||
mimetype: Optional[str] = None # Not to be confused with a file extension
|
||||
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
|
||||
# Location where the original content was found
|
||||
@@ -37,17 +37,15 @@ class Blob(BaseModel):
|
||||
# Useful for situations where downstream code assumes it must work with file paths
|
||||
# rather than in-memory content.
|
||||
path: Optional[PathLike] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
frozen = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
||||
|
||||
@property
|
||||
def source(self) -> Optional[str]:
|
||||
"""The source location of the blob as string if known otherwise none."""
|
||||
return str(self.path) if self.path else None
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""Verify that either data or path is provided."""
|
||||
if "data" not in values and "path" not in values:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from models.dataset import Document
|
||||
from models.model import UploadFile
|
||||
@@ -13,9 +13,7 @@ class NotionInfo(BaseModel):
|
||||
notion_page_type: str
|
||||
document: Document = None
|
||||
tenant_id: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
@@ -29,9 +27,7 @@ class ExtractSetting(BaseModel):
|
||||
upload_file: UploadFile = None
|
||||
notion_info: NotionInfo = None
|
||||
document_model: str = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
Reference in New Issue
Block a user