OpenStructure
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
ligand_scoring_lddtpli.py
Go to the documentation of this file.
1 import numpy as np
2 
3 from ost import LogWarning, LogInfo
4 from ost import geom
5 from ost import mol
6 from ost import seq
7 
8 from ost.mol.alg import lddt
9 from ost.mol.alg import chain_mapping
10 from ost.mol.alg import ligand_scoring_base
11 
13  """ :class:`LigandScorer` implementing LDDT-PLI.
14 
15  LDDT-PLI is an LDDT score considering contacts between ligand and
16  receptor. Where receptor consists of protein and nucleic acid chains that
17  pass the criteria for :class:`chain mapping <ost.mol.alg.chain_mapping>`.
18  This means ignoring other ligands, waters, short polymers as well as any
19  incorrectly connected chains that may be in proximity.
20 
21  :class:`LDDTPLIScorer` computes a score for a specific pair of target/model
22  ligands. Given a target/model ligand pair, all possible mappings of
23  model chains onto their chemically equivalent target chains are enumerated.
24  For each of these enumerations, all possible symmetries, i.e. atom-atom
25  assignments of the ligand as given by :class:`LigandScorer`, are evaluated
26  and an LDDT-PLI score is computed. The best possible LDDT-PLI score is
27  returned.
28 
29  The LDDT-PLI score is a variant of LDDT with a custom inclusion radius
30  (`lddt_pli_radius`), no stereochemistry checks, and which penalizes
31  contacts added in the model within `lddt_pli_radius` by default
32  (can be changed with the `add_mdl_contacts` flag) but only if the involved
33  atoms can be mapped to the target. This is a requirement to
34  1) extract the respective reference distance from the target
35  2) avoid usage of contacts for which we have no experimental evidence.
36  One special case are contacts from chains that are not mapped to the target
37  binding site. It is very well possible that we have experimental evidence
38  for this chain though its just too far away from the target binding site.
39  We therefore try to map these contacts to the chain in the target with
40  equivalent sequence that is closest to the target binding site. If the
41  respective atoms can be mapped there, the contact is considered not
42  fulfilled and added as penalty.
43 
44  Populates :attr:`LigandScorer.aux_data` with following :class:`dict` keys:
45 
46  * lddt_pli: The LDDT-PLI score
47  * lddt_pli_n_contacts: Number of contacts considered in LDDT computation
48  * target_ligand: The actual target ligand for which the score was computed
49  * model_ligand: The actual model ligand for which the score was computed
50  * chain_mapping: :class:`dict` with a chain mapping of chains involved in
51  binding site - key: trg chain name, value: mdl chain name
52  * bs_ref_res: :class:`set` of residues with potentially non-zero
53  contribution to score. That is every residue with at least one
54  atom within *lddt_pli_radius* + max(*lddt_pli_thresholds*) of
55  the ligand.
56  * bs_mdl_res: Same for model
57 
58  :param model: Passed to parent constructor - see :class:`LigandScorer`.
59  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
60  :param target: Passed to parent constructor - see :class:`LigandScorer`.
61  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
62  :param model_ligands: Passed to parent constructor - see
63  :class:`LigandScorer`.
64  :type model_ligands: :class:`list`
65  :param target_ligands: Passed to parent constructor - see
66  :class:`LigandScorer`.
67  :type target_ligands: :class:`list`
68  :param resnum_alignments: Passed to parent constructor - see
69  :class:`LigandScorer`.
70  :type resnum_alignments: :class:`bool`
71  :param rename_ligand_chain: Passed to parent constructor - see
72  :class:`LigandScorer`.
73  :type rename_ligand_chain: :class:`bool`
74  :param substructure_match: Passed to parent constructor - see
75  :class:`LigandScorer`.
76  :type substructure_match: :class:`bool`
77  :param coverage_delta: Passed to parent constructor - see
78  :class:`LigandScorer`.
79  :type coverage_delta: :class:`float`
80  :param max_symmetries: Passed to parent constructor - see
81  :class:`LigandScorer`.
82  :type max_symmetries: :class:`int`
83  :param lddt_pli_radius: LDDT inclusion radius for LDDT-PLI.
84  :type lddt_pli_radius: :class:`float`
85  :param add_mdl_contacts: Whether to penalize added model contacts.
86  :type add_mdl_contacts: :class:`bool`
87  :param lddt_pli_thresholds: Distance difference thresholds for LDDT.
88  :type lddt_pli_thresholds: :class:`list` of :class:`float`
89  :param lddt_pli_binding_site_radius: Pro param - dont use. Providing a value
90  Restores behaviour from previous
91  implementation that first extracted a
92  binding site with strict distance
93  threshold and computed LDDT-PLI only on
94  those target residues whereas the
95  current implementation includes every
96  atom within *lddt_pli_radius*.
97  :type lddt_pli_binding_site_radius: :class:`float`
98  :param min_pep_length: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`.
99  :type min_pep_length: :class:`int`
100  :param min_nuc_length: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`
101  :type min_nuc_length: :class:`int`
102  :param pep_seqid_thr: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`
103  :type pep_seqid_thr: :class:`float`
104  :param nuc_seqid_thr: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`
105  :type nuc_seqid_thr: :class:`float`
106  :param mdl_map_pep_seqid_thr: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`
107  :type mdl_map_pep_seqid_thr: :class:`float`
108  :param mdl_map_nuc_seqid_thr: See :class:`ost.mol.alg.ligand_scoring_base.LigandScorer`
109  :type mdl_map_nuc_seqid_thr: :class:`float`
110  """
111 
112  def __init__(self, model, target, model_ligands, target_ligands,
113  resnum_alignments=False, rename_ligand_chain=False,
114  substructure_match=False, coverage_delta=0.2,
115  max_symmetries=1e4, lddt_pli_radius=6.0,
116  add_mdl_contacts=True,
117  lddt_pli_thresholds = [0.5, 1.0, 2.0, 4.0],
118  lddt_pli_binding_site_radius=None,
119  min_pep_length = 6,
120  min_nuc_length = 4, pep_seqid_thr = 95.,
121  nuc_seqid_thr = 95.,
122  mdl_map_pep_seqid_thr = 0.,
123  mdl_map_nuc_seqid_thr = 0.):
124 
125  super().__init__(model, target, model_ligands, target_ligands,
126  resnum_alignments = resnum_alignments,
127  rename_ligand_chain = rename_ligand_chain,
128  substructure_match = substructure_match,
129  coverage_delta = coverage_delta,
130  max_symmetries = max_symmetries,
131  min_pep_length = min_pep_length,
132  min_nuc_length = min_nuc_length,
133  pep_seqid_thr = pep_seqid_thr,
134  nuc_seqid_thr = nuc_seqid_thr,
135  mdl_map_pep_seqid_thr = mdl_map_pep_seqid_thr,
136  mdl_map_nuc_seqid_thr = mdl_map_nuc_seqid_thr)
137 
138  self.lddt_pli_radiuslddt_pli_radius = lddt_pli_radius
139  self.add_mdl_contactsadd_mdl_contacts = add_mdl_contacts
140  self.lddt_pli_thresholdslddt_pli_thresholds = lddt_pli_thresholds
141  self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius = lddt_pli_binding_site_radius
142 
143  # lazily precomputed variables to speedup lddt-pli computation
144  self._lddt_pli_target_data_lddt_pli_target_data = dict()
145  self._lddt_pli_model_data_lddt_pli_model_data = dict()
146  self.__mappable_atoms__mappable_atoms = None
147 
148  # update state decoding from parent with subclass specific stuff
149  self.state_decodingstate_decoding[10] = ("no_contact",
150  "There were no LDDT contacts between the "
151  "binding site and the ligand, and LDDT-PLI "
152  "is undefined.")
153  self.state_decodingstate_decoding[20] = ("unknown",
154  "Unknown error occured in LDDTPLIScorer")
155 
156  def _compute(self, symmetries, target_ligand, model_ligand):
157  """ Implements interface from parent
158  """
159  if self.add_mdl_contactsadd_mdl_contacts:
160  LogInfo("Computing LDDT-PLI with added model contacts")
161  result = self._compute_lddt_pli_add_mdl_contacts_compute_lddt_pli_add_mdl_contacts(symmetries,
162  target_ligand,
163  model_ligand)
164  else:
165  LogInfo("Computing LDDT-PLI without added model contacts")
166  result = self._compute_lddt_pli_classic_compute_lddt_pli_classic(symmetries,
167  target_ligand,
168  model_ligand)
169 
170  pair_state = 0
171  score = result["lddt_pli"]
172 
173  if score is None or np.isnan(score):
174  if result["lddt_pli_n_contacts"] == 0:
175  # it's a space ship!
176  pair_state = 10
177  else:
178  # unknwon error state
179  pair_state = 20
180 
181  # the ligands get a zero-state...
182  target_ligand_state = 0
183  model_ligand_state = 0
184 
185  return (score, pair_state, target_ligand_state, model_ligand_state,
186  result)
187 
188  def _score_dir(self):
189  """ Implements interface from parent
190  """
191  return '+'
192 
193  def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand,
194  model_ligand):
195 
196 
199 
200  trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
201  trg_ligand_res, scorer, chem_groups = \
202  self._lddt_pli_get_trg_data_lddt_pli_get_trg_data(target_ligand)
203 
204  trg_bs_center = trg_bs.geometric_center
205 
206  # Copy to make sure that we don't change anything on underlying
207  # references
208  # This is not strictly necessary in the current implementation but
209  # hey, maybe it avoids hard to debug errors when someone changes things
210  ref_indices = [a.copy() for a in scorer.ref_indices_ic]
211  ref_distances = [a.copy() for a in scorer.ref_distances_ic]
212 
213  # distance hacking... remove any interchain distance except the ones
214  # with the ligand
215  ligand_start_idx = scorer.chain_start_indices[-1]
216  for at_idx in range(ligand_start_idx):
217  mask = ref_indices[at_idx] >= ligand_start_idx
218  ref_indices[at_idx] = ref_indices[at_idx][mask]
219  ref_distances[at_idx] = ref_distances[at_idx][mask]
220 
221  mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
222  chem_mapping = self._lddt_pli_get_mdl_data_lddt_pli_get_mdl_data(model_ligand)
223 
224 
227 
228  # ref_mdl_alns refers to full chain mapper trg and mdl structures
229  # => need to adapt mdl sequence that only contain residues in contact
230  # with ligand
231  cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns_lddt_pli_cut_ref_mdl_alns(chem_groups,
232  chem_mapping,
233  mdl_bs, trg_bs)
234 
235 
238 
239  # get each chain mapping that we ever observe in scoring
240  chain_mappings = list(chain_mapping._ChainMappings(chem_groups,
241  chem_mapping))
242 
243  # for each mdl ligand atom, we collect all trg ligand atoms that are
244  # ever mapped onto it given *symmetries*
245  ligand_atom_mappings = [set() for a in mdl_ligand_res.atoms]
246  for (trg_sym, mdl_sym) in symmetries:
247  for trg_i, mdl_i in zip(trg_sym, mdl_sym):
248  ligand_atom_mappings[mdl_i].add(trg_i)
249 
250  mdl_ligand_pos = np.zeros((mdl_ligand_res.GetAtomCount(), 3))
251  for a_idx, a in enumerate(mdl_ligand_res.atoms):
252  p = a.GetPos()
253  mdl_ligand_pos[a_idx, 0] = p[0]
254  mdl_ligand_pos[a_idx, 1] = p[1]
255  mdl_ligand_pos[a_idx, 2] = p[2]
256 
257  trg_ligand_pos = np.zeros((trg_ligand_res.GetAtomCount(), 3))
258  for a_idx, a in enumerate(trg_ligand_res.atoms):
259  p = a.GetPos()
260  trg_ligand_pos[a_idx, 0] = p[0]
261  trg_ligand_pos[a_idx, 1] = p[1]
262  trg_ligand_pos[a_idx, 2] = p[2]
263 
264  mdl_lig_hashes = [a.hash_code for a in mdl_ligand_res.atoms]
265 
266  symmetric_atoms = np.asarray(sorted(list(scorer.symmetric_atoms)),
267  dtype=np.int64)
268 
269  # two caches to cache things for each chain mapping => lists
270  # of len(chain_mappings)
271  #
272  # In principle we're caching for each trg/mdl ligand atom pair all
273  # information to update ref_indices/ref_distances and resolving the
274  # symmetries of the binding site.
275  # in detail: each list entry in *scoring_cache* is a dict with
276  # key: (mdl_lig_at_idx, trg_lig_at_idx)
277  # value: tuple with 4 elements - 1: indices of atoms representing added
278  # contacts relative to overall inexing scheme in scorer 2: the
279  # respective distances 3: the same but only containing indices towards
280  # atoms of the binding site that are considered symmetric 4: the
281  # respective indices.
282  # each list entry in *penalty_cache* is a list of len N mdl lig atoms.
283  # For each mdl lig at it contains a penalty for this mdl lig at. That
284  # means the number of contacts in the mdl binding site that can
285  # directly be mapped to the target given the local chain mapping but
286  # are not present in the target binding site, i.e. interacting atoms are
287  # too far away.
288  scoring_cache = list()
289  penalty_cache = list()
290 
291  for mapping in chain_mappings:
292 
293  # flat mapping with mdl chain names as key
294  flat_mapping = dict()
295  for trg_chem_group, mdl_chem_group in zip(chem_groups, mapping):
296  for a,b in zip(trg_chem_group, mdl_chem_group):
297  if a is not None and b is not None:
298  flat_mapping[b] = a
299 
300  # for each mdl bs atom (as atom hash), the trg bs atoms (as index in
301  # scorer)
302  bs_atom_mapping = dict()
303  for mdl_cname, ref_cname in flat_mapping.items():
304  aln = cut_ref_mdl_alns[(ref_cname, mdl_cname)]
305  ref_ch = trg_bs.Select(f"cname={mol.QueryQuoteName(ref_cname)}")
306  mdl_ch = mdl_bs.Select(f"cname={mol.QueryQuoteName(mdl_cname)}")
307  aln.AttachView(0, ref_ch)
308  aln.AttachView(1, mdl_ch)
309  for col in aln:
310  ref_r = col.GetResidue(0)
311  mdl_r = col.GetResidue(1)
312  if ref_r.IsValid() and mdl_r.IsValid():
313  for mdl_a in mdl_r.atoms:
314  ref_a = ref_r.FindAtom(mdl_a.GetName())
315  if ref_a.IsValid():
316  ref_h = ref_a.handle.hash_code
317  if ref_h in scorer.atom_indices:
318  mdl_h = mdl_a.handle.hash_code
319  bs_atom_mapping[mdl_h] = \
320  scorer.atom_indices[ref_h]
321 
322  cache = dict()
323  n_penalties = list()
324 
325  for mdl_a_idx, mdl_a in enumerate(mdl_ligand_res.atoms):
326  n_penalty = 0
327  trg_bs_indices = list()
328  close_a = mdl_bs.FindWithin(mdl_a.GetPos(),
329  self.lddt_pli_radiuslddt_pli_radius)
330  for a in close_a:
331  mdl_a_hash_code = a.hash_code
332  if mdl_a_hash_code in bs_atom_mapping:
333  trg_bs_indices.append(bs_atom_mapping[mdl_a_hash_code])
334  elif mdl_a_hash_code not in mdl_lig_hashes:
335  if a.GetChain().GetName() in flat_mapping:
336  # Its in a mapped chain
337  at_key = (a.GetResidue().GetNumber(), a.name)
338  cname = a.GetChain().name
339  cname_key = (flat_mapping[cname], cname)
340  if at_key in self._mappable_atoms_mappable_atoms[cname_key]:
341  # Its a contact in the model but not part of
342  # trg_bs. It can still be mapped using the
343  # global mdl_ch/ref_ch alignment
344  # d in ref > self.lddt_pli_radius + max_thresh
345  # => guaranteed to be non-fulfilled contact
346  n_penalty += 1
347 
348  n_penalties.append(n_penalty)
349 
350  trg_bs_indices = np.asarray(sorted(trg_bs_indices))
351 
352  for trg_a_idx in ligand_atom_mappings[mdl_a_idx]:
353  # mask selects entries in trg_bs_indices that are not yet
354  # part of classic LDDT ref_indices for atom at trg_a_idx
355  # => added mdl contacts
356  mask = np.isin(trg_bs_indices,
357  ref_indices[ligand_start_idx + trg_a_idx],
358  assume_unique=True, invert=True)
359  added_indices = np.asarray([], dtype=np.int64)
360  added_distances = np.asarray([], dtype=np.float64)
361  if np.sum(mask) > 0:
362  # compute ref distances on reference positions
363  added_indices = trg_bs_indices[mask]
364  tmp = scorer.positions.take(added_indices, axis=0)
365  np.subtract(tmp, trg_ligand_pos[trg_a_idx][None, :],
366  out=tmp)
367  np.square(tmp, out=tmp)
368  tmp = tmp.sum(axis=1)
369  np.sqrt(tmp, out=tmp)
370  added_distances = tmp
371 
372  # extract the distances towards bs atoms that are symmetric
373  sym_mask = np.isin(added_indices, symmetric_atoms,
374  assume_unique=True)
375 
376  cache[(mdl_a_idx, trg_a_idx)] = (added_indices,
377  added_distances,
378  added_indices[sym_mask],
379  added_distances[sym_mask])
380 
381  scoring_cache.append(cache)
382  penalty_cache.append(n_penalties)
383 
384  # cache for model contacts towards non mapped trg chains - this is
385  # relevant for self._lddt_pli_unmapped_chain_penalty
386  # key: tuple in form (trg_ch, mdl_ch)
387  # value: yet another dict with
388  # key: ligand_atom_hash
389  # value: n contacts towards respective trg chain that can be mapped
390  non_mapped_cache = dict()
391 
392 
395 
396  best_score = -1.0
397  best_result = {"lddt_pli": None,
398  "lddt_pli_n_contacts": 0,
399  "chain_mapping": None}
400 
401  # dummy alignment for ligand chains which is needed as input later on
402  ligand_aln = seq.CreateAlignment()
403  trg_s = seq.CreateSequence(trg_ligand_chain.name,
404  trg_ligand_res.GetOneLetterCode())
405  mdl_s = seq.CreateSequence(mdl_ligand_chain.name,
406  mdl_ligand_res.GetOneLetterCode())
407  ligand_aln.AddSequence(trg_s)
408  ligand_aln.AddSequence(mdl_s)
409  ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
410 
411  sym_idx_collector = [None] * scorer.n_atoms
412  sym_dist_collector = [None] * scorer.n_atoms
413 
414  for mapping, s_cache, p_cache in zip(chain_mappings, scoring_cache,
415  penalty_cache):
416 
417  lddt_chain_mapping = dict()
418  lddt_alns = dict()
419  for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
420  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
421  # some mdl chains can be None
422  if mdl_ch is not None:
423  lddt_chain_mapping[mdl_ch] = ref_ch
424  lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
425 
426  # add ligand to lddt_chain_mapping/lddt_alns
427  lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
428  lddt_alns[mdl_ligand_chain.name] = ligand_aln
429 
430  # already process model, positions will be manually hacked for each
431  # symmetry - small overhead for variables that are thrown away here
432  pos, _, _, _, _, _, lddt_symmetries = \
433  scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
434  residue_mapping = lddt_alns,
435  thresholds = self.lddt_pli_thresholdslddt_pli_thresholds,
436  check_resnames = False)
437 
438  # estimate a penalty for unsatisfied model contacts from chains
439  # that are not in the local trg binding site, but can be mapped in
440  # the target.
441  # We're using the trg chain with the closest geometric center to
442  # the trg binding site that can be mapped to the mdl chain
443  # according the chem mapping. An alternative would be to search for
444  # the target chain with the minimal number of additional contacts.
445  # There is not good solution for this problem...
446  unmapped_chains = list()
447  already_mapped = set()
448  for mdl_ch in mdl_chains:
449  if mdl_ch not in lddt_chain_mapping:
450  # check which chain in trg is closest
451  chem_grp_idx = None
452  for i, m in enumerate(self._chem_mapping_chem_mapping):
453  if mdl_ch in m:
454  chem_grp_idx = i
455  break
456  if chem_grp_idx is None:
457  raise RuntimeError("This should never happen... "
458  "ask Gabriel...")
459  closest_ch = None
460  closest_dist = None
461  for trg_ch in self._chain_mapper_chain_mapper.chem_groups[chem_grp_idx]:
462  if trg_ch not in lddt_chain_mapping.values():
463  if trg_ch not in already_mapped:
464  ch = self._chain_mapper_chain_mapper.target.FindChain(trg_ch)
465  c = ch.geometric_center
466  d = geom.Distance(trg_bs_center, c)
467  if closest_dist is None or d < closest_dist:
468  closest_dist = d
469  closest_ch = trg_ch
470  if closest_ch is not None:
471  unmapped_chains.append((closest_ch, mdl_ch))
472  already_mapped.add(closest_ch)
473 
474  for (trg_sym, mdl_sym) in symmetries:
475 
476  # update positions
477  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
478  pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
479 
480  # start new ref_indices/ref_distances from original values
481  funky_ref_indices = [np.copy(a) for a in ref_indices]
482  funky_ref_distances = [np.copy(a) for a in ref_distances]
483 
484  # The only distances from the binding site towards the ligand
485  # we care about are the ones from the symmetric atoms to
486  # correctly compute scorer._ResolveSymmetries.
487  # We collect them while updating distances from added mdl
488  # contacts
489  for idx in symmetric_atoms:
490  sym_idx_collector[idx] = list()
491  sym_dist_collector[idx] = list()
492 
493  # add data from added mdl contacts cache
494  added_penalty = 0
495  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
496  added_penalty += p_cache[mdl_i]
497  cache = s_cache[mdl_i, trg_i]
498  full_trg_i = ligand_start_idx + trg_i
499  funky_ref_indices[full_trg_i] = \
500  np.append(funky_ref_indices[full_trg_i], cache[0])
501  funky_ref_distances[full_trg_i] = \
502  np.append(funky_ref_distances[full_trg_i], cache[1])
503  for idx, d in zip(cache[2], cache[3]):
504  sym_idx_collector[idx].append(full_trg_i)
505  sym_dist_collector[idx].append(d)
506 
507  for idx in symmetric_atoms:
508  funky_ref_indices[idx] = \
509  np.append(funky_ref_indices[idx],
510  np.asarray(sym_idx_collector[idx],
511  dtype=np.int64))
512  funky_ref_distances[idx] = \
513  np.append(funky_ref_distances[idx],
514  np.asarray(sym_dist_collector[idx],
515  dtype=np.float64))
516 
517  # we can pass funky_ref_indices/funky_ref_distances as
518  # sym_ref_indices/sym_ref_distances in
519  # scorer._ResolveSymmetries as we only have distances of the bs
520  # to the ligand and ligand atoms are "non-symmetric"
521  scorer._ResolveSymmetries(pos, self.lddt_pli_thresholdslddt_pli_thresholds,
522  lddt_symmetries,
523  funky_ref_indices,
524  funky_ref_distances)
525 
526  N = sum([len(funky_ref_indices[i]) for i in ligand_at_indices])
527  N += added_penalty
528 
529  # collect number of expected contacts which can be mapped
530  if len(unmapped_chains) > 0:
531  N += self._lddt_pli_unmapped_chain_penalty_lddt_pli_unmapped_chain_penalty(unmapped_chains,
532  non_mapped_cache,
533  mdl_bs,
534  mdl_ligand_res,
535  mdl_sym)
536 
537  conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
538  self.lddt_pli_thresholdslddt_pli_thresholds,
539  funky_ref_indices,
540  funky_ref_distances),
541  axis=0)
542  score = None
543  if N > 0:
544  score = np.mean(conserved/N)
545 
546  if score is not None and score > best_score:
547  best_score = score
548  save_chain_mapping = dict(lddt_chain_mapping)
549  del save_chain_mapping[mdl_ligand_chain.name]
550  best_result = {"lddt_pli": score,
551  "lddt_pli_n_contacts": N,
552  "chain_mapping": save_chain_mapping}
553 
554  # fill misc info to result object
555  best_result["target_ligand"] = target_ligand
556  best_result["model_ligand"] = model_ligand
557  best_result["bs_ref_res"] = trg_residues
558  best_result["bs_mdl_res"] = mdl_residues
559 
560  return best_result
561 
562 
563  def _compute_lddt_pli_classic(self, symmetries, target_ligand,
564  model_ligand):
565 
566 
569 
570  max_r = None
571  if self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius:
572  max_r = self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius
573 
574  trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
575  trg_ligand_res, scorer, chem_groups = \
576  self._lddt_pli_get_trg_data_lddt_pli_get_trg_data(target_ligand, max_r = max_r)
577 
578  # Copy to make sure that we don't change anything on underlying
579  # references
580  # This is not strictly necessary in the current implementation but
581  # hey, maybe it avoids hard to debug errors when someone changes things
582  ref_indices = [a.copy() for a in scorer.ref_indices_ic]
583  ref_distances = [a.copy() for a in scorer.ref_distances_ic]
584 
585  # no matter what mapping/symmetries, the number of expected
586  # contacts stays the same
587  ligand_start_idx = scorer.chain_start_indices[-1]
588  ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
589  n_exp = sum([len(ref_indices[i]) for i in ligand_at_indices])
590 
591  mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
592  chem_mapping = self._lddt_pli_get_mdl_data_lddt_pli_get_mdl_data(model_ligand)
593 
594  if n_exp == 0:
595  # no contacts... nothing to compute...
596  return {"lddt_pli": None,
597  "lddt_pli_n_contacts": 0,
598  "chain_mapping": None,
599  "target_ligand": target_ligand,
600  "model_ligand": model_ligand,
601  "bs_ref_res": trg_residues,
602  "bs_mdl_res": mdl_residues}
603 
604  # Distance hacking... remove any interchain distance except the ones
605  # with the ligand
606  for at_idx in range(ligand_start_idx):
607  mask = ref_indices[at_idx] >= ligand_start_idx
608  ref_indices[at_idx] = ref_indices[at_idx][mask]
609  ref_distances[at_idx] = ref_distances[at_idx][mask]
610 
611 
614 
615  # ref_mdl_alns refers to full chain mapper trg and mdl structures
616  # => need to adapt mdl sequence that only contain residues in contact
617  # with ligand
618  cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns_lddt_pli_cut_ref_mdl_alns(chem_groups,
619  chem_mapping,
620  mdl_bs, trg_bs)
621 
622 
625 
626  best_score = -1.0
627 
628  # dummy alignment for ligand chains which is needed as input later on
629  l_aln = seq.CreateAlignment()
630  l_aln.AddSequence(seq.CreateSequence(trg_ligand_chain.name,
631  trg_ligand_res.GetOneLetterCode()))
632  l_aln.AddSequence(seq.CreateSequence(mdl_ligand_chain.name,
633  mdl_ligand_res.GetOneLetterCode()))
634 
635  mdl_ligand_pos = np.zeros((model_ligand.GetAtomCount(), 3))
636  for a_idx, a in enumerate(model_ligand.atoms):
637  p = a.GetPos()
638  mdl_ligand_pos[a_idx, 0] = p[0]
639  mdl_ligand_pos[a_idx, 1] = p[1]
640  mdl_ligand_pos[a_idx, 2] = p[2]
641 
642  for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping):
643 
644  lddt_chain_mapping = dict()
645  lddt_alns = dict()
646  for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
647  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
648  # some mdl chains can be None
649  if mdl_ch is not None:
650  lddt_chain_mapping[mdl_ch] = ref_ch
651  lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
652 
653  # add ligand to lddt_chain_mapping/lddt_alns
654  lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
655  lddt_alns[mdl_ligand_chain.name] = l_aln
656 
657  # already process model, positions will be manually hacked for each
658  # symmetry - small overhead for variables that are thrown away here
659  pos, _, _, _, _, _, lddt_symmetries = \
660  scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
661  residue_mapping = lddt_alns,
662  thresholds = self.lddt_pli_thresholdslddt_pli_thresholds,
663  check_resnames = False)
664 
665  for (trg_sym, mdl_sym) in symmetries:
666  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
667  pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
668  # we can pass ref_indices/ref_distances as
669  # sym_ref_indices/sym_ref_distances in
670  # scorer._ResolveSymmetries as we only have distances of the bs
671  # to the ligand and ligand atoms are "non-symmetric"
672  scorer._ResolveSymmetries(pos, self.lddt_pli_thresholdslddt_pli_thresholds,
673  lddt_symmetries,
674  ref_indices,
675  ref_distances)
676  # compute number of conserved distances for ligand atoms
677  conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
678  self.lddt_pli_thresholdslddt_pli_thresholds,
679  ref_indices,
680  ref_distances), axis=0)
681  score = np.mean(conserved/n_exp)
682 
683  if score > best_score:
684  best_score = score
685  save_chain_mapping = dict(lddt_chain_mapping)
686  del save_chain_mapping[mdl_ligand_chain.name]
687  best_result = {"lddt_pli": score,
688  "chain_mapping": save_chain_mapping}
689 
690  # fill misc info to result object
691  best_result["lddt_pli_n_contacts"] = n_exp
692  best_result["target_ligand"] = target_ligand
693  best_result["model_ligand"] = model_ligand
694  best_result["bs_ref_res"] = trg_residues
695  best_result["bs_mdl_res"] = mdl_residues
696 
697  return best_result
698 
699  def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains,
700  non_mapped_cache,
701  mdl_bs,
702  mdl_ligand_res,
703  mdl_sym):
704 
705  n_exp = 0
706  for ch_tuple in unmapped_chains:
707  if ch_tuple not in non_mapped_cache:
708  # for each ligand atom, we count the number of mappable atoms
709  # within lddt_pli_radius
710  counts = dict()
711  # the select statement also excludes the ligand in mdl_bs
712  # as it resides in a separate chain
713  mdl_cname = ch_tuple[1]
714  query = "cname=" + mol.QueryQuoteName(mdl_cname)
715  mdl_bs_ch = mdl_bs.Select(query)
716  for a in mdl_ligand_res.atoms:
717  close_atoms = \
718  mdl_bs_ch.FindWithin(a.GetPos(), self.lddt_pli_radiuslddt_pli_radius)
719  N = 0
720  for close_a in close_atoms:
721  at_key = (close_a.GetResidue().GetNumber(),
722  close_a.GetName())
723  if at_key in self._mappable_atoms_mappable_atoms[ch_tuple]:
724  N += 1
725  counts[a.hash_code] = N
726 
727  # fill cache
728  non_mapped_cache[ch_tuple] = counts
729 
730  # add number of mdl contacts which can be mapped to target
731  # as non-fulfilled contacts
732  counts = non_mapped_cache[ch_tuple]
733  lig_hash_codes = [a.hash_code for a in mdl_ligand_res.atoms]
734  for i in mdl_sym:
735  n_exp += counts[lig_hash_codes[i]]
736 
737  return n_exp
738 
739 
740  def _lddt_pli_get_mdl_data(self, model_ligand):
741  if model_ligand not in self._lddt_pli_model_data_lddt_pli_model_data:
742 
743  mdl = self._chain_mapping_mdl_chain_mapping_mdl
744 
745  mdl_residues = set()
746  for at in model_ligand.atoms:
747  close_atoms = mdl.FindWithin(at.GetPos(), self.lddt_pli_radiuslddt_pli_radius)
748  for close_at in close_atoms:
749  mdl_residues.add(close_at.GetResidue())
750 
751  max_r = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds)
752  for r in mdl.residues:
753  r.SetIntProp("bs", 0)
754  for at in model_ligand.atoms:
755  close_atoms = mdl.FindWithin(at.GetPos(), max_r)
756  for close_at in close_atoms:
757  close_at.GetResidue().SetIntProp("bs", 1)
758 
759  mdl_bs = mol.CreateEntityFromView(mdl.Select("grbs:0=1"), True)
760  mdl_chains = set([ch.name for ch in mdl_bs.chains])
761 
762  mdl_editor = mdl_bs.EditXCS(mol.BUFFERED_EDIT)
763  mdl_ligand_chain = None
764  for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
765  try:
766  # I'm pretty sure, one of these chain names is not there...
767  mdl_ligand_chain = mdl_editor.InsertChain(cname)
768  break
769  except:
770  pass
771  if mdl_ligand_chain is None:
772  raise RuntimeError("Fuck this, I'm out...")
773  mdl_ligand_res = mdl_editor.AppendResidue(mdl_ligand_chain,
774  model_ligand,
775  deep=True)
776  mdl_editor.RenameResidue(mdl_ligand_res, "LIG")
777  mdl_editor.SetResidueNumber(mdl_ligand_res, mol.ResNum(1))
778 
779  chem_mapping = list()
780  for m in self._chem_mapping_chem_mapping:
781  chem_mapping.append([x for x in m if x in mdl_chains])
782 
783  self._lddt_pli_model_data_lddt_pli_model_data[model_ligand] = (mdl_residues,
784  mdl_bs,
785  mdl_chains,
786  mdl_ligand_chain,
787  mdl_ligand_res,
788  chem_mapping)
789 
790  return self._lddt_pli_model_data_lddt_pli_model_data[model_ligand]
791 
792 
793  def _lddt_pli_get_trg_data(self, target_ligand, max_r = None):
794  if target_ligand not in self._lddt_pli_target_data_lddt_pli_target_data:
795 
796  trg = self._chain_mapper_chain_mapper.target
797 
798  if max_r is None:
799  max_r = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds)
800 
801  trg_residues = set()
802  for at in target_ligand.atoms:
803  close_atoms = trg.FindWithin(at.GetPos(), max_r)
804  for close_at in close_atoms:
805  trg_residues.add(close_at.GetResidue())
806 
807  for r in trg.residues:
808  r.SetIntProp("bs", 0)
809 
810  for r in trg_residues:
811  r.SetIntProp("bs", 1)
812 
813  trg_bs = mol.CreateEntityFromView(trg.Select("grbs:0=1"), True)
814  trg_chains = set([ch.name for ch in trg_bs.chains])
815 
816  trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT)
817  trg_ligand_chain = None
818  for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
819  try:
820  # I'm pretty sure, one of these chain names is not there yet
821  trg_ligand_chain = trg_editor.InsertChain(cname)
822  break
823  except:
824  pass
825  if trg_ligand_chain is None:
826  raise RuntimeError("Fuck this, I'm out...")
827 
828  trg_ligand_res = trg_editor.AppendResidue(trg_ligand_chain,
829  target_ligand,
830  deep=True)
831  trg_editor.RenameResidue(trg_ligand_res, "LIG")
832  trg_editor.SetResidueNumber(trg_ligand_res, mol.ResNum(1))
833 
834  compound_name = trg_ligand_res.name
835  compound = lddt.CustomCompound.FromResidue(trg_ligand_res)
836  custom_compounds = {compound_name: compound}
837 
838  scorer = lddt.lDDTScorer(trg_bs,
839  custom_compounds = custom_compounds,
840  inclusion_radius = self.lddt_pli_radiuslddt_pli_radius)
841 
842  chem_groups = list()
843  for g in self._chain_mapper_chain_mapper.chem_groups:
844  chem_groups.append([x for x in g if x in trg_chains])
845 
846  self._lddt_pli_target_data_lddt_pli_target_data[target_ligand] = (trg_residues,
847  trg_bs,
848  trg_chains,
849  trg_ligand_chain,
850  trg_ligand_res,
851  scorer,
852  chem_groups)
853 
854  return self._lddt_pli_target_data_lddt_pli_target_data[target_ligand]
855 
856 
857  def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs,
858  ref_bs):
859  cut_ref_mdl_alns = dict()
860  for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping):
861  for ref_ch in ref_chem_group:
862 
863  ref_bs_chain = ref_bs.FindChain(ref_ch)
864  query = "cname=" + mol.QueryQuoteName(ref_ch)
865  ref_view = self._chain_mapper_chain_mapper.target.Select(query)
866 
867  for mdl_ch in mdl_chem_group:
868  aln = self._ref_mdl_alns_ref_mdl_alns[(ref_ch, mdl_ch)]
869 
870  aln.AttachView(0, ref_view)
871 
872  mdl_bs_chain = mdl_bs.FindChain(mdl_ch)
873  query = "cname=" + mol.QueryQuoteName(mdl_ch)
874  aln.AttachView(1, self._chain_mapping_mdl_chain_mapping_mdl.Select(query))
875 
876  cut_mdl_seq = ['-'] * aln.GetLength()
877  cut_ref_seq = ['-'] * aln.GetLength()
878  for i, col in enumerate(aln):
879 
880  # check ref residue
881  r = col.GetResidue(0)
882  if r.IsValid():
883  bs_r = ref_bs_chain.FindResidue(r.GetNumber())
884  if bs_r.IsValid():
885  cut_ref_seq[i] = col[0]
886 
887  # check mdl residue
888  r = col.GetResidue(1)
889  if r.IsValid():
890  bs_r = mdl_bs_chain.FindResidue(r.GetNumber())
891  if bs_r.IsValid():
892  cut_mdl_seq[i] = col[1]
893 
894  cut_ref_seq = ''.join(cut_ref_seq)
895  cut_mdl_seq = ''.join(cut_mdl_seq)
896  cut_aln = seq.CreateAlignment()
897  cut_aln.AddSequence(seq.CreateSequence(ref_ch, cut_ref_seq))
898  cut_aln.AddSequence(seq.CreateSequence(mdl_ch, cut_mdl_seq))
899  cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln
900  return cut_ref_mdl_alns
901 
902  @property
903  def _mappable_atoms(self):
904  """ Stores mappable atoms given a chain mapping
905 
906  Store for each ref_ch,mdl_ch pair all mdl atoms that can be
907  mapped. Don't store mappable atoms as hashes but rather as tuple
908  (mdl_r.GetNumber(), mdl_a.GetName()). Reason for that is that one might
909  operate on Copied EntityHandle objects without corresponding hashes.
910  Given a tuple defining c_pair: (ref_cname, mdl_cname), one
911  can check if a certain atom is mappable by evaluating:
912  if (mdl_r.GetNumber(), mdl_a.GetName()) in self._mappable_atoms(c_pair)
913  """
914  if self.__mappable_atoms__mappable_atoms is None:
915  self.__mappable_atoms__mappable_atoms = dict()
916  for (ref_cname, mdl_cname), aln in self._ref_mdl_alns_ref_mdl_alns.items():
917  self._mappable_atoms_mappable_atoms[(ref_cname, mdl_cname)] = set()
918  ref_query = f"cname={mol.QueryQuoteName(ref_cname)}"
919  mdl_query = f"cname={mol.QueryQuoteName(mdl_cname)}"
920  ref_ch = self._chain_mapper_chain_mapper.target.Select(ref_query)
921  mdl_ch = self._chain_mapping_mdl_chain_mapping_mdl.Select(mdl_query)
922  aln.AttachView(0, ref_ch)
923  aln.AttachView(1, mdl_ch)
924  for col in aln:
925  ref_r = col.GetResidue(0)
926  mdl_r = col.GetResidue(1)
927  if ref_r.IsValid() and mdl_r.IsValid():
928  for mdl_a in mdl_r.atoms:
929  if ref_r.FindAtom(mdl_a.name).IsValid():
930  c_key = (ref_cname, mdl_cname)
931  at_key = (mdl_r.GetNumber(), mdl_a.name)
932  self.__mappable_atoms__mappable_atoms[c_key].add(at_key)
933 
934  return self.__mappable_atoms__mappable_atoms
935 
936 # specify public interface
937 __all__ = ('LDDTPLIScorer',)
def _lddt_pli_get_trg_data(self, target_ligand, max_r=None)
def _compute_lddt_pli_classic(self, symmetries, target_ligand, model_ligand)
def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains, non_mapped_cache, mdl_bs, mdl_ligand_res, mdl_sym)
def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs, ref_bs)
def __init__(self, model, target, model_ligands, target_ligands, resnum_alignments=False, rename_ligand_chain=False, substructure_match=False, coverage_delta=0.2, max_symmetries=1e4, lddt_pli_radius=6.0, add_mdl_contacts=True, lddt_pli_thresholds=[0.5, 1.0, 2.0, 4.0], lddt_pli_binding_site_radius=None, min_pep_length=6, min_nuc_length=4, pep_seqid_thr=95., nuc_seqid_thr=95., mdl_map_pep_seqid_thr=0., mdl_map_nuc_seqid_thr=0.)
def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand, model_ligand)
Real DLLEXPORT_OST_GEOM Distance(const Line2 &l, const Vec2 &v)