Coverage for cvpack/base_path_cv.py: 97%
30 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 16:14 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 16:14 +0000
1"""
2.. class:: BasePathCV
3 :platform: Linux, MacOS, Windows
4 :synopsis: A base class for path-related collective variables
6.. classauthor:: Charlles Abreu <craabreu@gmail.com>
8"""
10import typing as t
11from collections import OrderedDict
13import openmm
15from .collective_variable import CollectiveVariable
16from .path import Metric, deviation, progress
19class BasePathCV(CollectiveVariable, openmm.CustomCVForce):
20 """
21 A base class for path-related collective variables
23 Parameters
24 ----------
25 metric
26 A measure of progress or deviation with respect to a path in CV space
27 sigma
28 The width of the Gaussian kernels
29 squared_distances
30 Expressions for the squared distance to each milestone
31 variables
32 A dictionary of collective variables used in the expressions for the squared
33 distances
34 """
36 def __init__(
37 self,
38 metric: Metric,
39 sigma: float,
40 squared_distances: t.Sequence[str],
41 variables: t.Dict[str, CollectiveVariable],
42 ) -> None:
43 n = len(squared_distances)
44 definitions = OrderedDict(
45 {f"x{i}": sqdist for i, sqdist in enumerate(squared_distances)}
46 )
47 definitions["lambda"] = 1 / (2 * sigma**2)
48 definitions["xmin0"] = "min(x0,x1)"
49 for i in range(n - 2):
50 definitions[f"xmin{i+1}"] = f"min(xmin{i},x{i+2})"
51 for i in range(n):
52 definitions[f"w{i}"] = f"exp(lambda*(xmin{n - 2}-x{i}))"
53 definitions["wsum"] = "+".join(f"w{i}" for i in range(n))
54 expressions = [f"{key}={value}" for key, value in definitions.items()]
55 if metric == progress:
56 numerator = "+".join(f"{i}*w{i}" for i in range(1, n))
57 expressions.append(f"({numerator})/({n - 1}*wsum)")
58 else:
59 expressions.append(f"xmin{n - 2}-log(wsum)/lambda")
60 super().__init__("; ".join(reversed(expressions)))
61 for name, variable in variables.items():
62 self.addCollectiveVariable(name, variable)
64 def _generateName(self, metric: Metric, name: str, kind: str) -> str:
65 if metric not in (progress, deviation):
66 raise ValueError(
67 "Invalid metric. Use 'cvpack.path.progress' or 'cvpack.path.deviation'."
68 )
69 if name is None:
70 return f"path_{metric.name}_in_{kind}_space"
71 return name