Source code for koa_middleware.database.local_database

from typing import Sequence
from contextlib import contextmanager
from datetime import datetime, timezone

from sqlite_utils import Database
from sqlite_utils.db import NotFoundError

from ..utils import datetime_to_isot_ms

import logging
logger = logging.getLogger(__name__)

from ..utils import postgres_http_date_to_iso

__all__ = ["LocalCalibrationDB"]

_MIN_SCHEMA = {
    "id": str,
    "filename": str,
}


[docs] class LocalCalibrationDB: """ Class to interact with a local SQLite calibration database using sqlite-utils. This class provides a simple interface for adding, querying, and managing calibration metadata stored as dictionaries in a SQLite database. """ def __init__(self, db_path: str, table_name: str): """ Initialize a LocalCalibrationDB instance. Parameters ---------- db_path : str Path to the SQLite database file. table_name : str Name of the table to use for storing calibration metadata. """ self.db_path = db_path self.table_name = table_name self.db = Database(db_path) if not self.table.exists(): self.table.create( _MIN_SCHEMA, pk="id", )
[docs] @contextmanager def transaction(self): """ Context manager for database transactions. Ensures that changes are committed on success or rolled back on error. """ try: with self.db.conn: yield except Exception: logger.exception(f"Transaction failed on table {self.table_name!r}, rolling back.")
[docs] def get_last_updated(self) -> str | None: """ Get the most recent last_updated timestamp from the database. Returns ------- str | None The maximum last_updated value as a string, or None if the table is empty. """ row = next( self.db.execute( f"SELECT MAX(last_updated) AS v FROM {self.table_name}" ), None, ) if len(row) == 0: logger.warning("No entries found in the calibration database.") return None if row[0] is None: logger.warning("No entries found in the calibration database.") return None return row[0]
[docs] def custom_query(self, sql: str, params: tuple = ()) -> list[dict]: """ Execute a custom SQL query. Parameters ---------- sql : str The SQL query string. params : tuple, optional Parameters to pass to the SQL query. Returns ------- list[dict] List of matching rows as dictionaries. """ if len(self) == 0: return [] rows = self.db.execute(sql, params) return [dict(r) for r in rows]
[docs] def query( self, filename : str | None = None, cal_type: str | None = None, cal_id: str | None = None, cal_version_min: str | None = None, cal_version_max: str | None = None, date_time_start: str | None = None, date_time_end: str | None = None, last_updated_start: str | None = None, last_updated_end: str | None = None, origin : str | None = None, order_by: str = 'last_updated', fetch: str = "all", ) -> list[dict] | dict | None: """ Query calibration entries from the database with common use cases. Parameters ---------- filename : str, optional Filter by calibration filename. cal_type : str, optional Filter by calibration type. cal_id : str, optional Filter by a specific calibration UUID. Overrides other filters if provided. cal_version_min : str, optional Minimum calibration version to include. Default is "001" cal_version_max : str, optional Maximum calibration version to include. Default is "999" date_time_start : str, optional Minimum datetime_obs to include. date_time_end : str, optional Maximum datetime_obs to include. last_updated_start : str, optional Minimum last_updated timestamp to include. last_updated_end : str, optional Maximum last_updated timestamp to include. origin : str, optional Filter by origin ("ANY", "LOCAL", "REMOTE"). Default is None (equivalent to "ANY"). Whether to return all matching rows or just the first one. Returns ------- list[dict] or dict or None Matching calibration entries. If fetch='first', returns a single dict or None. If fetch='all', returns a list of dicts. """ # Delegate to query_id() for single-ID queries if cal_id is not None: return self.query_id(cal_id) # Delegate to query_filename() for single-filename queries if filename is not None: return self.query_filename(filename) if len(self) == 0: return None if fetch == "first" else [] sql = "" params = {} if date_time_start is not None: sql += "datetime_obs >= :date_time_start AND " params["date_time_start"] = date_time_start if date_time_end is not None: sql += "datetime_obs <= :date_time_end AND " params["date_time_end"] = date_time_end if cal_type is not None: sql += "cal_type = :cal_type AND " params["cal_type"] = cal_type if cal_version_min is not None: sql += "cal_version >= :cal_version_min AND " params["cal_version_min"] = cal_version_min if cal_version_max is not None: sql += "cal_version <= :cal_version_max AND " params["cal_version_max"] = cal_version_max if last_updated_start is not None: sql += "last_updated >= :last_updated_start AND " params["last_updated_start"] = last_updated_start if last_updated_end is not None: sql += "last_updated <= :last_updated_end AND " params["last_updated_end"] = last_updated_end if origin is not None: sql += "origin = :origin AND " params["origin"] = origin # Remove trailing " AND " sql = sql.rstrip(" AND ") if fetch == "first": rows = self.rows_where( sql if sql else None, params, limit=1, order_by=order_by ) row = next(rows, None) return dict(row) if row else None output = list( self.rows_where( sql if sql else None, params, order_by=order_by ) ) return output
[docs] def query_id(self, cal_id: str) -> dict | None: """ Query a calibration entry by its unique ID. Parameters ---------- cal_id : str The unique calibration ID (UUID). Returns ------- dict or None The calibration metadata dictionary if found, otherwise None. """ try: row = self.table.get(cal_id) return dict(row) if row else None except NotFoundError as e: logger.info(f"Calibration ID {cal_id!r} not found in table {self.table_name!r}.") return None
[docs] def query_filename(self, filename: str) -> dict | None: """ Query a calibration entry by its unique ID. Parameters ---------- filename : str The unique calibration filename. Returns ------- dict or None The calibration metadata dictionary if found, otherwise None. """ row = next( self.table.rows_where( "filename = ?", [filename], ), None, ) return dict(row) if row else None
[docs] def add( self, cals: dict | Sequence[dict], alter: bool = True, ): """ Add or update calibration entries in the database. Parameters ---------- cals : dict | Sequence[dict] A single calibration metadata dictionary or a sequence of calibration metadata dictionaries to add or update. Uses upsert semantics with 'id' as primary key. alter : bool, optional Whether to automatically alter the table schema to accommodate new fields. Default is True. """ single_input = False if isinstance(cals, dict): single_input = True cals = [cals] items = [dict(item) for item in cals] if not items: return # HACK: Temporary hack to convert PostgreSQL datetime strings to ISO format. # NOTE: Fix this on the backend, convert all timestamps to YYYY-MM-DDTHH:MM:SSS.SSS. datetime_cols = ['datetime_obs', 'last_updated', 'last_processed'] for item in items: for col in datetime_cols: if col in item and item[col] is not None: item[col] = postgres_http_date_to_iso(item[col]) # Use common last updated timestamp for all entries in this batch to ensure consistency last_updated = datetime_to_isot_ms(datetime.now(timezone.utc)) for item in items: if not item.get("last_updated"): item["last_updated"] = last_updated with self.transaction(): n = len(items) if n == 1: logger.info(f"Adding to local cal DB: filename={items[0]['filename']} ID={items[0]['id']}") else: s = f"Adding {len(items)} items into local cal DB:" for i, item in enumerate(items): s += f"\n {i+1} - filename={item['filename']} ID={item['id']}" logger.info(s) self.table.insert_all( items, pk="id", alter=alter, ) if single_input: return items[0] else: return items
[docs] def delete(self, cal_id: str): """ Delete a calibration entry by its unique ID. Parameters ---------- cal_id : str The unique calibration ID (UUID) to delete. """ try: self.table.delete(cal_id) logger.info(f"Deleted calibration ID {cal_id!r} from table {self.table_name!r}.") except NotFoundError: logger.warning(f"Calibration ID {cal_id} not found in the database, cannot delete.")
def _reset(self, confirm: bool = False): """ Reset the calibration database by dropping and recreating the table. WARNING: This will delete all existing calibration metadata in the DB. """ if not confirm: logger.warning("Reset not confirmed. To reset the database, call _reset with confirm=True.") return if self.table.exists(): logger.info(f"Dropping table {self.table_name!r}...") self.table.drop() logger.info(f"Recreating table {self.table_name!r} with minimal schema.") self.table.create( _MIN_SCHEMA, pk="id", ) @property def table(self): """ Returns the calibration table object. Returns ------- sqlite_utils.db.Table The table object for the calibration metadata. """ return self.db[self.table_name]
[docs] def close(self): """ Close the database connection. """ self.db.close()
def __repr__(self): return ( f"LocalCalibrationDB(\n" f" db_path={self.db_path!r},\n" f" table_name={self.table_name!r},\n" f" entries={self.table.count}\n" f" )" ) def __len__(self): """ Return the number of entries in the calibration table. Returns ------- int The number of calibration entries in the database. """ return self.table.count @property def rows(self) -> list[dict]: """ Get all rows in the calibration table. Returns ------- Generator[dict] Generator of all calibration entries as dictionaries. Call list() on the result to get a list. """ return self.table.rows @property def rows_where(self): """ Forward function to sqlite-utils Table.rows_where method. """ return self.table.rows_where
[docs] def get_column(self, column: str) -> list[dict]: return [ row[column] for row in self.table.rows ]