Coverage for cvpack/units/units.py: 95%
56 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 16:14 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 16:14 +0000
1"""
2.. module:: units
3 :platform: Linux, MacOS, Windows
4 :synopsis: Units of measurement for CVPack.
6.. classauthor:: Charlles Abreu <craabreu@gmail.com>
8"""
10import ast
11import typing as t
12from numbers import Real
14import numpy as np
15import openmm
16from openmm import unit as mmunit
18from ..serialization import Serializable
20ScalarQuantity = t.Union[mmunit.Quantity, Real]
21VectorQuantity = t.Union[mmunit.Quantity, np.ndarray, openmm.Vec3]
22MatrixQuantity = t.Union[
23 mmunit.Quantity, np.ndarray, t.Sequence[openmm.Vec3], t.Sequence[np.ndarray]
24]
27class Unit(mmunit.Unit, Serializable, ast.NodeTransformer):
28 """
29 Extension of the OpenMM Unit class to allow serialization and deserialization.
31 Parameters
32 ----------
33 data
34 The data to be used to create the unit.
35 """
37 def __init__(self, data: t.Union[str, mmunit.Unit, dict]) -> None:
38 if isinstance(data, str):
39 expression = self.visit(ast.parse(data, mode="eval"))
40 code = compile(ast.fix_missing_locations(expression), "", mode="eval")
41 data = eval(code) # pylint: disable=eval-used
42 if isinstance(data, mmunit.Unit):
43 data = dict(data.iter_base_or_scaled_units())
44 super().__init__(data)
46 def __repr__(self) -> str:
47 return self.get_symbol()
49 def __getstate__(self) -> t.Dict[str, str]:
50 return {"data": str(self)}
52 def __setstate__(self, keywords: t.Dict[str, str]) -> None:
53 self.__init__(keywords["data"])
55 def visit_Name( # pylint: disable=invalid-name
56 self, node: ast.Name
57 ) -> ast.Attribute:
58 """
59 Visit a Name node and transform it into an Attribute node.
61 Parameters
62 ----------
63 node
64 The node to be visited and transformed.
65 """
66 return ast.Attribute(
67 value=ast.Name(id="mmunit", ctx=ast.Load()), attr=node.id, ctx=ast.Load()
68 )
71Unit.registerTag("!cvpack.Unit")
74class Quantity(mmunit.Quantity, Serializable):
75 """
76 Extension of the OpenMM Quantity class to allow serialization and deserialization.
77 """
79 def __init__(self, *args: t.Any) -> None:
80 if len(args) == 1 and mmunit.is_quantity(args[0]):
81 super().__init__(args[0].value_in_unit(args[0].unit), Unit(args[0].unit))
82 else:
83 super().__init__(*args)
85 def __repr__(self):
86 return str(self)
88 def __getstate__(self) -> t.Dict[str, t.Any]:
89 return {"value": self.value, "unit": str(self.unit)}
91 def __setstate__(self, keywords: t.Dict[str, t.Any]) -> None:
92 self.__init__(keywords["value"], Unit(keywords["unit"]))
94 @property
95 def value(self) -> t.Any:
96 """The value of the quantity."""
97 return self._value
99 def in_md_units(self) -> t.Any: # pylint: disable=invalid-name
100 """The value of the quantity in MD units."""
101 return self.value_in_unit_system(mmunit.md_unit_system)
104Quantity.registerTag("!cvpack.Quantity")
107def in_md_units(
108 quantity: t.Union[ScalarQuantity, VectorQuantity, MatrixQuantity]
109) -> Quantity:
110 """
111 Return a quantity in the MD unit system (e.g. mass in Da, distance in
112 nm, time in ps, temperature in K, energy in kJ/mol, angle in rad).
114 Parameters
115 ----------
116 quantity
117 The quantity to be converted.
119 Returns
120 -------
121 Quantity
122 The quantity in the MD unit system.
123 """
124 if mmunit.is_quantity(quantity):
125 unit_in_md_system = quantity.unit.in_unit_system(mmunit.md_unit_system)
126 if 1 * quantity.unit / unit_in_md_system == 1:
127 return Quantity(quantity)
128 return Quantity(quantity.in_units_of(unit_in_md_system))
129 return quantity * Unit("dimensionless")
132def value_in_md_units(
133 quantity: t.Union[ScalarQuantity, VectorQuantity, MatrixQuantity]
134) -> t.Any:
135 """
136 Return the value of a quantity in the MD unit system (e.g. mass in Da, distance in
137 nm, time in ps, temperature in K, energy in kJ/mol, angle in rad).
139 Parameters
140 ----------
141 quantity
142 The quantity to be converted.
144 Returns
145 -------
146 Any
147 The value of the quantity in the MD unit system.
149 """
150 if mmunit.is_quantity(quantity):
151 return quantity.value_in_unit_system(mmunit.md_unit_system)
152 return quantity