[docs]classMPFITObjectiveTerm(ObjectiveTerm):"""Store precalculated values for multipole moment fitting. Computes the difference between a reference set of distributed multipole moments and a set computed using fixed partial charges. See the ``predict`` and ``loss`` functions for more details. """@classmethoddef_objective(cls)->type["MPFITObjective"]:returnMPFITObjective
[docs]classMPFITObjective(Objective):"""Compute contributions to the MPFIT least squares objective function. Contains helper functions for capturing the deviation of multipole moments computed using molecular partial charges from GDMA calculations. """@classmethoddef_objective_term(cls)->type[MPFITObjectiveTerm]:returnMPFITObjectiveTerm
[docs]@classmethoddefextract_arrays(cls,gdma_record:MoleculeGDMARecord,)->dict:"""Extract numerical arrays from a single GDMA record."""fromopenff.toolkitimportMoleculefrompympfit.mpfit.coreimport_convert_flat_to_hierarchicalmolecule=Molecule.from_mapped_smiles(gdma_record.tagged_smiles,allow_undefined_stereo=True)settings=gdma_record.gdma_settingsbohr_conformer=unit.convert(gdma_record.conformer,unit.angstrom,unit.bohr)multipoles=_convert_flat_to_hierarchical(gdma_record.multipoles,molecule.n_atoms,settings.limit)return{"bohr_conformer":bohr_conformer,"multipoles":multipoles,"rvdw":np.full(molecule.n_atoms,settings.mpfit_atom_radius),"lmax":np.full(molecule.n_atoms,settings.limit,dtype=float),"r1":settings.mpfit_inner_radius,"r2":settings.mpfit_outer_radius,"maxl":settings.limit,"n_atoms":molecule.n_atoms,}
[docs]@classmethoddefcompute_objective_terms(cls,gdma_records:list[MoleculeGDMARecord],vsite_collection:VirtualSiteCollection|None=None,_vsite_charge_parameter_keys:list[VirtualSiteChargeKey]|None=None,_vsite_coordinate_parameter_keys:list[VirtualSiteGeometryKey]|None=None,return_quse_masks:bool=False,)->Generator[tuple[MPFITObjectiveTerm,dict]|MPFITObjectiveTerm,None,None]:"""Pre-calculates the terms that contribute to the total objective function."""frompympfit.mpfit.coreimportbuild_A_matrix,build_b_vectorforgdma_recordingdma_records:arrays=cls.extract_arrays(gdma_record)bohr_conformer=arrays["bohr_conformer"]multipoles=arrays["multipoles"]rvdw=arrays["rvdw"]r1=arrays["r1"]r2=arrays["r2"]max_rank=arrays["maxl"]n_atoms=arrays["n_atoms"]atom_charge_design_matrices=[]# Prepare the reference values and quse_masksreference_values=[]quse_masks=[]# Process each atom siteforiinrange(n_atoms):# Calculate distances from current multipole site to all atomsrqm=np.linalg.norm(bohr_conformer[i]-bohr_conformer,axis=1)# Create mask for atoms within rvdwquse_mask=rqm<rvdw[i]# Store the mask for later use by the solverquse_masks.append(quse_mask)qsites=np.count_nonzero(quse_mask)# Build the A matrix for this site's multipolessite_A=np.zeros((qsites,qsites))site_b=np.zeros(qsites)# Apply the mask to get charge positions to usemasked_charge_conformer=bohr_conformer[quse_mask]# If no charges are within range, use all chargesifmasked_charge_conformer.shape[0]==0:masked_charge_conformer=bohr_conformer# Update the mask to include all atomsquse_masks[-1]=np.ones(n_atoms,dtype=bool)# Use the multipole site coordinates and masked charge coordinatessite_A=build_A_matrix(i,bohr_conformer,masked_charge_conformer,r1,r2,max_rank,site_A,)site_b=build_b_vector(i,bohr_conformer,masked_charge_conformer,r1,r2,max_rank,multipoles,site_b,)atom_charge_design_matrices.append(site_A)reference_values.append(site_b)# We don't currently support virtual sites for MPFITifvsite_collectionisnotNone:raiseNotImplementedError("Virtual sites are not supported for MPFIT")atom_charge_design_matrix=np.array(atom_charge_design_matrices,dtype=object)reference_values=np.array(reference_values,dtype=object)quse_masks=np.array(quse_masks,dtype=object)objective_term=cls._objective_term()(atom_charge_design_matrix,None,# vsite_charge_assignment_matrixNone,# vsite_fixed_chargesNone,# vsite_coord_assignment_matrixNone,# vsite_fixed_coordsNone,# vsite_local_coordinate_frameNone,# grid_coordinates not needed for MPFITreference_values,)ifreturn_quse_masks:# Return the quse_masks along with the objective termyieldobjective_term,{"quse_masks":quse_masks}else:yieldobjective_term