Coverage for cvpack/serialization/serialization.py: 93%

61 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-09 16:14 +0000

1""" 

2.. module:: serialization 

3 :platform: Linux, MacOS, Windows 

4 :synopsis: Collective Variable Serialization 

5 

6.. moduleauthor:: Charlles Abreu <craabreu@gmail.com> 

7 

8""" 

9 

10import typing as t 

11 

12import openmm 

13import yaml 

14from openmm import app as mmapp 

15 

16 

17class Serializable(yaml.YAMLObject): 

18 """ 

19 A mixin class that allows serialization and deserialization of objects with PyYAML. 

20 """ 

21 

22 @classmethod 

23 def registerTag(cls, tag: str) -> None: 

24 """ 

25 Register a class for serialization and deserialization with PyYAML. 

26 

27 Parameters 

28 ---------- 

29 tag 

30 The YAML tag to be used for this class. 

31 """ 

32 cls.yaml_tag = tag 

33 yaml.SafeDumper.add_representer(cls, cls.to_yaml) 

34 yaml.SafeLoader.add_constructor(tag, cls.from_yaml) 

35 

36 

37class SerializableAtom(Serializable): 

38 r""" 

39 A serializable version of OpenMM's Atom class. 

40 """ 

41 

42 def __init__( # pylint: disable=super-init-not-called 

43 self, atom: t.Union[mmapp.topology.Atom, "SerializableAtom"] 

44 ) -> None: 

45 self.name = atom.name 

46 self.index = atom.index 

47 if isinstance(atom, mmapp.topology.Atom): 

48 self._element_symbol = atom.element.symbol 

49 self._residue_index = atom.residue.index 

50 else: 

51 self._element_symbol = atom._element_symbol 

52 self._residue_index = atom._residue_index 

53 self.id = atom.id 

54 

55 def __getstate__(self) -> t.Dict[str, t.Any]: 

56 return self.__dict__ 

57 

58 def __setstate__(self, keywords: t.Dict[str, t.Any]) -> None: 

59 self.__dict__.update(keywords) 

60 

61 @property 

62 def element(self) -> mmapp.Element: 

63 """Return the Element of the Atom.""" 

64 return mmapp.Element.getBySymbol(self._element_symbol) 

65 

66 

67SerializableAtom.registerTag("!cvpack.Atom") 

68 

69 

70class SerializableResidue(Serializable): 

71 r"""A serializable version of OpenMM's Residue class.""" 

72 

73 def __init__( # pylint: disable=super-init-not-called 

74 self, residue: t.Union[mmapp.topology.Residue, "SerializableResidue"] 

75 ) -> None: 

76 self.name = residue.name 

77 self.index = residue.index 

78 if isinstance(residue, mmapp.topology.Residue): 

79 self._chain_index = residue.chain.index 

80 else: 

81 self._chain_index = residue._chain_index 

82 self.id = residue.id 

83 self._atoms = list(map(SerializableAtom, residue.atoms())) 

84 

85 def __getstate__(self) -> t.Dict[str, t.Any]: 

86 return self.__dict__ 

87 

88 def __setstate__(self, keywords: t.Dict[str, t.Any]) -> None: 

89 self.__dict__.update(keywords) 

90 

91 def __len__(self) -> int: 

92 return len(self._atoms) 

93 

94 def atoms(self): 

95 """Iterate over all Atoms in the Residue.""" 

96 return iter(self._atoms) 

97 

98 

99SerializableResidue.registerTag("!cvpack.Residue") 

100 

101 

102class SerializableForce(openmm.Force, Serializable): 

103 """A serializable version of OpenMM's Force class.""" 

104 

105 def __init__( # pylint: disable=super-init-not-called 

106 self, 

107 force: openmm.Force, 

108 ) -> None: 

109 self.force = force 

110 self.this = force.this 

111 

112 def __getattr__(self, name: str) -> t.Any: 

113 return getattr(self.force, name) 

114 

115 def __getstate__(self) -> t.Dict[str, str]: 

116 return {"xml_code": openmm.XmlSerializer.serialize(self)} 

117 

118 def __setstate__(self, keywords: t.Dict[str, str]) -> None: 

119 self.__init__(openmm.XmlSerializer.deserialize(keywords["xml_code"])) 

120 

121 

122SerializableForce.registerTag("!cvpack.Force") 

123 

124 

125def serialize(obj: t.Any, iostream: t.IO) -> None: 

126 """ 

127 Serializes a cvpack object. 

128 

129 Parameters 

130 ---------- 

131 obj 

132 The cvpack object to be serialized 

133 iostream 

134 A text stream in write mode 

135 

136 Example 

137 ======= 

138 >>> import cvpack 

139 >>> import io 

140 >>> from cvpack import serialization 

141 >>> radius_of_gyration = cvpack.RadiusOfGyration([0, 1, 2]) 

142 >>> iostream = io.StringIO() 

143 >>> serialization.serialize(radius_of_gyration, iostream) 

144 >>> print(iostream.getvalue()) 

145 !cvpack.RadiusOfGyration 

146 group: 

147 - 0 

148 - 1 

149 - 2 

150 name: radius_of_gyration 

151 pbc: false 

152 weighByMass: false 

153 <BLANKLINE> 

154 """ 

155 iostream.write(yaml.safe_dump(obj)) 

156 

157 

158def deserialize(iostream: t.IO) -> t.Any: 

159 """ 

160 Deserializes a cvpack object. 

161 

162 Parameters 

163 ---------- 

164 iostream 

165 A text stream in read mode containing the object to be deserialized 

166 

167 Returns 

168 ------- 

169 t.Any 

170 An instance of any cvpack class 

171 

172 Example 

173 ------- 

174 >>> import cvpack 

175 >>> import io 

176 >>> from cvpack import serialization 

177 >>> radius_of_gyration = cvpack.RadiusOfGyration([0, 1, 2]) 

178 >>> iostream = io.StringIO() 

179 >>> serialization.serialize(radius_of_gyration, iostream) 

180 >>> iostream.seek(0) 

181 0 

182 >>> new_object = serialization.deserialize(iostream) 

183 >>> type(new_object) 

184 <class 'cvpack.radius_of_gyration.RadiusOfGyration'> 

185 """ 

186 return yaml.safe_load(iostream.read())