From 9134849744613222021246605a4d27d6b8c7d1cf Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Wed, 3 Jan 2024 13:02:56 +0800 Subject: [PATCH] fix: remove tiktoken from text splitter (#1876) --- api/core/indexing_runner.py | 12 +++++---- api/core/spiltter/fixed_text_splitter.py | 34 +++++++++++++++++++++--- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index d7e0843c8..1c87432c4 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,12 +5,12 @@ import re import threading import time import uuid -from typing import Optional, List, cast +from typing import Optional, List, cast, Type, Union, Literal, AbstractSet, Collection, Any from flask import current_app, Flask from flask_login import current_user from langchain.schema import Document -from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter +from langchain.text_splitter import TextSplitter, TS, TokenTextSplitter from sqlalchemy.orm.exc import ObjectDeletedError from core.data_loader.file_extractor import FileExtractor @@ -23,7 +23,8 @@ from core.errors.error import ProviderTokenNotInitError from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter +from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -502,7 +503,8 @@ class IndexingRunner: if separator: separator = separator.replace('\\n', '\n') - character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder( + + character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder( chunk_size=segmentation["max_tokens"], chunk_overlap=0, fixed_separator=separator, @@ -510,7 +512,7 @@ class IndexingRunner: ) else: # Automatic segmentation - character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder( chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], chunk_overlap=0, separators=["\n\n", "。", ".", " ", ""] diff --git a/api/core/spiltter/fixed_text_splitter.py b/api/core/spiltter/fixed_text_splitter.py index aaaf8e5a1..bddaad292 100644 --- a/api/core/spiltter/fixed_text_splitter.py +++ b/api/core/spiltter/fixed_text_splitter.py @@ -7,10 +7,38 @@ from typing import ( Optional, ) -from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter, TS, Type, Union, AbstractSet, Literal, Collection +from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer -class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): +class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): + """ + This class is used to implement from_gpt2_encoder, to prevent using of tiktoken + """ + @classmethod + def from_gpt2_encoder( + cls: Type[TS], + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, + ): + def _token_encoder(text: str) -> int: + return GPT2Tokenizer.get_num_tokens(text) + + if issubclass(cls, TokenTextSplitter): + extra_kwargs = { + "encoding_name": encoding_name, + "model_name": model_name, + "allowed_special": allowed_special, + "disallowed_special": disallowed_special, + } + kwargs = {**kwargs, **extra_kwargs} + + return cls(length_function=_token_encoder, **kwargs) + +class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) @@ -65,4 +93,4 @@ class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): if _good_splits: merged_text = self._merge_splits(_good_splits, separator) final_chunks.extend(merged_text) - return final_chunks + return final_chunks \ No newline at end of file