from collections.abc import Generator
from typing import TYPE_CHECKING
import numpy as np
from openff.recharge.charges.vsite import (
VirtualSiteChargeKey,
VirtualSiteCollection,
VirtualSiteGeometryKey,
)
from openff.recharge.optimize._optimize import Objective, ObjectiveTerm
from openff.units import unit
from pympfit.gdma.storage import MoleculeGDMARecord
from pympfit.mbis.storage import MoleculeMBISRecord
# Type alias for records that can be used with MPFIT
MultipoleRecord = MoleculeGDMARecord | MoleculeMBISRecord
if TYPE_CHECKING:
import torch
[docs]
class MPFITObjectiveTerm(ObjectiveTerm):
"""Store precalculated values for multipole moment fitting.
Computes the difference between a reference set of distributed multipole
moments and a set computed using fixed partial charges.
Attributes
----------
multipole_record : MultipoleRecord | None
Reference to the source multipole record (GDMA or MBIS).
quse_masks : np.ndarray | None
Boolean masks indicating which charges are included for each multipole
site.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.multipole_record: MultipoleRecord | None = None
self.quse_masks: np.ndarray | None = None
@classmethod
def _objective(cls) -> type["MPFITObjective"]:
return MPFITObjective
[docs]
def predict(
self,
charge_parameters: "torch.Tensor",
vsite_coordinate_parameters: "torch.Tensor | None" = None,
) -> list["torch.Tensor"]:
"""Predict multipole moment contributions for given charges and vsite positions.
This method is designed for Bayesian inference and rebuilds A matrices
with new vsite positions, computing A @ q for each multipole site.
Parameters
----------
charge_parameters : torch.Tensor
Charge values with shape (n_atom_charges + n_vsite_charges, 1).
vsite_coordinate_parameters : torch.Tensor, optional
Virtual site local frame coordinates (distance, angles) being sampled.
Shape (n_trainable_coords, 1).
Returns
-------
list[torch.Tensor]
Predicted multipole contributions for each atom site, matching the
shape of ``reference_values``.
Raises
------
ValueError
If called on a term without virtual sites (use SVD solver instead).
"""
try:
import torch
except ImportError:
raise ImportError(
"predict() requires PyTorch for differentiable inference. "
"Install with: pip install torch sphericart[torch]"
) from None
from openff.recharge.charges.vsite import VirtualSiteGenerator
from openff.recharge.utilities.tensors import append_zero
from pympfit.mpfit.core import build_A_matrix_torch
if self.vsite_local_coordinate_frame is None:
raise ValueError(
"predict() requires virtual sites. For atom-only charge fitting, "
"use the SVD solver directly via generate_mpfit_charge_parameter()."
)
n_vsites = self.vsite_local_coordinate_frame.shape[1]
if n_vsites == 0:
raise ValueError(
"predict() requires virtual sites. For atom-only charge fitting, "
"use the SVD solver directly via generate_mpfit_charge_parameter()."
)
# Get settings from record (works for both GDMA and MBIS)
record = self.multipole_record
if isinstance(record, MoleculeGDMARecord):
settings = record.gdma_settings
max_rank = settings.limit # GDMA uses 0-based indexing
else:
settings = record.mbis_settings
if settings.multipole_format == "cartesian":
raise NotImplementedError(
"MPFIT against cartesian MBIS multipoles is not implemented. "
"Construct MBISSettings with multipole_format='spherical'."
)
max_rank = settings.limit - 1 # MBIS uses 1-based indexing
r1 = settings.mpfit_inner_radius
r2 = settings.mpfit_outer_radius
bohr_conformer_np = unit.convert(record.conformer, unit.angstrom, unit.bohr)
bohr_conformer = torch.from_numpy(bohr_conformer_np)
n_atoms = bohr_conformer.shape[0]
# compute new vsite Cartesian positions from local frame
trainable = append_zero(vsite_coordinate_parameters.flatten())[
self.vsite_coord_assignment_matrix
]
vsite_fixed_coords_t = torch.from_numpy(self.vsite_fixed_coords)
vsite_local_coords = vsite_fixed_coords_t + trainable
# Convert local coords (distance, angles) to Cartesian positions
vsite_local_frame_t = torch.from_numpy(self.vsite_local_coordinate_frame)
vsite_coords_angstrom = VirtualSiteGenerator.convert_local_coordinates(
vsite_local_coords, vsite_local_frame_t, backend="torch"
)
angstrom_to_bohr = unit.convert(1.0, unit.angstrom, unit.bohr)
vsite_coords_bohr = vsite_coords_angstrom * angstrom_to_bohr
augmented_coords = torch.cat([bohr_conformer, vsite_coords_bohr], dim=0)
n_trainable_vsite_charges = self.vsite_charge_assignment_matrix.shape[1]
atom_charges = charge_parameters[:n_atoms]
if n_trainable_vsite_charges > 0:
# redistribute vsite charge increments to parent atoms
trainable_vsite_charges = charge_parameters[n_atoms:]
vsite_charge_matrix_t = torch.from_numpy(
self.vsite_charge_assignment_matrix
)
vsite_fixed_charges_t = torch.from_numpy(self.vsite_fixed_charges)
charge_adjustment = (
vsite_charge_matrix_t @ trainable_vsite_charges + vsite_fixed_charges_t
)
else:
charge_adjustment = torch.from_numpy(self.vsite_fixed_charges)
atom_adjustment = charge_adjustment[:n_atoms]
vsite_charges = charge_adjustment[n_atoms:]
final_atom_charges = atom_charges + atom_adjustment
all_charges = torch.cat([final_atom_charges, vsite_charges], dim=0)
predictions = []
for i in range(n_atoms):
# Use stored quse_mask to ensure consistent shape with reference_values
quse_mask = torch.from_numpy(self.quse_masks[i].astype(bool))
masked_coords = augmented_coords[quse_mask]
masked_charges = all_charges[quse_mask]
site_A = build_A_matrix_torch(
i, bohr_conformer, masked_coords, r1, r2, max_rank
)
site_pred = site_A @ masked_charges
predictions.append(site_pred)
return predictions
[docs]
def predict_from_free_charges(
self,
free_charges: "torch.Tensor",
vsite_charge_increments: "torch.Tensor",
vsite_coordinate_parameters: "torch.Tensor | None" = None,
formal_charge: float = 0.0,
) -> list["torch.Tensor"]:
"""Predict multipoles with automatic charge conservation.
A convenience wrapper around ``predict()`` that enforces total charge
conservation. Users provide N-1 free atom charges and this method
computes the constrained last charge to satisfy the formal charge.
Parameters
----------
free_charges : torch.Tensor
Free charge parameters with shape (n_atoms - 1, 1).
vsite_charge_increments : torch.Tensor
Virtual site charge increments with shape (n_vsite_charges, 1).
vsite_coordinate_parameters : torch.Tensor, optional
Virtual site local frame coordinates being sampled.
formal_charge : float
Total molecular charge (default 0.0 for neutral molecules).
Returns
-------
list[torch.Tensor]
Predicted multipole contributions for each atom site.
"""
try:
import torch
except ImportError:
raise ImportError(
"predict_from_free_charges() requires PyTorch. "
"Install with: pip install torch"
) from None
last_charge = formal_charge - free_charges.sum(dim=0, keepdim=True)
atom_charges = torch.cat([free_charges, last_charge], dim=0)
charge_parameters = torch.cat([atom_charges, vsite_charge_increments], dim=0)
return self.predict(charge_parameters, vsite_coordinate_parameters)
[docs]
class MPFITObjective(Objective):
"""Compute contributions to the MPFIT least squares objective function.
Contains helper functions for capturing the deviation of multipole moments
computed using molecular partial charges from GDMA calculations.
"""
@classmethod
def _objective_term(cls) -> type[MPFITObjectiveTerm]:
return MPFITObjectiveTerm
[docs]
@classmethod
def compute_objective_terms(
cls,
multipole_records: list[MultipoleRecord],
vsite_collection: VirtualSiteCollection | None = None,
_vsite_charge_parameter_keys: list[VirtualSiteChargeKey] | None = None,
_vsite_coordinate_parameter_keys: list[VirtualSiteGeometryKey] | None = None,
return_quse_masks: bool = False,
fit_limit: int | None = None,
) -> Generator[tuple[MPFITObjectiveTerm, dict] | MPFITObjectiveTerm, None, None]:
"""Pre-calculates the terms that contribute to the total objective function.
Parameters
----------
fit_limit
Optional maximum multipole rank (0-indexed) for fitting. See
``extract_arrays`` for details.
"""
from pympfit.mpfit.core import build_A_matrix, build_b_vector
for multipole_record in multipole_records:
arrays = cls.extract_arrays(multipole_record, fit_limit=fit_limit)
bohr_conformer = arrays["bohr_conformer"]
multipoles = arrays["multipoles"]
rvdw = arrays["rvdw"]
r1 = arrays["r1"]
r2 = arrays["r2"]
max_rank = arrays["maxl"]
n_atoms = arrays["n_atoms"]
if vsite_collection is not None:
from openff.recharge.charges.vsite import VirtualSiteGenerator
from openff.toolkit import Molecule
molecule = Molecule.from_mapped_smiles(
multipole_record.tagged_smiles, allow_undefined_stereo=True
)
conformer_angstrom = multipole_record.conformer
vsite_positions = VirtualSiteGenerator.generate_positions(
molecule, vsite_collection, conformer_angstrom * unit.angstrom
)
vsite_positions_bohr = unit.convert(
vsite_positions.m_as(unit.angstrom), unit.angstrom, unit.bohr
)
n_vsites = vsite_positions_bohr.shape[0]
(vsite_charge_assignment_matrix, vsite_fixed_charges) = (
cls._compute_vsite_charge_terms(
molecule, vsite_collection, _vsite_charge_parameter_keys or []
)
)
(
vsite_coord_assignment_matrix,
vsite_fixed_coords,
vsite_local_coordinate_frame,
) = cls._compute_vsite_coord_terms(
molecule,
conformer_angstrom,
vsite_collection,
_vsite_coordinate_parameter_keys or [],
)
augmented_coords_bohr = np.vstack(
[bohr_conformer, vsite_positions_bohr]
)
rvdw = np.concatenate([rvdw, np.full(n_vsites, arrays["r1"])])
else:
n_vsites = 0
augmented_coords_bohr = bohr_conformer
vsite_charge_assignment_matrix = None
vsite_fixed_charges = None
vsite_coord_assignment_matrix = None
vsite_fixed_coords = None
vsite_local_coordinate_frame = None
atom_charge_design_matrices = []
reference_values = []
quse_masks = []
for i in range(n_atoms):
rqm = np.linalg.norm(augmented_coords_bohr - bohr_conformer[i], axis=1)
quse_mask = rqm < rvdw[i]
quse_masks.append(quse_mask)
qsites = np.count_nonzero(quse_mask)
site_A = np.zeros((qsites, qsites))
site_b = np.zeros(qsites)
masked_charge_conformer = augmented_coords_bohr[quse_mask]
if masked_charge_conformer.shape[0] == 0:
masked_charge_conformer = augmented_coords_bohr
quse_masks[-1] = np.ones(n_atoms + n_vsites, dtype=bool)
site_A = build_A_matrix(
i,
bohr_conformer,
masked_charge_conformer,
r1,
r2,
max_rank,
site_A,
)
site_b = build_b_vector(
i,
bohr_conformer,
masked_charge_conformer,
r1,
r2,
max_rank,
multipoles,
site_b,
)
atom_charge_design_matrices.append(site_A)
reference_values.append(site_b)
atom_charge_design_matrix = np.array(
atom_charge_design_matrices, dtype=object
)
reference_values = np.array(reference_values, dtype=object)
quse_masks = np.array(quse_masks, dtype=object)
# Base class assertion requires all vsite fields to be non-None
_GRID_PLACEHOLDER = np.empty((0, 3)) if n_vsites > 0 else None
objective_term = cls._objective_term()(
atom_charge_design_matrix,
vsite_charge_assignment_matrix,
vsite_fixed_charges,
vsite_coord_assignment_matrix,
vsite_fixed_coords,
vsite_local_coordinate_frame,
_GRID_PLACEHOLDER,
reference_values,
)
objective_term.multipole_record = multipole_record
objective_term.quse_masks = quse_masks
if return_quse_masks:
yield objective_term, {"quse_masks": quse_masks, "n_vsites": n_vsites}
else:
yield objective_term