diff --git a/app/__pycache__/main.cpython-313.pyc b/app/__pycache__/main.cpython-313.pyc index d5ae5f7..bcae2a4 100644 Binary files a/app/__pycache__/main.cpython-313.pyc and b/app/__pycache__/main.cpython-313.pyc differ diff --git a/app/main.py b/app/main.py index 5d97673..7bb9b66 100644 --- a/app/main.py +++ b/app/main.py @@ -3,7 +3,7 @@ from typing import List, Dict, Tuple, Sequence from starlette.responses import StreamingResponse from annotated_types import IsDigit -from fastapi import FastAPI, File, HTTPException, Path, UploadFile, Request, Form +from fastapi import FastAPI, File, HTTPException, UploadFile, Request, Form from fastapi.responses import FileResponse # import multiprocessing @@ -21,7 +21,9 @@ import re import os +import signal import mariadb +import sys import filetype @@ -41,6 +43,14 @@ logging.basicConfig( debug = log.debug info = log.info error = log.error +critical = log.critical + + +def exception_handler(etype, value, tb): + log.exception(f"Uncought Exception: {value}") + + +sys.excepthook = exception_handler db = mariadb.connect( host=os.environ.get("DB_HOST", "db"), @@ -117,12 +127,40 @@ def _sql_quarry( ) +def sql_connector_is_active(connector: mariadb.Connection) -> bool: + try: + connector.ping() + except mariadb.Error as e: + return False + return True + + +def sql_connect(connector: mariadb.Connection) -> mariadb.Connection: + try: + connector = mariadb.connect( + host=os.environ.get("DB_HOST", "db"), + user=os.environ.get("DB_USER", "user"), + password=os.environ.get("DB_PASSWORD", "DBPASSWORD"), + database=os.environ.get("DB_DATABASE", "Unizeug"), + ) + except mariadb.Error as e: + critical( + f"Cannot reconnect to Database {os.environ.get('DB_DATABASE', 'Unizeug')} on {os.environ.get('DB_HOST', 'db')}. Got Mariadb Error: {e}" + ) + os.kill(os.getpid(), signal.SIGTERM) + raise HTTPException(500, detail="Database failed") + return connector + + def sql( querry: str, data: Tuple[str | int, ...] | str | int = (), return_result: bool = True, commit: bool = False, ) -> List[Tuple]: + global db + if not sql_connector_is_active(db): + db = sql_connect(db) cur = db.cursor(dictionary=False) return _sql_quarry(cur, querry, data, return_result, commit) @@ -133,6 +171,10 @@ def sqlT( return_result: bool = True, commit: bool = False, ) -> List[Dict]: + global db + if not sql_connector_is_active(db): + db = sql_connect(db) + cur = db.cursor(dictionary=True) return _sql_quarry(cur, querry, data, return_result, commit)