Coverage for cvpack/utils.py: 89%

92 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-09 16:14 +0000

1""" 

2.. class:: Utils 

3 :platform: Linux, MacOS, Windows 

4 :synopsis: Utility functions and classes for CVpack 

5 

6.. classauthor:: Charlles Abreu <craabreu@gmail.com> 

7 

8""" 

9 

10import functools 

11import inspect 

12import typing as t 

13 

14import numpy as np 

15import openmm 

16from numpy import typing as npt 

17from openmm import XmlSerializer 

18from openmm import app as mmapp 

19from openmm import unit as mmunit 

20 

21from .serialization import ( 

22 Serializable, 

23 SerializableAtom, 

24 SerializableForce, 

25 SerializableResidue, 

26) 

27from .units import Quantity, Unit, value_in_md_units 

28 

29# pylint: disable=protected-access,c-extension-no-member 

30 

31 

32def evaluate_in_context( 

33 forces: t.Union[openmm.Force, t.Iterable[openmm.Force]], context: openmm.Context 

34) -> t.Union[float, t.List[float]]: 

35 """Evaluate the potential energies of OpenMM Forces in a given context. 

36 

37 Parameters 

38 ---------- 

39 forces 

40 The forces to be evaluated. 

41 context 

42 The context in which the force will be evaluated. 

43 

44 Returns 

45 ------- 

46 float 

47 The potential energy of the force in the given context. 

48 """ 

49 is_single = isinstance(forces, openmm.Force) 

50 if is_single: 

51 forces = [forces] 

52 system = openmm.System() 

53 for _ in range(context.getSystem().getNumParticles()): 

54 system.addParticle(1.0) 

55 for i, force in enumerate(forces): 

56 force_copy = XmlSerializer.deserialize(XmlSerializer.serialize(force)) 

57 force_copy.setForceGroup(i) 

58 system.addForce(force_copy) 

59 state = context.getState(getPositions=True) 

60 context = openmm.Context(system, openmm.VerletIntegrator(1.0)) 

61 context.setPositions(state.getPositions()) 

62 context.setPeriodicBoxVectors(*state.getPeriodicBoxVectors()) 

63 energies = [] 

64 for i in range(len(forces)): 

65 state = context.getState( # pylint: disable=unexpected-keyword-arg 

66 getEnergy=True, groups=1 << i 

67 ) 

68 energies.append(value_in_md_units(state.getPotentialEnergy())) 

69 return energies[0] if is_single else tuple(energies) 

70 

71 

72def convert_to_matrix(array: npt.ArrayLike) -> t.Tuple[np.ndarray, int, int]: 

73 """Convert a 1D or 2D array-like object to a 2D numpy array. 

74 

75 Parameters 

76 ---------- 

77 array : array_like 

78 The array to be converted. 

79 

80 Returns 

81 ------- 

82 numpy.ndarray 

83 The 2D numpy array. 

84 int 

85 The number of rows in the array. 

86 int 

87 The number of columns in the array. 

88 """ 

89 array = np.atleast_2d(array) 

90 numrows, numcols, *other_dimensions = array.shape 

91 if other_dimensions: 

92 raise ValueError("Array-like object cannot have more than two dimensions.") 

93 return array, numrows, numcols 

94 

95 

96def get_single_force_state( 

97 force: openmm.Force, 

98 context: openmm.Context, 

99 allowReinitialization: bool = False, 

100 **kwargs: bool, 

101) -> openmm.State: 

102 """ 

103 Get an OpenMM State containing the potential energy and/or force values computed 

104 from a single force object. 

105 

106 Parameters 

107 ---------- 

108 force 

109 The force object from which the state should be extracted. 

110 context 

111 The context from which the state should be extracted. 

112 allowReinitialization 

113 If True, the force group of the given force will be temporarily changed to a 

114 group that is not used by any other force in the system, if necessary. 

115 

116 Keyword Args 

117 ------------ 

118 **kwargs 

119 Additional keyword arguments to be passed to the `getState` method, except for 

120 the `groups` argument. 

121 

122 Returns 

123 ------- 

124 openmm.State 

125 The state containing the requested values. 

126 

127 Raises 

128 ------ 

129 ValueError 

130 If this force is not present in the given context. 

131 """ 

132 forces_and_groups = [ 

133 (f, f.getForceGroup()) for f in context.getSystem().getForces() 

134 ] 

135 if not any(f.this == force.this for f, _ in forces_and_groups): 

136 raise RuntimeError("This force is not present in the given context.") 

