fix: count down thread in completion db not commit (#1267)

This commit is contained in:
takatost
2023-10-02 10:19:26 +08:00
committed by GitHub
parent 86a9dea428
commit 41d4c5b424
3 changed files with 21 additions and 27 deletions

View File

@@ -155,7 +155,7 @@ class CompletionService:
generate_worker_thread.start()
# wait for 10 minutes to close the thread
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
return cls.compact_response(pubsub, streaming)
@@ -210,25 +210,26 @@ class CompletionService:
db.session.commit()
@classmethod
def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
# wait for 10 minutes to close the thread
timeout = 600
def close_pubsub():
sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive():
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
PubHandler.ping(user, generate_task_id)
with flask_app.app_context():
sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive():
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
PubHandler.ping(user, generate_task_id)
time.sleep(1)
sleep_iterations += 1
time.sleep(1)
sleep_iterations += 1
if worker_thread.is_alive():
PubHandler.stop(user, generate_task_id)
try:
pubsub.close()
except:
pass
if worker_thread.is_alive():
PubHandler.stop(user, generate_task_id)
try:
pubsub.close()
except:
pass
countdown_thread = threading.Thread(target=close_pubsub)
countdown_thread.start()
@@ -288,7 +289,7 @@ class CompletionService:
generate_worker_thread.start()
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
return cls.compact_response(pubsub, streaming)
@@ -313,15 +314,14 @@ class CompletionService:
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, e)
except LLMAuthorizationError:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
except Exception as e:
db.session.rollback()
logging.exception("Unknown Error in completion")
PubHandler.pub_error(user, generate_task_id, e)
finally:
db.session.commit()
@classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):