Coverage for cvpack/base_path_cv.py: 97%

30 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-09 16:14 +0000

1""" 

2.. class:: BasePathCV 

3 :platform: Linux, MacOS, Windows 

4 :synopsis: A base class for path-related collective variables 

5 

6.. classauthor:: Charlles Abreu <craabreu@gmail.com> 

7 

8""" 

9 

10import typing as t 

11from collections import OrderedDict 

12 

13import openmm 

14 

15from .collective_variable import CollectiveVariable 

16from .path import Metric, deviation, progress 

17 

18 

19class BasePathCV(CollectiveVariable, openmm.CustomCVForce): 

20 """ 

21 A base class for path-related collective variables 

22 

23 Parameters 

24 ---------- 

25 metric 

26 A measure of progress or deviation with respect to a path in CV space 

27 sigma 

28 The width of the Gaussian kernels 

29 squared_distances 

30 Expressions for the squared distance to each milestone 

31 variables 

32 A dictionary of collective variables used in the expressions for the squared 

33 distances 

34 """ 

35 

36 def __init__( 

37 self, 

38 metric: Metric, 

39 sigma: float, 

40 squared_distances: t.Sequence[str], 

41 variables: t.Dict[str, CollectiveVariable], 

42 ) -> None: 

43 n = len(squared_distances) 

44 definitions = OrderedDict( 

45 {f"x{i}": sqdist for i, sqdist in enumerate(squared_distances)} 

46 ) 

47 definitions["lambda"] = 1 / (2 * sigma**2) 

48 definitions["xmin0"] = "min(x0,x1)" 

49 for i in range(n - 2): 

50 definitions[f"xmin{i+1}"] = f"min(xmin{i},x{i+2})" 

51 for i in range(n): 

52 definitions[f"w{i}"] = f"exp(lambda*(xmin{n - 2}-x{i}))" 

53 definitions["wsum"] = "+".join(f"w{i}" for i in range(n)) 

54 expressions = [f"{key}={value}" for key, value in definitions.items()] 

55 if metric == progress: 

56 numerator = "+".join(f"{i}*w{i}" for i in range(1, n)) 

57 expressions.append(f"({numerator})/({n - 1}*wsum)") 

58 else: 

59 expressions.append(f"xmin{n - 2}-log(wsum)/lambda") 

60 super().__init__("; ".join(reversed(expressions))) 

61 for name, variable in variables.items(): 

62 self.addCollectiveVariable(name, variable) 

63 

64 def _generateName(self, metric: Metric, name: str, kind: str) -> str: 

65 if metric not in (progress, deviation): 

66 raise ValueError( 

67 "Invalid metric. Use 'cvpack.path.progress' or 'cvpack.path.deviation'." 

68 ) 

69 if name is None: 

70 return f"path_{metric.name}_in_{kind}_space" 

71 return name