Coverage for cvpack/base_rmsd_content.py: 92%

53 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-09 16:14 +0000

1""" 

2.. class:: BaseRMSDContent 

3 :platform: Linux, MacOS, Windows 

4 :synopsis: Secondary-structure RMSD content of a sequence of residues 

5 

6.. classauthor:: Charlles Abreu <craabreu@gmail.com> 

7 

8""" 

9 

10import typing as t 

11from importlib import resources 

12 

13import numpy as np 

14import openmm 

15from openmm import app as mmapp 

16 

17from .collective_variable import CollectiveVariable 

18from .rmsd import RMSD 

19from .units import ScalarQuantity, value_in_md_units 

20 

21 

22class BaseRMSDContent(CollectiveVariable, openmm.CustomCVForce): 

23 """ 

24 Abstract class for secondary-structure RMSD content of a sequence of `n` residues. 

25 """ 

26 

27 def __init__( 

28 self, 

29 residue_blocks: t.List[int], 

30 ideal_positions: t.List[openmm.Vec3], 

31 residues: t.List[mmapp.topology.Residue], 

32 numAtoms: int, 

33 thresholdRMSD: ScalarQuantity, 

34 stepFunction: str = "(1+x^4)/(1+x^4+x^8)", 

35 normalize: bool = False, 

36 ): 

37 num_residue_blocks = self._num_residue_blocks = len(residue_blocks) 

38 if not 1 <= num_residue_blocks <= 1024: 

39 raise ValueError( 

40 f"{len(residues)} residues yield {num_residue_blocks} blocks, " 

41 "which is not between 1 and 1024" 

42 ) 

43 residue_atoms = list(map(self._getAtomList, residues)) 

44 block_atoms = [ 

45 sum([residue_atoms[index] for index in block], []) 

46 for block in residue_blocks 

47 ] 

48 

49 def get_expression(start): 

50 summands = [] 

51 definitions = [] 

52 for i in range(start, min(start + 32, num_residue_blocks)): 

53 summands.append(stepFunction.replace("x", f"x{i}")) 

54 definitions.append(f"x{i}=rmsd{i}/{value_in_md_units(thresholdRMSD)}") 

55 summation = "+".join(summands) 

56 if normalize: 

57 summation = f"({summation})/{num_residue_blocks}" 

58 return ";".join([summation] + definitions) 

59 

60 if num_residue_blocks <= 32: 

61 expression = get_expression(0) 

62 force = self 

63 else: 

64 expression = "+".join( 

65 f"chunk{i}" for i in range((num_residue_blocks + 31) // 32) 

66 ) 

67 super().__init__(expression) 

68 for index in range(num_residue_blocks): 

69 if num_residue_blocks > 32 and index % 32 == 0: 

70 force = openmm.CustomCVForce(get_expression(index)) 

71 self.addCollectiveVariable(f"chunk{index//32}", force) 

72 force.addCollectiveVariable( 

73 f"rmsd{index}", 

74 RMSD( 

75 dict(zip(block_atoms[index], ideal_positions)), 

76 block_atoms[index], 

77 numAtoms, 

78 ), 

79 ) 

80 

81 @classmethod 

82 def _loadPositions(cls, filename: str) -> t.List[openmm.Vec3]: 

83 positions = 0.1 * np.loadtxt( 

84 str(resources.files("cvpack").joinpath("data").joinpath(filename)), 

85 delimiter=",", 

86 ) 

87 return [openmm.Vec3(*position) for position in positions] 

88 

89 @staticmethod 

90 def _getAtomList(residue: mmapp.topology.Residue) -> t.List[int]: 

91 residue_atoms = {atom.name: atom.index for atom in residue.atoms()} 

92 if residue.name == "GLY": 

93 residue_atoms["CB"] = residue_atoms["HA2"] 

94 atom_list = [] 

95 for atom in ("N", "CA", "CB", "C", "O"): 

96 try: 

97 atom_list.append(residue_atoms[atom]) 

98 except KeyError as error: 

99 raise ValueError( 

100 f"Atom {atom} not found in residue {residue.name}{residue.id}" 

101 ) from error 

102 return atom_list 

103 

104 def getNumResidueBlocks(self) -> int: 

105 """ 

106 Get the number of residue blocks. 

107 

108 Returns 

109 ------- 

110 The number of residue blocks. 

111 """ 

112 return self._num_residue_blocks