improve: mordernizing validation by migrating pydantic from 1.x to 2.x (#4592)

This commit is contained in:
Bowen Liang
2024-06-14 01:05:37 +08:00
committed by GitHub
parent e8afc416dd
commit f976740b57
87 changed files with 697 additions and 300 deletions

View File

@@ -4,7 +4,7 @@ from typing import Any, Optional
from uuid import uuid4
from flask import current_app
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException, connections
from core.rag.datasource.entity.embedding import Embeddings
@@ -28,7 +28,7 @@ class MilvusConfig(BaseModel):
batch_size: int = 100
database: str = "default"
@root_validator()
@model_validator(mode='before')
def validate_config(cls, values: dict) -> dict:
if not values.get('host'):
raise ValueError("config MILVUS_HOST is required")

View File

@@ -6,7 +6,7 @@ from uuid import UUID, uuid4
from flask import current_app
from numpy import ndarray
from pgvecto_rs.sqlalchemy import Vector
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, model_validator
from sqlalchemy import Float, String, create_engine, insert, select, text
from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql
@@ -31,7 +31,7 @@ class PgvectoRSConfig(BaseModel):
password: str
database: str
@root_validator()
@model_validator(mode='before')
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config PGVECTO_RS_HOST is required")

View File

@@ -6,7 +6,7 @@ from typing import Any
import psycopg2.extras
import psycopg2.pool
from flask import current_app
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, model_validator
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -24,7 +24,7 @@ class PGVectorConfig(BaseModel):
password: str
database: str
@root_validator()
@model_validator(mode='before')
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config PGVECTOR_HOST is required")

View File

@@ -40,9 +40,9 @@ if TYPE_CHECKING:
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
api_key: Optional[str] = None
timeout: float = 20
root_path: Optional[str]
root_path: Optional[str] = None
grpc_port: int = 6334
prefer_grpc: bool = False

View File

@@ -3,7 +3,7 @@ import uuid
from typing import Any, Optional
from flask import current_app
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, model_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
@@ -33,7 +33,7 @@ class RelytConfig(BaseModel):
password: str
database: str
@root_validator()
@model_validator(mode='before')
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config RELYT_HOST is required")

View File

@@ -4,7 +4,7 @@ from typing import Any
import sqlalchemy
from flask import current_app
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
@@ -27,7 +27,7 @@ class TiDBVectorConfig(BaseModel):
password: str
database: str
@root_validator()
@model_validator(mode='before')
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config TIDB_VECTOR_HOST is required")

View File

@@ -5,7 +5,7 @@ from typing import Any, Optional
import requests
import weaviate
from flask import current_app
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, model_validator
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
@@ -19,10 +19,10 @@ from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
api_key: Optional[str] = None
batch_size: int = 100
@root_validator()
@model_validator(mode='before')
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")

View File

@@ -14,7 +14,7 @@ from io import BufferedReader, BytesIO
from pathlib import PurePath
from typing import Any, Optional, Union
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, ConfigDict, model_validator
PathLike = Union[str, PurePath]
@@ -29,7 +29,7 @@ class Blob(BaseModel):
Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
"""
data: Union[bytes, str, None] # Raw data
data: Union[bytes, str, None] = None # Raw data
mimetype: Optional[str] = None # Not to be confused with a file extension
encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string
# Location where the original content was found
@@ -37,17 +37,15 @@ class Blob(BaseModel):
# Useful for situations where downstream code assumes it must work with file paths
# rather than in-memory content.
path: Optional[PathLike] = None
class Config:
arbitrary_types_allowed = True
frozen = True
model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
@property
def source(self) -> Optional[str]:
"""The source location of the blob as string if known otherwise none."""
return str(self.path) if self.path else None
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
"""Verify that either data or path is provided."""
if "data" not in values and "path" not in values:

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from models.dataset import Document
from models.model import UploadFile
@@ -13,9 +13,7 @@ class NotionInfo(BaseModel):
notion_page_type: str
document: Document = None
tenant_id: str
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data) -> None:
super().__init__(**data)
@@ -29,9 +27,7 @@ class ExtractSetting(BaseModel):
upload_file: UploadFile = None
notion_info: NotionInfo = None
document_model: str = None
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data) -> None:
super().__init__(**data)