OpenStructure
ligand_scoring_scrmsd.py
Go to the documentation of this file.
1 import numpy as np
2 
3 from ost import LogWarning, LogScript, LogInfo, LogVerbose
4 from ost import geom
5 from ost import mol
6 
7 from ost.mol.alg import ligand_scoring_base
8 
9 
11  """ :class:`LigandScorer` implementing symmetry corrected RMSD (BiSyRMSD).
12 
13  :class:`SCRMSDScorer` computes a score for a specific pair of target/model
14  ligands.
15 
16  The returned RMSD is based on a binding site superposition.
17  The binding site of the target structure is defined as all residues with at
18  least one atom within `bs_radius` around the target ligand.
19  It only contains protein and nucleic acid residues from chains that
20  pass the criteria for the
21  :class:`chain mapping <ost.mol.alg.chain_mapping>`. This means ignoring
22  other ligands, waters, short polymers as well as any incorrectly connected
23  chains that may be in proximity.
24  The respective model binding site for superposition is identified by
25  naively enumerating all possible mappings of model chains onto their
26  chemically equivalent target counterparts from the target binding site.
27  The `binding_sites_topn` with respect to lDDT score are evaluated and
28  an RMSD is computed.
29  You can either try to map ALL model chains onto the target binding site by
30  enabling `full_bs_search` or restrict the model chains for a specific
31  target/model ligand pair to the chains with at least one atom within
32  *model_bs_radius* around the model ligand. The latter can be significantly
33  faster in case of large complexes.
34  Symmetry correction is achieved by simply computing an RMSD value for
35  each symmetry, i.e. atom-atom assignments of the ligand as given by
36  :class:`LigandScorer`. The lowest RMSD value is returned.
37 
38  Populates :attr:`LigandScorer.aux_data` with following :class:`dict` keys:
39 
40  * rmsd: The BiSyRMSD score
41  * lddt_lp: lDDT of the binding pocket used for superposition (lDDT-LP)
42  * bs_ref_res: :class:`list` of binding site residues in target
43  * bs_ref_res_mapped: :class:`list` of target binding site residues that
44  are mapped to model
45  * bs_mdl_res_mapped: :class:`list` of same length with respective model
46  residues
47  * bb_rmsd: Backbone RMSD (CA, C3' for nucleotides; full backbone for
48  binding sites with fewer than 3 residues) for mapped binding site
49  residues after superposition
50  * target_ligand: The actual target ligand for which the score was computed
51  * model_ligand: The actual model ligand for which the score was computed
52  * chain_mapping: :class:`dict` with a chain mapping of chains involved in
53  binding site - key: trg chain name, value: mdl chain name
54  * transform: :class:`geom.Mat4` to transform model binding site onto target
55  binding site
56  * inconsistent_residues: :class:`list` of :class:`tuple` representing
57  residues with inconsistent residue names upon mapping (which is given by
58  bs_ref_res_mapped and bs_mdl_res_mapped). Tuples have two elements:
59  1) trg residue 2) mdl residue
60 
61  :param model: Passed to parent constructor - see :class:`LigandScorer`.
62  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
63  :param target: Passed to parent constructor - see :class:`LigandScorer`.
64  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
65  :param model_ligands: Passed to parent constructor - see
66  :class:`LigandScorer`.
67  :type model_ligands: :class:`list`
68  :param target_ligands: Passed to parent constructor - see
69  :class:`LigandScorer`.
70  :type target_ligands: :class:`list`
71  :param resnum_alignments: Passed to parent constructor - see
72  :class:`LigandScorer`.
73  :type resnum_alignments: :class:`bool`
74  :param rename_ligand_chain: Passed to parent constructor - see
75  :class:`LigandScorer`.
76  :type rename_ligand_chain: :class:`bool`
77  :param substructure_match: Passed to parent constructor - see
78  :class:`LigandScorer`.
79  :type substructure_match: :class:`bool`
80  :param coverage_delta: Passed to parent constructor - see
81  :class:`LigandScorer`.
82  :type coverage_delta: :class:`float`
83  :param max_symmetries: Passed to parent constructor - see
84  :class:`LigandScorer`.
85  :type max_symmetries: :class:`int`
86  :param bs_radius: Inclusion radius for the binding site. Residues with
87  atoms within this distance of the ligand will be considered
88  for inclusion in the binding site.
89  :type bs_radius: :class:`float`
90  :param lddt_lp_radius: lDDT inclusion radius for lDDT-LP.
91  :type lddt_lp_radius: :class:`float`
92  :param model_bs_radius: inclusion radius for model binding sites.
93  Only used when full_bs_search=False, otherwise the
94  radius is effectively infinite. Only chains with
95  atoms within this distance of a model ligand will
96  be considered in the chain mapping.
97  :type model_bs_radius: :class:`float`
98  :param binding_sites_topn: maximum number of model binding site
99  representations to assess per target binding
100  site.
101  :type binding_sites_topn: :class:`int`
102  :param full_bs_search: If True, all potential binding sites in the model
103  are searched for each target binding site. If False,
104  the search space in the model is reduced to chains
105  around (`model_bs_radius` Å) model ligands.
106  This speeds up computations, but may result in
107  ligands not being scored if the predicted ligand
108  pose is too far from the actual binding site.
109  :type full_bs_search: :class:`bool`
110  """
111  def __init__(self, model, target, model_ligands=None, target_ligands=None,
112  resnum_alignments=False, rename_ligand_chain=False,
113  substructure_match=False, coverage_delta=0.2,
114  max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0,
115  model_bs_radius=25, binding_sites_topn=100000,
116  full_bs_search=False):
117 
118  super().__init__(model, target, model_ligands = model_ligands,
119  target_ligands = target_ligands,
120  resnum_alignments = resnum_alignments,
121  rename_ligand_chain = rename_ligand_chain,
122  substructure_match = substructure_match,
123  coverage_delta = coverage_delta,
124  max_symmetries = max_symmetries)
125 
126  self.bs_radiusbs_radius = bs_radius
127  self.lddt_lp_radiuslddt_lp_radius = lddt_lp_radius
128  self.model_bs_radiusmodel_bs_radius = model_bs_radius
129  self.binding_sites_topnbinding_sites_topn = binding_sites_topn
130  self.full_bs_searchfull_bs_search = full_bs_search
131 
132  # Residues that are in contact with a ligand => binding site
133  # defined as all residues with at least one atom within self.radius
134  # key: ligand.handle.hash_code, value: EntityView of whatever
135  # entity ligand belongs to
136  self._binding_sites_binding_sites = dict()
137 
138  # cache for GetRepr chain mapping calls
139  self._repr_repr = dict()
140 
141  # lazily precomputed variables to speedup GetRepr chain mapping calls
142  # for localized GetRepr searches
143  self.__chem_mapping__chem_mapping = None
144  self.__chem_group_alns__chem_group_alns = None
145  self.__ref_mdl_alns__ref_mdl_alns = None
146  self.__chain_mapping_mdl__chain_mapping_mdl = None
147  self._get_repr_input_get_repr_input = dict()
148 
149  # update state decoding from parent with subclass specific stuff
150  self.state_decodingstate_decoding[10] = ("target_binding_site",
151  "No residues were in proximity of the "
152  "target ligand.")
153  self.state_decodingstate_decoding[11] = ("model_binding_site", "Binding site was not"
154  " found in the model, i.e. the binding site"
155  " was not modeled or the model ligand was "
156  "positioned too far in combination with "
157  "full_bs_search=False.")
158  self.state_decodingstate_decoding[20] = ("unknown",
159  "Unknown error occured in SCRMSDScorer")
160 
161  def _compute(self, symmetries, target_ligand, model_ligand):
162  """ Implements interface from parent
163  """
164  # set default to invalid scores
165  best_rmsd_result = {"rmsd": None,
166  "lddt_lp": None,
167  "bs_ref_res": list(),
168  "bs_ref_res_mapped": list(),
169  "bs_mdl_res_mapped": list(),
170  "bb_rmsd": None,
171  "target_ligand": target_ligand,
172  "model_ligand": model_ligand,
173  "chain_mapping": dict(),
174  "transform": geom.Mat4(),
175  "inconsistent_residues": list()}
176 
177  representations = self._get_repr_get_repr(target_ligand, model_ligand)
178  # This step can be slow so give some hints in logs
179  msg = "Computing BiSyRMSD with %d chain mappings" % len(representations)
180  (LogWarning if len(representations) > 10000 else LogInfo)(msg)
181 
182  for r in representations:
183  rmsd = _SCRMSD_symmetries(symmetries, model_ligand,
184  target_ligand, transformation=r.transform)
185 
186  if best_rmsd_result["rmsd"] is None or \
187  rmsd < best_rmsd_result["rmsd"]:
188  best_rmsd_result = {"rmsd": rmsd,
189  "lddt_lp": r.lDDT,
190  "bs_ref_res": r.substructure.residues,
191  "bs_ref_res_mapped": r.ref_residues,
192  "bs_mdl_res_mapped": r.mdl_residues,
193  "bb_rmsd": r.bb_rmsd,
194  "target_ligand": target_ligand,
195  "model_ligand": model_ligand,
196  "chain_mapping": r.GetFlatChainMapping(),
197  "transform": r.transform,
198  "inconsistent_residues":
199  r.inconsistent_residues}
200 
201  target_ligand_state = 0
202  model_ligand_state = 0
203  pair_state = 0
204 
205  if best_rmsd_result["rmsd"] is not None:
206  best_rmsd = best_rmsd_result["rmsd"]
207  else:
208  # try to identify error states
209  best_rmsd = np.nan
210  error_state = 20 # unknown error
211  N = self._get_target_binding_site_get_target_binding_site(target_ligand).GetResidueCount()
212  if N == 0:
213  pair_state = 6 # binding_site
214  target_ligand_state = 10
215  elif len(representations) == 0:
216  pair_state = 11
217 
218  return (best_rmsd, pair_state, target_ligand_state, model_ligand_state,
219  best_rmsd_result)
220 
221  def _score_dir(self):
222  """ Implements interface from parent
223  """
224  return '-'
225 
226  def _get_repr(self, target_ligand, model_ligand):
227 
228  key = None
229  if self.full_bs_searchfull_bs_search:
230  # all possible binding sites, independent from actual model ligand
231  key = (target_ligand.handle.hash_code, 0)
232  else:
233  key = (target_ligand.handle.hash_code,
234  model_ligand.handle.hash_code)
235 
236  if key not in self._repr_repr:
237  ref_bs = self._get_target_binding_site_get_target_binding_site(target_ligand)
238  LogVerbose("%d chains are in proximity of the target ligand: %s" % (
239  ref_bs.chain_count, ", ".join([c.name for c in ref_bs.chains])))
240  if self.full_bs_searchfull_bs_search:
241  reprs = self._chain_mapper_chain_mapper.GetRepr(
242  ref_bs, self.modelmodel, inclusion_radius=self.lddt_lp_radiuslddt_lp_radius,
243  topn=self.binding_sites_topnbinding_sites_topn)
244  else:
245  repr_in = self._get_get_repr_input_get_get_repr_input(model_ligand)
246  radius = self.lddt_lp_radiuslddt_lp_radius
247  reprs = self._chain_mapper_chain_mapper.GetRepr(ref_bs, self.modelmodel,
248  inclusion_radius=radius,
249  topn=self.binding_sites_topnbinding_sites_topn,
250  chem_mapping_result=repr_in)
251  self._repr_repr[key] = reprs
252 
253  return self._repr_repr[key]
254 
255  def _get_target_binding_site(self, target_ligand):
256 
257  if target_ligand.handle.hash_code not in self._binding_sites_binding_sites:
258 
259  # create view of reference binding site
260  ref_residues_hashes = set() # helper to keep track of added residues
261  ignored_residue_hashes = {target_ligand.hash_code}
262  for ligand_at in target_ligand.atoms:
263  close_atoms = self.targettarget.FindWithin(ligand_at.GetPos(),
264  self.bs_radiusbs_radius)
265  for close_at in close_atoms:
266  # Skip any residue not in the chain mapping target
267  ref_res = close_at.GetResidue()
268  h = ref_res.handle.GetHashCode()
269  if h not in ref_residues_hashes and \
270  h not in ignored_residue_hashes:
271  with ligand_scoring_base._SinkVerbosityLevel(1):
272  view = self._chain_mapper_chain_mapper.target.ViewForHandle(ref_res)
273  if view.IsValid():
274  h = ref_res.handle.GetHashCode()
275  ref_residues_hashes.add(h)
276  elif ref_res.is_ligand:
277  msg = f"Ignoring ligand {ref_res.qualified_name} "
278  msg += "in binding site of "
279  msg += str(target_ligand.qualified_name)
280  LogWarning(msg)
281  ignored_residue_hashes.add(h)
282  elif ref_res.chem_type == mol.ChemType.WATERS:
283  pass # That's ok, no need to warn
284  else:
285  msg = f"Ignoring residue {ref_res.qualified_name} "
286  msg += "in binding site of "
287  msg += str(target_ligand.qualified_name)
288  LogWarning(msg)
289  ignored_residue_hashes.add(h)
290 
291  ref_bs = self.targettarget.CreateEmptyView()
292  if ref_residues_hashes:
293  # reason for doing that separately is to guarantee same ordering
294  # of residues as in underlying entity. (Reorder by ResNum seems
295  # only available on ChainHandles)
296  for ch in self.targettarget.chains:
297  for r in ch.residues:
298  if r.handle.GetHashCode() in ref_residues_hashes:
299  ref_bs.AddResidue(r, mol.ViewAddFlag.INCLUDE_ALL)
300  if len(ref_bs.residues) == 0:
301  raise RuntimeError("Failed to add proximity residues to "
302  "the reference binding site entity")
303 
304  self._binding_sites_binding_sites[target_ligand.handle.hash_code] = ref_bs
305 
306  return self._binding_sites_binding_sites[target_ligand.handle.hash_code]
307 
308  @property
309  def _chem_mapping(self):
310  if self.__chem_mapping__chem_mapping is None:
311  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
312  self.__chain_mapping_mdl__chain_mapping_mdl = \
313  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
314  return self.__chem_mapping__chem_mapping
315 
316  @property
317  def _chem_group_alns(self):
318  if self.__chem_group_alns__chem_group_alns is None:
319  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
320  self.__chain_mapping_mdl__chain_mapping_mdl = \
321  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
322  return self.__chem_group_alns__chem_group_alns
323 
324  @property
325  def _ref_mdl_alns(self):
326  if self.__ref_mdl_alns__ref_mdl_alns is None:
327  self.__ref_mdl_alns__ref_mdl_alns = \
328  chain_mapping._GetRefMdlAlns(self._chain_mapper_chain_mapper.chem_groups,
329  self._chain_mapper_chain_mapper.chem_group_alignments,
330  self._chem_mapping_chem_mapping,
331  self._chem_group_alns_chem_group_alns)
332  return self.__ref_mdl_alns__ref_mdl_alns
333 
334  @property
335  def _chain_mapping_mdl(self):
336  if self.__chain_mapping_mdl__chain_mapping_mdl is None:
337  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
338  self.__chain_mapping_mdl__chain_mapping_mdl = \
339  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
340  return self.__chain_mapping_mdl__chain_mapping_mdl
341 
342  def _get_get_repr_input(self, mdl_ligand):
343  if mdl_ligand.handle.hash_code not in self._get_repr_input_get_repr_input:
344 
345  # figure out what chains in the model are in contact with the ligand
346  # that may give a non-zero contribution to lDDT in
347  # chain_mapper.GetRepr
348  radius = self.model_bs_radiusmodel_bs_radius
349  chains = set()
350  for at in mdl_ligand.atoms:
351  with ligand_scoring_base._SinkVerbosityLevel(1):
352  close_atoms = self._chain_mapping_mdl_chain_mapping_mdl.FindWithin(at.GetPos(),
353  radius)
354  for close_at in close_atoms:
355  chains.add(close_at.GetChain().GetName())
356 
357  if len(chains) > 0:
358  LogVerbose("%d chains are in proximity of the model ligand: %s" % (
359  len(chains), ", ".join(chains)))
360 
361  # the chain mapping model which only contains close chains
362  query = "cname="
363  query += ','.join([mol.QueryQuoteName(x) for x in chains])
364  mdl = self._chain_mapping_mdl_chain_mapping_mdl.Select(query)
365 
366  # chem mapping which is reduced to the respective chains
367  chem_mapping = list()
368  for m in self._chem_mapping_chem_mapping:
369  chem_mapping.append([x for x in m if x in chains])
370 
371  self._get_repr_input_get_repr_input[mdl_ligand.handle.hash_code] = \
372  (mdl, chem_mapping)
373 
374  else:
375  self._get_repr_input_get_repr_input[mdl_ligand.handle.hash_code] = \
376  (self._chain_mapping_mdl_chain_mapping_mdl.CreateEmptyView(),
377  [list() for _ in self._chem_mapping_chem_mapping])
378 
379  return (self._get_repr_input_get_repr_input[mdl_ligand.hash_code][1],
380  self._chem_group_alns_chem_group_alns,
381  self._get_repr_input_get_repr_input[mdl_ligand.hash_code][0])
382 
383 
384 def SCRMSD(model_ligand, target_ligand, transformation=geom.Mat4(),
385  substructure_match=False, max_symmetries=1e6):
386  """Calculate symmetry-corrected RMSD.
387 
388  Binding site superposition must be computed separately and passed as
389  `transformation`.
390 
391  :param model_ligand: The model ligand
392  :type model_ligand: :class:`ost.mol.ResidueHandle` or
393  :class:`ost.mol.ResidueView`
394  :param target_ligand: The target ligand
395  :type target_ligand: :class:`ost.mol.ResidueHandle` or
396  :class:`ost.mol.ResidueView`
397  :param transformation: Optional transformation to apply on each atom
398  position of model_ligand.
399  :type transformation: :class:`ost.geom.Mat4`
400  :param substructure_match: Set this to True to allow partial target
401  ligand.
402  :type substructure_match: :class:`bool`
403  :param max_symmetries: If more than that many isomorphisms exist, raise
404  a :class:`TooManySymmetriesError`. This can only be assessed by
405  generating at least that many isomorphisms and can take some time.
406  :type max_symmetries: :class:`int`
407  :rtype: :class:`float`
408  :raises: :class:`ost.mol.alg.ligand_scoring_base.NoSymmetryError` when no
409  symmetry can be found,
410  :class:`ost.mol.alg.ligand_scoring_base.DisconnectedGraphError`
411  when ligand graph is disconnected,
412  :class:`ost.mol.alg.ligand_scoring_base.TooManySymmetriesError`
413  when more than *max_symmetries* isomorphisms are found.
414  """
415 
416  symmetries = ligand_scoring_base.ComputeSymmetries(model_ligand,
417  target_ligand,
418  substructure_match=substructure_match,
419  by_atom_index=True,
420  max_symmetries=max_symmetries)
421  return _SCRMSD_symmetries(symmetries, model_ligand, target_ligand,
422  transformation)
423 
424 
425 def _SCRMSD_symmetries(symmetries, model_ligand, target_ligand,
426  transformation):
427  """Compute SCRMSD with pre-computed symmetries. Internal. """
428 
429  # setup numpy positions for model ligand and apply transformation
430  mdl_ligand_pos = np.ones((model_ligand.GetAtomCount(), 4))
431  for a_idx, a in enumerate(model_ligand.atoms):
432  p = a.GetPos()
433  mdl_ligand_pos[a_idx, 0] = p[0]
434  mdl_ligand_pos[a_idx, 1] = p[1]
435  mdl_ligand_pos[a_idx, 2] = p[2]
436  np_transformation = np.zeros((4,4))
437  for i in range(4):
438  for j in range(4):
439  np_transformation[i,j] = transformation[i,j]
440  mdl_ligand_pos = mdl_ligand_pos.dot(np_transformation.T)[:,:3]
441 
442  # setup numpy positions for target ligand
443  trg_ligand_pos = np.zeros((target_ligand.GetAtomCount(), 3))
444  for a_idx, a in enumerate(target_ligand.atoms):
445  p = a.GetPos()
446  trg_ligand_pos[a_idx, 0] = p[0]
447  trg_ligand_pos[a_idx, 1] = p[1]
448  trg_ligand_pos[a_idx, 2] = p[2]
449 
450  # position matrices to iterate symmetries
451  # there is a guarantee that
452  # target_ligand.GetAtomCount() <= model_ligand.GetAtomCount()
453  # and that each target ligand atom is part of every symmetry
454  # => target_ligand.GetAtomCount() is size of both position matrices
455  rmsd_mdl_pos = np.zeros((target_ligand.GetAtomCount(), 3))
456  rmsd_trg_pos = np.zeros((target_ligand.GetAtomCount(), 3))
457 
458  # iterate symmetries and find the one with lowest RMSD
459  best_rmsd = np.inf
460  for i, (trg_sym, mdl_sym) in enumerate(symmetries):
461  for idx, (mdl_anum, trg_anum) in enumerate(zip(mdl_sym, trg_sym)):
462  rmsd_mdl_pos[idx,:] = mdl_ligand_pos[mdl_anum, :]
463  rmsd_trg_pos[idx,:] = trg_ligand_pos[trg_anum, :]
464  rmsd = np.sqrt(((rmsd_mdl_pos - rmsd_trg_pos)**2).sum(-1).mean())
465  if rmsd < best_rmsd:
466  best_rmsd = rmsd
467 
468  return best_rmsd
469 
470 # specify public interface
471 __all__ = ('SCRMSDScorer', 'SCRMSD')
def _get_repr(self, target_ligand, model_ligand)
def __init__(self, model, target, model_ligands=None, target_ligands=None, resnum_alignments=False, rename_ligand_chain=False, substructure_match=False, coverage_delta=0.2, max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0, model_bs_radius=25, binding_sites_topn=100000, full_bs_search=False)
def SCRMSD(model_ligand, target_ligand, transformation=geom.Mat4(), substructure_match=False, max_symmetries=1e6)