"""Store and retrieve calculated MBIS data in unified data collections."""
import functools
import warnings
from collections import defaultdict
from contextlib import AbstractContextManager, contextmanager
from openff.toolkit import Molecule, Quantity
from openff.toolkit.utils.exceptions import AtomMappingWarning
from openff.units import unit
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import create_engine, event
from sqlalchemy.orm import Session, sessionmaker
from pympfit._annotations import MP, Coordinates
from pympfit.mbis import MBISSettings
from pympfit.mbis.storage.db import (
DB_VERSION,
DBBase,
DBConformerRecord,
DBGeneralProvenance,
DBInformation,
DBMBISSettings,
DBMoleculeRecord,
DBSoftwareProvenance,
)
from pympfit.mbis.storage.exceptions import IncompatibleDBVersion
unit.define("AU = [] = au = atomic_unit")
[docs]
class MoleculeMBISRecord(BaseModel):
"""Record containing MBIS results for a molecule conformer.
Includes molecule information, conformer coordinates, MBIS settings
provenance, and multipole values for each atom.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
tagged_smiles: str = Field(
...,
description="The tagged SMILES patterns (SMARTS) which encodes both the "
"molecule stored in this record, a map between the atoms and the molecule and "
"their coordinates.",
)
conformer: Coordinates = Field(
...,
description="The coordinates [Angstrom] of this conformer with "
"shape=(n_atoms, 3).",
)
multipoles: MP = Field(
...,
description="The multipole moments [AU] for each atom in the molecule.",
)
mbis_settings: MBISSettings = Field(
..., description="The settings used to generate the MBIS stored in this record."
)
@property
def conformer_quantity(self) -> Quantity:
return Quantity(self.conformer, "angstrom")
@property
def multipoles_quantity(self) -> Quantity:
return Quantity(self.multipoles, "AU")
[docs]
@classmethod
def from_molecule(
cls,
molecule: Molecule,
conformer: Quantity,
multipoles: Quantity,
mbis_settings: MBISSettings,
) -> "MoleculeMBISRecord":
"""Create a new ``MoleculeMBISRecord`` from an existing molecule.
Takes care of creating the InChI and SMARTS representations.
Parameters
----------
molecule
The molecule to store in the record.
conformer
The coordinates [Angstrom] of this conformer with shape=(n_atoms, 3).
multipoles
The multipole moments [AU] for each atom in the molecule.
mbis_settings
The settings used to generate the MBIS stored in this record.
Returns
-------
The created record.
"""
tagged_smiles = molecule.to_smiles(
isomeric=True, explicit_hydrogens=True, mapped=True
)
return MoleculeMBISRecord(
tagged_smiles=tagged_smiles,
conformer=conformer,
multipoles=multipoles,
mbis_settings=mbis_settings,
)
[docs]
class MoleculeMBISStore:
"""Store and retrieve MBIS results for molecules in multiple conformers.
This class currently can only store the data in a SQLite database.
"""
@property
def db_version(self) -> int:
with self._get_session() as db:
db_info = db.query(DBInformation).first()
return db_info.version
@property
def general_provenance(self) -> dict[str, str]:
with self._get_session() as db:
db_info = db.query(DBInformation).first()
return {
provenance.key: provenance.value
for provenance in db_info.general_provenance
}
@property
def software_provenance(self) -> dict[str, str]:
with self._get_session() as db:
db_info = db.query(DBInformation).first()
return {
provenance.key: provenance.value
for provenance in db_info.software_provenance
}
def __init__(
self,
database_path: str = "mbis-store.sqlite",
cache_size: None | int = None,
) -> None:
"""Initialize the MBIS store.
Parameters
----------
database_path
The path to the SQLite database to store to and retrieve data from.
cache_size
The size in pages (20000 pages (~20MB)) of the cache size of the db
"""
self._database_url = f"sqlite:///{database_path}"
self._engine = create_engine(self._database_url, echo=False)
DBBase.metadata.create_all(self._engine)
if cache_size:
@event.listens_for(self._engine, "connect")
def set_sqlite_pragma(
dbapi_connection: object, _connection_record: object
) -> None:
cursor = dbapi_connection.cursor()
cursor.execute(
f"PRAGMA cache_size = -{cache_size}"
) # 20000 pages (~20MB), adjust based on your needs
cursor.execute(
"PRAGMA synchronous = OFF"
) # Improves speed but less safe
cursor.execute(
"PRAGMA journal_mode = MEMORY"
) # Use in-memory journaling
cursor.close()
self._session_maker = sessionmaker(
autocommit=False, autoflush=False, bind=self._engine
)
# Validate the DB version if present, or add one if not.
with self._get_session() as db:
db_info = db.query(DBInformation).first()
if not db_info:
db_info = DBInformation(version=DB_VERSION)
db.add(db_info)
if db_info.version != DB_VERSION:
raise IncompatibleDBVersion(db_info.version, DB_VERSION)
[docs]
def set_provenance(
self,
general_provenance: dict[str, str],
software_provenance: dict[str, str],
) -> None:
"""Set the stores provenance information.
Parameters
----------
general_provenance
A dictionary storing provenance about the store such as the author,
which QCArchive data set it was generated from, when it was generated
etc.
software_provenance
A dictionary storing the provenance of the software and packages used
to generate the data in the store.
"""
with self._get_session() as db:
db_info: DBInformation = db.query(DBInformation).first()
db_info.general_provenance = [
DBGeneralProvenance(key=key, value=value)
for key, value in general_provenance.items()
]
db_info.software_provenance = [
DBSoftwareProvenance(key=key, value=value)
for key, value in software_provenance.items()
]
@contextmanager
def _get_session(self) -> AbstractContextManager[Session]:
session = self._session_maker()
try:
yield session
session.commit()
except BaseException:
session.rollback()
raise
finally:
session.close()
@classmethod
def _db_records_to_model(
cls, db_records: list[DBMoleculeRecord]
) -> list[MoleculeMBISRecord]:
"""Map a set of database records into their corresponding data models.
Parameters
----------
db_records
The records to map.
Returns
-------
The mapped data models.
"""
# noinspection PyTypeChecker
return [
MoleculeMBISRecord(
tagged_smiles=db_conformer.tagged_smiles,
conformer=db_conformer.coordinates,
multipoles=db_conformer.multipoles,
mbis_settings=DBMBISSettings.db_to_instance(db_conformer.mbis_settings),
)
for db_record in db_records
for db_conformer in db_record.conformers
]
@classmethod
def _store_smiles_records(
cls, db: Session, smiles: str, records: list[MoleculeMBISRecord]
) -> DBMoleculeRecord:
"""Store a set of records which all store information for the same molecule.
Parameters
----------
db
The current database session.
smiles
The smiles representation of the molecule.
records
The records to store.
"""
existing_db_molecule = (
db.query(DBMoleculeRecord).filter(DBMoleculeRecord.smiles == smiles).first()
)
if existing_db_molecule is not None:
db_record = existing_db_molecule
else:
db_record = DBMoleculeRecord(smiles=smiles)
# noinspection PyTypeChecker
# noinspection PyUnresolvedReferences
db_record.conformers.extend(
DBConformerRecord(
tagged_smiles=record.tagged_smiles,
coordinates=record.conformer,
multipoles=record.multipoles,
mbis_settings=DBMBISSettings.unique(db, record.mbis_settings),
)
for record in records
)
if existing_db_molecule is None:
db.add(db_record)
return db_record
@classmethod
@functools.lru_cache(10000)
def _tagged_to_canonical_smiles(cls, tagged_smiles: str) -> str:
"""Convert a smiles pattern with atom indices to canonical smiles.
Parameters
----------
tagged_smiles
The tagged smiles pattern to convert.
Returns
-------
The canonical smiles pattern.
"""
from openff.toolkit import Molecule
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=AtomMappingWarning)
return Molecule.from_smiles(
tagged_smiles, allow_undefined_stereo=True
).to_smiles(isomeric=False, explicit_hydrogens=False, mapped=False)
[docs]
def store(self, *records: MoleculeMBISRecord) -> None:
"""Store the MBIS calculated for a given molecule in the data store.
Parameters
----------
records
The records to store.
Returns
-------
The records as they appear in the store.
"""
# Validate and re-partition the records by their smiles patterns.
records_by_smiles: dict[str, list[MoleculeMBISRecord]] = defaultdict(list)
for record in records:
validated_record = MoleculeMBISRecord(**record.model_dump())
smiles = self._tagged_to_canonical_smiles(validated_record.tagged_smiles)
records_by_smiles[smiles].append(validated_record)
# Store the records.
with self._get_session() as db:
for smiles, smiles_records in records_by_smiles.items():
self._store_smiles_records(db, smiles, smiles_records)
[docs]
def retrieve(
self,
smiles: str | None = None,
basis: str | None = None,
method: str | None = None,
) -> list[MoleculeMBISRecord]:
"""Retrieve records stored in this data store.
Optionally filters according to a set of criteria.
"""
with self._get_session() as db:
db_records = db.query(DBMoleculeRecord)
if smiles is not None:
smiles = self._tagged_to_canonical_smiles(smiles)
db_records = db_records.filter(DBMoleculeRecord.smiles == smiles)
if basis is not None or method is not None:
db_records = db_records.join(DBConformerRecord)
if basis is not None or method is not None:
db_records = db_records.join(
DBMBISSettings, DBConformerRecord.mbis_settings
)
if basis is not None:
db_records = db_records.filter(DBMBISSettings.basis == basis)
if method is not None:
db_records = db_records.filter(DBMBISSettings.method == method)
db_records = db_records.all()
records = self._db_records_to_model(db_records)
if basis:
records = [
record for record in records if record.mbis_settings.basis == basis
]
if method:
records = [
record
for record in records
if record.mbis_settings.method == method
]
return records
[docs]
def list(self) -> list[str]:
"""List the molecules which exist in and may be retrieved from the store."""
with self._get_session() as db:
return [smiles for (smiles,) in db.query(DBMoleculeRecord.smiles).all()]