Coverage for cvpack/serialization/serialization.py: 93%
61 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 16:14 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 16:14 +0000
1"""
2.. module:: serialization
3 :platform: Linux, MacOS, Windows
4 :synopsis: Collective Variable Serialization
6.. moduleauthor:: Charlles Abreu <craabreu@gmail.com>
8"""
10import typing as t
12import openmm
13import yaml
14from openmm import app as mmapp
17class Serializable(yaml.YAMLObject):
18 """
19 A mixin class that allows serialization and deserialization of objects with PyYAML.
20 """
22 @classmethod
23 def registerTag(cls, tag: str) -> None:
24 """
25 Register a class for serialization and deserialization with PyYAML.
27 Parameters
28 ----------
29 tag
30 The YAML tag to be used for this class.
31 """
32 cls.yaml_tag = tag
33 yaml.SafeDumper.add_representer(cls, cls.to_yaml)
34 yaml.SafeLoader.add_constructor(tag, cls.from_yaml)
37class SerializableAtom(Serializable):
38 r"""
39 A serializable version of OpenMM's Atom class.
40 """
42 def __init__( # pylint: disable=super-init-not-called
43 self, atom: t.Union[mmapp.topology.Atom, "SerializableAtom"]
44 ) -> None:
45 self.name = atom.name
46 self.index = atom.index
47 if isinstance(atom, mmapp.topology.Atom):
48 self._element_symbol = atom.element.symbol
49 self._residue_index = atom.residue.index
50 else:
51 self._element_symbol = atom._element_symbol
52 self._residue_index = atom._residue_index
53 self.id = atom.id
55 def __getstate__(self) -> t.Dict[str, t.Any]:
56 return self.__dict__
58 def __setstate__(self, keywords: t.Dict[str, t.Any]) -> None:
59 self.__dict__.update(keywords)
61 @property
62 def element(self) -> mmapp.Element:
63 """Return the Element of the Atom."""
64 return mmapp.Element.getBySymbol(self._element_symbol)
67SerializableAtom.registerTag("!cvpack.Atom")
70class SerializableResidue(Serializable):
71 r"""A serializable version of OpenMM's Residue class."""
73 def __init__( # pylint: disable=super-init-not-called
74 self, residue: t.Union[mmapp.topology.Residue, "SerializableResidue"]
75 ) -> None:
76 self.name = residue.name
77 self.index = residue.index
78 if isinstance(residue, mmapp.topology.Residue):
79 self._chain_index = residue.chain.index
80 else:
81 self._chain_index = residue._chain_index
82 self.id = residue.id
83 self._atoms = list(map(SerializableAtom, residue.atoms()))
85 def __getstate__(self) -> t.Dict[str, t.Any]:
86 return self.__dict__
88 def __setstate__(self, keywords: t.Dict[str, t.Any]) -> None:
89 self.__dict__.update(keywords)
91 def __len__(self) -> int:
92 return len(self._atoms)
94 def atoms(self):
95 """Iterate over all Atoms in the Residue."""
96 return iter(self._atoms)
99SerializableResidue.registerTag("!cvpack.Residue")
102class SerializableForce(openmm.Force, Serializable):
103 """A serializable version of OpenMM's Force class."""
105 def __init__( # pylint: disable=super-init-not-called
106 self,
107 force: openmm.Force,
108 ) -> None:
109 self.force = force
110 self.this = force.this
112 def __getattr__(self, name: str) -> t.Any:
113 return getattr(self.force, name)
115 def __getstate__(self) -> t.Dict[str, str]:
116 return {"xml_code": openmm.XmlSerializer.serialize(self)}
118 def __setstate__(self, keywords: t.Dict[str, str]) -> None:
119 self.__init__(openmm.XmlSerializer.deserialize(keywords["xml_code"]))
122SerializableForce.registerTag("!cvpack.Force")
125def serialize(obj: t.Any, iostream: t.IO) -> None:
126 """
127 Serializes a cvpack object.
129 Parameters
130 ----------
131 obj
132 The cvpack object to be serialized
133 iostream
134 A text stream in write mode
136 Example
137 =======
138 >>> import cvpack
139 >>> import io
140 >>> from cvpack import serialization
141 >>> radius_of_gyration = cvpack.RadiusOfGyration([0, 1, 2])
142 >>> iostream = io.StringIO()
143 >>> serialization.serialize(radius_of_gyration, iostream)
144 >>> print(iostream.getvalue())
145 !cvpack.RadiusOfGyration
146 group:
147 - 0
148 - 1
149 - 2
150 name: radius_of_gyration
151 pbc: false
152 weighByMass: false
153 <BLANKLINE>
154 """
155 iostream.write(yaml.safe_dump(obj))
158def deserialize(iostream: t.IO) -> t.Any:
159 """
160 Deserializes a cvpack object.
162 Parameters
163 ----------
164 iostream
165 A text stream in read mode containing the object to be deserialized
167 Returns
168 -------
169 t.Any
170 An instance of any cvpack class
172 Example
173 -------
174 >>> import cvpack
175 >>> import io
176 >>> from cvpack import serialization
177 >>> radius_of_gyration = cvpack.RadiusOfGyration([0, 1, 2])
178 >>> iostream = io.StringIO()
179 >>> serialization.serialize(radius_of_gyration, iostream)
180 >>> iostream.seek(0)
181 0
182 >>> new_object = serialization.deserialize(iostream)
183 >>> type(new_object)
184 <class 'cvpack.radius_of_gyration.RadiusOfGyration'>
185 """
186 return yaml.safe_load(iostream.read())