137 self_group = force.getForceGroup() 

138 other_groups = {g for f, g in forces_and_groups if f.this != force.this} 

139 if self_group not in other_groups: 

140 return context.getState(groups=1 << self_group, **kwargs) 

141 if not allowReinitialization: 

142 raise ValueError("Context reinitialization required, but not allowed.") 

143 new_group = force._setUnusedForceGroup(context.getSystem()) 

144 context.reinitialize(preserveState=True) 

145 state = context.getState(groups=1 << new_group, **kwargs) 

146 force.setForceGroup(self_group) 

147 context.reinitialize(preserveState=True) 

148 return state 

149 

150 

151def compute_effective_mass( 

152 force: openmm.Force, context: openmm.Context, allowReinitialization: bool = False 

153) -> float: 

154 r""" 

155 Compute the effective mass of an :OpenMM:`Force` at a given :OpenMM:`Context`. 

156 

157 Parameters 

158 ---------- 

159 force 

160 The force object from which the effective mass should be computed 

161 context 

162 The context at which the force's effective mass should be evaluated 

163 allowReinitialization 

164 If True, the force group of the given force will be temporarily changed to a 

165 group that is not used by any other force in the system, if necessary. 

166 

167 Returns 

168 ------- 

169 float 

170 The effective mass of the force at the given context 

171 """ 

172 state = get_single_force_state( 

173 force, context, allowReinitialization, getForces=True 

174 ) 

175 get_mass = functools.partial( 

176 openmm._openmm.System_getParticleMass, context.getSystem() 

177 ) 

178 force_vectors = state.getForces(asNumpy=True)._value 

179 squared_forces = np.sum(np.square(force_vectors), axis=1) 

180 nonzeros = np.nonzero(squared_forces)[0] 

181 if nonzeros.size == 0: 

182 return np.inf 

183 mass_values = np.fromiter(map(get_mass, nonzeros), dtype=np.float64) 

184 return 1.0 / np.sum(squared_forces[nonzeros] / mass_values) 

185 

186 

187def preprocess_args(func: t.Callable) -> t.Callable: 

188 """ 

189 A decorator that converts instances of unserializable classes to their 

190 serializable counterparts. 

191 

192 Parameters 

193 ---------- 

194 func 

195 The function to be decorated. 

196 

197 Returns 

198 ------- 

199 The decorated function. 

200 

201 Example 

202 ------- 

203 >>> from cvpack import units, utils 

204 >>> from openmm import unit as mmunit 

205 >>> @utils.preprocess_args 

206 ... def function(data): 

207 ... return data 

208 >>> assert isinstance(function(mmunit.angstrom), units.Unit) 

209 >>> assert isinstance(function(5 * mmunit.angstrom), units.Quantity) 

210 >>> seq = [mmunit.angstrom, mmunit.nanometer] 

211 >>> assert isinstance(function(seq), list) 

212 >>> assert all(isinstance(item, units.Unit) for item in function(seq)) 

213 >>> dct = {"length": 3 * mmunit.angstrom, "time": 2 * mmunit.picosecond} 

214 >>> assert isinstance(function(dct), dict) 

215 >>> assert all(isinstance(item, units.Quantity) for item in function(dct).values()) 

216 """ 

217 signature = inspect.signature(func) 

218 

219 def convert(data: t.Any) -> t.Any: # pylint: disable=too-many-return-statements 

220 if isinstance(data, np.integer): 

221 return int(data) 

222 if isinstance(data, np.floating): 

223 return float(data) 

224 if isinstance(data, mmunit.Quantity): 

225 return Quantity(data) 

226 if isinstance(data, mmunit.Unit): 

227 return Unit(data) 

228 if isinstance(data, mmapp.Atom): 

229 return SerializableAtom(data) 

230 if isinstance(data, mmapp.Residue): 

231 return SerializableResidue(data) 

232 if isinstance(data, openmm.Force) and not isinstance(data, Serializable): 

233 return SerializableForce(data) 

234 if isinstance(data, t.Sequence) and not isinstance(data, str): 

235 return type(data)(map(convert, data)) 

236 if isinstance(data, t.Dict): 

237 return type(data)((key, convert(value)) for key, value in data.items()) 

238 return data 

239 

240 @functools.wraps(func) 

241 def wrapper(*args, **kwargs): 

242 bound = signature.bind(*args, **kwargs) 

243 for name, data in bound.arguments.items(): 

244 bound.arguments[name] = convert(data) 

245 return func(*bound.args, **bound.kwargs) 

246 

247 return wrapper