OpenStructure
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
ligand_scoring_scrmsd.py
Go to the documentation of this file.
1 import numpy as np
2 
3 from ost import LogWarning, LogScript, LogInfo, LogVerbose
4 from ost import geom
5 from ost import mol
6 
7 from ost.mol.alg import ligand_scoring_base
8 
9 
11  """ :class:`LigandScorer` implementing symmetry corrected RMSD (BiSyRMSD).
12 
13  :class:`SCRMSDScorer` computes a score for a specific pair of target/model
14  ligands.
15 
16  The returned RMSD is based on a binding site superposition.
17  The binding site of the target structure is defined as all residues with at
18  least one atom within `bs_radius` around the target ligand.
19  It only contains protein and nucleic acid residues from chains that
20  pass the criteria for the
21  :class:`chain mapping <ost.mol.alg.chain_mapping>`. This means ignoring
22  other ligands, waters, short polymers as well as any incorrectly connected
23  chains that may be in proximity.
24  The respective model binding site for superposition is identified by
25  naively enumerating all possible mappings of model chains onto their
26  chemically equivalent target counterparts from the target binding site.
27  The `binding_sites_topn` with respect to lDDT score are evaluated and
28  an RMSD is computed.
29  You can either try to map ALL model chains onto the target binding site by
30  enabling `full_bs_search` or restrict the model chains for a specific
31  target/model ligand pair to the chains with at least one atom within
32  *model_bs_radius* around the model ligand. The latter can be significantly
33  faster in case of large complexes.
34  Symmetry correction is achieved by simply computing an RMSD value for
35  each symmetry, i.e. atom-atom assignments of the ligand as given by
36  :class:`LigandScorer`. The lowest RMSD value is returned.
37 
38  Populates :attr:`LigandScorer.aux_data` with following :class:`dict` keys:
39 
40  * rmsd: The BiSyRMSD score
41  * lddt_lp: lDDT of the binding pocket used for superposition (lDDT-LP)
42  * bs_ref_res: :class:`list` of binding site residues in target
43  * bs_ref_res_mapped: :class:`list` of target binding site residues that
44  are mapped to model
45  * bs_mdl_res_mapped: :class:`list` of same length with respective model
46  residues
47  * bb_rmsd: Backbone RMSD (CA, C3' for nucleotides; full backbone for
48  binding sites with fewer than 3 residues) for mapped binding site
49  residues after superposition
50  * target_ligand: The actual target ligand for which the score was computed
51  * model_ligand: The actual model ligand for which the score was computed
52  * chain_mapping: :class:`dict` with a chain mapping of chains involved in
53  binding site - key: trg chain name, value: mdl chain name
54  * transform: :class:`geom.Mat4` to transform model binding site onto target
55  binding site
56  * inconsistent_residues: :class:`list` of :class:`tuple` representing
57  residues with inconsistent residue names upon mapping (which is given by
58  bs_ref_res_mapped and bs_mdl_res_mapped). Tuples have two elements:
59  1) trg residue 2) mdl residue
60 
61  :param model: Passed to parent constructor - see :class:`LigandScorer`.
62  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
63  :param target: Passed to parent constructor - see :class:`LigandScorer`.
64  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
65  :param model_ligands: Passed to parent constructor - see
66  :class:`LigandScorer`.
67  :type model_ligands: :class:`list`
68  :param target_ligands: Passed to parent constructor - see
69  :class:`LigandScorer`.
70  :type target_ligands: :class:`list`
71  :param resnum_alignments: Passed to parent constructor - see
72  :class:`LigandScorer`.
73  :type resnum_alignments: :class:`bool`
74  :param rename_ligand_chain: Passed to parent constructor - see
75  :class:`LigandScorer`.
76  :type rename_ligand_chain: :class:`bool`
77  :param substructure_match: Passed to parent constructor - see
78  :class:`LigandScorer`.
79  :type substructure_match: :class:`bool`
80  :param coverage_delta: Passed to parent constructor - see
81  :class:`LigandScorer`.
82  :type coverage_delta: :class:`float`
83  :param max_symmetries: Passed to parent constructor - see
84  :class:`LigandScorer`.
85  :type max_symmetries: :class:`int`
86  :param bs_radius: Inclusion radius for the binding site. Residues with
87  atoms within this distance of the ligand will be considered
88  for inclusion in the binding site.
89  :type bs_radius: :class:`float`
90  :param lddt_lp_radius: lDDT inclusion radius for lDDT-LP.
91  :type lddt_lp_radius: :class:`float`
92  :param model_bs_radius: inclusion radius for model binding sites.
93  Only used when full_bs_search=False, otherwise the
94  radius is effectively infinite. Only chains with
95  atoms within this distance of a model ligand will
96  be considered in the chain mapping.
97  :type model_bs_radius: :class:`float`
98  :param binding_sites_topn: maximum number of model binding site
99  representations to assess per target binding
100  site.
101  :type binding_sites_topn: :class:`int`
102  :param full_bs_search: If True, all potential binding sites in the model
103  are searched for each target binding site. If False,
104  the search space in the model is reduced to chains
105  around (`model_bs_radius` Å) model ligands.
106  This speeds up computations, but may result in
107  ligands not being scored if the predicted ligand
108  pose is too far from the actual binding site.
109  :type full_bs_search: :class:`bool`
110  """
111  def __init__(self, model, target, model_ligands, target_ligands,
112  resnum_alignments=False, rename_ligand_chain=False,
113  substructure_match=False, coverage_delta=0.2,
114  max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0,
115  model_bs_radius=25, binding_sites_topn=100000,
116  full_bs_search=False):
117 
118  super().__init__(model, target, model_ligands, target_ligands,
119  resnum_alignments = resnum_alignments,
120  rename_ligand_chain = rename_ligand_chain,
121  substructure_match = substructure_match,
122  coverage_delta = coverage_delta,
123  max_symmetries = max_symmetries)
124 
125  self.bs_radiusbs_radius = bs_radius
126  self.lddt_lp_radiuslddt_lp_radius = lddt_lp_radius
127  self.model_bs_radiusmodel_bs_radius = model_bs_radius
128  self.binding_sites_topnbinding_sites_topn = binding_sites_topn
129  self.full_bs_searchfull_bs_search = full_bs_search
130 
131  # Residues that are in contact with a ligand => binding site
132  # defined as all residues with at least one atom within self.radius
133  # key: ligand.handle.hash_code, value: EntityView of whatever
134  # entity ligand belongs to
135  self._binding_sites_binding_sites = dict()
136 
137  # cache for GetRepr chain mapping calls
138  self._repr_repr = dict()
139 
140  # lazily precomputed variables to speedup GetRepr chain mapping calls
141  # for localized GetRepr searches
142  self.__chem_mapping__chem_mapping = None
143  self.__chem_group_alns__chem_group_alns = None
144  self.__ref_mdl_alns__ref_mdl_alns = None
145  self.__chain_mapping_mdl__chain_mapping_mdl = None
146  self._get_repr_input_get_repr_input = dict()
147 
148  # update state decoding from parent with subclass specific stuff
149  self.state_decodingstate_decoding[10] = ("target_binding_site",
150  "No residues were in proximity of the "
151  "target ligand.")
152  self.state_decodingstate_decoding[11] = ("model_binding_site", "Binding site was not"
153  " found in the model, i.e. the binding site"
154  " was not modeled or the model ligand was "
155  "positioned too far in combination with "
156  "full_bs_search=False.")
157  self.state_decodingstate_decoding[20] = ("unknown",
158  "Unknown error occured in SCRMSDScorer")
159 
160  def _compute(self, symmetries, target_ligand, model_ligand):
161  """ Implements interface from parent
162  """
163  # set default to invalid scores
164  best_rmsd_result = {"rmsd": None,
165  "lddt_lp": None,
166  "bs_ref_res": list(),
167  "bs_ref_res_mapped": list(),
168  "bs_mdl_res_mapped": list(),
169  "bb_rmsd": None,
170  "target_ligand": target_ligand,
171  "model_ligand": model_ligand,
172  "chain_mapping": dict(),
173  "transform": geom.Mat4(),
174  "inconsistent_residues": list()}
175 
176  representations = self._get_repr_get_repr(target_ligand, model_ligand)
177  # This step can be slow so give some hints in logs
178  msg = "Computing BiSyRMSD with %d chain mappings" % len(representations)
179  (LogWarning if len(representations) > 10000 else LogInfo)(msg)
180 
181  for r in representations:
182  rmsd = _SCRMSD_symmetries(symmetries, model_ligand,
183  target_ligand, transformation=r.transform)
184 
185  if best_rmsd_result["rmsd"] is None or \
186  rmsd < best_rmsd_result["rmsd"]:
187  best_rmsd_result = {"rmsd": rmsd,
188  "lddt_lp": r.lDDT,
189  "bs_ref_res": r.substructure.residues,
190  "bs_ref_res_mapped": r.ref_residues,
191  "bs_mdl_res_mapped": r.mdl_residues,
192  "bb_rmsd": r.bb_rmsd,
193  "target_ligand": target_ligand,
194  "model_ligand": model_ligand,
195  "chain_mapping": r.GetFlatChainMapping(),
196  "transform": r.transform,
197  "inconsistent_residues":
198  r.inconsistent_residues}
199 
200  target_ligand_state = 0
201  model_ligand_state = 0
202  pair_state = 0
203 
204  if best_rmsd_result["rmsd"] is not None:
205  best_rmsd = best_rmsd_result["rmsd"]
206  else:
207  # try to identify error states
208  best_rmsd = np.nan
209  error_state = 20 # unknown error
210  N = self._get_target_binding_site_get_target_binding_site(target_ligand).GetResidueCount()
211  if N == 0:
212  pair_state = 6 # binding_site
213  target_ligand_state = 10
214  elif len(representations) == 0:
215  pair_state = 11
216 
217  return (best_rmsd, pair_state, target_ligand_state, model_ligand_state,
218  best_rmsd_result)
219 
220  def _score_dir(self):
221  """ Implements interface from parent
222  """
223  return '-'
224 
225  def _get_repr(self, target_ligand, model_ligand):
226 
227  key = None
228  if self.full_bs_searchfull_bs_search:
229  # all possible binding sites, independent from actual model ligand
230  key = (target_ligand.handle.hash_code, 0)
231  else:
232  key = (target_ligand.handle.hash_code,
233  model_ligand.handle.hash_code)
234 
235  if key not in self._repr_repr:
236  ref_bs = self._get_target_binding_site_get_target_binding_site(target_ligand)
237  LogVerbose("%d chains are in proximity of the target ligand: %s" % (
238  ref_bs.chain_count, ", ".join([c.name for c in ref_bs.chains])))
239  if self.full_bs_searchfull_bs_search:
240  reprs = self._chain_mapper_chain_mapper.GetRepr(
241  ref_bs, self.modelmodel, inclusion_radius=self.lddt_lp_radiuslddt_lp_radius,
242  topn=self.binding_sites_topnbinding_sites_topn)
243  else:
244  repr_in = self._get_get_repr_input_get_get_repr_input(model_ligand)
245  radius = self.lddt_lp_radiuslddt_lp_radius
246  reprs = self._chain_mapper_chain_mapper.GetRepr(ref_bs, self.modelmodel,
247  inclusion_radius=radius,
248  topn=self.binding_sites_topnbinding_sites_topn,
249  chem_mapping_result=repr_in)
250  self._repr_repr[key] = reprs
251 
252  return self._repr_repr[key]
253 
254  def _get_target_binding_site(self, target_ligand):
255 
256  if target_ligand.handle.hash_code not in self._binding_sites_binding_sites:
257 
258  # create view of reference binding site
259  ref_residues_hashes = set() # helper to keep track of added residues
260  ignored_residue_hashes = {target_ligand.hash_code}
261  for ligand_at in target_ligand.atoms:
262  close_atoms = self.targettarget.FindWithin(ligand_at.GetPos(),
263  self.bs_radiusbs_radius)
264  for close_at in close_atoms:
265  # Skip any residue not in the chain mapping target
266  ref_res = close_at.GetResidue()
267  h = ref_res.handle.GetHashCode()
268  if h not in ref_residues_hashes and \
269  h not in ignored_residue_hashes:
270  with ligand_scoring_base._SinkVerbosityLevel(1):
271  view = self._chain_mapper_chain_mapper.target.ViewForHandle(ref_res)
272  if view.IsValid():
273  h = ref_res.handle.GetHashCode()
274  ref_residues_hashes.add(h)
275  elif ref_res.is_ligand:
276  msg = f"Ignoring ligand {ref_res.qualified_name} "
277  msg += "in binding site of "
278  msg += str(target_ligand.qualified_name)
279  LogWarning(msg)
280  ignored_residue_hashes.add(h)
281  elif ref_res.chem_type == mol.ChemType.WATERS:
282  pass # That's ok, no need to warn
283  else:
284  msg = f"Ignoring residue {ref_res.qualified_name} "
285  msg += "in binding site of "
286  msg += str(target_ligand.qualified_name)
287  LogWarning(msg)
288  ignored_residue_hashes.add(h)
289 
290  ref_bs = self.targettarget.CreateEmptyView()
291  if ref_residues_hashes:
292  # reason for doing that separately is to guarantee same ordering
293  # of residues as in underlying entity. (Reorder by ResNum seems
294  # only available on ChainHandles)
295  for ch in self.targettarget.chains:
296  for r in ch.residues:
297  if r.handle.GetHashCode() in ref_residues_hashes:
298  ref_bs.AddResidue(r, mol.ViewAddFlag.INCLUDE_ALL)
299  if len(ref_bs.residues) == 0:
300  raise RuntimeError("Failed to add proximity residues to "
301  "the reference binding site entity")
302 
303  self._binding_sites_binding_sites[target_ligand.handle.hash_code] = ref_bs
304 
305  return self._binding_sites_binding_sites[target_ligand.handle.hash_code]
306 
307  @property
308  def _chem_mapping(self):
309  if self.__chem_mapping__chem_mapping is None:
310  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
311  self.__chain_mapping_mdl__chain_mapping_mdl = \
312  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
313  return self.__chem_mapping__chem_mapping
314 
315  @property
316  def _chem_group_alns(self):
317  if self.__chem_group_alns__chem_group_alns is None:
318  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
319  self.__chain_mapping_mdl__chain_mapping_mdl = \
320  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
321  return self.__chem_group_alns__chem_group_alns
322 
323  @property
324  def _ref_mdl_alns(self):
325  if self.__ref_mdl_alns__ref_mdl_alns is None:
326  self.__ref_mdl_alns__ref_mdl_alns = \
327  chain_mapping._GetRefMdlAlns(self._chain_mapper_chain_mapper.chem_groups,
328  self._chain_mapper_chain_mapper.chem_group_alignments,
329  self._chem_mapping_chem_mapping,
330  self._chem_group_alns_chem_group_alns)
331  return self.__ref_mdl_alns__ref_mdl_alns
332 
333  @property
334  def _chain_mapping_mdl(self):
335  if self.__chain_mapping_mdl__chain_mapping_mdl is None:
336  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
337  self.__chain_mapping_mdl__chain_mapping_mdl = \
338  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
339  return self.__chain_mapping_mdl__chain_mapping_mdl
340 
341  def _get_get_repr_input(self, mdl_ligand):
342  if mdl_ligand.handle.hash_code not in self._get_repr_input_get_repr_input:
343 
344  # figure out what chains in the model are in contact with the ligand
345  # that may give a non-zero contribution to lDDT in
346  # chain_mapper.GetRepr
347  radius = self.model_bs_radiusmodel_bs_radius
348  chains = set()
349  for at in mdl_ligand.atoms:
350  with ligand_scoring_base._SinkVerbosityLevel(1):
351  close_atoms = self._chain_mapping_mdl_chain_mapping_mdl.FindWithin(at.GetPos(),
352  radius)
353  for close_at in close_atoms:
354  chains.add(close_at.GetChain().GetName())
355 
356  if len(chains) > 0:
357  LogVerbose("%d chains are in proximity of the model ligand: %s" % (
358  len(chains), ", ".join(chains)))
359 
360  # the chain mapping model which only contains close chains
361  query = "cname="
362  query += ','.join([mol.QueryQuoteName(x) for x in chains])
363  mdl = self._chain_mapping_mdl_chain_mapping_mdl.Select(query)
364 
365  # chem mapping which is reduced to the respective chains
366  chem_mapping = list()
367  for m in self._chem_mapping_chem_mapping:
368  chem_mapping.append([x for x in m if x in chains])
369 
370  self._get_repr_input_get_repr_input[mdl_ligand.handle.hash_code] = \
371  (mdl, chem_mapping)
372 
373  else:
374  self._get_repr_input_get_repr_input[mdl_ligand.handle.hash_code] = \
375  (self._chain_mapping_mdl_chain_mapping_mdl.CreateEmptyView(),
376  [list() for _ in self._chem_mapping_chem_mapping])
377 
378  return (self._get_repr_input_get_repr_input[mdl_ligand.hash_code][1],
379  self._chem_group_alns_chem_group_alns,
380  self._get_repr_input_get_repr_input[mdl_ligand.hash_code][0])
381 
382 
383 def SCRMSD(model_ligand, target_ligand, transformation=geom.Mat4(),
384  substructure_match=False, max_symmetries=1e6):
385  """Calculate symmetry-corrected RMSD.
386 
387  Binding site superposition must be computed separately and passed as
388  `transformation`.
389 
390  :param model_ligand: The model ligand
391  :type model_ligand: :class:`ost.mol.ResidueHandle` or
392  :class:`ost.mol.ResidueView`
393  :param target_ligand: The target ligand
394  :type target_ligand: :class:`ost.mol.ResidueHandle` or
395  :class:`ost.mol.ResidueView`
396  :param transformation: Optional transformation to apply on each atom
397  position of model_ligand.
398  :type transformation: :class:`ost.geom.Mat4`
399  :param substructure_match: Set this to True to allow partial target
400  ligand.
401  :type substructure_match: :class:`bool`
402  :param max_symmetries: If more than that many isomorphisms exist, raise
403  a :class:`TooManySymmetriesError`. This can only be assessed by
404  generating at least that many isomorphisms and can take some time.
405  :type max_symmetries: :class:`int`
406  :rtype: :class:`float`
407  :raises: :class:`ost.mol.alg.ligand_scoring_base.NoSymmetryError` when no
408  symmetry can be found,
409  :class:`ost.mol.alg.ligand_scoring_base.DisconnectedGraphError`
410  when ligand graph is disconnected,
411  :class:`ost.mol.alg.ligand_scoring_base.TooManySymmetriesError`
412  when more than *max_symmetries* isomorphisms are found.
413  """
414 
415  symmetries = ligand_scoring_base.ComputeSymmetries(model_ligand,
416  target_ligand,
417  substructure_match=substructure_match,
418  by_atom_index=True,
419  max_symmetries=max_symmetries)
420  return _SCRMSD_symmetries(symmetries, model_ligand, target_ligand,
421  transformation)
422 
423 
424 def _SCRMSD_symmetries(symmetries, model_ligand, target_ligand,
425  transformation):
426  """Compute SCRMSD with pre-computed symmetries. Internal. """
427 
428  # setup numpy positions for model ligand and apply transformation
429  mdl_ligand_pos = np.ones((model_ligand.GetAtomCount(), 4))
430  for a_idx, a in enumerate(model_ligand.atoms):
431  p = a.GetPos()
432  mdl_ligand_pos[a_idx, 0] = p[0]
433  mdl_ligand_pos[a_idx, 1] = p[1]
434  mdl_ligand_pos[a_idx, 2] = p[2]
435  np_transformation = np.zeros((4,4))
436  for i in range(4):
437  for j in range(4):
438  np_transformation[i,j] = transformation[i,j]
439  mdl_ligand_pos = mdl_ligand_pos.dot(np_transformation.T)[:,:3]
440 
441  # setup numpy positions for target ligand
442  trg_ligand_pos = np.zeros((target_ligand.GetAtomCount(), 3))
443  for a_idx, a in enumerate(target_ligand.atoms):
444  p = a.GetPos()
445  trg_ligand_pos[a_idx, 0] = p[0]
446  trg_ligand_pos[a_idx, 1] = p[1]
447  trg_ligand_pos[a_idx, 2] = p[2]
448 
449  # position matrices to iterate symmetries
450  # there is a guarantee that
451  # target_ligand.GetAtomCount() <= model_ligand.GetAtomCount()
452  # and that each target ligand atom is part of every symmetry
453  # => target_ligand.GetAtomCount() is size of both position matrices
454  rmsd_mdl_pos = np.zeros((target_ligand.GetAtomCount(), 3))
455  rmsd_trg_pos = np.zeros((target_ligand.GetAtomCount(), 3))
456 
457  # iterate symmetries and find the one with lowest RMSD
458  best_rmsd = np.inf
459  for i, (trg_sym, mdl_sym) in enumerate(symmetries):
460  for idx, (mdl_anum, trg_anum) in enumerate(zip(mdl_sym, trg_sym)):
461  rmsd_mdl_pos[idx,:] = mdl_ligand_pos[mdl_anum, :]
462  rmsd_trg_pos[idx,:] = trg_ligand_pos[trg_anum, :]
463  rmsd = np.sqrt(((rmsd_mdl_pos - rmsd_trg_pos)**2).sum(-1).mean())
464  if rmsd < best_rmsd:
465  best_rmsd = rmsd
466 
467  return best_rmsd
468 
469 # specify public interface
470 __all__ = ('SCRMSDScorer', 'SCRMSD')
def _get_repr(self, target_ligand, model_ligand)
def __init__(self, model, target, model_ligands, target_ligands, resnum_alignments=False, rename_ligand_chain=False, substructure_match=False, coverage_delta=0.2, max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0, model_bs_radius=25, binding_sites_topn=100000, full_bs_search=False)
def SCRMSD(model_ligand, target_ligand, transformation=geom.Mat4(), substructure_match=False, max_symmetries=1e6)