OpenStructure
ligand_scoring_scrmsd.py
Go to the documentation of this file.
1 import numpy as np
2 
3 from ost import LogWarning
4 from ost import geom
5 from ost import mol
6 
7 from ost.mol.alg import ligand_scoring_base
8 
10  """ :class:`LigandScorer` implementing symmetry corrected RMSD.
11 
12  :class:`SCRMSDScorer` computes a score for a specific pair of target/model
13  ligands.
14 
15  The returned RMSD is based on a binding site superposition.
16  The binding site of the target structure is defined as all residues with at
17  least one atom within *bs_radius* around the target ligand.
18  It only contains protein and nucleic acid residues from chains that
19  pass the criteria for the
20  :class:`chain mapping <ost.mol.alg.chain_mapping>`. This means ignoring
21  other ligands, waters, short polymers as well as any incorrectly connected
22  chains that may be in proximity.
23  The respective model binding site for superposition is identified by
24  naively enumerating all possible mappings of model chains onto their
25  chemically equivalent target counterparts from the target binding site.
26  The *binding_sites_topn* with respect to lDDT score are evaluated and
27  an RMSD is computed.
28  You can either try to map ALL model chains onto the target binding site by
29  enabling *full_bs_search* or restrict the model chains for a specific
30  target/model ligand pair to the chains with at least one atom within
31  *model_bs_radius* around the model ligand. The latter can be significantly
32  faster in case of large complexes.
33  Symmetry correction is achieved by simply computing an RMSD value for
34  each symmetry, i.e. atom-atom assignments of the ligand as given by
35  :class:`LigandScorer`. The lowest RMSD value is returned
36 
37  Populates :attr:`LigandScorer.aux_data` with following :class:`dict` keys:
38 
39  * rmsd: The score
40  * lddt_lp: lDDT of the binding site used for superposition
41  * bs_ref_res: :class:`list` of binding site residues in target
42  * bs_ref_res_mapped: :class:`list` of target binding site residues that
43  are mapped to model
44  * bs_mdl_res_mapped: :class:`list` of same length with respective model
45  residues
46  * bb_rmsd: Backbone RMSD (CA, C3' for nucleotides) for mapped residues
47  given transform
48  * target_ligand: The actual target ligand for which the score was computed
49  * model_ligand: The actual model ligand for which the score was computed
50  * chain_mapping: :class:`dict` with a chain mapping of chains involved in
51  binding site - key: trg chain name, value: mdl chain name
52  * transform: :class:`geom.Mat4` to transform model binding site onto target
53  binding site
54  * inconsistent_residues: :class:`list` of :class:`tuple` representing
55  residues with inconsistent residue names upon mapping (which is given by
56  bs_ref_res_mapped and bs_mdl_res_mapped). Tuples have two elements:
57  1) trg residue 2) mdl residue
58 
59  :param model: Passed to parent constructor - see :class:`LigandScorer`.
60  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
61  :param target: Passed to parent constructor - see :class:`LigandScorer`.
62  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
63  :param model_ligands: Passed to parent constructor - see
64  :class:`LigandScorer`.
65  :type model_ligands: :class:`list`
66  :param target_ligands: Passed to parent constructor - see
67  :class:`LigandScorer`.
68  :type target_ligands: :class:`list`
69  :param resnum_alignments: Passed to parent constructor - see
70  :class:`LigandScorer`.
71  :type resnum_alignments: :class:`bool`
72  :param rename_ligand_chain: Passed to parent constructor - see
73  :class:`LigandScorer`.
74  :type rename_ligand_chain: :class:`bool`
75  :param substructure_match: Passed to parent constructor - see
76  :class:`LigandScorer`.
77  :type substructure_match: :class:`bool`
78  :param coverage_delta: Passed to parent constructor - see
79  :class:`LigandScorer`.
80  :type coverage_delta: :class:`float`
81  :param max_symmetries: Passed to parent constructor - see
82  :class:`LigandScorer`.
83  :type max_symmetries: :class:`int`
84  :param bs_radius: Inclusion radius for the binding site. Residues with
85  atoms within this distance of the ligand will be considered
86  for inclusion in the binding site.
87  :type bs_radius: :class:`float`
88  :param lddt_lp_radius: lDDT inclusion radius for lDDT-LP.
89  :type lddt_lp_radius: :class:`float`
90  :param model_bs_radius: inclusion radius for model binding sites.
91  Only used when full_bs_search=False, otherwise the
92  radius is effectively infinite. Only chains with
93  atoms within this distance of a model ligand will
94  be considered in the chain mapping.
95  :type model_bs_radius: :class:`float`
96  :param binding_sites_topn: maximum number of model binding site
97  representations to assess per target binding
98  site.
99  :type binding_sites_topn: :class:`int`
100  :param full_bs_search: If True, all potential binding sites in the model
101  are searched for each target binding site. If False,
102  the search space in the model is reduced to chains
103  around (`model_bs_radius` Å) model ligands.
104  This speeds up computations, but may result in
105  ligands not being scored if the predicted ligand
106  pose is too far from the actual binding site.
107  :type full_bs_search: :class:`bool`
108  """
109  def __init__(self, model, target, model_ligands=None, target_ligands=None,
110  resnum_alignments=False, rename_ligand_chain=False,
111  substructure_match=False, coverage_delta=0.2,
112  max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0,
113  model_bs_radius=25, binding_sites_topn=100000,
114  full_bs_search=False):
115 
116  super().__init__(model, target, model_ligands = model_ligands,
117  target_ligands = target_ligands,
118  resnum_alignments = resnum_alignments,
119  rename_ligand_chain = rename_ligand_chain,
120  substructure_match = substructure_match,
121  coverage_delta = coverage_delta,
122  max_symmetries = max_symmetries)
123 
124  self.bs_radiusbs_radius = bs_radius
125  self.lddt_lp_radiuslddt_lp_radius = lddt_lp_radius
126  self.model_bs_radiusmodel_bs_radius = model_bs_radius
127  self.binding_sites_topnbinding_sites_topn = binding_sites_topn
128  self.full_bs_searchfull_bs_search = full_bs_search
129 
130  # Residues that are in contact with a ligand => binding site
131  # defined as all residues with at least one atom within self.radius
132  # key: ligand.handle.hash_code, value: EntityView of whatever
133  # entity ligand belongs to
134  self._binding_sites_binding_sites = dict()
135 
136  # cache for GetRepr chain mapping calls
137  self._repr_repr = dict()
138 
139  # lazily precomputed variables to speedup GetRepr chain mapping calls
140  # for localized GetRepr searches
141  self.__chem_mapping__chem_mapping = None
142  self.__chem_group_alns__chem_group_alns = None
143  self.__ref_mdl_alns__ref_mdl_alns = None
144  self.__chain_mapping_mdl__chain_mapping_mdl = None
145  self._get_repr_input_get_repr_input = dict()
146 
147  # update state decoding from parent with subclass specific stuff
148  self.state_decodingstate_decoding[10] = ("binding_site",
149  "No residues were in proximity of the "
150  "target ligand.")
151  self.state_decodingstate_decoding[11] = ("model_representation", "No representation "
152  "of the reference binding site was found in "
153  "the model, i.e. the binding site was not "
154  "modeled or the model ligand was positioned "
155  "too far in combination with "
156  "full_bs_search=False.")
157  self.state_decodingstate_decoding[20] = ("unknown",
158  "Unknown error occured in SCRMSDScorer")
159 
160  def _compute(self, symmetries, target_ligand, model_ligand):
161  """ Implements interface from parent
162  """
163  # set default to invalid scores
164  best_rmsd_result = {"rmsd": None,
165  "lddt_lp": None,
166  "bs_ref_res": list(),
167  "bs_ref_res_mapped": list(),
168  "bs_mdl_res_mapped": list(),
169  "bb_rmsd": None,
170  "target_ligand": target_ligand,
171  "model_ligand": model_ligand,
172  "chain_mapping": dict(),
173  "transform": geom.Mat4(),
174  "inconsistent_residues": list()}
175 
176  representations = self._get_repr_get_repr(target_ligand, model_ligand)
177 
178  for r in representations:
179  rmsd = _SCRMSD_symmetries(symmetries, model_ligand,
180  target_ligand, transformation=r.transform)
181 
182  if best_rmsd_result["rmsd"] is None or \
183  rmsd < best_rmsd_result["rmsd"]:
184  best_rmsd_result = {"rmsd": rmsd,
185  "lddt_lp": r.lDDT,
186  "bs_ref_res": r.substructure.residues,
187  "bs_ref_res_mapped": r.ref_residues,
188  "bs_mdl_res_mapped": r.mdl_residues,
189  "bb_rmsd": r.bb_rmsd,
190  "target_ligand": target_ligand,
191  "model_ligand": model_ligand,
192  "chain_mapping": r.GetFlatChainMapping(),
193  "transform": r.transform,
194  "inconsistent_residues":
195  r.inconsistent_residues}
196 
197  target_ligand_state = 0
198  model_ligand_state = 0
199  pair_state = 0
200 
201  if best_rmsd_result["rmsd"] is not None:
202  best_rmsd = best_rmsd_result["rmsd"]
203  else:
204  # try to identify error states
205  best_rmsd = np.nan
206  error_state = 20 # unknown error
207  N = self._get_target_binding_site_get_target_binding_site(target_ligand).GetResidueCount()
208  if N == 0:
209  pair_state = 6 # binding_site
210  target_ligand_state = 10
211  elif len(representations) == 0:
212  pair_state = 11
213 
214  return (best_rmsd, pair_state, target_ligand_state, model_ligand_state,
215  best_rmsd_result)
216 
217  def _score_dir(self):
218  """ Implements interface from parent
219  """
220  return '-'
221 
222  def _get_repr(self, target_ligand, model_ligand):
223 
224  key = None
225  if self.full_bs_searchfull_bs_search:
226  # all possible binding sites, independent from actual model ligand
227  key = (target_ligand.handle.hash_code, 0)
228  else:
229  key = (target_ligand.handle.hash_code,
230  model_ligand.handle.hash_code)
231 
232  if key not in self._repr_repr:
233  ref_bs = self._get_target_binding_site_get_target_binding_site(target_ligand)
234  if self.full_bs_searchfull_bs_search:
235  reprs = self._chain_mapper_chain_mapper.GetRepr(
236  ref_bs, self.modelmodel, inclusion_radius=self.lddt_lp_radiuslddt_lp_radius,
237  topn=self.binding_sites_topnbinding_sites_topn)
238  else:
239  repr_in = self._get_get_repr_input_get_get_repr_input(model_ligand)
240  radius = self.lddt_lp_radiuslddt_lp_radius
241  reprs = self._chain_mapper_chain_mapper.GetRepr(ref_bs, self.modelmodel,
242  inclusion_radius=radius,
243  topn=self.binding_sites_topnbinding_sites_topn,
244  chem_mapping_result=repr_in)
245  self._repr_repr[key] = reprs
246 
247  return self._repr_repr[key]
248 
249  def _get_target_binding_site(self, target_ligand):
250 
251  if target_ligand.handle.hash_code not in self._binding_sites_binding_sites:
252 
253  # create view of reference binding site
254  ref_residues_hashes = set() # helper to keep track of added residues
255  ignored_residue_hashes = {target_ligand.hash_code}
256  for ligand_at in target_ligand.atoms:
257  close_atoms = self.targettarget.FindWithin(ligand_at.GetPos(),
258  self.bs_radiusbs_radius)
259  for close_at in close_atoms:
260  # Skip any residue not in the chain mapping target
261  ref_res = close_at.GetResidue()
262  h = ref_res.handle.GetHashCode()
263  if h not in ref_residues_hashes and \
264  h not in ignored_residue_hashes:
265  view = self._chain_mapper_chain_mapper.target.ViewForHandle(ref_res)
266  if view.IsValid():
267  h = ref_res.handle.GetHashCode()
268  ref_residues_hashes.add(h)
269  elif ref_res.is_ligand:
270  msg = f"Ignoring ligand {ref_res.qualified_name} "
271  msg += "in binding site of "
272  msg += str(target_ligand.qualified_name)
273  LogWarning(msg)
274  ignored_residue_hashes.add(h)
275  elif ref_res.chem_type == mol.ChemType.WATERS:
276  pass # That's ok, no need to warn
277  else:
278  msg = f"Ignoring residue {ref_res.qualified_name} "
279  msg += "in binding site of "
280  msg += str(target_ligand.qualified_name)
281  LogWarning(msg)
282  ignored_residue_hashes.add(h)
283 
284  ref_bs = self.targettarget.CreateEmptyView()
285  if ref_residues_hashes:
286  # reason for doing that separately is to guarantee same ordering
287  # of residues as in underlying entity. (Reorder by ResNum seems
288  # only available on ChainHandles)
289  for ch in self.targettarget.chains:
290  for r in ch.residues:
291  if r.handle.GetHashCode() in ref_residues_hashes:
292  ref_bs.AddResidue(r, mol.ViewAddFlag.INCLUDE_ALL)
293  if len(ref_bs.residues) == 0:
294  raise RuntimeError("Failed to add proximity residues to "
295  "the reference binding site entity")
296 
297  self._binding_sites_binding_sites[target_ligand.handle.hash_code] = ref_bs
298 
299  return self._binding_sites_binding_sites[target_ligand.handle.hash_code]
300 
301  @property
302  def _chem_mapping(self):
303  if self.__chem_mapping__chem_mapping is None:
304  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
305  self.__chain_mapping_mdl__chain_mapping_mdl = \
306  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
307  return self.__chem_mapping__chem_mapping
308 
309  @property
310  def _chem_group_alns(self):
311  if self.__chem_group_alns__chem_group_alns is None:
312  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
313  self.__chain_mapping_mdl__chain_mapping_mdl = \
314  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
315  return self.__chem_group_alns__chem_group_alns
316 
317  @property
318  def _ref_mdl_alns(self):
319  if self.__ref_mdl_alns__ref_mdl_alns is None:
320  self.__ref_mdl_alns__ref_mdl_alns = \
321  chain_mapping._GetRefMdlAlns(self._chain_mapper_chain_mapper.chem_groups,
322  self._chain_mapper_chain_mapper.chem_group_alignments,
323  self._chem_mapping_chem_mapping,
324  self._chem_group_alns_chem_group_alns)
325  return self.__ref_mdl_alns__ref_mdl_alns
326 
327  @property
328  def _chain_mapping_mdl(self):
329  if self.__chain_mapping_mdl__chain_mapping_mdl is None:
330  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
331  self.__chain_mapping_mdl__chain_mapping_mdl = \
332  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
333  return self.__chain_mapping_mdl__chain_mapping_mdl
334 
335  def _get_get_repr_input(self, mdl_ligand):
336  if mdl_ligand.handle.hash_code not in self._get_repr_input_get_repr_input:
337 
338  # figure out what chains in the model are in contact with the ligand
339  # that may give a non-zero contribution to lDDT in
340  # chain_mapper.GetRepr
341  radius = self.model_bs_radiusmodel_bs_radius
342  chains = set()
343  for at in mdl_ligand.atoms:
344  close_atoms = self._chain_mapping_mdl_chain_mapping_mdl.FindWithin(at.GetPos(),
345  radius)
346  for close_at in close_atoms:
347  chains.add(close_at.GetChain().GetName())
348 
349  if len(chains) > 0:
350 
351  # the chain mapping model which only contains close chains
352  query = "cname="
353  query += ','.join([mol.QueryQuoteName(x) for x in chains])
354  mdl = self._chain_mapping_mdl_chain_mapping_mdl.Select(query)
355 
356  # chem mapping which is reduced to the respective chains
357  chem_mapping = list()
358  for m in self._chem_mapping_chem_mapping:
359  chem_mapping.append([x for x in m if x in chains])
360 
361  self._get_repr_input_get_repr_input[mdl_ligand.handle.hash_code] = \
362  (mdl, chem_mapping)
363 
364  else:
365  self._get_repr_input_get_repr_input[mdl_ligand.handle.hash_code] = \
366  (self._chain_mapping_mdl_chain_mapping_mdl.CreateEmptyView(),
367  [list() for _ in self._chem_mapping_chem_mapping])
368 
369  return (self._get_repr_input_get_repr_input[mdl_ligand.hash_code][1],
370  self._chem_group_alns_chem_group_alns,
371  self._get_repr_input_get_repr_input[mdl_ligand.hash_code][0])
372 
373 
374 def SCRMSD(model_ligand, target_ligand, transformation=geom.Mat4(),
375  substructure_match=False, max_symmetries=1e6):
376  """Calculate symmetry-corrected RMSD.
377 
378  Binding site superposition must be computed separately and passed as
379  `transformation`.
380 
381  :param model_ligand: The model ligand
382  :type model_ligand: :class:`ost.mol.ResidueHandle` or
383  :class:`ost.mol.ResidueView`
384  :param target_ligand: The target ligand
385  :type target_ligand: :class:`ost.mol.ResidueHandle` or
386  :class:`ost.mol.ResidueView`
387  :param transformation: Optional transformation to apply on each atom
388  position of model_ligand.
389  :type transformation: :class:`ost.geom.Mat4`
390  :param substructure_match: Set this to True to allow partial target
391  ligand.
392  :type substructure_match: :class:`bool`
393  :param max_symmetries: If more than that many isomorphisms exist, raise
394  a :class:`TooManySymmetriesError`. This can only be assessed by
395  generating at least that many isomorphisms and can take some time.
396  :type max_symmetries: :class:`int`
397  :rtype: :class:`float`
398  :raises: :class:`ost.mol.alg.ligand_scoring_base.NoSymmetryError` when no
399  symmetry can be found,
400  :class:`ost.mol.alg.ligand_scoring_base.DisconnectedGraphError`
401  when ligand graph is disconnected,
402  :class:`ost.mol.alg.ligand_scoring_base.TooManySymmetriesError`
403  when more than *max_symmetries* isomorphisms are found.
404  """
405 
406  symmetries = ligand_scoring_base.ComputeSymmetries(model_ligand,
407  target_ligand,
408  substructure_match=substructure_match,
409  by_atom_index=True,
410  max_symmetries=max_symmetries)
411  return _SCRMSD_symmetries(symmetries, model_ligand, target_ligand,
412  transformation)
413 
414 
415 def _SCRMSD_symmetries(symmetries, model_ligand, target_ligand,
416  transformation):
417  """Compute SCRMSD with pre-computed symmetries. Internal. """
418 
419  # setup numpy positions for model ligand and apply transformation
420  mdl_ligand_pos = np.ones((model_ligand.GetAtomCount(), 4))
421  for a_idx, a in enumerate(model_ligand.atoms):
422  p = a.GetPos()
423  mdl_ligand_pos[a_idx, 0] = p[0]
424  mdl_ligand_pos[a_idx, 1] = p[1]
425  mdl_ligand_pos[a_idx, 2] = p[2]
426  np_transformation = np.zeros((4,4))
427  for i in range(4):
428  for j in range(4):
429  np_transformation[i,j] = transformation[i,j]
430  mdl_ligand_pos = mdl_ligand_pos.dot(np_transformation.T)[:,:3]
431 
432  # setup numpy positions for target ligand
433  trg_ligand_pos = np.zeros((target_ligand.GetAtomCount(), 3))
434  for a_idx, a in enumerate(target_ligand.atoms):
435  p = a.GetPos()
436  trg_ligand_pos[a_idx, 0] = p[0]
437  trg_ligand_pos[a_idx, 1] = p[1]
438  trg_ligand_pos[a_idx, 2] = p[2]
439 
440  # position matrices to iterate symmetries
441  # there is a guarantee that
442  # target_ligand.GetAtomCount() <= model_ligand.GetAtomCount()
443  # and that each target ligand atom is part of every symmetry
444  # => target_ligand.GetAtomCount() is size of both position matrices
445  rmsd_mdl_pos = np.zeros((target_ligand.GetAtomCount(), 3))
446  rmsd_trg_pos = np.zeros((target_ligand.GetAtomCount(), 3))
447 
448  # iterate symmetries and find the one with lowest RMSD
449  best_rmsd = np.inf
450  for i, (trg_sym, mdl_sym) in enumerate(symmetries):
451  for idx, (mdl_anum, trg_anum) in enumerate(zip(mdl_sym, trg_sym)):
452  rmsd_mdl_pos[idx,:] = mdl_ligand_pos[mdl_anum, :]
453  rmsd_trg_pos[idx,:] = trg_ligand_pos[trg_anum, :]
454  rmsd = np.sqrt(((rmsd_mdl_pos - rmsd_trg_pos)**2).sum(-1).mean())
455  if rmsd < best_rmsd:
456  best_rmsd = rmsd
457 
458  return best_rmsd
459 
460 # specify public interface
461 __all__ = ('SCRMSDScorer', 'SCRMSD')
def _get_repr(self, target_ligand, model_ligand)
def __init__(self, model, target, model_ligands=None, target_ligands=None, resnum_alignments=False, rename_ligand_chain=False, substructure_match=False, coverage_delta=0.2, max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0, model_bs_radius=25, binding_sites_topn=100000, full_bs_search=False)
def SCRMSD(model_ligand, target_ligand, transformation=geom.Mat4(), substructure_match=False, max_symmetries=1e6)