"""Utilities for storing MBIS data in a SQLite database."""
import abc
import math
from typing import TypeVar
from sqlalchemy import Column, ForeignKey, Integer, PickleType, String, UniqueConstraint
from sqlalchemy.orm import Query, Session, declarative_base, relationship
from pympfit.mbis import MBISSettings
DBBase = declarative_base()
_InstanceType = TypeVar("_InstanceType")
_DBInstanceType = TypeVar("_DBInstanceType")
DB_VERSION = 1
_DB_FLOAT_PRECISION = 100000.0
def _float_to_db_int(value: float) -> int:
return int(math.floor(value * _DB_FLOAT_PRECISION))
def _db_int_to_float(value: int) -> float:
return value / _DB_FLOAT_PRECISION
class _UniqueMixin:
"""Base class for records which should be unique in the database."""
@classmethod
@abc.abstractmethod
def _hash(cls, instance: _InstanceType) -> int:
"""Return the hash of the instance that this record represents."""
raise NotImplementedError
@classmethod
@abc.abstractmethod
def _query(cls, db: Session, instance: _InstanceType) -> Query:
"""Return a query which should find existing copies of an instance."""
raise NotImplementedError
@classmethod
@abc.abstractmethod
def _instance_to_db(cls, instance: _InstanceType) -> _DBInstanceType:
"""Map an instance into a database version of itself."""
raise NotImplementedError
@classmethod
def unique(cls, db: Session, instance: _InstanceType) -> _DBInstanceType:
"""Create a new database object from the instance if it doesn't exist.
If the instance already exists on the database, the existing
instance is returned.
"""
cache = getattr(db, "_unique_cache", None)
if cache is None:
db._unique_cache = cache = {}
key = (cls, cls._hash(instance))
if key in cache:
return cache[key]
with db.no_autoflush:
existing_instance = cls._query(db, instance).first()
if not existing_instance:
existing_instance = cls._instance_to_db(instance)
db.add(existing_instance)
cache[key] = existing_instance
return existing_instance
[docs]
class DBMBISSettings(_UniqueMixin, DBBase):
"""Database representation of MBISSettings."""
__tablename__ = "mbis_settings"
__table_args__ = (UniqueConstraint("basis", "method"),)
id = Column(Integer, primary_key=True, index=True)
basis = Column(String, index=True, nullable=False)
method = Column(String, index=True, nullable=False)
limit = Column(Integer, nullable=False)
e_convergence = Column(Integer, nullable=False)
d_convergence = Column(Integer, nullable=False)
dft_radial_points = Column(Integer, nullable=False)
dft_spherical_points = Column(Integer, nullable=False)
max_radial_moment = Column(Integer, nullable=False)
max_moment = Column(Integer, nullable=False)
mbis_d_convergence = Column(Integer, nullable=False)
mbis_radial_points = Column(Integer, nullable=False)
mbis_spherical_points = Column(Integer, nullable=False)
guess = Column(String, nullable=False)
multipole_units = Column(String, nullable=False)
multipole_format = Column(String, nullable=False)
# MPFIT specific parameters stored as integers
mpfit_inner_radius = Column(Integer, nullable=False)
mpfit_outer_radius = Column(Integer, nullable=False)
mpfit_atom_radius = Column(Integer, nullable=False)
@classmethod
def _hash(cls, instance: MBISSettings) -> int:
return hash(
(
instance.basis,
instance.method,
instance.limit,
instance.e_convergence,
instance.d_convergence,
instance.dft_radial_points,
instance.dft_spherical_points,
instance.max_radial_moment,
instance.max_moment,
instance.mbis_d_convergence,
instance.mbis_radial_points,
instance.mbis_spherical_points,
instance.guess,
instance.multipole_units,
instance.multipole_format,
)
)
@classmethod
def _query(cls, db: Session, instance: MBISSettings) -> Query:
return (
db.query(DBMBISSettings)
.filter(DBMBISSettings.basis == instance.basis)
.filter(DBMBISSettings.method == instance.method)
.filter(DBMBISSettings.limit == instance.limit)
.filter(DBMBISSettings.e_convergence == instance.e_convergence)
.filter(DBMBISSettings.d_convergence == instance.d_convergence)
.filter(DBMBISSettings.dft_radial_points == instance.dft_radial_points)
.filter(
DBMBISSettings.dft_spherical_points == instance.dft_spherical_points
)
.filter(DBMBISSettings.max_radial_moment == instance.max_radial_moment)
.filter(DBMBISSettings.max_moment == instance.max_moment)
.filter(DBMBISSettings.mbis_d_convergence == instance.mbis_d_convergence)
.filter(DBMBISSettings.mbis_radial_points == instance.mbis_radial_points)
.filter(
DBMBISSettings.mbis_spherical_points == instance.mbis_spherical_points
)
.filter(DBMBISSettings.guess == instance.guess)
.filter(DBMBISSettings.multipole_units == instance.multipole_units)
.filter(DBMBISSettings.multipole_format == instance.multipole_format)
)
@classmethod
def _instance_to_db(cls, instance: MBISSettings) -> "DBMBISSettings":
return DBMBISSettings(
basis=instance.basis,
method=instance.method,
limit=instance.limit,
e_convergence=instance.e_convergence,
d_convergence=instance.d_convergence,
dft_radial_points=instance.dft_radial_points,
dft_spherical_points=instance.dft_spherical_points,
max_radial_moment=instance.max_radial_moment,
max_moment=instance.max_moment,
mbis_d_convergence=instance.mbis_d_convergence,
mbis_radial_points=instance.mbis_radial_points,
mbis_spherical_points=instance.mbis_spherical_points,
guess=instance.guess,
multipole_units=instance.multipole_units,
multipole_format=instance.multipole_format,
mpfit_inner_radius=_float_to_db_int(instance.mpfit_inner_radius),
mpfit_outer_radius=_float_to_db_int(instance.mpfit_outer_radius),
mpfit_atom_radius=_float_to_db_int(instance.mpfit_atom_radius),
)
[docs]
@classmethod
def db_to_instance(cls, db_instance: "DBMBISSettings") -> MBISSettings:
"""Convert a database record to a MBISSettings instance."""
# noinspection PyTypeChecker
return MBISSettings(
basis=db_instance.basis,
method=db_instance.method,
limit=db_instance.limit,
e_convergence=db_instance.e_convergence,
d_convergence=db_instance.d_convergence,
dft_radial_points=db_instance.dft_radial_points,
dft_spherical_points=db_instance.dft_spherical_points,
max_radial_moment=db_instance.max_radial_moment,
max_moment=db_instance.max_moment,
mbis_d_convergence=db_instance.mbis_d_convergence,
mbis_radial_points=db_instance.mbis_radial_points,
mbis_spherical_points=db_instance.mbis_spherical_points,
guess=db_instance.guess,
multipole_units=db_instance.multipole_units,
multipole_format=db_instance.multipole_format,
mpfit_inner_radius=_db_int_to_float(db_instance.mpfit_inner_radius),
mpfit_outer_radius=_db_int_to_float(db_instance.mpfit_outer_radius),
mpfit_atom_radius=_db_int_to_float(db_instance.mpfit_atom_radius),
)
[docs]
class DBMoleculeRecord(DBBase):
"""Database representation of a molecule record."""
__tablename__ = "molecules"
smiles = Column(String, primary_key=True, index=True)
conformers = relationship("DBConformerRecord")
[docs]
class DBGeneralProvenance(DBBase):
"""Database representation of general provenance information."""
__tablename__ = "general_provenance"
key = Column(String, primary_key=True, index=True, unique=True)
value = Column(String, nullable=False)
parent_id = Column(Integer, ForeignKey("db_info.version"))
[docs]
class DBSoftwareProvenance(DBBase):
"""Database representation of software provenance information."""
__tablename__ = "software_provenance"
key = Column(String, primary_key=True, index=True, unique=True)
value = Column(String, nullable=False)
parent_id = Column(Integer, ForeignKey("db_info.version"))