Coverage for cvpack/reporting/state_data_reporter.py: 96%

24 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-09 16:14 +0000

1""" 

2.. class:: StateDataReporter 

3 :platform: Linux, MacOS, Windows 

4 :synopsis: An extension of the OpenMM `StateDataReporter`_ class to include writers 

5 

6.. moduleauthor:: Charlles Abreu <craabreu@gmail.com> 

7 

8""" 

9 

10import io 

11import typing as t 

12 

13import openmm 

14from openmm import app as mmapp 

15 

16from .custom_writer import CustomWriter 

17 

18 

19class StateDataReporter(mmapp.StateDataReporter): 

20 """ 

21 An extended version of OpenMM's `openmm.app.StateDataReporter`_ class that includes 

22 custom writers for reporting additional simulation data. 

23 

24 .. _openmm.app.StateDataReporter: http://docs.openmm.org/latest/api-python/ 

25 generated/openmm.app.statedatareporter.StateDataReporter.html 

26 

27 A custom writer is an object that includes the following methods: 

28 

29 1. **getHeaders**: returns a list of strings containing the headers to be added 

30 to the report. It must have the following signature: 

31 

32 .. code-block:: 

33 

34 def getHeaders(self) -> List[str]: 

35 pass 

36 

37 2. **getValues**: returns a list of floats containing the values to be added to 

38 the report at a given time step. It must have the following signature: 

39 

40 .. code-block:: 

41 

42 def getValues(self, simulation: openmm.app.Simulation) -> List[float]: 

43 pass 

44 

45 3. **initialize** (optional): performs any necessary setup before the first report. 

46 If present, it must have the following signature: 

47 

48 .. code-block:: 

49 

50 def initialize(self, simulation: openmm.app.Simulation) -> None: 

51 pass 

52 

53 Parameters 

54 ---------- 

55 file 

56 The file to write to. This can be a file name or a file object. 

57 reportInterval 

58 The interval (in time steps) at which to report data. 

59 writers 

60 A sequence of custom writers. 

61 **kwargs 

62 Additional keyword arguments to be passed to the `StateDataReporter`_ 

63 constructor. 

64 

65 Examples 

66 -------- 

67 >>> import cvpack 

68 >>> import openmm 

69 >>> from math import pi 

70 >>> from cvpack import reporting 

71 >>> from openmm import app, unit 

72 >>> from sys import stdout 

73 >>> from openmmtools import testsystems 

74 >>> model = testsystems.AlanineDipeptideVacuum() 

75 >>> phi = cvpack.Torsion(6, 8, 14, 16, name="phi") 

76 >>> psi = cvpack.Torsion(8, 14, 16, 18, name="psi") 

77 >>> umbrella = cvpack.MetaCollectiveVariable( 

78 ... f"0.5*kappa*(min(dphi,{2*pi}-dphi)^2+min(dpsi,{2*pi}-dpsi)^2)" 

79 ... "; dphi=abs(phi-phi0); dpsi=abs(psi-psi0)", 

80 ... [phi, psi], 

81 ... unit.kilojoules_per_mole, 

82 ... name="umbrella", 

83 ... kappa=100 * unit.kilojoules_per_mole/unit.radian**2, 

84 ... phi0=5*pi/6 * unit.radian, 

85 ... psi0=-5*pi/6 * unit.radian, 

86 ... ) 

87 >>> reporter = reporting.StateDataReporter( 

88 ... stdout, 

89 ... 100, 

90 ... writers=[ 

91 ... reporting.CVWriter(umbrella, value=True, emass=True), 

92 ... reporting.MetaCVWriter( 

93 ... umbrella, 

94 ... values=["phi", "psi"], 

95 ... emasses=["phi", "psi"], 

96 ... parameters=["phi0", "psi0"], 

97 ... derivatives=["phi0", "psi0"], 

98 ... ), 

99 ... ], 

100 ... step=True, 

101 ... ) 

102 >>> integrator = openmm.LangevinIntegrator( 

103 ... 300 * unit.kelvin, 

104 ... 1 / unit.picosecond, 

105 ... 2 * unit.femtosecond, 

106 ... ) 

107 >>> integrator.setRandomNumberSeed(1234) 

108 >>> umbrella.addToSystem(model.system) 

109 >>> simulation = app.Simulation(model.topology, model.system, integrator) 

110 >>> simulation.context.setPositions(model.positions) 

111 >>> simulation.context.setVelocitiesToTemperature(300 * unit.kelvin, 5678) 

112 >>> simulation.reporters.append(reporter) 

113 >>> simulation.step(1000) # doctest: +SKIP 

114 #"Step","umbrella (kJ/mol)",...,"d[umbrella]/d[psi0] (kJ/(mol rad))" 

115 100,11.26...,40.371... 

116 200,7.463...,27.910... 

117 300,2.558...,-12.74... 

118 400,6.199...,3.9768... 

119 500,8.827...,41.878... 

120 600,3.761...,25.262... 

121 700,3.388...,25.342... 

122 800,1.071...,11.349... 

123 900,8.586...,37.380... 

124 1000,5.84...,31.159... 

125 """ 

126 

127 def __init__( 

128 self, 

129 file: t.Union[str, io.TextIOBase], 

130 reportInterval: int, 

131 writers: t.Sequence[CustomWriter] = (), 

132 **kwargs, 

133 ) -> None: 

134 super().__init__(file, reportInterval, **kwargs) 

135 if not all(isinstance(w, CustomWriter) for w in writers): 

136 raise TypeError("All items in writers must satisfy the Writer protocol") 

137 self._writers = writers 

138 self._back_steps = sum([self._speed, self._elapsedTime, self._remainingTime]) 

139 

140 def _expand(self, sequence: list, addition: t.Iterable) -> list: 

141 pos = len(sequence) - self._back_steps 

142 return sum(addition, sequence[:pos]) + sequence[pos:] 

143 

144 def _initializeConstants(self, simulation: mmapp.Simulation) -> None: 

145 super()._initializeConstants(simulation) 

146 for writer in self._writers: 

147 if hasattr(writer, "initialize"): 

148 writer.initialize(simulation) 

149 

150 def _constructHeaders(self) -> t.List[str]: 

151 return self._expand( 

152 super()._constructHeaders(), 

153 (w.getHeaders() for w in self._writers), 

154 ) 

155 

156 def _constructReportValues( 

157 self, simulation: mmapp.Simulation, state: openmm.State 

158 ) -> t.List[float]: 

159 return self._expand( 

160 super()._constructReportValues(simulation, state), 

161 (w.getValues(simulation) for w in self._writers), 

162 )