OpenStructure
ligand_scoring_lddtpli.py
Go to the documentation of this file.
1 import numpy as np
2 
3 from ost import LogWarning, LogInfo
4 from ost import geom
5 from ost import mol
6 from ost import seq
7 
8 from ost.mol.alg import lddt
9 from ost.mol.alg import chain_mapping
10 from ost.mol.alg import ligand_scoring_base
11 
13  """ :class:`LigandScorer` implementing LDDT-PLI.
14 
15  LDDT-PLI is an LDDT score considering contacts between ligand and
16  receptor. Where receptor consists of protein and nucleic acid chains that
17  pass the criteria for :class:`chain mapping <ost.mol.alg.chain_mapping>`.
18  This means ignoring other ligands, waters, short polymers as well as any
19  incorrectly connected chains that may be in proximity.
20 
21  :class:`LDDTPLIScorer` computes a score for a specific pair of target/model
22  ligands. Given a target/model ligand pair, all possible mappings of
23  model chains onto their chemically equivalent target chains are enumerated.
24  For each of these enumerations, all possible symmetries, i.e. atom-atom
25  assignments of the ligand as given by :class:`LigandScorer`, are evaluated
26  and an LDDT-PLI score is computed. The best possible LDDT-PLI score is
27  returned.
28 
29  The LDDT-PLI score is a variant of LDDT with a custom inclusion radius
30  (`lddt_pli_radius`), no stereochemistry checks, and which penalizes
31  contacts added in the model within `lddt_pli_radius` by default
32  (can be changed with the `add_mdl_contacts` flag) but only if the involved
33  atoms can be mapped to the target. This is a requirement to
34  1) extract the respective reference distance from the target
35  2) avoid usage of contacts for which we have no experimental evidence.
36  One special case are contacts from chains that are not mapped to the target
37  binding site. It is very well possible that we have experimental evidence
38  for this chain though its just too far away from the target binding site.
39  We therefore try to map these contacts to the chain in the target with
40  equivalent sequence that is closest to the target binding site. If the
41  respective atoms can be mapped there, the contact is considered not
42  fulfilled and added as penalty.
43 
44  Populates :attr:`LigandScorer.aux_data` with following :class:`dict` keys:
45 
46  * lddt_pli: The LDDT-PLI score
47  * lddt_pli_n_contacts: Number of contacts considered in LDDT computation
48  * target_ligand: The actual target ligand for which the score was computed
49  * model_ligand: The actual model ligand for which the score was computed
50  * chain_mapping: :class:`dict` with a chain mapping of chains involved in
51  binding site - key: trg chain name, value: mdl chain name
52  * bs_ref_res: :class:`set` of residues with potentially non-zero
53  contribution to score. That is every residue with at least one
54  atom within *lddt_pli_radius* + max(*lddt_pli_thresholds*) of
55  the ligand.
56  * bs_mdl_res: Same for model
57 
58  :param model: Passed to parent constructor - see :class:`LigandScorer`.
59  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
60  :param target: Passed to parent constructor - see :class:`LigandScorer`.
61  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
62  :param model_ligands: Passed to parent constructor - see
63  :class:`LigandScorer`.
64  :type model_ligands: :class:`list`
65  :param target_ligands: Passed to parent constructor - see
66  :class:`LigandScorer`.
67  :type target_ligands: :class:`list`
68  :param resnum_alignments: Passed to parent constructor - see
69  :class:`LigandScorer`.
70  :type resnum_alignments: :class:`bool`
71  :param rename_ligand_chain: Passed to parent constructor - see
72  :class:`LigandScorer`.
73  :type rename_ligand_chain: :class:`bool`
74  :param substructure_match: Passed to parent constructor - see
75  :class:`LigandScorer`.
76  :type substructure_match: :class:`bool`
77  :param coverage_delta: Passed to parent constructor - see
78  :class:`LigandScorer`.
79  :type coverage_delta: :class:`float`
80  :param max_symmetries: Passed to parent constructor - see
81  :class:`LigandScorer`.
82  :type max_symmetries: :class:`int`
83  :param lddt_pli_radius: LDDT inclusion radius for LDDT-PLI.
84  :type lddt_pli_radius: :class:`float`
85  :param add_mdl_contacts: Whether to penalize added model contacts.
86  :type add_mdl_contacts: :class:`bool`
87  :param lddt_pli_thresholds: Distance difference thresholds for LDDT.
88  :type lddt_pli_thresholds: :class:`list` of :class:`float`
89  :param lddt_pli_binding_site_radius: Pro param - dont use. Providing a value
90  Restores behaviour from previous
91  implementation that first extracted a
92  binding site with strict distance
93  threshold and computed LDDT-PLI only on
94  those target residues whereas the
95  current implementation includes every
96  atom within *lddt_pli_radius*.
97  :type lddt_pli_binding_site_radius: :class:`float`
98  :param min_pep_length: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`.
99  :type min_pep_length: :class:`int`
100  :param min_nuc_length: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`
101  :type min_nuc_length: :class:`int`
102  :param pep_seqid_thr: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`
103  :type pep_seqid_thr: :class:`float`
104  :param nuc_seqid_thr: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`
105  :type nuc_seqid_thr: :class:`float`
106  :param mdl_map_pep_seqid_thr: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`
107  :type mdl_map_pep_seqid_thr: :class:`float`
108  :param mdl_map_nuc_seqid_thr: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`
109  :type mdl_map_nuc_seqid_thr: :class:`float`
110  """
111 
112  def __init__(self, model, target, model_ligands, target_ligands,
113  resnum_alignments=False, rename_ligand_chain=False,
114  substructure_match=False, coverage_delta=0.2,
115  max_symmetries=1e4, lddt_pli_radius=6.0,
116  add_mdl_contacts=True,
117  lddt_pli_thresholds = [0.5, 1.0, 2.0, 4.0],
118  lddt_pli_binding_site_radius=None,
119  min_pep_length = 6,
120  min_nuc_length = 4, pep_seqid_thr = 95.,
121  nuc_seqid_thr = 95.,
122  mdl_map_pep_seqid_thr = 0.,
123  mdl_map_nuc_seqid_thr = 0.,
124  seqres=None,
125  trg_seqres_mapping=None):
126 
127  super().__init__(model, target, model_ligands, target_ligands,
128  resnum_alignments = resnum_alignments,
129  rename_ligand_chain = rename_ligand_chain,
130  substructure_match = substructure_match,
131  coverage_delta = coverage_delta,
132  max_symmetries = max_symmetries,
133  min_pep_length = min_pep_length,
134  min_nuc_length = min_nuc_length,
135  pep_seqid_thr = pep_seqid_thr,
136  nuc_seqid_thr = nuc_seqid_thr,
137  mdl_map_pep_seqid_thr = mdl_map_pep_seqid_thr,
138  mdl_map_nuc_seqid_thr = mdl_map_nuc_seqid_thr,
139  seqres = seqres,
140  trg_seqres_mapping = trg_seqres_mapping)
141 
142  self.lddt_pli_radiuslddt_pli_radius = lddt_pli_radius
143  self.add_mdl_contactsadd_mdl_contacts = add_mdl_contacts
144  self.lddt_pli_thresholdslddt_pli_thresholds = lddt_pli_thresholds
145  self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius = lddt_pli_binding_site_radius
146 
147  # lazily precomputed variables to speedup lddt-pli computation
148  self._lddt_pli_target_data_lddt_pli_target_data = dict()
149  self._lddt_pli_model_data_lddt_pli_model_data = dict()
150  self.__mappable_atoms__mappable_atoms = None
151 
152  # update state decoding from parent with subclass specific stuff
153  self.state_decodingstate_decoding[10] = ("no_contact",
154  "There were no LDDT contacts between the "
155  "binding site and the ligand, and LDDT-PLI "
156  "is undefined.")
157  self.state_decodingstate_decoding[20] = ("unknown",
158  "Unknown error occured in LDDTPLIScorer")
159 
160  def _compute(self, symmetries, target_ligand, model_ligand):
161  """ Implements interface from parent
162  """
163  if self.add_mdl_contactsadd_mdl_contacts:
164  LogInfo("Computing LDDT-PLI with added model contacts")
165  result = self._compute_lddt_pli_add_mdl_contacts_compute_lddt_pli_add_mdl_contacts(symmetries,
166  target_ligand,
167  model_ligand)
168  else:
169  LogInfo("Computing LDDT-PLI without added model contacts")
170  result = self._compute_lddt_pli_classic_compute_lddt_pli_classic(symmetries,
171  target_ligand,
172  model_ligand)
173 
174  pair_state = 0
175  score = result["lddt_pli"]
176 
177  if score is None or np.isnan(score):
178  if result["lddt_pli_n_contacts"] == 0:
179  # it's a space ship!
180  pair_state = 10
181  else:
182  # unknwon error state
183  pair_state = 20
184 
185  # the ligands get a zero-state...
186  target_ligand_state = 0
187  model_ligand_state = 0
188 
189  return (score, pair_state, target_ligand_state, model_ligand_state,
190  result)
191 
192  def _score_dir(self):
193  """ Implements interface from parent
194  """
195  return '+'
196 
197  def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand,
198  model_ligand):
199 
200 
203 
204  trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
205  trg_ligand_res, scorer, chem_groups = \
206  self._lddt_pli_get_trg_data_lddt_pli_get_trg_data(target_ligand)
207 
208  trg_bs_center = trg_bs.geometric_center
209 
210  # Copy to make sure that we don't change anything on underlying
211  # references
212  # This is not strictly necessary in the current implementation but
213  # hey, maybe it avoids hard to debug errors when someone changes things
214  ref_indices = [a.copy() for a in scorer.ref_indices_ic]
215  ref_distances = [a.copy() for a in scorer.ref_distances_ic]
216 
217  # distance hacking... remove any interchain distance except the ones
218  # with the ligand
219  ligand_start_idx = scorer.chain_start_indices[-1]
220  for at_idx in range(ligand_start_idx):
221  mask = ref_indices[at_idx] >= ligand_start_idx
222  ref_indices[at_idx] = ref_indices[at_idx][mask]
223  ref_distances[at_idx] = ref_distances[at_idx][mask]
224 
225  mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
226  chem_mapping = self._lddt_pli_get_mdl_data_lddt_pli_get_mdl_data(model_ligand)
227 
228 
231 
232  # ref_mdl_alns refers to full chain mapper trg and mdl structures
233  # => need to adapt mdl sequence that only contain residues in contact
234  # with ligand
235  cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns_lddt_pli_cut_ref_mdl_alns(chem_groups,
236  chem_mapping,
237  mdl_bs, trg_bs)
238 
239 
242 
243  # get each chain mapping that we ever observe in scoring
244  chain_mappings = list(chain_mapping._ChainMappings(chem_groups,
245  chem_mapping))
246 
247  # for each mdl ligand atom, we collect all trg ligand atoms that are
248  # ever mapped onto it given *symmetries*
249  ligand_atom_mappings = [set() for a in mdl_ligand_res.atoms]
250  for (trg_sym, mdl_sym) in symmetries:
251  for trg_i, mdl_i in zip(trg_sym, mdl_sym):
252  ligand_atom_mappings[mdl_i].add(trg_i)
253 
254  mdl_ligand_pos = np.zeros((mdl_ligand_res.GetAtomCount(), 3))
255  for a_idx, a in enumerate(mdl_ligand_res.atoms):
256  p = a.GetPos()
257  mdl_ligand_pos[a_idx, 0] = p[0]
258  mdl_ligand_pos[a_idx, 1] = p[1]
259  mdl_ligand_pos[a_idx, 2] = p[2]
260 
261  trg_ligand_pos = np.zeros((trg_ligand_res.GetAtomCount(), 3))
262  for a_idx, a in enumerate(trg_ligand_res.atoms):
263  p = a.GetPos()
264  trg_ligand_pos[a_idx, 0] = p[0]
265  trg_ligand_pos[a_idx, 1] = p[1]
266  trg_ligand_pos[a_idx, 2] = p[2]
267 
268  mdl_lig_hashes = [a.hash_code for a in mdl_ligand_res.atoms]
269 
270  symmetric_atoms = np.asarray(sorted(list(scorer.symmetric_atoms)),
271  dtype=np.int64)
272 
273  # two caches to cache things for each chain mapping => lists
274  # of len(chain_mappings)
275  #
276  # In principle we're caching for each trg/mdl ligand atom pair all
277  # information to update ref_indices/ref_distances and resolving the
278  # symmetries of the binding site.
279  # in detail: each list entry in *scoring_cache* is a dict with
280  # key: (mdl_lig_at_idx, trg_lig_at_idx)
281  # value: tuple with 4 elements - 1: indices of atoms representing added
282  # contacts relative to overall inexing scheme in scorer 2: the
283  # respective distances 3: the same but only containing indices towards
284  # atoms of the binding site that are considered symmetric 4: the
285  # respective indices.
286  # each list entry in *penalty_cache* is a list of len N mdl lig atoms.
287  # For each mdl lig at it contains a penalty for this mdl lig at. That
288  # means the number of contacts in the mdl binding site that can
289  # directly be mapped to the target given the local chain mapping but
290  # are not present in the target binding site, i.e. interacting atoms are
291  # too far away.
292  scoring_cache = list()
293  penalty_cache = list()
294 
295  for mapping in chain_mappings:
296 
297  # flat mapping with mdl chain names as key
298  flat_mapping = dict()
299  for trg_chem_group, mdl_chem_group in zip(chem_groups, mapping):
300  for a,b in zip(trg_chem_group, mdl_chem_group):
301  if a is not None and b is not None:
302  flat_mapping[b] = a
303 
304  # for each mdl bs atom (as atom hash), the trg bs atoms (as index in
305  # scorer)
306  bs_atom_mapping = dict()
307  for mdl_cname, ref_cname in flat_mapping.items():
308  aln = cut_ref_mdl_alns[(ref_cname, mdl_cname)]
309  ref_ch = trg_bs.Select(f"cname={mol.QueryQuoteName(ref_cname)}")
310  mdl_ch = mdl_bs.Select(f"cname={mol.QueryQuoteName(mdl_cname)}")
311  aln.AttachView(0, ref_ch)
312  aln.AttachView(1, mdl_ch)
313  for col in aln:
314  ref_r = col.GetResidue(0)
315  mdl_r = col.GetResidue(1)
316  if ref_r.IsValid() and mdl_r.IsValid():
317  for mdl_a in mdl_r.atoms:
318  ref_a = ref_r.FindAtom(mdl_a.GetName())
319  if ref_a.IsValid():
320  ref_h = ref_a.handle.hash_code
321  if ref_h in scorer.atom_indices:
322  mdl_h = mdl_a.handle.hash_code
323  bs_atom_mapping[mdl_h] = \
324  scorer.atom_indices[ref_h]
325 
326  cache = dict()
327  n_penalties = list()
328 
329  for mdl_a_idx, mdl_a in enumerate(mdl_ligand_res.atoms):
330  n_penalty = 0
331  trg_bs_indices = list()
332  close_a = mdl_bs.FindWithin(mdl_a.GetPos(),
333  self.lddt_pli_radiuslddt_pli_radius)
334  for a in close_a:
335  mdl_a_hash_code = a.hash_code
336  if mdl_a_hash_code in bs_atom_mapping:
337  trg_bs_indices.append(bs_atom_mapping[mdl_a_hash_code])
338  elif mdl_a_hash_code not in mdl_lig_hashes:
339  if a.GetChain().GetName() in flat_mapping:
340  # Its in a mapped chain
341  at_key = (a.GetResidue().GetNumber(), a.name)
342  cname = a.GetChain().name
343  cname_key = (flat_mapping[cname], cname)
344  if at_key in self._mappable_atoms_mappable_atoms[cname_key]:
345  # Its a contact in the model but not part of
346  # trg_bs. It can still be mapped using the
347  # global mdl_ch/ref_ch alignment
348  # d in ref > self.lddt_pli_radius + max_thresh
349  # => guaranteed to be non-fulfilled contact
350  n_penalty += 1
351 
352  n_penalties.append(n_penalty)
353 
354  trg_bs_indices = np.asarray(sorted(trg_bs_indices))
355 
356  for trg_a_idx in ligand_atom_mappings[mdl_a_idx]:
357  # mask selects entries in trg_bs_indices that are not yet
358  # part of classic LDDT ref_indices for atom at trg_a_idx
359  # => added mdl contacts
360  mask = np.isin(trg_bs_indices,
361  ref_indices[ligand_start_idx + trg_a_idx],
362  assume_unique=True, invert=True)
363  added_indices = np.asarray([], dtype=np.int64)
364  added_distances = np.asarray([], dtype=np.float64)
365  if np.sum(mask) > 0:
366  # compute ref distances on reference positions
367  added_indices = trg_bs_indices[mask]
368  tmp = scorer.positions.take(added_indices, axis=0)
369  np.subtract(tmp, trg_ligand_pos[trg_a_idx][None, :],
370  out=tmp)
371  np.square(tmp, out=tmp)
372  tmp = tmp.sum(axis=1)
373  np.sqrt(tmp, out=tmp)
374  added_distances = tmp
375 
376  # extract the distances towards bs atoms that are symmetric
377  sym_mask = np.isin(added_indices, symmetric_atoms,
378  assume_unique=True)
379 
380  cache[(mdl_a_idx, trg_a_idx)] = (added_indices,
381  added_distances,
382  added_indices[sym_mask],
383  added_distances[sym_mask])
384 
385  scoring_cache.append(cache)
386  penalty_cache.append(n_penalties)
387 
388  # cache for model contacts towards non mapped trg chains - this is
389  # relevant for self._lddt_pli_unmapped_chain_penalty
390  # key: tuple in form (trg_ch, mdl_ch)
391  # value: yet another dict with
392  # key: ligand_atom_hash
393  # value: n contacts towards respective trg chain that can be mapped
394  non_mapped_cache = dict()
395 
396 
399 
400  best_score = -1.0
401  best_result = {"lddt_pli": None,
402  "lddt_pli_n_contacts": 0,
403  "chain_mapping": None}
404 
405  # dummy alignment for ligand chains which is needed as input later on
406  ligand_aln = seq.CreateAlignment()
407  trg_s = seq.CreateSequence(trg_ligand_chain.name,
408  trg_ligand_res.GetOneLetterCode())
409  mdl_s = seq.CreateSequence(mdl_ligand_chain.name,
410  mdl_ligand_res.GetOneLetterCode())
411  ligand_aln.AddSequence(trg_s)
412  ligand_aln.AddSequence(mdl_s)
413  ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
414 
415  sym_idx_collector = [None] * scorer.n_atoms
416  sym_dist_collector = [None] * scorer.n_atoms
417 
418  for mapping, s_cache, p_cache in zip(chain_mappings, scoring_cache,
419  penalty_cache):
420 
421  lddt_chain_mapping = dict()
422  lddt_alns = dict()
423  for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
424  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
425  # some mdl chains can be None
426  if mdl_ch is not None:
427  lddt_chain_mapping[mdl_ch] = ref_ch
428  lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
429 
430  # add ligand to lddt_chain_mapping/lddt_alns
431  lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
432  lddt_alns[mdl_ligand_chain.name] = ligand_aln
433 
434  # already process model, positions will be manually hacked for each
435  # symmetry - small overhead for variables that are thrown away here
436  pos, _, _, _, _, _, lddt_symmetries = \
437  scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
438  residue_mapping = lddt_alns,
439  nirvana_dist = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds),
440  check_resnames = False)
441 
442  # estimate a penalty for unsatisfied model contacts from chains
443  # that are not in the local trg binding site, but can be mapped in
444  # the target.
445  # We're using the trg chain with the closest geometric center to
446  # the trg binding site that can be mapped to the mdl chain
447  # according the chem mapping. An alternative would be to search for
448  # the target chain with the minimal number of additional contacts.
449  # There is not good solution for this problem...
450  unmapped_chains = list()
451  already_mapped = set()
452  for mdl_ch in mdl_chains:
453  if mdl_ch not in lddt_chain_mapping:
454 
455  if mdl_ch in self._mdl_chains_without_chem_mapping_mdl_chains_without_chem_mapping:
456  # this mdl chain does not map to any trg chain
457  continue
458 
459  # check which chain in trg is closest
460  chem_grp_idx = None
461  for i, m in enumerate(self._chem_mapping_chem_mapping):
462  if mdl_ch in m:
463  chem_grp_idx = i
464  break
465  if chem_grp_idx is None:
466  raise RuntimeError("This should never happen... "
467  "ask Gabriel...")
468  closest_ch = None
469  closest_dist = None
470  for trg_ch in self._chain_mapper_chain_mapper.chem_groups[chem_grp_idx]:
471  if trg_ch not in lddt_chain_mapping.values():
472  if trg_ch not in already_mapped:
473  ch = self._chain_mapper_chain_mapper.target.FindChain(trg_ch)
474  c = ch.geometric_center
475  d = geom.Distance(trg_bs_center, c)
476  if closest_dist is None or d < closest_dist:
477  closest_dist = d
478  closest_ch = trg_ch
479  if closest_ch is not None:
480  unmapped_chains.append((closest_ch, mdl_ch))
481  already_mapped.add(closest_ch)
482 
483  for (trg_sym, mdl_sym) in symmetries:
484 
485  # update positions
486  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
487  pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
488 
489  # start new ref_indices/ref_distances from original values
490  funky_ref_indices = [np.copy(a) for a in ref_indices]
491  funky_ref_distances = [np.copy(a) for a in ref_distances]
492 
493  # The only distances from the binding site towards the ligand
494  # we care about are the ones from the symmetric atoms to
495  # correctly compute scorer._ResolveSymmetries.
496  # We collect them while updating distances from added mdl
497  # contacts
498  for idx in symmetric_atoms:
499  sym_idx_collector[idx] = list()
500  sym_dist_collector[idx] = list()
501 
502  # add data from added mdl contacts cache
503  added_penalty = 0
504  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
505  added_penalty += p_cache[mdl_i]
506  cache = s_cache[mdl_i, trg_i]
507  full_trg_i = ligand_start_idx + trg_i
508  funky_ref_indices[full_trg_i] = \
509  np.append(funky_ref_indices[full_trg_i], cache[0])
510  funky_ref_distances[full_trg_i] = \
511  np.append(funky_ref_distances[full_trg_i], cache[1])
512  for idx, d in zip(cache[2], cache[3]):
513  sym_idx_collector[idx].append(full_trg_i)
514  sym_dist_collector[idx].append(d)
515 
516  for idx in symmetric_atoms:
517  funky_ref_indices[idx] = \
518  np.append(funky_ref_indices[idx],
519  np.asarray(sym_idx_collector[idx],
520  dtype=np.int64))
521  funky_ref_distances[idx] = \
522  np.append(funky_ref_distances[idx],
523  np.asarray(sym_dist_collector[idx],
524  dtype=np.float64))
525 
526  # we can pass funky_ref_indices/funky_ref_distances as
527  # sym_ref_indices/sym_ref_distances in
528  # scorer._ResolveSymmetries as we only have distances of the bs
529  # to the ligand and ligand atoms are "non-symmetric"
530  scorer._ResolveSymmetries(pos, self.lddt_pli_thresholdslddt_pli_thresholds,
531  lddt_symmetries,
532  funky_ref_indices,
533  funky_ref_distances)
534 
535  N = sum([len(funky_ref_indices[i]) for i in ligand_at_indices])
536  N += added_penalty
537 
538  # collect number of expected contacts which can be mapped
539  if len(unmapped_chains) > 0:
540  N += self._lddt_pli_unmapped_chain_penalty_lddt_pli_unmapped_chain_penalty(unmapped_chains,
541  non_mapped_cache,
542  mdl_bs,
543  mdl_ligand_res,
544  mdl_sym)
545 
546  conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
547  self.lddt_pli_thresholdslddt_pli_thresholds,
548  funky_ref_indices,
549  funky_ref_distances),
550  axis=0)
551  score = None
552  if N > 0:
553  score = np.mean(conserved/N)
554 
555  if score is not None and score > best_score:
556  best_score = score
557  save_chain_mapping = dict(lddt_chain_mapping)
558  del save_chain_mapping[mdl_ligand_chain.name]
559  best_result = {"lddt_pli": score,
560  "lddt_pli_n_contacts": N,
561  "chain_mapping": save_chain_mapping}
562 
563  # fill misc info to result object
564  best_result["target_ligand"] = target_ligand
565  best_result["model_ligand"] = model_ligand
566  best_result["bs_ref_res"] = trg_residues
567  best_result["bs_mdl_res"] = mdl_residues
568 
569  return best_result
570 
571 
572  def _compute_lddt_pli_classic(self, symmetries, target_ligand,
573  model_ligand):
574 
575 
578 
579  max_r = None
580  if self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius:
581  max_r = self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius
582 
583  trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
584  trg_ligand_res, scorer, chem_groups = \
585  self._lddt_pli_get_trg_data_lddt_pli_get_trg_data(target_ligand, max_r = max_r)
586 
587  # Copy to make sure that we don't change anything on underlying
588  # references
589  # This is not strictly necessary in the current implementation but
590  # hey, maybe it avoids hard to debug errors when someone changes things
591  ref_indices = [a.copy() for a in scorer.ref_indices_ic]
592  ref_distances = [a.copy() for a in scorer.ref_distances_ic]
593 
594  # no matter what mapping/symmetries, the number of expected
595  # contacts stays the same
596  ligand_start_idx = scorer.chain_start_indices[-1]
597  ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
598  n_exp = sum([len(ref_indices[i]) for i in ligand_at_indices])
599 
600  mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
601  chem_mapping = self._lddt_pli_get_mdl_data_lddt_pli_get_mdl_data(model_ligand)
602 
603  if n_exp == 0:
604  # no contacts... nothing to compute...
605  return {"lddt_pli": None,
606  "lddt_pli_n_contacts": 0,
607  "chain_mapping": None,
608  "target_ligand": target_ligand,
609  "model_ligand": model_ligand,
610  "bs_ref_res": trg_residues,
611  "bs_mdl_res": mdl_residues}
612 
613  # Distance hacking... remove any interchain distance except the ones
614  # with the ligand
615  for at_idx in range(ligand_start_idx):
616  mask = ref_indices[at_idx] >= ligand_start_idx
617  ref_indices[at_idx] = ref_indices[at_idx][mask]
618  ref_distances[at_idx] = ref_distances[at_idx][mask]
619 
620 
623 
624  # ref_mdl_alns refers to full chain mapper trg and mdl structures
625  # => need to adapt mdl sequence that only contain residues in contact
626  # with ligand
627  cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns_lddt_pli_cut_ref_mdl_alns(chem_groups,
628  chem_mapping,
629  mdl_bs, trg_bs)
630 
631 
634 
635  best_score = -1.0
636 
637  # dummy alignment for ligand chains which is needed as input later on
638  l_aln = seq.CreateAlignment()
639  l_aln.AddSequence(seq.CreateSequence(trg_ligand_chain.name,
640  trg_ligand_res.GetOneLetterCode()))
641  l_aln.AddSequence(seq.CreateSequence(mdl_ligand_chain.name,
642  mdl_ligand_res.GetOneLetterCode()))
643 
644  mdl_ligand_pos = np.zeros((model_ligand.GetAtomCount(), 3))
645  for a_idx, a in enumerate(model_ligand.atoms):
646  p = a.GetPos()
647  mdl_ligand_pos[a_idx, 0] = p[0]
648  mdl_ligand_pos[a_idx, 1] = p[1]
649  mdl_ligand_pos[a_idx, 2] = p[2]
650 
651  for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping):
652 
653  lddt_chain_mapping = dict()
654  lddt_alns = dict()
655  for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
656  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
657  # some mdl chains can be None
658  if mdl_ch is not None:
659  lddt_chain_mapping[mdl_ch] = ref_ch
660  lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
661 
662  # add ligand to lddt_chain_mapping/lddt_alns
663  lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
664  lddt_alns[mdl_ligand_chain.name] = l_aln
665 
666  # already process model, positions will be manually hacked for each
667  # symmetry - small overhead for variables that are thrown away here
668  pos, _, _, _, _, _, lddt_symmetries = \
669  scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
670  residue_mapping = lddt_alns,
671  nirvana_dist = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds),
672  check_resnames = False)
673 
674  for (trg_sym, mdl_sym) in symmetries:
675  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
676  pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
677  # we can pass ref_indices/ref_distances as
678  # sym_ref_indices/sym_ref_distances in
679  # scorer._ResolveSymmetries as we only have distances of the bs
680  # to the ligand and ligand atoms are "non-symmetric"
681  scorer._ResolveSymmetries(pos, self.lddt_pli_thresholdslddt_pli_thresholds,
682  lddt_symmetries,
683  ref_indices,
684  ref_distances)
685  # compute number of conserved distances for ligand atoms
686  conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
687  self.lddt_pli_thresholdslddt_pli_thresholds,
688  ref_indices,
689  ref_distances), axis=0)
690  score = np.mean(conserved/n_exp)
691 
692  if score > best_score:
693  best_score = score
694  save_chain_mapping = dict(lddt_chain_mapping)
695  del save_chain_mapping[mdl_ligand_chain.name]
696  best_result = {"lddt_pli": score,
697  "chain_mapping": save_chain_mapping}
698 
699  # fill misc info to result object
700  best_result["lddt_pli_n_contacts"] = n_exp
701  best_result["target_ligand"] = target_ligand
702  best_result["model_ligand"] = model_ligand
703  best_result["bs_ref_res"] = trg_residues
704  best_result["bs_mdl_res"] = mdl_residues
705 
706  return best_result
707 
708  def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains,
709  non_mapped_cache,
710  mdl_bs,
711  mdl_ligand_res,
712  mdl_sym):
713 
714  n_exp = 0
715  for ch_tuple in unmapped_chains:
716  if ch_tuple not in non_mapped_cache:
717  # for each ligand atom, we count the number of mappable atoms
718  # within lddt_pli_radius
719  counts = dict()
720  # the select statement also excludes the ligand in mdl_bs
721  # as it resides in a separate chain
722  mdl_cname = ch_tuple[1]
723  query = "cname=" + mol.QueryQuoteName(mdl_cname)
724  mdl_bs_ch = mdl_bs.Select(query)
725  for a in mdl_ligand_res.atoms:
726  close_atoms = \
727  mdl_bs_ch.FindWithin(a.GetPos(), self.lddt_pli_radiuslddt_pli_radius)
728  N = 0
729  for close_a in close_atoms:
730  at_key = (close_a.GetResidue().GetNumber(),
731  close_a.GetName())
732  if at_key in self._mappable_atoms_mappable_atoms[ch_tuple]:
733  N += 1
734  counts[a.hash_code] = N
735 
736  # fill cache
737  non_mapped_cache[ch_tuple] = counts
738 
739  # add number of mdl contacts which can be mapped to target
740  # as non-fulfilled contacts
741  counts = non_mapped_cache[ch_tuple]
742  lig_hash_codes = [a.hash_code for a in mdl_ligand_res.atoms]
743  for i in mdl_sym:
744  n_exp += counts[lig_hash_codes[i]]
745 
746  return n_exp
747 
748 
749  def _lddt_pli_get_mdl_data(self, model_ligand):
750  if model_ligand not in self._lddt_pli_model_data_lddt_pli_model_data:
751 
752  mdl = self._chain_mapping_mdl_chain_mapping_mdl
753 
754  mdl_residues = set()
755  for at in model_ligand.atoms:
756  close_atoms = mdl.FindWithin(at.GetPos(), self.lddt_pli_radiuslddt_pli_radius)
757  for close_at in close_atoms:
758  mdl_residues.add(close_at.GetResidue())
759 
760  max_r = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds)
761  for r in mdl.residues:
762  r.SetIntProp("bs", 0)
763  for at in model_ligand.atoms:
764  close_atoms = mdl.FindWithin(at.GetPos(), max_r)
765  for close_at in close_atoms:
766  close_at.GetResidue().SetIntProp("bs", 1)
767 
768  mdl_bs = mol.CreateEntityFromView(mdl.Select("grbs:0=1"), True)
769  mdl_chains = set([ch.name for ch in mdl_bs.chains])
770 
771  mdl_editor = mdl_bs.EditXCS(mol.BUFFERED_EDIT)
772  mdl_ligand_chain = None
773  for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
774  try:
775  # I'm pretty sure, one of these chain names is not there...
776  mdl_ligand_chain = mdl_editor.InsertChain(cname)
777  break
778  except:
779  pass
780  if mdl_ligand_chain is None:
781  raise RuntimeError("Fuck this, I'm out...")
782  mdl_ligand_res = mdl_editor.AppendResidue(mdl_ligand_chain,
783  model_ligand,
784  deep=True)
785  mdl_editor.RenameResidue(mdl_ligand_res, "LIG")
786  mdl_editor.SetResidueNumber(mdl_ligand_res, mol.ResNum(1))
787 
788  chem_mapping = list()
789  for m in self._chem_mapping_chem_mapping:
790  chem_mapping.append([x for x in m if x in mdl_chains])
791 
792  self._lddt_pli_model_data_lddt_pli_model_data[model_ligand] = (mdl_residues,
793  mdl_bs,
794  mdl_chains,
795  mdl_ligand_chain,
796  mdl_ligand_res,
797  chem_mapping)
798 
799  return self._lddt_pli_model_data_lddt_pli_model_data[model_ligand]
800 
801 
802  def _lddt_pli_get_trg_data(self, target_ligand, max_r = None):
803  if target_ligand not in self._lddt_pli_target_data_lddt_pli_target_data:
804 
805  trg = self._chain_mapper_chain_mapper.target
806 
807  if max_r is None:
808  max_r = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds)
809 
810  trg_residues = set()
811  for at in target_ligand.atoms:
812  close_atoms = trg.FindWithin(at.GetPos(), max_r)
813  for close_at in close_atoms:
814  trg_residues.add(close_at.GetResidue())
815 
816  for r in trg.residues:
817  r.SetIntProp("bs", 0)
818 
819  for r in trg_residues:
820  r.SetIntProp("bs", 1)
821 
822  trg_bs = mol.CreateEntityFromView(trg.Select("grbs:0=1"), True)
823  trg_chains = set([ch.name for ch in trg_bs.chains])
824 
825  trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT)
826  trg_ligand_chain = None
827  for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
828  try:
829  # I'm pretty sure, one of these chain names is not there yet
830  trg_ligand_chain = trg_editor.InsertChain(cname)
831  break
832  except:
833  pass
834  if trg_ligand_chain is None:
835  raise RuntimeError("Fuck this, I'm out...")
836 
837  trg_ligand_res = trg_editor.AppendResidue(trg_ligand_chain,
838  target_ligand,
839  deep=True)
840  trg_editor.RenameResidue(trg_ligand_res, "LIG")
841  trg_editor.SetResidueNumber(trg_ligand_res, mol.ResNum(1))
842 
843  compound_name = trg_ligand_res.name
844  compound = lddt.CustomCompound.FromResidue(trg_ligand_res)
845  custom_compounds = {compound_name: compound}
846 
847  scorer = lddt.lDDTScorer(trg_bs,
848  custom_compounds = custom_compounds,
849  inclusion_radius = self.lddt_pli_radiuslddt_pli_radius)
850 
851  chem_groups = list()
852  for g in self._chain_mapper_chain_mapper.chem_groups:
853  chem_groups.append([x for x in g if x in trg_chains])
854 
855  self._lddt_pli_target_data_lddt_pli_target_data[target_ligand] = (trg_residues,
856  trg_bs,
857  trg_chains,
858  trg_ligand_chain,
859  trg_ligand_res,
860  scorer,
861  chem_groups)
862 
863  return self._lddt_pli_target_data_lddt_pli_target_data[target_ligand]
864 
865 
866  def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs,
867  ref_bs):
868  cut_ref_mdl_alns = dict()
869  for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping):
870  for ref_ch in ref_chem_group:
871 
872  ref_bs_chain = ref_bs.FindChain(ref_ch)
873  query = "cname=" + mol.QueryQuoteName(ref_ch)
874  ref_view = self._chain_mapper_chain_mapper.target.Select(query)
875 
876  for mdl_ch in mdl_chem_group:
877  aln = self._ref_mdl_alns_ref_mdl_alns[(ref_ch, mdl_ch)]
878 
879  aln.AttachView(0, ref_view)
880 
881  mdl_bs_chain = mdl_bs.FindChain(mdl_ch)
882  query = "cname=" + mol.QueryQuoteName(mdl_ch)
883  aln.AttachView(1, self._chain_mapping_mdl_chain_mapping_mdl.Select(query))
884 
885  cut_mdl_seq = ['-'] * aln.GetLength()
886  cut_ref_seq = ['-'] * aln.GetLength()
887  for i, col in enumerate(aln):
888 
889  # check ref residue
890  r = col.GetResidue(0)
891  if r.IsValid():
892  bs_r = ref_bs_chain.FindResidue(r.GetNumber())
893  if bs_r.IsValid():
894  cut_ref_seq[i] = col[0]
895 
896  # check mdl residue
897  r = col.GetResidue(1)
898  if r.IsValid():
899  bs_r = mdl_bs_chain.FindResidue(r.GetNumber())
900  if bs_r.IsValid():
901  cut_mdl_seq[i] = col[1]
902 
903  cut_ref_seq = ''.join(cut_ref_seq)
904  cut_mdl_seq = ''.join(cut_mdl_seq)
905  cut_aln = seq.CreateAlignment()
906  cut_aln.AddSequence(seq.CreateSequence(ref_ch, cut_ref_seq))
907  cut_aln.AddSequence(seq.CreateSequence(mdl_ch, cut_mdl_seq))
908  cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln
909  return cut_ref_mdl_alns
910 
911  @property
912  def _mappable_atoms(self):
913  """ Stores mappable atoms given a chain mapping
914 
915  Store for each ref_ch,mdl_ch pair all mdl atoms that can be
916  mapped. Don't store mappable atoms as hashes but rather as tuple
917  (mdl_r.GetNumber(), mdl_a.GetName()). Reason for that is that one might
918  operate on Copied EntityHandle objects without corresponding hashes.
919  Given a tuple defining c_pair: (ref_cname, mdl_cname), one
920  can check if a certain atom is mappable by evaluating:
921  if (mdl_r.GetNumber(), mdl_a.GetName()) in self._mappable_atoms(c_pair)
922  """
923  if self.__mappable_atoms__mappable_atoms is None:
924  self.__mappable_atoms__mappable_atoms = dict()
925  for (ref_cname, mdl_cname), aln in self._ref_mdl_alns_ref_mdl_alns.items():
926  self._mappable_atoms_mappable_atoms[(ref_cname, mdl_cname)] = set()
927  ref_query = f"cname={mol.QueryQuoteName(ref_cname)}"
928  mdl_query = f"cname={mol.QueryQuoteName(mdl_cname)}"
929  ref_ch = self._chain_mapper_chain_mapper.target.Select(ref_query)
930  mdl_ch = self._chain_mapping_mdl_chain_mapping_mdl.Select(mdl_query)
931  aln.AttachView(0, ref_ch)
932  aln.AttachView(1, mdl_ch)
933  for col in aln:
934  ref_r = col.GetResidue(0)
935  mdl_r = col.GetResidue(1)
936  if ref_r.IsValid() and mdl_r.IsValid():
937  for mdl_a in mdl_r.atoms:
938  if ref_r.FindAtom(mdl_a.name).IsValid():
939  c_key = (ref_cname, mdl_cname)
940  at_key = (mdl_r.GetNumber(), mdl_a.name)
941  self.__mappable_atoms__mappable_atoms[c_key].add(at_key)
942 
943  return self.__mappable_atoms__mappable_atoms
944 
945 # specify public interface
946 __all__ = ('LDDTPLIScorer',)
def __init__(self, model, target, model_ligands, target_ligands, resnum_alignments=False, rename_ligand_chain=False, substructure_match=False, coverage_delta=0.2, max_symmetries=1e4, lddt_pli_radius=6.0, add_mdl_contacts=True, lddt_pli_thresholds=[0.5, 1.0, 2.0, 4.0], lddt_pli_binding_site_radius=None, min_pep_length=6, min_nuc_length=4, pep_seqid_thr=95., nuc_seqid_thr=95., mdl_map_pep_seqid_thr=0., mdl_map_nuc_seqid_thr=0., seqres=None, trg_seqres_mapping=None)
def _lddt_pli_get_trg_data(self, target_ligand, max_r=None)
def _compute_lddt_pli_classic(self, symmetries, target_ligand, model_ligand)
def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains, non_mapped_cache, mdl_bs, mdl_ligand_res, mdl_sym)
def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs, ref_bs)
def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand, model_ligand)
Real DLLEXPORT_OST_GEOM Distance(const Line2 &l, const Vec2 &v)