OpenStructure
ligand_scoring_lddtpli.py
Go to the documentation of this file.
1 import numpy as np
2 
3 from ost import LogWarning
4 from ost import geom
5 from ost import mol
6 from ost import seq
7 
8 from ost.mol.alg import lddt
9 from ost.mol.alg import chain_mapping
10 from ost.mol.alg import ligand_scoring_base
11 
13  """ :class:`LigandScorer` implementing lDDT-PLI.
14 
15  lDDT-PLI is an lDDT score considering contacts between ligand and
16  receptor. Where receptor consists of protein and nucleic acid chains that
17  pass the criteria for :class:`chain mapping <ost.mol.alg.chain_mapping>`.
18  This means ignoring other ligands, waters, short polymers as well as any
19  incorrectly connected chains that may be in proximity.
20 
21  :class:`LDDTPLIScorer` computes a score for a specific pair of target/model
22  ligands. Given a target/model ligand pair, all possible mappings of
23  model chains onto their chemically equivalent target chains are enumerated.
24  For each of these enumerations, all possible symmetries, i.e. atom-atom
25  assignments of the ligand as given by :class:`LigandScorer`, are evaluated
26  and an lDDT-PLI score is computed. The best possible lDDT-PLI score is
27  returned.
28 
29  By default, classic lDDT is computed. That means, contacts within
30  *lddt_pli_radius* are identified in the target and checked if they're
31  conserved in the model. Added contacts are not penalized. That means if
32  the ligand is nicely placed in the correct pocket, but that pocket now
33  suddenly interacts with MORE residues in the model, you still get a high
34  score. You can penalize for these added contacts with the
35  *add_mdl_contacts* flag. This additionally considers contacts within
36  *lddt_pli_radius* in the model but only if the involved atoms can
37  be mapped to the target. This is a requirement to 1) extract the respective
38  reference distance from the target 2) avoid usage of contacts for which
39  we have no experimental evidence. One special case are
40  contacts from chains that are NOT mapped to the target binding site. It is
41  very well possible that we have experimental evidence for this chain though
42  its just too far away from the target binding site.
43  We therefore try to map these contacts to the chain in the target with
44  equivalent sequence that is closest to the target binding site. If the
45  respective atoms can be mapped there, the contact is considered not
46  fulfilled and added as penalty.
47 
48  Populates :attr:`LigandScorer.aux_data` with following :class:`dict` keys:
49 
50  * lddt_pli: The score
51  * lddt_pli_n_contacts: Number of contacts considered in lDDT computation
52  * target_ligand: The actual target ligand for which the score was computed
53  * model_ligand: The actual model ligand for which the score was computed
54  * bs_ref_res: :class:`set` of residues with potentially non-zero
55  contribution to score. That is every residue with at least one
56  atom within *lddt_pli_radius* + max(*lddt_pli_thresholds*) of
57  the ligand.
58  * bs_mdl_res: Same for model
59 
60  :param model: Passed to parent constructor - see :class:`LigandScorer`.
61  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
62  :param target: Passed to parent constructor - see :class:`LigandScorer`.
63  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
64  :param model_ligands: Passed to parent constructor - see
65  :class:`LigandScorer`.
66  :type model_ligands: :class:`list`
67  :param target_ligands: Passed to parent constructor - see
68  :class:`LigandScorer`.
69  :type target_ligands: :class:`list`
70  :param resnum_alignments: Passed to parent constructor - see
71  :class:`LigandScorer`.
72  :type resnum_alignments: :class:`bool`
73  :param rename_ligand_chain: Passed to parent constructor - see
74  :class:`LigandScorer`.
75  :type rename_ligand_chain: :class:`bool`
76  :param substructure_match: Passed to parent constructor - see
77  :class:`LigandScorer`.
78  :type substructure_match: :class:`bool`
79  :param coverage_delta: Passed to parent constructor - see
80  :class:`LigandScorer`.
81  :type coverage_delta: :class:`float`
82  :param max_symmetries: Passed to parent constructor - see
83  :class:`LigandScorer`.
84  :type max_symmetries: :class:`int`
85  :param lddt_pli_radius: lDDT inclusion radius for lDDT-PLI.
86  :type lddt_pli_radius: :class:`float`
87  :param add_mdl_contacts: Whether to add mdl contacts.
88  :type add_mdl_contacts: :class:`bool`
89  :param lddt_pli_thresholds: Distance difference thresholds for lDDT.
90  :type lddt_pli_thresholds: :class:`list` of :class:`float`
91  :param lddt_pli_binding_site_radius: Pro param - dont use. Providing a value
92  Restores behaviour from previous
93  implementation that first extracted a
94  binding site with strict distance
95  threshold and computed lDDT-PLI only on
96  those target residues whereas the
97  current implementation includes every
98  atom within *lddt_pli_radius*.
99  :type lddt_pli_binding_site_radius: :class:`float`
100  """
101 
102  def __init__(self, model, target, model_ligands=None, target_ligands=None,
103  resnum_alignments=False, rename_ligand_chain=False,
104  substructure_match=False, coverage_delta=0.2,
105  max_symmetries=1e5, lddt_pli_radius=6.0,
106  add_mdl_contacts=False,
107  lddt_pli_thresholds = [0.5, 1.0, 2.0, 4.0],
108  lddt_pli_binding_site_radius=None):
109 
110  super().__init__(model, target, model_ligands = model_ligands,
111  target_ligands = target_ligands,
112  resnum_alignments = resnum_alignments,
113  rename_ligand_chain = rename_ligand_chain,
114  substructure_match = substructure_match,
115  coverage_delta = coverage_delta,
116  max_symmetries = max_symmetries)
117 
118  self.lddt_pli_radiuslddt_pli_radius = lddt_pli_radius
119  self.add_mdl_contactsadd_mdl_contacts = add_mdl_contacts
120  self.lddt_pli_thresholdslddt_pli_thresholds = lddt_pli_thresholds
121  self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius = lddt_pli_binding_site_radius
122 
123  # lazily precomputed variables to speedup lddt-pli computation
124  self._lddt_pli_target_data_lddt_pli_target_data = dict()
125  self._lddt_pli_model_data_lddt_pli_model_data = dict()
126  self.__mappable_atoms__mappable_atoms = None
127  self.__chem_mapping__chem_mapping = None
128  self.__chem_group_alns__chem_group_alns = None
129  self.__ref_mdl_alns__ref_mdl_alns = None
130  self.__chain_mapping_mdl__chain_mapping_mdl = None
131 
132  # update state decoding from parent with subclass specific stuff
133  self.state_decodingstate_decoding[10] = ("no_contact",
134  "There were no lDDT contacts between the "
135  "binding site and the ligand, and lDDT-PLI "
136  "is undefined.")
137  self.state_decodingstate_decoding[20] = ("unknown",
138  "Unknown error occured in LDDTPLIScorer")
139 
140  def _compute(self, symmetries, target_ligand, model_ligand):
141  """ Implements interface from parent
142  """
143  if self.add_mdl_contactsadd_mdl_contacts:
144  result = self._compute_lddt_pli_add_mdl_contacts_compute_lddt_pli_add_mdl_contacts(symmetries,
145  target_ligand,
146  model_ligand)
147  else:
148  result = self._compute_lddt_pli_classic_compute_lddt_pli_classic(symmetries,
149  target_ligand,
150  model_ligand)
151 
152  pair_state = 0
153  score = result["lddt_pli"]
154 
155  if score is None or np.isnan(score):
156  if result["lddt_pli_n_contacts"] == 0:
157  # it's a space ship!
158  pair_state = 10
159  else:
160  # unknwon error state
161  pair_state = 20
162 
163  # the ligands get a zero-state...
164  target_ligand_state = 0
165  model_ligand_state = 0
166 
167  return (score, pair_state, target_ligand_state, model_ligand_state,
168  result)
169 
170  def _score_dir(self):
171  """ Implements interface from parent
172  """
173  return '+'
174 
175  def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand,
176  model_ligand):
177 
178 
181 
182  trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
183  trg_ligand_res, scorer, chem_groups = \
184  self._lddt_pli_get_trg_data_lddt_pli_get_trg_data(target_ligand)
185 
186  trg_bs_center = trg_bs.geometric_center
187 
188  # Copy to make sure that we don't change anything on underlying
189  # references
190  # This is not strictly necessary in the current implementation but
191  # hey, maybe it avoids hard to debug errors when someone changes things
192  ref_indices = [a.copy() for a in scorer.ref_indices_ic]
193  ref_distances = [a.copy() for a in scorer.ref_distances_ic]
194 
195  # distance hacking... remove any interchain distance except the ones
196  # with the ligand
197  ligand_start_idx = scorer.chain_start_indices[-1]
198  for at_idx in range(ligand_start_idx):
199  mask = ref_indices[at_idx] >= ligand_start_idx
200  ref_indices[at_idx] = ref_indices[at_idx][mask]
201  ref_distances[at_idx] = ref_distances[at_idx][mask]
202 
203  mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
204  chem_mapping = self._lddt_pli_get_mdl_data_lddt_pli_get_mdl_data(model_ligand)
205 
206 
209 
210  # ref_mdl_alns refers to full chain mapper trg and mdl structures
211  # => need to adapt mdl sequence that only contain residues in contact
212  # with ligand
213  cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns_lddt_pli_cut_ref_mdl_alns(chem_groups,
214  chem_mapping,
215  mdl_bs, trg_bs)
216 
217 
220 
221  # get each chain mapping that we ever observe in scoring
222  chain_mappings = list(chain_mapping._ChainMappings(chem_groups,
223  chem_mapping))
224 
225  # for each mdl ligand atom, we collect all trg ligand atoms that are
226  # ever mapped onto it given *symmetries*
227  ligand_atom_mappings = [set() for a in mdl_ligand_res.atoms]
228  for (trg_sym, mdl_sym) in symmetries:
229  for trg_i, mdl_i in zip(trg_sym, mdl_sym):
230  ligand_atom_mappings[mdl_i].add(trg_i)
231 
232  mdl_ligand_pos = np.zeros((mdl_ligand_res.GetAtomCount(), 3))
233  for a_idx, a in enumerate(mdl_ligand_res.atoms):
234  p = a.GetPos()
235  mdl_ligand_pos[a_idx, 0] = p[0]
236  mdl_ligand_pos[a_idx, 1] = p[1]
237  mdl_ligand_pos[a_idx, 2] = p[2]
238 
239  trg_ligand_pos = np.zeros((trg_ligand_res.GetAtomCount(), 3))
240  for a_idx, a in enumerate(trg_ligand_res.atoms):
241  p = a.GetPos()
242  trg_ligand_pos[a_idx, 0] = p[0]
243  trg_ligand_pos[a_idx, 1] = p[1]
244  trg_ligand_pos[a_idx, 2] = p[2]
245 
246  mdl_lig_hashes = [a.hash_code for a in mdl_ligand_res.atoms]
247 
248  symmetric_atoms = np.asarray(sorted(list(scorer.symmetric_atoms)),
249  dtype=np.int64)
250 
251  # two caches to cache things for each chain mapping => lists
252  # of len(chain_mappings)
253  #
254  # In principle we're caching for each trg/mdl ligand atom pair all
255  # information to update ref_indices/ref_distances and resolving the
256  # symmetries of the binding site.
257  # in detail: each list entry in *scoring_cache* is a dict with
258  # key: (mdl_lig_at_idx, trg_lig_at_idx)
259  # value: tuple with 4 elements - 1: indices of atoms representing added
260  # contacts relative to overall inexing scheme in scorer 2: the
261  # respective distances 3: the same but only containing indices towards
262  # atoms of the binding site that are considered symmetric 4: the
263  # respective indices.
264  # each list entry in *penalty_cache* is a list of len N mdl lig atoms.
265  # For each mdl lig at it contains a penalty for this mdl lig at. That
266  # means the number of contacts in the mdl binding site that can
267  # directly be mapped to the target given the local chain mapping but
268  # are not present in the target binding site, i.e. interacting atoms are
269  # too far away.
270  scoring_cache = list()
271  penalty_cache = list()
272 
273  for mapping in chain_mappings:
274 
275  # flat mapping with mdl chain names as key
276  flat_mapping = dict()
277  for trg_chem_group, mdl_chem_group in zip(chem_groups, mapping):
278  for a,b in zip(trg_chem_group, mdl_chem_group):
279  if a is not None and b is not None:
280  flat_mapping[b] = a
281 
282  # for each mdl bs atom (as atom hash), the trg bs atoms (as index in
283  # scorer)
284  bs_atom_mapping = dict()
285  for mdl_cname, ref_cname in flat_mapping.items():
286  aln = cut_ref_mdl_alns[(ref_cname, mdl_cname)]
287  ref_ch = trg_bs.Select(f"cname={mol.QueryQuoteName(ref_cname)}")
288  mdl_ch = mdl_bs.Select(f"cname={mol.QueryQuoteName(mdl_cname)}")
289  aln.AttachView(0, ref_ch)
290  aln.AttachView(1, mdl_ch)
291  for col in aln:
292  ref_r = col.GetResidue(0)
293  mdl_r = col.GetResidue(1)
294  if ref_r.IsValid() and mdl_r.IsValid():
295  for mdl_a in mdl_r.atoms:
296  ref_a = ref_r.FindAtom(mdl_a.GetName())
297  if ref_a.IsValid():
298  ref_h = ref_a.handle.hash_code
299  if ref_h in scorer.atom_indices:
300  mdl_h = mdl_a.handle.hash_code
301  bs_atom_mapping[mdl_h] = \
302  scorer.atom_indices[ref_h]
303 
304  cache = dict()
305  n_penalties = list()
306 
307  for mdl_a_idx, mdl_a in enumerate(mdl_ligand_res.atoms):
308  n_penalty = 0
309  trg_bs_indices = list()
310  close_a = mdl_bs.FindWithin(mdl_a.GetPos(),
311  self.lddt_pli_radiuslddt_pli_radius)
312  for a in close_a:
313  mdl_a_hash_code = a.hash_code
314  if mdl_a_hash_code in bs_atom_mapping:
315  trg_bs_indices.append(bs_atom_mapping[mdl_a_hash_code])
316  elif mdl_a_hash_code not in mdl_lig_hashes:
317  if a.GetChain().GetName() in flat_mapping:
318  # Its in a mapped chain
319  at_key = (a.GetResidue().GetNumber(), a.name)
320  cname = a.GetChain().name
321  cname_key = (flat_mapping[cname], cname)
322  if at_key in self._mappable_atoms_mappable_atoms[cname_key]:
323  # Its a contact in the model but not part of
324  # trg_bs. It can still be mapped using the
325  # global mdl_ch/ref_ch alignment
326  # d in ref > self.lddt_pli_radius + max_thresh
327  # => guaranteed to be non-fulfilled contact
328  n_penalty += 1
329 
330  n_penalties.append(n_penalty)
331 
332  trg_bs_indices = np.asarray(sorted(trg_bs_indices))
333 
334  for trg_a_idx in ligand_atom_mappings[mdl_a_idx]:
335  # mask selects entries in trg_bs_indices that are not yet
336  # part of classic lDDT ref_indices for atom at trg_a_idx
337  # => added mdl contacts
338  mask = np.isin(trg_bs_indices,
339  ref_indices[ligand_start_idx + trg_a_idx],
340  assume_unique=True, invert=True)
341  added_indices = np.asarray([], dtype=np.int64)
342  added_distances = np.asarray([], dtype=np.float64)
343  if np.sum(mask) > 0:
344  # compute ref distances on reference positions
345  added_indices = trg_bs_indices[mask]
346  tmp = scorer.positions.take(added_indices, axis=0)
347  np.subtract(tmp, trg_ligand_pos[trg_a_idx][None, :],
348  out=tmp)
349  np.square(tmp, out=tmp)
350  tmp = tmp.sum(axis=1)
351  np.sqrt(tmp, out=tmp)
352  added_distances = tmp
353 
354  # extract the distances towards bs atoms that are symmetric
355  sym_mask = np.isin(added_indices, symmetric_atoms,
356  assume_unique=True)
357 
358  cache[(mdl_a_idx, trg_a_idx)] = (added_indices,
359  added_distances,
360  added_indices[sym_mask],
361  added_distances[sym_mask])
362 
363  scoring_cache.append(cache)
364  penalty_cache.append(n_penalties)
365 
366  # cache for model contacts towards non mapped trg chains - this is
367  # relevant for self._lddt_pli_unmapped_chain_penalty
368  # key: tuple in form (trg_ch, mdl_ch)
369  # value: yet another dict with
370  # key: ligand_atom_hash
371  # value: n contacts towards respective trg chain that can be mapped
372  non_mapped_cache = dict()
373 
374 
377 
378  best_score = -1.0
379  best_result = {"lddt_pli": None,
380  "lddt_pli_n_contacts": 0}
381 
382  # dummy alignment for ligand chains which is needed as input later on
383  ligand_aln = seq.CreateAlignment()
384  trg_s = seq.CreateSequence(trg_ligand_chain.name,
385  trg_ligand_res.GetOneLetterCode())
386  mdl_s = seq.CreateSequence(mdl_ligand_chain.name,
387  mdl_ligand_res.GetOneLetterCode())
388  ligand_aln.AddSequence(trg_s)
389  ligand_aln.AddSequence(mdl_s)
390  ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
391 
392  sym_idx_collector = [None] * scorer.n_atoms
393  sym_dist_collector = [None] * scorer.n_atoms
394 
395  for mapping, s_cache, p_cache in zip(chain_mappings, scoring_cache,
396  penalty_cache):
397 
398  lddt_chain_mapping = dict()
399  lddt_alns = dict()
400  for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
401  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
402  # some mdl chains can be None
403  if mdl_ch is not None:
404  lddt_chain_mapping[mdl_ch] = ref_ch
405  lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
406 
407  # add ligand to lddt_chain_mapping/lddt_alns
408  lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
409  lddt_alns[mdl_ligand_chain.name] = ligand_aln
410 
411  # already process model, positions will be manually hacked for each
412  # symmetry - small overhead for variables that are thrown away here
413  pos, _, _, _, _, _, lddt_symmetries = \
414  scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
415  residue_mapping = lddt_alns,
416  thresholds = self.lddt_pli_thresholdslddt_pli_thresholds,
417  check_resnames = False)
418 
419  # estimate a penalty for unsatisfied model contacts from chains
420  # that are not in the local trg binding site, but can be mapped in
421  # the target.
422  # We're using the trg chain with the closest geometric center to
423  # the trg binding site that can be mapped to the mdl chain
424  # according the chem mapping. An alternative would be to search for
425  # the target chain with the minimal number of additional contacts.
426  # There is not good solution for this problem...
427  unmapped_chains = list()
428  already_mapped = set()
429  for mdl_ch in mdl_chains:
430  if mdl_ch not in lddt_chain_mapping:
431  # check which chain in trg is closest
432  chem_grp_idx = None
433  for i, m in enumerate(self._chem_mapping_chem_mapping):
434  if mdl_ch in m:
435  chem_grp_idx = i
436  break
437  if chem_grp_idx is None:
438  raise RuntimeError("This should never happen... "
439  "ask Gabriel...")
440  closest_ch = None
441  closest_dist = None
442  for trg_ch in self._chain_mapper_chain_mapper.chem_groups[chem_grp_idx]:
443  if trg_ch not in lddt_chain_mapping.values():
444  if trg_ch not in already_mapped:
445  ch = self._chain_mapper_chain_mapper.target.FindChain(trg_ch)
446  c = ch.geometric_center
447  d = geom.Distance(trg_bs_center, c)
448  if closest_dist is None or d < closest_dist:
449  closest_dist = d
450  closest_ch = trg_ch
451  if closest_ch is not None:
452  unmapped_chains.append((closest_ch, mdl_ch))
453  already_mapped.add(closest_ch)
454 
455  for (trg_sym, mdl_sym) in symmetries:
456 
457  # update positions
458  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
459  pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
460 
461  # start new ref_indices/ref_distances from original values
462  funky_ref_indices = [np.copy(a) for a in ref_indices]
463  funky_ref_distances = [np.copy(a) for a in ref_distances]
464 
465  # The only distances from the binding site towards the ligand
466  # we care about are the ones from the symmetric atoms to
467  # correctly compute scorer._ResolveSymmetries.
468  # We collect them while updating distances from added mdl
469  # contacts
470  for idx in symmetric_atoms:
471  sym_idx_collector[idx] = list()
472  sym_dist_collector[idx] = list()
473 
474  # add data from added mdl contacts cache
475  added_penalty = 0
476  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
477  added_penalty += p_cache[mdl_i]
478  cache = s_cache[mdl_i, trg_i]
479  full_trg_i = ligand_start_idx + trg_i
480  funky_ref_indices[full_trg_i] = \
481  np.append(funky_ref_indices[full_trg_i], cache[0])
482  funky_ref_distances[full_trg_i] = \
483  np.append(funky_ref_distances[full_trg_i], cache[1])
484  for idx, d in zip(cache[2], cache[3]):
485  sym_idx_collector[idx].append(full_trg_i)
486  sym_dist_collector[idx].append(d)
487 
488  for idx in symmetric_atoms:
489  funky_ref_indices[idx] = \
490  np.append(funky_ref_indices[idx],
491  np.asarray(sym_idx_collector[idx],
492  dtype=np.int64))
493  funky_ref_distances[idx] = \
494  np.append(funky_ref_distances[idx],
495  np.asarray(sym_dist_collector[idx],
496  dtype=np.float64))
497 
498  # we can pass funky_ref_indices/funky_ref_distances as
499  # sym_ref_indices/sym_ref_distances in
500  # scorer._ResolveSymmetries as we only have distances of the bs
501  # to the ligand and ligand atoms are "non-symmetric"
502  scorer._ResolveSymmetries(pos, self.lddt_pli_thresholdslddt_pli_thresholds,
503  lddt_symmetries,
504  funky_ref_indices,
505  funky_ref_distances)
506 
507  N = sum([len(funky_ref_indices[i]) for i in ligand_at_indices])
508  N += added_penalty
509 
510  # collect number of expected contacts which can be mapped
511  if len(unmapped_chains) > 0:
512  N += self._lddt_pli_unmapped_chain_penalty_lddt_pli_unmapped_chain_penalty(unmapped_chains,
513  non_mapped_cache,
514  mdl_bs,
515  mdl_ligand_res,
516  mdl_sym)
517 
518  conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
519  self.lddt_pli_thresholdslddt_pli_thresholds,
520  funky_ref_indices,
521  funky_ref_distances),
522  axis=0)
523  score = None
524  if N > 0:
525  score = np.mean(conserved/N)
526 
527  if score is not None and score > best_score:
528  best_score = score
529  best_result = {"lddt_pli": score,
530  "lddt_pli_n_contacts": N}
531 
532  # fill misc info to result object
533  best_result["target_ligand"] = target_ligand
534  best_result["model_ligand"] = model_ligand
535  best_result["bs_ref_res"] = trg_residues
536  best_result["bs_mdl_res"] = mdl_residues
537 
538  return best_result
539 
540 
541  def _compute_lddt_pli_classic(self, symmetries, target_ligand,
542  model_ligand):
543 
544 
547 
548  max_r = None
549  if self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius:
550  max_r = self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius
551 
552  trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
553  trg_ligand_res, scorer, chem_groups = \
554  self._lddt_pli_get_trg_data_lddt_pli_get_trg_data(target_ligand, max_r = max_r)
555 
556  # Copy to make sure that we don't change anything on underlying
557  # references
558  # This is not strictly necessary in the current implementation but
559  # hey, maybe it avoids hard to debug errors when someone changes things
560  ref_indices = [a.copy() for a in scorer.ref_indices_ic]
561  ref_distances = [a.copy() for a in scorer.ref_distances_ic]
562 
563  # no matter what mapping/symmetries, the number of expected
564  # contacts stays the same
565  ligand_start_idx = scorer.chain_start_indices[-1]
566  ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
567  n_exp = sum([len(ref_indices[i]) for i in ligand_at_indices])
568 
569  mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
570  chem_mapping = self._lddt_pli_get_mdl_data_lddt_pli_get_mdl_data(model_ligand)
571 
572  if n_exp == 0:
573  # no contacts... nothing to compute...
574  return {"lddt_pli": None,
575  "lddt_pli_n_contacts": 0,
576  "target_ligand": target_ligand,
577  "model_ligand": model_ligand,
578  "bs_ref_res": trg_residues,
579  "bs_mdl_res": mdl_residues}
580 
581  # Distance hacking... remove any interchain distance except the ones
582  # with the ligand
583  for at_idx in range(ligand_start_idx):
584  mask = ref_indices[at_idx] >= ligand_start_idx
585  ref_indices[at_idx] = ref_indices[at_idx][mask]
586  ref_distances[at_idx] = ref_distances[at_idx][mask]
587 
588 
591 
592  # ref_mdl_alns refers to full chain mapper trg and mdl structures
593  # => need to adapt mdl sequence that only contain residues in contact
594  # with ligand
595  cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns_lddt_pli_cut_ref_mdl_alns(chem_groups,
596  chem_mapping,
597  mdl_bs, trg_bs)
598 
599 
602 
603  best_score = -1.0
604 
605  # dummy alignment for ligand chains which is needed as input later on
606  l_aln = seq.CreateAlignment()
607  l_aln.AddSequence(seq.CreateSequence(trg_ligand_chain.name,
608  trg_ligand_res.GetOneLetterCode()))
609  l_aln.AddSequence(seq.CreateSequence(mdl_ligand_chain.name,
610  mdl_ligand_res.GetOneLetterCode()))
611 
612  mdl_ligand_pos = np.zeros((model_ligand.GetAtomCount(), 3))
613  for a_idx, a in enumerate(model_ligand.atoms):
614  p = a.GetPos()
615  mdl_ligand_pos[a_idx, 0] = p[0]
616  mdl_ligand_pos[a_idx, 1] = p[1]
617  mdl_ligand_pos[a_idx, 2] = p[2]
618 
619  for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping):
620 
621  lddt_chain_mapping = dict()
622  lddt_alns = dict()
623  for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
624  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
625  # some mdl chains can be None
626  if mdl_ch is not None:
627  lddt_chain_mapping[mdl_ch] = ref_ch
628  lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
629 
630  # add ligand to lddt_chain_mapping/lddt_alns
631  lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
632  lddt_alns[mdl_ligand_chain.name] = l_aln
633 
634  # already process model, positions will be manually hacked for each
635  # symmetry - small overhead for variables that are thrown away here
636  pos, _, _, _, _, _, lddt_symmetries = \
637  scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
638  residue_mapping = lddt_alns,
639  thresholds = self.lddt_pli_thresholdslddt_pli_thresholds,
640  check_resnames = False)
641 
642  for (trg_sym, mdl_sym) in symmetries:
643  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
644  pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
645  # we can pass ref_indices/ref_distances as
646  # sym_ref_indices/sym_ref_distances in
647  # scorer._ResolveSymmetries as we only have distances of the bs
648  # to the ligand and ligand atoms are "non-symmetric"
649  scorer._ResolveSymmetries(pos, self.lddt_pli_thresholdslddt_pli_thresholds,
650  lddt_symmetries,
651  ref_indices,
652  ref_distances)
653  # compute number of conserved distances for ligand atoms
654  conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
655  self.lddt_pli_thresholdslddt_pli_thresholds,
656  ref_indices,
657  ref_distances), axis=0)
658  score = np.mean(conserved/n_exp)
659 
660  if score > best_score:
661  best_score = score
662 
663  # fill misc info to result object
664  best_result = {"lddt_pli": best_score,
665  "lddt_pli_n_contacts": n_exp,
666  "target_ligand": target_ligand,
667  "model_ligand": model_ligand,
668  "bs_ref_res": trg_residues,
669  "bs_mdl_res": mdl_residues}
670 
671  return best_result
672 
673  def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains,
674  non_mapped_cache,
675  mdl_bs,
676  mdl_ligand_res,
677  mdl_sym):
678 
679  n_exp = 0
680  for ch_tuple in unmapped_chains:
681  if ch_tuple not in non_mapped_cache:
682  # for each ligand atom, we count the number of mappable atoms
683  # within lddt_pli_radius
684  counts = dict()
685  # the select statement also excludes the ligand in mdl_bs
686  # as it resides in a separate chain
687  mdl_cname = ch_tuple[1]
688  query = "cname=" + mol.QueryQuoteName(mdl_cname)
689  mdl_bs_ch = mdl_bs.Select(query)
690  for a in mdl_ligand_res.atoms:
691  close_atoms = \
692  mdl_bs_ch.FindWithin(a.GetPos(), self.lddt_pli_radiuslddt_pli_radius)
693  N = 0
694  for close_a in close_atoms:
695  at_key = (close_a.GetResidue().GetNumber(),
696  close_a.GetName())
697  if at_key in self._mappable_atoms_mappable_atoms[ch_tuple]:
698  N += 1
699  counts[a.hash_code] = N
700 
701  # fill cache
702  non_mapped_cache[ch_tuple] = counts
703 
704  # add number of mdl contacts which can be mapped to target
705  # as non-fulfilled contacts
706  counts = non_mapped_cache[ch_tuple]
707  lig_hash_codes = [a.hash_code for a in mdl_ligand_res.atoms]
708  for i in mdl_sym:
709  n_exp += counts[lig_hash_codes[i]]
710 
711  return n_exp
712 
713 
714  def _lddt_pli_get_mdl_data(self, model_ligand):
715  if model_ligand not in self._lddt_pli_model_data_lddt_pli_model_data:
716 
717  mdl = self._chain_mapping_mdl_chain_mapping_mdl
718 
719  mdl_residues = set()
720  for at in model_ligand.atoms:
721  close_atoms = mdl.FindWithin(at.GetPos(), self.lddt_pli_radiuslddt_pli_radius)
722  for close_at in close_atoms:
723  mdl_residues.add(close_at.GetResidue())
724 
725  max_r = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds)
726  for r in mdl.residues:
727  r.SetIntProp("bs", 0)
728  for at in model_ligand.atoms:
729  close_atoms = mdl.FindWithin(at.GetPos(), max_r)
730  for close_at in close_atoms:
731  close_at.GetResidue().SetIntProp("bs", 1)
732 
733  mdl_bs = mol.CreateEntityFromView(mdl.Select("grbs:0=1"), True)
734  mdl_chains = set([ch.name for ch in mdl_bs.chains])
735 
736  mdl_editor = mdl_bs.EditXCS(mol.BUFFERED_EDIT)
737  mdl_ligand_chain = None
738  for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
739  try:
740  # I'm pretty sure, one of these chain names is not there...
741  mdl_ligand_chain = mdl_editor.InsertChain(cname)
742  break
743  except:
744  pass
745  if mdl_ligand_chain is None:
746  raise RuntimeError("Fuck this, I'm out...")
747  mdl_ligand_res = mdl_editor.AppendResidue(mdl_ligand_chain,
748  model_ligand,
749  deep=True)
750  mdl_editor.RenameResidue(mdl_ligand_res, "LIG")
751  mdl_editor.SetResidueNumber(mdl_ligand_res, mol.ResNum(1))
752 
753  chem_mapping = list()
754  for m in self._chem_mapping_chem_mapping:
755  chem_mapping.append([x for x in m if x in mdl_chains])
756 
757  self._lddt_pli_model_data_lddt_pli_model_data[model_ligand] = (mdl_residues,
758  mdl_bs,
759  mdl_chains,
760  mdl_ligand_chain,
761  mdl_ligand_res,
762  chem_mapping)
763 
764  return self._lddt_pli_model_data_lddt_pli_model_data[model_ligand]
765 
766 
767  def _lddt_pli_get_trg_data(self, target_ligand, max_r = None):
768  if target_ligand not in self._lddt_pli_target_data_lddt_pli_target_data:
769 
770  trg = self._chain_mapper_chain_mapper.target
771 
772  if max_r is None:
773  max_r = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds)
774 
775  trg_residues = set()
776  for at in target_ligand.atoms:
777  close_atoms = trg.FindWithin(at.GetPos(), max_r)
778  for close_at in close_atoms:
779  trg_residues.add(close_at.GetResidue())
780 
781  for r in trg.residues:
782  r.SetIntProp("bs", 0)
783 
784  for r in trg_residues:
785  r.SetIntProp("bs", 1)
786 
787  trg_bs = mol.CreateEntityFromView(trg.Select("grbs:0=1"), True)
788  trg_chains = set([ch.name for ch in trg_bs.chains])
789 
790  trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT)
791  trg_ligand_chain = None
792  for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
793  try:
794  # I'm pretty sure, one of these chain names is not there yet
795  trg_ligand_chain = trg_editor.InsertChain(cname)
796  break
797  except:
798  pass
799  if trg_ligand_chain is None:
800  raise RuntimeError("Fuck this, I'm out...")
801 
802  trg_ligand_res = trg_editor.AppendResidue(trg_ligand_chain,
803  target_ligand,
804  deep=True)
805  trg_editor.RenameResidue(trg_ligand_res, "LIG")
806  trg_editor.SetResidueNumber(trg_ligand_res, mol.ResNum(1))
807 
808  compound_name = trg_ligand_res.name
809  compound = lddt.CustomCompound.FromResidue(trg_ligand_res)
810  custom_compounds = {compound_name: compound}
811 
812  scorer = lddt.lDDTScorer(trg_bs,
813  custom_compounds = custom_compounds,
814  inclusion_radius = self.lddt_pli_radiuslddt_pli_radius)
815 
816  chem_groups = list()
817  for g in self._chain_mapper_chain_mapper.chem_groups:
818  chem_groups.append([x for x in g if x in trg_chains])
819 
820  self._lddt_pli_target_data_lddt_pli_target_data[target_ligand] = (trg_residues,
821  trg_bs,
822  trg_chains,
823  trg_ligand_chain,
824  trg_ligand_res,
825  scorer,
826  chem_groups)
827 
828  return self._lddt_pli_target_data_lddt_pli_target_data[target_ligand]
829 
830 
831  def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs,
832  ref_bs):
833  cut_ref_mdl_alns = dict()
834  for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping):
835  for ref_ch in ref_chem_group:
836 
837  ref_bs_chain = ref_bs.FindChain(ref_ch)
838  query = "cname=" + mol.QueryQuoteName(ref_ch)
839  ref_view = self._chain_mapper_chain_mapper.target.Select(query)
840 
841  for mdl_ch in mdl_chem_group:
842  aln = self._ref_mdl_alns_ref_mdl_alns[(ref_ch, mdl_ch)]
843 
844  aln.AttachView(0, ref_view)
845 
846  mdl_bs_chain = mdl_bs.FindChain(mdl_ch)
847  query = "cname=" + mol.QueryQuoteName(mdl_ch)
848  aln.AttachView(1, self._chain_mapping_mdl_chain_mapping_mdl.Select(query))
849 
850  cut_mdl_seq = ['-'] * aln.GetLength()
851  cut_ref_seq = ['-'] * aln.GetLength()
852  for i, col in enumerate(aln):
853 
854  # check ref residue
855  r = col.GetResidue(0)
856  if r.IsValid():
857  bs_r = ref_bs_chain.FindResidue(r.GetNumber())
858  if bs_r.IsValid():
859  cut_ref_seq[i] = col[0]
860 
861  # check mdl residue
862  r = col.GetResidue(1)
863  if r.IsValid():
864  bs_r = mdl_bs_chain.FindResidue(r.GetNumber())
865  if bs_r.IsValid():
866  cut_mdl_seq[i] = col[1]
867 
868  cut_ref_seq = ''.join(cut_ref_seq)
869  cut_mdl_seq = ''.join(cut_mdl_seq)
870  cut_aln = seq.CreateAlignment()
871  cut_aln.AddSequence(seq.CreateSequence(ref_ch, cut_ref_seq))
872  cut_aln.AddSequence(seq.CreateSequence(mdl_ch, cut_mdl_seq))
873  cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln
874  return cut_ref_mdl_alns
875 
876  @property
877  def _mappable_atoms(self):
878  """ Stores mappable atoms given a chain mapping
879 
880  Store for each ref_ch,mdl_ch pair all mdl atoms that can be
881  mapped. Don't store mappable atoms as hashes but rather as tuple
882  (mdl_r.GetNumber(), mdl_a.GetName()). Reason for that is that one might
883  operate on Copied EntityHandle objects without corresponding hashes.
884  Given a tuple defining c_pair: (ref_cname, mdl_cname), one
885  can check if a certain atom is mappable by evaluating:
886  if (mdl_r.GetNumber(), mdl_a.GetName()) in self._mappable_atoms(c_pair)
887  """
888  if self.__mappable_atoms__mappable_atoms is None:
889  self.__mappable_atoms__mappable_atoms = dict()
890  for (ref_cname, mdl_cname), aln in self._ref_mdl_alns_ref_mdl_alns.items():
891  self._mappable_atoms_mappable_atoms[(ref_cname, mdl_cname)] = set()
892  ref_query = f"cname={mol.QueryQuoteName(ref_cname)}"
893  mdl_query = f"cname={mol.QueryQuoteName(mdl_cname)}"
894  ref_ch = self._chain_mapper_chain_mapper.target.Select(ref_query)
895  mdl_ch = self._chain_mapping_mdl_chain_mapping_mdl.Select(mdl_query)
896  aln.AttachView(0, ref_ch)
897  aln.AttachView(1, mdl_ch)
898  for col in aln:
899  ref_r = col.GetResidue(0)
900  mdl_r = col.GetResidue(1)
901  if ref_r.IsValid() and mdl_r.IsValid():
902  for mdl_a in mdl_r.atoms:
903  if ref_r.FindAtom(mdl_a.name).IsValid():
904  c_key = (ref_cname, mdl_cname)
905  at_key = (mdl_r.GetNumber(), mdl_a.name)
906  self.__mappable_atoms__mappable_atoms[c_key].add(at_key)
907 
908  return self.__mappable_atoms__mappable_atoms
909 
910  @property
911  def _chem_mapping(self):
912  if self.__chem_mapping__chem_mapping is None:
913  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
914  self.__chain_mapping_mdl__chain_mapping_mdl = \
915  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
916  return self.__chem_mapping__chem_mapping
917 
918  @property
919  def _chem_group_alns(self):
920  if self.__chem_group_alns__chem_group_alns is None:
921  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
922  self.__chain_mapping_mdl__chain_mapping_mdl = \
923  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
924  return self.__chem_group_alns__chem_group_alns
925 
926  @property
927  def _ref_mdl_alns(self):
928  if self.__ref_mdl_alns__ref_mdl_alns is None:
929  self.__ref_mdl_alns__ref_mdl_alns = \
930  chain_mapping._GetRefMdlAlns(self._chain_mapper_chain_mapper.chem_groups,
931  self._chain_mapper_chain_mapper.chem_group_alignments,
932  self._chem_mapping_chem_mapping,
933  self._chem_group_alns_chem_group_alns)
934  return self.__ref_mdl_alns__ref_mdl_alns
935 
936  @property
937  def _chain_mapping_mdl(self):
938  if self.__chain_mapping_mdl__chain_mapping_mdl is None:
939  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
940  self.__chain_mapping_mdl__chain_mapping_mdl = \
941  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
942  return self.__chain_mapping_mdl__chain_mapping_mdl
943 
944 # specify public interface
945 __all__ = ('LDDTPLIScorer',)
def __init__(self, model, target, model_ligands=None, target_ligands=None, resnum_alignments=False, rename_ligand_chain=False, substructure_match=False, coverage_delta=0.2, max_symmetries=1e5, lddt_pli_radius=6.0, add_mdl_contacts=False, lddt_pli_thresholds=[0.5, 1.0, 2.0, 4.0], lddt_pli_binding_site_radius=None)
def _lddt_pli_get_trg_data(self, target_ligand, max_r=None)
def _compute_lddt_pli_classic(self, symmetries, target_ligand, model_ligand)
def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains, non_mapped_cache, mdl_bs, mdl_ligand_res, mdl_sym)
def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs, ref_bs)
def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand, model_ligand)
Real DLLEXPORT_OST_GEOM Distance(const Line2 &l, const Vec2 &v)