Coverage for cvpack/reporting/state_data_reporter.py: 96%
24 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 16:14 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 16:14 +0000
1"""
2.. class:: StateDataReporter
3 :platform: Linux, MacOS, Windows
4 :synopsis: An extension of the OpenMM `StateDataReporter`_ class to include writers
6.. moduleauthor:: Charlles Abreu <craabreu@gmail.com>
8"""
10import io
11import typing as t
13import openmm
14from openmm import app as mmapp
16from .custom_writer import CustomWriter
19class StateDataReporter(mmapp.StateDataReporter):
20 """
21 An extended version of OpenMM's `openmm.app.StateDataReporter`_ class that includes
22 custom writers for reporting additional simulation data.
24 .. _openmm.app.StateDataReporter: http://docs.openmm.org/latest/api-python/
25 generated/openmm.app.statedatareporter.StateDataReporter.html
27 A custom writer is an object that includes the following methods:
29 1. **getHeaders**: returns a list of strings containing the headers to be added
30 to the report. It must have the following signature:
32 .. code-block::
34 def getHeaders(self) -> List[str]:
35 pass
37 2. **getValues**: returns a list of floats containing the values to be added to
38 the report at a given time step. It must have the following signature:
40 .. code-block::
42 def getValues(self, simulation: openmm.app.Simulation) -> List[float]:
43 pass
45 3. **initialize** (optional): performs any necessary setup before the first report.
46 If present, it must have the following signature:
48 .. code-block::
50 def initialize(self, simulation: openmm.app.Simulation) -> None:
51 pass
53 Parameters
54 ----------
55 file
56 The file to write to. This can be a file name or a file object.
57 reportInterval
58 The interval (in time steps) at which to report data.
59 writers
60 A sequence of custom writers.
61 **kwargs
62 Additional keyword arguments to be passed to the `StateDataReporter`_
63 constructor.
65 Examples
66 --------
67 >>> import cvpack
68 >>> import openmm
69 >>> from math import pi
70 >>> from cvpack import reporting
71 >>> from openmm import app, unit
72 >>> from sys import stdout
73 >>> from openmmtools import testsystems
74 >>> model = testsystems.AlanineDipeptideVacuum()
75 >>> phi = cvpack.Torsion(6, 8, 14, 16, name="phi")
76 >>> psi = cvpack.Torsion(8, 14, 16, 18, name="psi")
77 >>> umbrella = cvpack.MetaCollectiveVariable(
78 ... f"0.5*kappa*(min(dphi,{2*pi}-dphi)^2+min(dpsi,{2*pi}-dpsi)^2)"
79 ... "; dphi=abs(phi-phi0); dpsi=abs(psi-psi0)",
80 ... [phi, psi],
81 ... unit.kilojoules_per_mole,
82 ... name="umbrella",
83 ... kappa=100 * unit.kilojoules_per_mole/unit.radian**2,
84 ... phi0=5*pi/6 * unit.radian,
85 ... psi0=-5*pi/6 * unit.radian,
86 ... )
87 >>> reporter = reporting.StateDataReporter(
88 ... stdout,
89 ... 100,
90 ... writers=[
91 ... reporting.CVWriter(umbrella, value=True, emass=True),
92 ... reporting.MetaCVWriter(
93 ... umbrella,
94 ... values=["phi", "psi"],
95 ... emasses=["phi", "psi"],
96 ... parameters=["phi0", "psi0"],
97 ... derivatives=["phi0", "psi0"],
98 ... ),
99 ... ],
100 ... step=True,
101 ... )
102 >>> integrator = openmm.LangevinIntegrator(
103 ... 300 * unit.kelvin,
104 ... 1 / unit.picosecond,
105 ... 2 * unit.femtosecond,
106 ... )
107 >>> integrator.setRandomNumberSeed(1234)
108 >>> umbrella.addToSystem(model.system)
109 >>> simulation = app.Simulation(model.topology, model.system, integrator)
110 >>> simulation.context.setPositions(model.positions)
111 >>> simulation.context.setVelocitiesToTemperature(300 * unit.kelvin, 5678)
112 >>> simulation.reporters.append(reporter)
113 >>> simulation.step(1000) # doctest: +SKIP
114 #"Step","umbrella (kJ/mol)",...,"d[umbrella]/d[psi0] (kJ/(mol rad))"
115 100,11.26...,40.371...
116 200,7.463...,27.910...
117 300,2.558...,-12.74...
118 400,6.199...,3.9768...
119 500,8.827...,41.878...
120 600,3.761...,25.262...
121 700,3.388...,25.342...
122 800,1.071...,11.349...
123 900,8.586...,37.380...
124 1000,5.84...,31.159...
125 """
127 def __init__(
128 self,
129 file: t.Union[str, io.TextIOBase],
130 reportInterval: int,
131 writers: t.Sequence[CustomWriter] = (),
132 **kwargs,
133 ) -> None:
134 super().__init__(file, reportInterval, **kwargs)
135 if not all(isinstance(w, CustomWriter) for w in writers):
136 raise TypeError("All items in writers must satisfy the Writer protocol")
137 self._writers = writers
138 self._back_steps = sum([self._speed, self._elapsedTime, self._remainingTime])
140 def _expand(self, sequence: list, addition: t.Iterable) -> list:
141 pos = len(sequence) - self._back_steps
142 return sum(addition, sequence[:pos]) + sequence[pos:]
144 def _initializeConstants(self, simulation: mmapp.Simulation) -> None:
145 super()._initializeConstants(simulation)
146 for writer in self._writers:
147 if hasattr(writer, "initialize"):
148 writer.initialize(simulation)
150 def _constructHeaders(self) -> t.List[str]:
151 return self._expand(
152 super()._constructHeaders(),
153 (w.getHeaders() for w in self._writers),
154 )
156 def _constructReportValues(
157 self, simulation: mmapp.Simulation, state: openmm.State
158 ) -> t.List[float]:
159 return self._expand(
160 super()._constructReportValues(simulation, state),
161 (w.getValues(simulation) for w in self._writers),
162 )