OpenStructure
ligand_scoring_lddtpli.py
Go to the documentation of this file.
1 import numpy as np
2 
3 from ost import LogWarning, LogInfo
4 from ost import geom
5 from ost import mol
6 from ost import seq
7 
8 from ost.mol.alg import lddt
9 from ost.mol.alg import chain_mapping
10 from ost.mol.alg import ligand_scoring_base
11 
13  """ :class:`LigandScorer` implementing lDDT-PLI.
14 
15  lDDT-PLI is an lDDT score considering contacts between ligand and
16  receptor. Where receptor consists of protein and nucleic acid chains that
17  pass the criteria for :class:`chain mapping <ost.mol.alg.chain_mapping>`.
18  This means ignoring other ligands, waters, short polymers as well as any
19  incorrectly connected chains that may be in proximity.
20 
21  :class:`LDDTPLIScorer` computes a score for a specific pair of target/model
22  ligands. Given a target/model ligand pair, all possible mappings of
23  model chains onto their chemically equivalent target chains are enumerated.
24  For each of these enumerations, all possible symmetries, i.e. atom-atom
25  assignments of the ligand as given by :class:`LigandScorer`, are evaluated
26  and an lDDT-PLI score is computed. The best possible lDDT-PLI score is
27  returned.
28 
29  The lDDT-PLI score is a variant of lDDT with a custom inclusion radius
30  (`lddt_pli_radius`), no stereochemistry checks, and which penalizes
31  contacts added in the model within `lddt_pli_radius` by default
32  (can be changed with the `add_mdl_contacts` flag) but only if the involved
33  atoms can be mapped to the target. This is a requirement to
34  1) extract the respective reference distance from the target
35  2) avoid usage of contacts for which we have no experimental evidence.
36  One special case are contacts from chains that are not mapped to the target
37  binding site. It is very well possible that we have experimental evidence
38  for this chain though its just too far away from the target binding site.
39  We therefore try to map these contacts to the chain in the target with
40  equivalent sequence that is closest to the target binding site. If the
41  respective atoms can be mapped there, the contact is considered not
42  fulfilled and added as penalty.
43 
44  Populates :attr:`LigandScorer.aux_data` with following :class:`dict` keys:
45 
46  * lddt_pli: The LDDT-PLI score
47  * lddt_pli_n_contacts: Number of contacts considered in lDDT computation
48  * target_ligand: The actual target ligand for which the score was computed
49  * model_ligand: The actual model ligand for which the score was computed
50  * bs_ref_res: :class:`set` of residues with potentially non-zero
51  contribution to score. That is every residue with at least one
52  atom within *lddt_pli_radius* + max(*lddt_pli_thresholds*) of
53  the ligand.
54  * bs_mdl_res: Same for model
55 
56  :param model: Passed to parent constructor - see :class:`LigandScorer`.
57  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
58  :param target: Passed to parent constructor - see :class:`LigandScorer`.
59  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
60  :param model_ligands: Passed to parent constructor - see
61  :class:`LigandScorer`.
62  :type model_ligands: :class:`list`
63  :param target_ligands: Passed to parent constructor - see
64  :class:`LigandScorer`.
65  :type target_ligands: :class:`list`
66  :param resnum_alignments: Passed to parent constructor - see
67  :class:`LigandScorer`.
68  :type resnum_alignments: :class:`bool`
69  :param rename_ligand_chain: Passed to parent constructor - see
70  :class:`LigandScorer`.
71  :type rename_ligand_chain: :class:`bool`
72  :param substructure_match: Passed to parent constructor - see
73  :class:`LigandScorer`.
74  :type substructure_match: :class:`bool`
75  :param coverage_delta: Passed to parent constructor - see
76  :class:`LigandScorer`.
77  :type coverage_delta: :class:`float`
78  :param max_symmetries: Passed to parent constructor - see
79  :class:`LigandScorer`.
80  :type max_symmetries: :class:`int`
81  :param lddt_pli_radius: lDDT inclusion radius for lDDT-PLI.
82  :type lddt_pli_radius: :class:`float`
83  :param add_mdl_contacts: Whether to penalize added model contacts.
84  :type add_mdl_contacts: :class:`bool`
85  :param lddt_pli_thresholds: Distance difference thresholds for lDDT.
86  :type lddt_pli_thresholds: :class:`list` of :class:`float`
87  :param lddt_pli_binding_site_radius: Pro param - dont use. Providing a value
88  Restores behaviour from previous
89  implementation that first extracted a
90  binding site with strict distance
91  threshold and computed lDDT-PLI only on
92  those target residues whereas the
93  current implementation includes every
94  atom within *lddt_pli_radius*.
95  :type lddt_pli_binding_site_radius: :class:`float`
96  """
97 
98  def __init__(self, model, target, model_ligands=None, target_ligands=None,
99  resnum_alignments=False, rename_ligand_chain=False,
100  substructure_match=False, coverage_delta=0.2,
101  max_symmetries=1e4, lddt_pli_radius=6.0,
102  add_mdl_contacts=True,
103  lddt_pli_thresholds = [0.5, 1.0, 2.0, 4.0],
104  lddt_pli_binding_site_radius=None):
105 
106  super().__init__(model, target, model_ligands = model_ligands,
107  target_ligands = target_ligands,
108  resnum_alignments = resnum_alignments,
109  rename_ligand_chain = rename_ligand_chain,
110  substructure_match = substructure_match,
111  coverage_delta = coverage_delta,
112  max_symmetries = max_symmetries)
113 
114  self.lddt_pli_radiuslddt_pli_radius = lddt_pli_radius
115  self.add_mdl_contactsadd_mdl_contacts = add_mdl_contacts
116  self.lddt_pli_thresholdslddt_pli_thresholds = lddt_pli_thresholds
117  self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius = lddt_pli_binding_site_radius
118 
119  # lazily precomputed variables to speedup lddt-pli computation
120  self._lddt_pli_target_data_lddt_pli_target_data = dict()
121  self._lddt_pli_model_data_lddt_pli_model_data = dict()
122  self.__mappable_atoms__mappable_atoms = None
123  self.__chem_mapping__chem_mapping = None
124  self.__chem_group_alns__chem_group_alns = None
125  self.__ref_mdl_alns__ref_mdl_alns = None
126  self.__chain_mapping_mdl__chain_mapping_mdl = None
127 
128  # update state decoding from parent with subclass specific stuff
129  self.state_decodingstate_decoding[10] = ("no_contact",
130  "There were no lDDT contacts between the "
131  "binding site and the ligand, and lDDT-PLI "
132  "is undefined.")
133  self.state_decodingstate_decoding[20] = ("unknown",
134  "Unknown error occured in LDDTPLIScorer")
135 
136  def _compute(self, symmetries, target_ligand, model_ligand):
137  """ Implements interface from parent
138  """
139  if self.add_mdl_contactsadd_mdl_contacts:
140  LogInfo("Computing lDDT-PLI with added model contacts")
141  result = self._compute_lddt_pli_add_mdl_contacts_compute_lddt_pli_add_mdl_contacts(symmetries,
142  target_ligand,
143  model_ligand)
144  else:
145  LogInfo("Computing lDDT-PLI without added model contacts")
146  result = self._compute_lddt_pli_classic_compute_lddt_pli_classic(symmetries,
147  target_ligand,
148  model_ligand)
149 
150  pair_state = 0
151  score = result["lddt_pli"]
152 
153  if score is None or np.isnan(score):
154  if result["lddt_pli_n_contacts"] == 0:
155  # it's a space ship!
156  pair_state = 10
157  else:
158  # unknwon error state
159  pair_state = 20
160 
161  # the ligands get a zero-state...
162  target_ligand_state = 0
163  model_ligand_state = 0
164 
165  return (score, pair_state, target_ligand_state, model_ligand_state,
166  result)
167 
168  def _score_dir(self):
169  """ Implements interface from parent
170  """
171  return '+'
172 
173  def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand,
174  model_ligand):
175 
176 
179 
180  trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
181  trg_ligand_res, scorer, chem_groups = \
182  self._lddt_pli_get_trg_data_lddt_pli_get_trg_data(target_ligand)
183 
184  trg_bs_center = trg_bs.geometric_center
185 
186  # Copy to make sure that we don't change anything on underlying
187  # references
188  # This is not strictly necessary in the current implementation but
189  # hey, maybe it avoids hard to debug errors when someone changes things
190  ref_indices = [a.copy() for a in scorer.ref_indices_ic]
191  ref_distances = [a.copy() for a in scorer.ref_distances_ic]
192 
193  # distance hacking... remove any interchain distance except the ones
194  # with the ligand
195  ligand_start_idx = scorer.chain_start_indices[-1]
196  for at_idx in range(ligand_start_idx):
197  mask = ref_indices[at_idx] >= ligand_start_idx
198  ref_indices[at_idx] = ref_indices[at_idx][mask]
199  ref_distances[at_idx] = ref_distances[at_idx][mask]
200 
201  mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
202  chem_mapping = self._lddt_pli_get_mdl_data_lddt_pli_get_mdl_data(model_ligand)
203 
204 
207 
208  # ref_mdl_alns refers to full chain mapper trg and mdl structures
209  # => need to adapt mdl sequence that only contain residues in contact
210  # with ligand
211  cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns_lddt_pli_cut_ref_mdl_alns(chem_groups,
212  chem_mapping,
213  mdl_bs, trg_bs)
214 
215 
218 
219  # get each chain mapping that we ever observe in scoring
220  chain_mappings = list(chain_mapping._ChainMappings(chem_groups,
221  chem_mapping))
222 
223  # for each mdl ligand atom, we collect all trg ligand atoms that are
224  # ever mapped onto it given *symmetries*
225  ligand_atom_mappings = [set() for a in mdl_ligand_res.atoms]
226  for (trg_sym, mdl_sym) in symmetries:
227  for trg_i, mdl_i in zip(trg_sym, mdl_sym):
228  ligand_atom_mappings[mdl_i].add(trg_i)
229 
230  mdl_ligand_pos = np.zeros((mdl_ligand_res.GetAtomCount(), 3))
231  for a_idx, a in enumerate(mdl_ligand_res.atoms):
232  p = a.GetPos()
233  mdl_ligand_pos[a_idx, 0] = p[0]
234  mdl_ligand_pos[a_idx, 1] = p[1]
235  mdl_ligand_pos[a_idx, 2] = p[2]
236 
237  trg_ligand_pos = np.zeros((trg_ligand_res.GetAtomCount(), 3))
238  for a_idx, a in enumerate(trg_ligand_res.atoms):
239  p = a.GetPos()
240  trg_ligand_pos[a_idx, 0] = p[0]
241  trg_ligand_pos[a_idx, 1] = p[1]
242  trg_ligand_pos[a_idx, 2] = p[2]
243 
244  mdl_lig_hashes = [a.hash_code for a in mdl_ligand_res.atoms]
245 
246  symmetric_atoms = np.asarray(sorted(list(scorer.symmetric_atoms)),
247  dtype=np.int64)
248 
249  # two caches to cache things for each chain mapping => lists
250  # of len(chain_mappings)
251  #
252  # In principle we're caching for each trg/mdl ligand atom pair all
253  # information to update ref_indices/ref_distances and resolving the
254  # symmetries of the binding site.
255  # in detail: each list entry in *scoring_cache* is a dict with
256  # key: (mdl_lig_at_idx, trg_lig_at_idx)
257  # value: tuple with 4 elements - 1: indices of atoms representing added
258  # contacts relative to overall inexing scheme in scorer 2: the
259  # respective distances 3: the same but only containing indices towards
260  # atoms of the binding site that are considered symmetric 4: the
261  # respective indices.
262  # each list entry in *penalty_cache* is a list of len N mdl lig atoms.
263  # For each mdl lig at it contains a penalty for this mdl lig at. That
264  # means the number of contacts in the mdl binding site that can
265  # directly be mapped to the target given the local chain mapping but
266  # are not present in the target binding site, i.e. interacting atoms are
267  # too far away.
268  scoring_cache = list()
269  penalty_cache = list()
270 
271  for mapping in chain_mappings:
272 
273  # flat mapping with mdl chain names as key
274  flat_mapping = dict()
275  for trg_chem_group, mdl_chem_group in zip(chem_groups, mapping):
276  for a,b in zip(trg_chem_group, mdl_chem_group):
277  if a is not None and b is not None:
278  flat_mapping[b] = a
279 
280  # for each mdl bs atom (as atom hash), the trg bs atoms (as index in
281  # scorer)
282  bs_atom_mapping = dict()
283  for mdl_cname, ref_cname in flat_mapping.items():
284  aln = cut_ref_mdl_alns[(ref_cname, mdl_cname)]
285  ref_ch = trg_bs.Select(f"cname={mol.QueryQuoteName(ref_cname)}")
286  mdl_ch = mdl_bs.Select(f"cname={mol.QueryQuoteName(mdl_cname)}")
287  aln.AttachView(0, ref_ch)
288  aln.AttachView(1, mdl_ch)
289  for col in aln:
290  ref_r = col.GetResidue(0)
291  mdl_r = col.GetResidue(1)
292  if ref_r.IsValid() and mdl_r.IsValid():
293  for mdl_a in mdl_r.atoms:
294  ref_a = ref_r.FindAtom(mdl_a.GetName())
295  if ref_a.IsValid():
296  ref_h = ref_a.handle.hash_code
297  if ref_h in scorer.atom_indices:
298  mdl_h = mdl_a.handle.hash_code
299  bs_atom_mapping[mdl_h] = \
300  scorer.atom_indices[ref_h]
301 
302  cache = dict()
303  n_penalties = list()
304 
305  for mdl_a_idx, mdl_a in enumerate(mdl_ligand_res.atoms):
306  n_penalty = 0
307  trg_bs_indices = list()
308  close_a = mdl_bs.FindWithin(mdl_a.GetPos(),
309  self.lddt_pli_radiuslddt_pli_radius)
310  for a in close_a:
311  mdl_a_hash_code = a.hash_code
312  if mdl_a_hash_code in bs_atom_mapping:
313  trg_bs_indices.append(bs_atom_mapping[mdl_a_hash_code])
314  elif mdl_a_hash_code not in mdl_lig_hashes:
315  if a.GetChain().GetName() in flat_mapping:
316  # Its in a mapped chain
317  at_key = (a.GetResidue().GetNumber(), a.name)
318  cname = a.GetChain().name
319  cname_key = (flat_mapping[cname], cname)
320  if at_key in self._mappable_atoms_mappable_atoms[cname_key]:
321  # Its a contact in the model but not part of
322  # trg_bs. It can still be mapped using the
323  # global mdl_ch/ref_ch alignment
324  # d in ref > self.lddt_pli_radius + max_thresh
325  # => guaranteed to be non-fulfilled contact
326  n_penalty += 1
327 
328  n_penalties.append(n_penalty)
329 
330  trg_bs_indices = np.asarray(sorted(trg_bs_indices))
331 
332  for trg_a_idx in ligand_atom_mappings[mdl_a_idx]:
333  # mask selects entries in trg_bs_indices that are not yet
334  # part of classic lDDT ref_indices for atom at trg_a_idx
335  # => added mdl contacts
336  mask = np.isin(trg_bs_indices,
337  ref_indices[ligand_start_idx + trg_a_idx],
338  assume_unique=True, invert=True)
339  added_indices = np.asarray([], dtype=np.int64)
340  added_distances = np.asarray([], dtype=np.float64)
341  if np.sum(mask) > 0:
342  # compute ref distances on reference positions
343  added_indices = trg_bs_indices[mask]
344  tmp = scorer.positions.take(added_indices, axis=0)
345  np.subtract(tmp, trg_ligand_pos[trg_a_idx][None, :],
346  out=tmp)
347  np.square(tmp, out=tmp)
348  tmp = tmp.sum(axis=1)
349  np.sqrt(tmp, out=tmp)
350  added_distances = tmp
351 
352  # extract the distances towards bs atoms that are symmetric
353  sym_mask = np.isin(added_indices, symmetric_atoms,
354  assume_unique=True)
355 
356  cache[(mdl_a_idx, trg_a_idx)] = (added_indices,
357  added_distances,
358  added_indices[sym_mask],
359  added_distances[sym_mask])
360 
361  scoring_cache.append(cache)
362  penalty_cache.append(n_penalties)
363 
364  # cache for model contacts towards non mapped trg chains - this is
365  # relevant for self._lddt_pli_unmapped_chain_penalty
366  # key: tuple in form (trg_ch, mdl_ch)
367  # value: yet another dict with
368  # key: ligand_atom_hash
369  # value: n contacts towards respective trg chain that can be mapped
370  non_mapped_cache = dict()
371 
372 
375 
376  best_score = -1.0
377  best_result = {"lddt_pli": None,
378  "lddt_pli_n_contacts": 0}
379 
380  # dummy alignment for ligand chains which is needed as input later on
381  ligand_aln = seq.CreateAlignment()
382  trg_s = seq.CreateSequence(trg_ligand_chain.name,
383  trg_ligand_res.GetOneLetterCode())
384  mdl_s = seq.CreateSequence(mdl_ligand_chain.name,
385  mdl_ligand_res.GetOneLetterCode())
386  ligand_aln.AddSequence(trg_s)
387  ligand_aln.AddSequence(mdl_s)
388  ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
389 
390  sym_idx_collector = [None] * scorer.n_atoms
391  sym_dist_collector = [None] * scorer.n_atoms
392 
393  for mapping, s_cache, p_cache in zip(chain_mappings, scoring_cache,
394  penalty_cache):
395 
396  lddt_chain_mapping = dict()
397  lddt_alns = dict()
398  for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
399  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
400  # some mdl chains can be None
401  if mdl_ch is not None:
402  lddt_chain_mapping[mdl_ch] = ref_ch
403  lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
404 
405  # add ligand to lddt_chain_mapping/lddt_alns
406  lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
407  lddt_alns[mdl_ligand_chain.name] = ligand_aln
408 
409  # already process model, positions will be manually hacked for each
410  # symmetry - small overhead for variables that are thrown away here
411  pos, _, _, _, _, _, lddt_symmetries = \
412  scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
413  residue_mapping = lddt_alns,
414  thresholds = self.lddt_pli_thresholdslddt_pli_thresholds,
415  check_resnames = False)
416 
417  # estimate a penalty for unsatisfied model contacts from chains
418  # that are not in the local trg binding site, but can be mapped in
419  # the target.
420  # We're using the trg chain with the closest geometric center to
421  # the trg binding site that can be mapped to the mdl chain
422  # according the chem mapping. An alternative would be to search for
423  # the target chain with the minimal number of additional contacts.
424  # There is not good solution for this problem...
425  unmapped_chains = list()
426  already_mapped = set()
427  for mdl_ch in mdl_chains:
428  if mdl_ch not in lddt_chain_mapping:
429  # check which chain in trg is closest
430  chem_grp_idx = None
431  for i, m in enumerate(self._chem_mapping_chem_mapping):
432  if mdl_ch in m:
433  chem_grp_idx = i
434  break
435  if chem_grp_idx is None:
436  raise RuntimeError("This should never happen... "
437  "ask Gabriel...")
438  closest_ch = None
439  closest_dist = None
440  for trg_ch in self._chain_mapper_chain_mapper.chem_groups[chem_grp_idx]:
441  if trg_ch not in lddt_chain_mapping.values():
442  if trg_ch not in already_mapped:
443  ch = self._chain_mapper_chain_mapper.target.FindChain(trg_ch)
444  c = ch.geometric_center
445  d = geom.Distance(trg_bs_center, c)
446  if closest_dist is None or d < closest_dist:
447  closest_dist = d
448  closest_ch = trg_ch
449  if closest_ch is not None:
450  unmapped_chains.append((closest_ch, mdl_ch))
451  already_mapped.add(closest_ch)
452 
453  for (trg_sym, mdl_sym) in symmetries:
454 
455  # update positions
456  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
457  pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
458 
459  # start new ref_indices/ref_distances from original values
460  funky_ref_indices = [np.copy(a) for a in ref_indices]
461  funky_ref_distances = [np.copy(a) for a in ref_distances]
462 
463  # The only distances from the binding site towards the ligand
464  # we care about are the ones from the symmetric atoms to
465  # correctly compute scorer._ResolveSymmetries.
466  # We collect them while updating distances from added mdl
467  # contacts
468  for idx in symmetric_atoms:
469  sym_idx_collector[idx] = list()
470  sym_dist_collector[idx] = list()
471 
472  # add data from added mdl contacts cache
473  added_penalty = 0
474  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
475  added_penalty += p_cache[mdl_i]
476  cache = s_cache[mdl_i, trg_i]
477  full_trg_i = ligand_start_idx + trg_i
478  funky_ref_indices[full_trg_i] = \
479  np.append(funky_ref_indices[full_trg_i], cache[0])
480  funky_ref_distances[full_trg_i] = \
481  np.append(funky_ref_distances[full_trg_i], cache[1])
482  for idx, d in zip(cache[2], cache[3]):
483  sym_idx_collector[idx].append(full_trg_i)
484  sym_dist_collector[idx].append(d)
485 
486  for idx in symmetric_atoms:
487  funky_ref_indices[idx] = \
488  np.append(funky_ref_indices[idx],
489  np.asarray(sym_idx_collector[idx],
490  dtype=np.int64))
491  funky_ref_distances[idx] = \
492  np.append(funky_ref_distances[idx],
493  np.asarray(sym_dist_collector[idx],
494  dtype=np.float64))
495 
496  # we can pass funky_ref_indices/funky_ref_distances as
497  # sym_ref_indices/sym_ref_distances in
498  # scorer._ResolveSymmetries as we only have distances of the bs
499  # to the ligand and ligand atoms are "non-symmetric"
500  scorer._ResolveSymmetries(pos, self.lddt_pli_thresholdslddt_pli_thresholds,
501  lddt_symmetries,
502  funky_ref_indices,
503  funky_ref_distances)
504 
505  N = sum([len(funky_ref_indices[i]) for i in ligand_at_indices])
506  N += added_penalty
507 
508  # collect number of expected contacts which can be mapped
509  if len(unmapped_chains) > 0:
510  N += self._lddt_pli_unmapped_chain_penalty_lddt_pli_unmapped_chain_penalty(unmapped_chains,
511  non_mapped_cache,
512  mdl_bs,
513  mdl_ligand_res,
514  mdl_sym)
515 
516  conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
517  self.lddt_pli_thresholdslddt_pli_thresholds,
518  funky_ref_indices,
519  funky_ref_distances),
520  axis=0)
521  score = None
522  if N > 0:
523  score = np.mean(conserved/N)
524 
525  if score is not None and score > best_score:
526  best_score = score
527  best_result = {"lddt_pli": score,
528  "lddt_pli_n_contacts": N}
529 
530  # fill misc info to result object
531  best_result["target_ligand"] = target_ligand
532  best_result["model_ligand"] = model_ligand
533  best_result["bs_ref_res"] = trg_residues
534  best_result["bs_mdl_res"] = mdl_residues
535 
536  return best_result
537 
538 
539  def _compute_lddt_pli_classic(self, symmetries, target_ligand,
540  model_ligand):
541 
542 
545 
546  max_r = None
547  if self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius:
548  max_r = self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius
549 
550  trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
551  trg_ligand_res, scorer, chem_groups = \
552  self._lddt_pli_get_trg_data_lddt_pli_get_trg_data(target_ligand, max_r = max_r)
553 
554  # Copy to make sure that we don't change anything on underlying
555  # references
556  # This is not strictly necessary in the current implementation but
557  # hey, maybe it avoids hard to debug errors when someone changes things
558  ref_indices = [a.copy() for a in scorer.ref_indices_ic]
559  ref_distances = [a.copy() for a in scorer.ref_distances_ic]
560 
561  # no matter what mapping/symmetries, the number of expected
562  # contacts stays the same
563  ligand_start_idx = scorer.chain_start_indices[-1]
564  ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
565  n_exp = sum([len(ref_indices[i]) for i in ligand_at_indices])
566 
567  mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
568  chem_mapping = self._lddt_pli_get_mdl_data_lddt_pli_get_mdl_data(model_ligand)
569 
570  if n_exp == 0:
571  # no contacts... nothing to compute...
572  return {"lddt_pli": None,
573  "lddt_pli_n_contacts": 0,
574  "target_ligand": target_ligand,
575  "model_ligand": model_ligand,
576  "bs_ref_res": trg_residues,
577  "bs_mdl_res": mdl_residues}
578 
579  # Distance hacking... remove any interchain distance except the ones
580  # with the ligand
581  for at_idx in range(ligand_start_idx):
582  mask = ref_indices[at_idx] >= ligand_start_idx
583  ref_indices[at_idx] = ref_indices[at_idx][mask]
584  ref_distances[at_idx] = ref_distances[at_idx][mask]
585 
586 
589 
590  # ref_mdl_alns refers to full chain mapper trg and mdl structures
591  # => need to adapt mdl sequence that only contain residues in contact
592  # with ligand
593  cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns_lddt_pli_cut_ref_mdl_alns(chem_groups,
594  chem_mapping,
595  mdl_bs, trg_bs)
596 
597 
600 
601  best_score = -1.0
602 
603  # dummy alignment for ligand chains which is needed as input later on
604  l_aln = seq.CreateAlignment()
605  l_aln.AddSequence(seq.CreateSequence(trg_ligand_chain.name,
606  trg_ligand_res.GetOneLetterCode()))
607  l_aln.AddSequence(seq.CreateSequence(mdl_ligand_chain.name,
608  mdl_ligand_res.GetOneLetterCode()))
609 
610  mdl_ligand_pos = np.zeros((model_ligand.GetAtomCount(), 3))
611  for a_idx, a in enumerate(model_ligand.atoms):
612  p = a.GetPos()
613  mdl_ligand_pos[a_idx, 0] = p[0]
614  mdl_ligand_pos[a_idx, 1] = p[1]
615  mdl_ligand_pos[a_idx, 2] = p[2]
616 
617  for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping):
618 
619  lddt_chain_mapping = dict()
620  lddt_alns = dict()
621  for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
622  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
623  # some mdl chains can be None
624  if mdl_ch is not None:
625  lddt_chain_mapping[mdl_ch] = ref_ch
626  lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
627 
628  # add ligand to lddt_chain_mapping/lddt_alns
629  lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
630  lddt_alns[mdl_ligand_chain.name] = l_aln
631 
632  # already process model, positions will be manually hacked for each
633  # symmetry - small overhead for variables that are thrown away here
634  pos, _, _, _, _, _, lddt_symmetries = \
635  scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
636  residue_mapping = lddt_alns,
637  thresholds = self.lddt_pli_thresholdslddt_pli_thresholds,
638  check_resnames = False)
639 
640  for (trg_sym, mdl_sym) in symmetries:
641  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
642  pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
643  # we can pass ref_indices/ref_distances as
644  # sym_ref_indices/sym_ref_distances in
645  # scorer._ResolveSymmetries as we only have distances of the bs
646  # to the ligand and ligand atoms are "non-symmetric"
647  scorer._ResolveSymmetries(pos, self.lddt_pli_thresholdslddt_pli_thresholds,
648  lddt_symmetries,
649  ref_indices,
650  ref_distances)
651  # compute number of conserved distances for ligand atoms
652  conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
653  self.lddt_pli_thresholdslddt_pli_thresholds,
654  ref_indices,
655  ref_distances), axis=0)
656  score = np.mean(conserved/n_exp)
657 
658  if score > best_score:
659  best_score = score
660 
661  # fill misc info to result object
662  best_result = {"lddt_pli": best_score,
663  "lddt_pli_n_contacts": n_exp,
664  "target_ligand": target_ligand,
665  "model_ligand": model_ligand,
666  "bs_ref_res": trg_residues,
667  "bs_mdl_res": mdl_residues}
668 
669  return best_result
670 
671  def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains,
672  non_mapped_cache,
673  mdl_bs,
674  mdl_ligand_res,
675  mdl_sym):
676 
677  n_exp = 0
678  for ch_tuple in unmapped_chains:
679  if ch_tuple not in non_mapped_cache:
680  # for each ligand atom, we count the number of mappable atoms
681  # within lddt_pli_radius
682  counts = dict()
683  # the select statement also excludes the ligand in mdl_bs
684  # as it resides in a separate chain
685  mdl_cname = ch_tuple[1]
686  query = "cname=" + mol.QueryQuoteName(mdl_cname)
687  mdl_bs_ch = mdl_bs.Select(query)
688  for a in mdl_ligand_res.atoms:
689  close_atoms = \
690  mdl_bs_ch.FindWithin(a.GetPos(), self.lddt_pli_radiuslddt_pli_radius)
691  N = 0
692  for close_a in close_atoms:
693  at_key = (close_a.GetResidue().GetNumber(),
694  close_a.GetName())
695  if at_key in self._mappable_atoms_mappable_atoms[ch_tuple]:
696  N += 1
697  counts[a.hash_code] = N
698 
699  # fill cache
700  non_mapped_cache[ch_tuple] = counts
701 
702  # add number of mdl contacts which can be mapped to target
703  # as non-fulfilled contacts
704  counts = non_mapped_cache[ch_tuple]
705  lig_hash_codes = [a.hash_code for a in mdl_ligand_res.atoms]
706  for i in mdl_sym:
707  n_exp += counts[lig_hash_codes[i]]
708 
709  return n_exp
710 
711 
712  def _lddt_pli_get_mdl_data(self, model_ligand):
713  if model_ligand not in self._lddt_pli_model_data_lddt_pli_model_data:
714 
715  mdl = self._chain_mapping_mdl_chain_mapping_mdl
716 
717  mdl_residues = set()
718  for at in model_ligand.atoms:
719  close_atoms = mdl.FindWithin(at.GetPos(), self.lddt_pli_radiuslddt_pli_radius)
720  for close_at in close_atoms:
721  mdl_residues.add(close_at.GetResidue())
722 
723  max_r = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds)
724  for r in mdl.residues:
725  r.SetIntProp("bs", 0)
726  for at in model_ligand.atoms:
727  close_atoms = mdl.FindWithin(at.GetPos(), max_r)
728  for close_at in close_atoms:
729  close_at.GetResidue().SetIntProp("bs", 1)
730 
731  mdl_bs = mol.CreateEntityFromView(mdl.Select("grbs:0=1"), True)
732  mdl_chains = set([ch.name for ch in mdl_bs.chains])
733 
734  mdl_editor = mdl_bs.EditXCS(mol.BUFFERED_EDIT)
735  mdl_ligand_chain = None
736  for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
737  try:
738  # I'm pretty sure, one of these chain names is not there...
739  mdl_ligand_chain = mdl_editor.InsertChain(cname)
740  break
741  except:
742  pass
743  if mdl_ligand_chain is None:
744  raise RuntimeError("Fuck this, I'm out...")
745  mdl_ligand_res = mdl_editor.AppendResidue(mdl_ligand_chain,
746  model_ligand,
747  deep=True)
748  mdl_editor.RenameResidue(mdl_ligand_res, "LIG")
749  mdl_editor.SetResidueNumber(mdl_ligand_res, mol.ResNum(1))
750 
751  chem_mapping = list()
752  for m in self._chem_mapping_chem_mapping:
753  chem_mapping.append([x for x in m if x in mdl_chains])
754 
755  self._lddt_pli_model_data_lddt_pli_model_data[model_ligand] = (mdl_residues,
756  mdl_bs,
757  mdl_chains,
758  mdl_ligand_chain,
759  mdl_ligand_res,
760  chem_mapping)
761 
762  return self._lddt_pli_model_data_lddt_pli_model_data[model_ligand]
763 
764 
765  def _lddt_pli_get_trg_data(self, target_ligand, max_r = None):
766  if target_ligand not in self._lddt_pli_target_data_lddt_pli_target_data:
767 
768  trg = self._chain_mapper_chain_mapper.target
769 
770  if max_r is None:
771  max_r = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds)
772 
773  trg_residues = set()
774  for at in target_ligand.atoms:
775  close_atoms = trg.FindWithin(at.GetPos(), max_r)
776  for close_at in close_atoms:
777  trg_residues.add(close_at.GetResidue())
778 
779  for r in trg.residues:
780  r.SetIntProp("bs", 0)
781 
782  for r in trg_residues:
783  r.SetIntProp("bs", 1)
784 
785  trg_bs = mol.CreateEntityFromView(trg.Select("grbs:0=1"), True)
786  trg_chains = set([ch.name for ch in trg_bs.chains])
787 
788  trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT)
789  trg_ligand_chain = None
790  for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
791  try:
792  # I'm pretty sure, one of these chain names is not there yet
793  trg_ligand_chain = trg_editor.InsertChain(cname)
794  break
795  except:
796  pass
797  if trg_ligand_chain is None:
798  raise RuntimeError("Fuck this, I'm out...")
799 
800  trg_ligand_res = trg_editor.AppendResidue(trg_ligand_chain,
801  target_ligand,
802  deep=True)
803  trg_editor.RenameResidue(trg_ligand_res, "LIG")
804  trg_editor.SetResidueNumber(trg_ligand_res, mol.ResNum(1))
805 
806  compound_name = trg_ligand_res.name
807  compound = lddt.CustomCompound.FromResidue(trg_ligand_res)
808  custom_compounds = {compound_name: compound}
809 
810  scorer = lddt.lDDTScorer(trg_bs,
811  custom_compounds = custom_compounds,
812  inclusion_radius = self.lddt_pli_radiuslddt_pli_radius)
813 
814  chem_groups = list()
815  for g in self._chain_mapper_chain_mapper.chem_groups:
816  chem_groups.append([x for x in g if x in trg_chains])
817 
818  self._lddt_pli_target_data_lddt_pli_target_data[target_ligand] = (trg_residues,
819  trg_bs,
820  trg_chains,
821  trg_ligand_chain,
822  trg_ligand_res,
823  scorer,
824  chem_groups)
825 
826  return self._lddt_pli_target_data_lddt_pli_target_data[target_ligand]
827 
828 
829  def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs,
830  ref_bs):
831  cut_ref_mdl_alns = dict()
832  for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping):
833  for ref_ch in ref_chem_group:
834 
835  ref_bs_chain = ref_bs.FindChain(ref_ch)
836  query = "cname=" + mol.QueryQuoteName(ref_ch)
837  ref_view = self._chain_mapper_chain_mapper.target.Select(query)
838 
839  for mdl_ch in mdl_chem_group:
840  aln = self._ref_mdl_alns_ref_mdl_alns[(ref_ch, mdl_ch)]
841 
842  aln.AttachView(0, ref_view)
843 
844  mdl_bs_chain = mdl_bs.FindChain(mdl_ch)
845  query = "cname=" + mol.QueryQuoteName(mdl_ch)
846  aln.AttachView(1, self._chain_mapping_mdl_chain_mapping_mdl.Select(query))
847 
848  cut_mdl_seq = ['-'] * aln.GetLength()
849  cut_ref_seq = ['-'] * aln.GetLength()
850  for i, col in enumerate(aln):
851 
852  # check ref residue
853  r = col.GetResidue(0)
854  if r.IsValid():
855  bs_r = ref_bs_chain.FindResidue(r.GetNumber())
856  if bs_r.IsValid():
857  cut_ref_seq[i] = col[0]
858 
859  # check mdl residue
860  r = col.GetResidue(1)
861  if r.IsValid():
862  bs_r = mdl_bs_chain.FindResidue(r.GetNumber())
863  if bs_r.IsValid():
864  cut_mdl_seq[i] = col[1]
865 
866  cut_ref_seq = ''.join(cut_ref_seq)
867  cut_mdl_seq = ''.join(cut_mdl_seq)
868  cut_aln = seq.CreateAlignment()
869  cut_aln.AddSequence(seq.CreateSequence(ref_ch, cut_ref_seq))
870  cut_aln.AddSequence(seq.CreateSequence(mdl_ch, cut_mdl_seq))
871  cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln
872  return cut_ref_mdl_alns
873 
874  @property
875  def _mappable_atoms(self):
876  """ Stores mappable atoms given a chain mapping
877 
878  Store for each ref_ch,mdl_ch pair all mdl atoms that can be
879  mapped. Don't store mappable atoms as hashes but rather as tuple
880  (mdl_r.GetNumber(), mdl_a.GetName()). Reason for that is that one might
881  operate on Copied EntityHandle objects without corresponding hashes.
882  Given a tuple defining c_pair: (ref_cname, mdl_cname), one
883  can check if a certain atom is mappable by evaluating:
884  if (mdl_r.GetNumber(), mdl_a.GetName()) in self._mappable_atoms(c_pair)
885  """
886  if self.__mappable_atoms__mappable_atoms is None:
887  self.__mappable_atoms__mappable_atoms = dict()
888  for (ref_cname, mdl_cname), aln in self._ref_mdl_alns_ref_mdl_alns.items():
889  self._mappable_atoms_mappable_atoms[(ref_cname, mdl_cname)] = set()
890  ref_query = f"cname={mol.QueryQuoteName(ref_cname)}"
891  mdl_query = f"cname={mol.QueryQuoteName(mdl_cname)}"
892  ref_ch = self._chain_mapper_chain_mapper.target.Select(ref_query)
893  mdl_ch = self._chain_mapping_mdl_chain_mapping_mdl.Select(mdl_query)
894  aln.AttachView(0, ref_ch)
895  aln.AttachView(1, mdl_ch)
896  for col in aln:
897  ref_r = col.GetResidue(0)
898  mdl_r = col.GetResidue(1)
899  if ref_r.IsValid() and mdl_r.IsValid():
900  for mdl_a in mdl_r.atoms:
901  if ref_r.FindAtom(mdl_a.name).IsValid():
902  c_key = (ref_cname, mdl_cname)
903  at_key = (mdl_r.GetNumber(), mdl_a.name)
904  self.__mappable_atoms__mappable_atoms[c_key].add(at_key)
905 
906  return self.__mappable_atoms__mappable_atoms
907 
908  @property
909  def _chem_mapping(self):
910  if self.__chem_mapping__chem_mapping is None:
911  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
912  self.__chain_mapping_mdl__chain_mapping_mdl = \
913  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
914  return self.__chem_mapping__chem_mapping
915 
916  @property
917  def _chem_group_alns(self):
918  if self.__chem_group_alns__chem_group_alns is None:
919  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
920  self.__chain_mapping_mdl__chain_mapping_mdl = \
921  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
922  return self.__chem_group_alns__chem_group_alns
923 
924  @property
925  def _ref_mdl_alns(self):
926  if self.__ref_mdl_alns__ref_mdl_alns is None:
927  self.__ref_mdl_alns__ref_mdl_alns = \
928  chain_mapping._GetRefMdlAlns(self._chain_mapper_chain_mapper.chem_groups,
929  self._chain_mapper_chain_mapper.chem_group_alignments,
930  self._chem_mapping_chem_mapping,
931  self._chem_group_alns_chem_group_alns)
932  return self.__ref_mdl_alns__ref_mdl_alns
933 
934  @property
935  def _chain_mapping_mdl(self):
936  if self.__chain_mapping_mdl__chain_mapping_mdl is None:
937  with ligand_scoring_base._SinkVerbosityLevel():
938  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
939  self.__chain_mapping_mdl__chain_mapping_mdl = \
940  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
941  return self.__chain_mapping_mdl__chain_mapping_mdl
942 
943 # specify public interface
944 __all__ = ('LDDTPLIScorer',)
def __init__(self, model, target, model_ligands=None, target_ligands=None, resnum_alignments=False, rename_ligand_chain=False, substructure_match=False, coverage_delta=0.2, max_symmetries=1e4, lddt_pli_radius=6.0, add_mdl_contacts=True, lddt_pli_thresholds=[0.5, 1.0, 2.0, 4.0], lddt_pli_binding_site_radius=None)
def _lddt_pli_get_trg_data(self, target_ligand, max_r=None)
def _compute_lddt_pli_classic(self, symmetries, target_ligand, model_ligand)
def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains, non_mapped_cache, mdl_bs, mdl_ligand_res, mdl_sym)
def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs, ref_bs)
def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand, model_ligand)
Real DLLEXPORT_OST_GEOM Distance(const Line2 &l, const Vec2 &v)