Coverage for cvpack/base_custom_function.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-09 16:14 +0000

1""" 

2.. class:: BaseCustomFunction 

3 :platform: Linux, MacOS, Windows 

4 :synopsis: Abstract class for collective variables defined by a custom function 

5 

6.. classauthor:: Charlles Abreu <craabreu@gmail.com> 

7 

8""" 

9 

10import typing as t 

11 

12from openmm import unit as mmunit 

13 

14from .collective_variable import CollectiveVariable 

15from .units import ScalarQuantity, VectorQuantity, value_in_md_units 

16 

17 

18class BaseCustomFunction(CollectiveVariable): 

19 """ 

20 Abstract class for collective variables defined by a custom function. 

21 """ 

22 

23 def _extractParameters( 

24 self, 

25 size: int, 

26 **parameters: t.Union[ScalarQuantity, VectorQuantity], 

27 ) -> t.Tuple[t.Dict[str, ScalarQuantity], t.Dict[str, t.List[ScalarQuantity]]]: 

28 global_parameters = {} 

29 perbond_parameters = {} 

30 for name, data in parameters.items(): 

31 data = value_in_md_units(data) 

32 if isinstance(data, t.Iterable): 

33 perbond_parameters[name] = [data[i] for i in range(size)] 

34 else: 

35 global_parameters[name] = data 

36 return global_parameters, perbond_parameters 

37 

38 def _addParameters( 

39 self, 

40 overalls: t.Dict[str, ScalarQuantity], 

41 perbonds: t.Dict[str, t.List[ScalarQuantity]], 

42 groups: t.List[t.Tuple[int, ...]], 

43 pbc: bool, 

44 unit: mmunit.Unit, 

45 ) -> None: 

46 # pylint: disable=no-member 

47 for name, value in overalls.items(): 

48 self.addGlobalParameter(name, value) 

49 for name in perbonds: 

50 self.addPerBondParameter(name) 

51 for group, *values in zip(groups, *perbonds.values()): 

52 self.addBond(group, values) 

53 self.setUsesPeriodicBoundaryConditions(pbc) 

54 if (1 * unit).value_in_unit_system(mmunit.md_unit_system) != 1: 

55 raise ValueError(f"Unit {unit} is not compatible with the MD unit system.")