Coverage for cvpack/utils.py: 89%
92 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 16:14 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 16:14 +0000
1"""
2.. class:: Utils
3 :platform: Linux, MacOS, Windows
4 :synopsis: Utility functions and classes for CVpack
6.. classauthor:: Charlles Abreu <craabreu@gmail.com>
8"""
10import functools
11import inspect
12import typing as t
14import numpy as np
15import openmm
16from numpy import typing as npt
17from openmm import XmlSerializer
18from openmm import app as mmapp
19from openmm import unit as mmunit
21from .serialization import (
22 Serializable,
23 SerializableAtom,
24 SerializableForce,
25 SerializableResidue,
26)
27from .units import Quantity, Unit, value_in_md_units
29# pylint: disable=protected-access,c-extension-no-member
32def evaluate_in_context(
33 forces: t.Union[openmm.Force, t.Iterable[openmm.Force]], context: openmm.Context
34) -> t.Union[float, t.List[float]]:
35 """Evaluate the potential energies of OpenMM Forces in a given context.
37 Parameters
38 ----------
39 forces
40 The forces to be evaluated.
41 context
42 The context in which the force will be evaluated.
44 Returns
45 -------
46 float
47 The potential energy of the force in the given context.
48 """
49 is_single = isinstance(forces, openmm.Force)
50 if is_single:
51 forces = [forces]
52 system = openmm.System()
53 for _ in range(context.getSystem().getNumParticles()):
54 system.addParticle(1.0)
55 for i, force in enumerate(forces):
56 force_copy = XmlSerializer.deserialize(XmlSerializer.serialize(force))
57 force_copy.setForceGroup(i)
58 system.addForce(force_copy)
59 state = context.getState(getPositions=True)
60 context = openmm.Context(system, openmm.VerletIntegrator(1.0))
61 context.setPositions(state.getPositions())
62 context.setPeriodicBoxVectors(*state.getPeriodicBoxVectors())
63 energies = []
64 for i in range(len(forces)):
65 state = context.getState( # pylint: disable=unexpected-keyword-arg
66 getEnergy=True, groups=1 << i
67 )
68 energies.append(value_in_md_units(state.getPotentialEnergy()))
69 return energies[0] if is_single else tuple(energies)
72def convert_to_matrix(array: npt.ArrayLike) -> t.Tuple[np.ndarray, int, int]:
73 """Convert a 1D or 2D array-like object to a 2D numpy array.
75 Parameters
76 ----------
77 array : array_like
78 The array to be converted.
80 Returns
81 -------
82 numpy.ndarray
83 The 2D numpy array.
84 int
85 The number of rows in the array.
86 int
87 The number of columns in the array.
88 """
89 array = np.atleast_2d(array)
90 numrows, numcols, *other_dimensions = array.shape
91 if other_dimensions:
92 raise ValueError("Array-like object cannot have more than two dimensions.")
93 return array, numrows, numcols
96def get_single_force_state(
97 force: openmm.Force,
98 context: openmm.Context,
99 allowReinitialization: bool = False,
100 **kwargs: bool,
101) -> openmm.State:
102 """
103 Get an OpenMM State containing the potential energy and/or force values computed
104 from a single force object.
106 Parameters
107 ----------
108 force
109 The force object from which the state should be extracted.
110 context
111 The context from which the state should be extracted.
112 allowReinitialization
113 If True, the force group of the given force will be temporarily changed to a
114 group that is not used by any other force in the system, if necessary.
116 Keyword Args
117 ------------
118 **kwargs
119 Additional keyword arguments to be passed to the `getState` method, except for
120 the `groups` argument.
122 Returns
123 -------
124 openmm.State
125 The state containing the requested values.
127 Raises
128 ------
129 ValueError
130 If this force is not present in the given context.
131 """
132 forces_and_groups = [
133 (f, f.getForceGroup()) for f in context.getSystem().getForces()
134 ]
135 if not any(f.this == force.this for f, _ in forces_and_groups):
136 raise RuntimeError("This force is not present in the given context.")
137 self_group = force.getForceGroup()
138 other_groups = {g for f, g in forces_and_groups if f.this != force.this}
139 if self_group not in other_groups:
140 return context.getState(groups=1 << self_group, **kwargs)
141 if not allowReinitialization:
142 raise ValueError("Context reinitialization required, but not allowed.")
143 new_group = force._setUnusedForceGroup(context.getSystem())
144 context.reinitialize(preserveState=True)
145 state = context.getState(groups=1 << new_group, **kwargs)
146 force.setForceGroup(self_group)
147 context.reinitialize(preserveState=True)
148 return state
151def compute_effective_mass(
152 force: openmm.Force, context: openmm.Context, allowReinitialization: bool = False
153) -> float:
154 r"""
155 Compute the effective mass of an :OpenMM:`Force` at a given :OpenMM:`Context`.
157 Parameters
158 ----------
159 force
160 The force object from which the effective mass should be computed
161 context
162 The context at which the force's effective mass should be evaluated
163 allowReinitialization
164 If True, the force group of the given force will be temporarily changed to a
165 group that is not used by any other force in the system, if necessary.
167 Returns
168 -------
169 float
170 The effective mass of the force at the given context
171 """
172 state = get_single_force_state(
173 force, context, allowReinitialization, getForces=True
174 )
175 get_mass = functools.partial(
176 openmm._openmm.System_getParticleMass, context.getSystem()
177 )
178 force_vectors = state.getForces(asNumpy=True)._value
179 squared_forces = np.sum(np.square(force_vectors), axis=1)
180 nonzeros = np.nonzero(squared_forces)[0]
181 if nonzeros.size == 0:
182 return np.inf
183 mass_values = np.fromiter(map(get_mass, nonzeros), dtype=np.float64)
184 return 1.0 / np.sum(squared_forces[nonzeros] / mass_values)
187def preprocess_args(func: t.Callable) -> t.Callable:
188 """
189 A decorator that converts instances of unserializable classes to their
190 serializable counterparts.
192 Parameters
193 ----------
194 func
195 The function to be decorated.
197 Returns
198 -------
199 The decorated function.
201 Example
202 -------
203 >>> from cvpack import units, utils
204 >>> from openmm import unit as mmunit
205 >>> @utils.preprocess_args
206 ... def function(data):
207 ... return data
208 >>> assert isinstance(function(mmunit.angstrom), units.Unit)
209 >>> assert isinstance(function(5 * mmunit.angstrom), units.Quantity)
210 >>> seq = [mmunit.angstrom, mmunit.nanometer]
211 >>> assert isinstance(function(seq), list)
212 >>> assert all(isinstance(item, units.Unit) for item in function(seq))
213 >>> dct = {"length": 3 * mmunit.angstrom, "time": 2 * mmunit.picosecond}
214 >>> assert isinstance(function(dct), dict)
215 >>> assert all(isinstance(item, units.Quantity) for item in function(dct).values())
216 """
217 signature = inspect.signature(func)
219 def convert(data: t.Any) -> t.Any: # pylint: disable=too-many-return-statements
220 if isinstance(data, np.integer):
221 return int(data)
222 if isinstance(data, np.floating):
223 return float(data)
224 if isinstance(data, mmunit.Quantity):
225 return Quantity(data)
226 if isinstance(data, mmunit.Unit):
227 return Unit(data)
228 if isinstance(data, mmapp.Atom):
229 return SerializableAtom(data)
230 if isinstance(data, mmapp.Residue):
231 return SerializableResidue(data)
232 if isinstance(data, openmm.Force) and not isinstance(data, Serializable):
233 return SerializableForce(data)
234 if isinstance(data, t.Sequence) and not isinstance(data, str):
235 return type(data)(map(convert, data))
236 if isinstance(data, t.Dict):
237 return type(data)((key, convert(value)) for key, value in data.items())
238 return data
240 @functools.wraps(func)
241 def wrapper(*args, **kwargs):
242 bound = signature.bind(*args, **kwargs)
243 for name, data in bound.arguments.items():
244 bound.arguments[name] = convert(data)
245 return func(*bound.args, **bound.kwargs)
247 return wrapper