OpenStructure
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
ligand_scoring_lddtpli.py
Go to the documentation of this file.
1 import numpy as np
2 
3 from ost import LogWarning, LogInfo
4 from ost import geom
5 from ost import mol
6 from ost import seq
7 
8 from ost.mol.alg import lddt
9 from ost.mol.alg import chain_mapping
10 from ost.mol.alg import ligand_scoring_base
11 
13  """ :class:`LigandScorer` implementing lDDT-PLI.
14 
15  lDDT-PLI is an lDDT score considering contacts between ligand and
16  receptor. Where receptor consists of protein and nucleic acid chains that
17  pass the criteria for :class:`chain mapping <ost.mol.alg.chain_mapping>`.
18  This means ignoring other ligands, waters, short polymers as well as any
19  incorrectly connected chains that may be in proximity.
20 
21  :class:`LDDTPLIScorer` computes a score for a specific pair of target/model
22  ligands. Given a target/model ligand pair, all possible mappings of
23  model chains onto their chemically equivalent target chains are enumerated.
24  For each of these enumerations, all possible symmetries, i.e. atom-atom
25  assignments of the ligand as given by :class:`LigandScorer`, are evaluated
26  and an lDDT-PLI score is computed. The best possible lDDT-PLI score is
27  returned.
28 
29  The lDDT-PLI score is a variant of lDDT with a custom inclusion radius
30  (`lddt_pli_radius`), no stereochemistry checks, and which penalizes
31  contacts added in the model within `lddt_pli_radius` by default
32  (can be changed with the `add_mdl_contacts` flag) but only if the involved
33  atoms can be mapped to the target. This is a requirement to
34  1) extract the respective reference distance from the target
35  2) avoid usage of contacts for which we have no experimental evidence.
36  One special case are contacts from chains that are not mapped to the target
37  binding site. It is very well possible that we have experimental evidence
38  for this chain though its just too far away from the target binding site.
39  We therefore try to map these contacts to the chain in the target with
40  equivalent sequence that is closest to the target binding site. If the
41  respective atoms can be mapped there, the contact is considered not
42  fulfilled and added as penalty.
43 
44  Populates :attr:`LigandScorer.aux_data` with following :class:`dict` keys:
45 
46  * lddt_pli: The LDDT-PLI score
47  * lddt_pli_n_contacts: Number of contacts considered in lDDT computation
48  * target_ligand: The actual target ligand for which the score was computed
49  * model_ligand: The actual model ligand for which the score was computed
50  * bs_ref_res: :class:`set` of residues with potentially non-zero
51  contribution to score. That is every residue with at least one
52  atom within *lddt_pli_radius* + max(*lddt_pli_thresholds*) of
53  the ligand.
54  * bs_mdl_res: Same for model
55 
56  :param model: Passed to parent constructor - see :class:`LigandScorer`.
57  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
58  :param target: Passed to parent constructor - see :class:`LigandScorer`.
59  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
60  :param model_ligands: Passed to parent constructor - see
61  :class:`LigandScorer`.
62  :type model_ligands: :class:`list`
63  :param target_ligands: Passed to parent constructor - see
64  :class:`LigandScorer`.
65  :type target_ligands: :class:`list`
66  :param resnum_alignments: Passed to parent constructor - see
67  :class:`LigandScorer`.
68  :type resnum_alignments: :class:`bool`
69  :param rename_ligand_chain: Passed to parent constructor - see
70  :class:`LigandScorer`.
71  :type rename_ligand_chain: :class:`bool`
72  :param substructure_match: Passed to parent constructor - see
73  :class:`LigandScorer`.
74  :type substructure_match: :class:`bool`
75  :param coverage_delta: Passed to parent constructor - see
76  :class:`LigandScorer`.
77  :type coverage_delta: :class:`float`
78  :param max_symmetries: Passed to parent constructor - see
79  :class:`LigandScorer`.
80  :type max_symmetries: :class:`int`
81  :param lddt_pli_radius: lDDT inclusion radius for lDDT-PLI.
82  :type lddt_pli_radius: :class:`float`
83  :param add_mdl_contacts: Whether to penalize added model contacts.
84  :type add_mdl_contacts: :class:`bool`
85  :param lddt_pli_thresholds: Distance difference thresholds for lDDT.
86  :type lddt_pli_thresholds: :class:`list` of :class:`float`
87  :param lddt_pli_binding_site_radius: Pro param - dont use. Providing a value
88  Restores behaviour from previous
89  implementation that first extracted a
90  binding site with strict distance
91  threshold and computed lDDT-PLI only on
92  those target residues whereas the
93  current implementation includes every
94  atom within *lddt_pli_radius*.
95  :type lddt_pli_binding_site_radius: :class:`float`
96  """
97 
98  def __init__(self, model, target, model_ligands, target_ligands,
99  resnum_alignments=False, rename_ligand_chain=False,
100  substructure_match=False, coverage_delta=0.2,
101  max_symmetries=1e4, lddt_pli_radius=6.0,
102  add_mdl_contacts=True,
103  lddt_pli_thresholds = [0.5, 1.0, 2.0, 4.0],
104  lddt_pli_binding_site_radius=None):
105 
106  super().__init__(model, target, model_ligands, target_ligands,
107  resnum_alignments = resnum_alignments,
108  rename_ligand_chain = rename_ligand_chain,
109  substructure_match = substructure_match,
110  coverage_delta = coverage_delta,
111  max_symmetries = max_symmetries)
112 
113  self.lddt_pli_radiuslddt_pli_radius = lddt_pli_radius
114  self.add_mdl_contactsadd_mdl_contacts = add_mdl_contacts
115  self.lddt_pli_thresholdslddt_pli_thresholds = lddt_pli_thresholds
116  self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius = lddt_pli_binding_site_radius
117 
118  # lazily precomputed variables to speedup lddt-pli computation
119  self._lddt_pli_target_data_lddt_pli_target_data = dict()
120  self._lddt_pli_model_data_lddt_pli_model_data = dict()
121  self.__mappable_atoms__mappable_atoms = None
122  self.__chem_mapping__chem_mapping = None
123  self.__chem_group_alns__chem_group_alns = None
124  self.__ref_mdl_alns__ref_mdl_alns = None
125  self.__chain_mapping_mdl__chain_mapping_mdl = None
126 
127  # update state decoding from parent with subclass specific stuff
128  self.state_decodingstate_decoding[10] = ("no_contact",
129  "There were no lDDT contacts between the "
130  "binding site and the ligand, and lDDT-PLI "
131  "is undefined.")
132  self.state_decodingstate_decoding[20] = ("unknown",
133  "Unknown error occured in LDDTPLIScorer")
134 
135  def _compute(self, symmetries, target_ligand, model_ligand):
136  """ Implements interface from parent
137  """
138  if self.add_mdl_contactsadd_mdl_contacts:
139  LogInfo("Computing lDDT-PLI with added model contacts")
140  result = self._compute_lddt_pli_add_mdl_contacts_compute_lddt_pli_add_mdl_contacts(symmetries,
141  target_ligand,
142  model_ligand)
143  else:
144  LogInfo("Computing lDDT-PLI without added model contacts")
145  result = self._compute_lddt_pli_classic_compute_lddt_pli_classic(symmetries,
146  target_ligand,
147  model_ligand)
148 
149  pair_state = 0
150  score = result["lddt_pli"]
151 
152  if score is None or np.isnan(score):
153  if result["lddt_pli_n_contacts"] == 0:
154  # it's a space ship!
155  pair_state = 10
156  else:
157  # unknwon error state
158  pair_state = 20
159 
160  # the ligands get a zero-state...
161  target_ligand_state = 0
162  model_ligand_state = 0
163 
164  return (score, pair_state, target_ligand_state, model_ligand_state,
165  result)
166 
167  def _score_dir(self):
168  """ Implements interface from parent
169  """
170  return '+'
171 
172  def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand,
173  model_ligand):
174 
175 
178 
179  trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
180  trg_ligand_res, scorer, chem_groups = \
181  self._lddt_pli_get_trg_data_lddt_pli_get_trg_data(target_ligand)
182 
183  trg_bs_center = trg_bs.geometric_center
184 
185  # Copy to make sure that we don't change anything on underlying
186  # references
187  # This is not strictly necessary in the current implementation but
188  # hey, maybe it avoids hard to debug errors when someone changes things
189  ref_indices = [a.copy() for a in scorer.ref_indices_ic]
190  ref_distances = [a.copy() for a in scorer.ref_distances_ic]
191 
192  # distance hacking... remove any interchain distance except the ones
193  # with the ligand
194  ligand_start_idx = scorer.chain_start_indices[-1]
195  for at_idx in range(ligand_start_idx):
196  mask = ref_indices[at_idx] >= ligand_start_idx
197  ref_indices[at_idx] = ref_indices[at_idx][mask]
198  ref_distances[at_idx] = ref_distances[at_idx][mask]
199 
200  mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
201  chem_mapping = self._lddt_pli_get_mdl_data_lddt_pli_get_mdl_data(model_ligand)
202 
203 
206 
207  # ref_mdl_alns refers to full chain mapper trg and mdl structures
208  # => need to adapt mdl sequence that only contain residues in contact
209  # with ligand
210  cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns_lddt_pli_cut_ref_mdl_alns(chem_groups,
211  chem_mapping,
212  mdl_bs, trg_bs)
213 
214 
217 
218  # get each chain mapping that we ever observe in scoring
219  chain_mappings = list(chain_mapping._ChainMappings(chem_groups,
220  chem_mapping))
221 
222  # for each mdl ligand atom, we collect all trg ligand atoms that are
223  # ever mapped onto it given *symmetries*
224  ligand_atom_mappings = [set() for a in mdl_ligand_res.atoms]
225  for (trg_sym, mdl_sym) in symmetries:
226  for trg_i, mdl_i in zip(trg_sym, mdl_sym):
227  ligand_atom_mappings[mdl_i].add(trg_i)
228 
229  mdl_ligand_pos = np.zeros((mdl_ligand_res.GetAtomCount(), 3))
230  for a_idx, a in enumerate(mdl_ligand_res.atoms):
231  p = a.GetPos()
232  mdl_ligand_pos[a_idx, 0] = p[0]
233  mdl_ligand_pos[a_idx, 1] = p[1]
234  mdl_ligand_pos[a_idx, 2] = p[2]
235 
236  trg_ligand_pos = np.zeros((trg_ligand_res.GetAtomCount(), 3))
237  for a_idx, a in enumerate(trg_ligand_res.atoms):
238  p = a.GetPos()
239  trg_ligand_pos[a_idx, 0] = p[0]
240  trg_ligand_pos[a_idx, 1] = p[1]
241  trg_ligand_pos[a_idx, 2] = p[2]
242 
243  mdl_lig_hashes = [a.hash_code for a in mdl_ligand_res.atoms]
244 
245  symmetric_atoms = np.asarray(sorted(list(scorer.symmetric_atoms)),
246  dtype=np.int64)
247 
248  # two caches to cache things for each chain mapping => lists
249  # of len(chain_mappings)
250  #
251  # In principle we're caching for each trg/mdl ligand atom pair all
252  # information to update ref_indices/ref_distances and resolving the
253  # symmetries of the binding site.
254  # in detail: each list entry in *scoring_cache* is a dict with
255  # key: (mdl_lig_at_idx, trg_lig_at_idx)
256  # value: tuple with 4 elements - 1: indices of atoms representing added
257  # contacts relative to overall inexing scheme in scorer 2: the
258  # respective distances 3: the same but only containing indices towards
259  # atoms of the binding site that are considered symmetric 4: the
260  # respective indices.
261  # each list entry in *penalty_cache* is a list of len N mdl lig atoms.
262  # For each mdl lig at it contains a penalty for this mdl lig at. That
263  # means the number of contacts in the mdl binding site that can
264  # directly be mapped to the target given the local chain mapping but
265  # are not present in the target binding site, i.e. interacting atoms are
266  # too far away.
267  scoring_cache = list()
268  penalty_cache = list()
269 
270  for mapping in chain_mappings:
271 
272  # flat mapping with mdl chain names as key
273  flat_mapping = dict()
274  for trg_chem_group, mdl_chem_group in zip(chem_groups, mapping):
275  for a,b in zip(trg_chem_group, mdl_chem_group):
276  if a is not None and b is not None:
277  flat_mapping[b] = a
278 
279  # for each mdl bs atom (as atom hash), the trg bs atoms (as index in
280  # scorer)
281  bs_atom_mapping = dict()
282  for mdl_cname, ref_cname in flat_mapping.items():
283  aln = cut_ref_mdl_alns[(ref_cname, mdl_cname)]
284  ref_ch = trg_bs.Select(f"cname={mol.QueryQuoteName(ref_cname)}")
285  mdl_ch = mdl_bs.Select(f"cname={mol.QueryQuoteName(mdl_cname)}")
286  aln.AttachView(0, ref_ch)
287  aln.AttachView(1, mdl_ch)
288  for col in aln:
289  ref_r = col.GetResidue(0)
290  mdl_r = col.GetResidue(1)
291  if ref_r.IsValid() and mdl_r.IsValid():
292  for mdl_a in mdl_r.atoms:
293  ref_a = ref_r.FindAtom(mdl_a.GetName())
294  if ref_a.IsValid():
295  ref_h = ref_a.handle.hash_code
296  if ref_h in scorer.atom_indices:
297  mdl_h = mdl_a.handle.hash_code
298  bs_atom_mapping[mdl_h] = \
299  scorer.atom_indices[ref_h]
300 
301  cache = dict()
302  n_penalties = list()
303 
304  for mdl_a_idx, mdl_a in enumerate(mdl_ligand_res.atoms):
305  n_penalty = 0
306  trg_bs_indices = list()
307  close_a = mdl_bs.FindWithin(mdl_a.GetPos(),
308  self.lddt_pli_radiuslddt_pli_radius)
309  for a in close_a:
310  mdl_a_hash_code = a.hash_code
311  if mdl_a_hash_code in bs_atom_mapping:
312  trg_bs_indices.append(bs_atom_mapping[mdl_a_hash_code])
313  elif mdl_a_hash_code not in mdl_lig_hashes:
314  if a.GetChain().GetName() in flat_mapping:
315  # Its in a mapped chain
316  at_key = (a.GetResidue().GetNumber(), a.name)
317  cname = a.GetChain().name
318  cname_key = (flat_mapping[cname], cname)
319  if at_key in self._mappable_atoms_mappable_atoms[cname_key]:
320  # Its a contact in the model but not part of
321  # trg_bs. It can still be mapped using the
322  # global mdl_ch/ref_ch alignment
323  # d in ref > self.lddt_pli_radius + max_thresh
324  # => guaranteed to be non-fulfilled contact
325  n_penalty += 1
326 
327  n_penalties.append(n_penalty)
328 
329  trg_bs_indices = np.asarray(sorted(trg_bs_indices))
330 
331  for trg_a_idx in ligand_atom_mappings[mdl_a_idx]:
332  # mask selects entries in trg_bs_indices that are not yet
333  # part of classic lDDT ref_indices for atom at trg_a_idx
334  # => added mdl contacts
335  mask = np.isin(trg_bs_indices,
336  ref_indices[ligand_start_idx + trg_a_idx],
337  assume_unique=True, invert=True)
338  added_indices = np.asarray([], dtype=np.int64)
339  added_distances = np.asarray([], dtype=np.float64)
340  if np.sum(mask) > 0:
341  # compute ref distances on reference positions
342  added_indices = trg_bs_indices[mask]
343  tmp = scorer.positions.take(added_indices, axis=0)
344  np.subtract(tmp, trg_ligand_pos[trg_a_idx][None, :],
345  out=tmp)
346  np.square(tmp, out=tmp)
347  tmp = tmp.sum(axis=1)
348  np.sqrt(tmp, out=tmp)
349  added_distances = tmp
350 
351  # extract the distances towards bs atoms that are symmetric
352  sym_mask = np.isin(added_indices, symmetric_atoms,
353  assume_unique=True)
354 
355  cache[(mdl_a_idx, trg_a_idx)] = (added_indices,
356  added_distances,
357  added_indices[sym_mask],
358  added_distances[sym_mask])
359 
360  scoring_cache.append(cache)
361  penalty_cache.append(n_penalties)
362 
363  # cache for model contacts towards non mapped trg chains - this is
364  # relevant for self._lddt_pli_unmapped_chain_penalty
365  # key: tuple in form (trg_ch, mdl_ch)
366  # value: yet another dict with
367  # key: ligand_atom_hash
368  # value: n contacts towards respective trg chain that can be mapped
369  non_mapped_cache = dict()
370 
371 
374 
375  best_score = -1.0
376  best_result = {"lddt_pli": None,
377  "lddt_pli_n_contacts": 0}
378 
379  # dummy alignment for ligand chains which is needed as input later on
380  ligand_aln = seq.CreateAlignment()
381  trg_s = seq.CreateSequence(trg_ligand_chain.name,
382  trg_ligand_res.GetOneLetterCode())
383  mdl_s = seq.CreateSequence(mdl_ligand_chain.name,
384  mdl_ligand_res.GetOneLetterCode())
385  ligand_aln.AddSequence(trg_s)
386  ligand_aln.AddSequence(mdl_s)
387  ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
388 
389  sym_idx_collector = [None] * scorer.n_atoms
390  sym_dist_collector = [None] * scorer.n_atoms
391 
392  for mapping, s_cache, p_cache in zip(chain_mappings, scoring_cache,
393  penalty_cache):
394 
395  lddt_chain_mapping = dict()
396  lddt_alns = dict()
397  for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
398  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
399  # some mdl chains can be None
400  if mdl_ch is not None:
401  lddt_chain_mapping[mdl_ch] = ref_ch
402  lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
403 
404  # add ligand to lddt_chain_mapping/lddt_alns
405  lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
406  lddt_alns[mdl_ligand_chain.name] = ligand_aln
407 
408  # already process model, positions will be manually hacked for each
409  # symmetry - small overhead for variables that are thrown away here
410  pos, _, _, _, _, _, lddt_symmetries = \
411  scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
412  residue_mapping = lddt_alns,
413  thresholds = self.lddt_pli_thresholdslddt_pli_thresholds,
414  check_resnames = False)
415 
416  # estimate a penalty for unsatisfied model contacts from chains
417  # that are not in the local trg binding site, but can be mapped in
418  # the target.
419  # We're using the trg chain with the closest geometric center to
420  # the trg binding site that can be mapped to the mdl chain
421  # according the chem mapping. An alternative would be to search for
422  # the target chain with the minimal number of additional contacts.
423  # There is not good solution for this problem...
424  unmapped_chains = list()
425  already_mapped = set()
426  for mdl_ch in mdl_chains:
427  if mdl_ch not in lddt_chain_mapping:
428  # check which chain in trg is closest
429  chem_grp_idx = None
430  for i, m in enumerate(self._chem_mapping_chem_mapping):
431  if mdl_ch in m:
432  chem_grp_idx = i
433  break
434  if chem_grp_idx is None:
435  raise RuntimeError("This should never happen... "
436  "ask Gabriel...")
437  closest_ch = None
438  closest_dist = None
439  for trg_ch in self._chain_mapper_chain_mapper.chem_groups[chem_grp_idx]:
440  if trg_ch not in lddt_chain_mapping.values():
441  if trg_ch not in already_mapped:
442  ch = self._chain_mapper_chain_mapper.target.FindChain(trg_ch)
443  c = ch.geometric_center
444  d = geom.Distance(trg_bs_center, c)
445  if closest_dist is None or d < closest_dist:
446  closest_dist = d
447  closest_ch = trg_ch
448  if closest_ch is not None:
449  unmapped_chains.append((closest_ch, mdl_ch))
450  already_mapped.add(closest_ch)
451 
452  for (trg_sym, mdl_sym) in symmetries:
453 
454  # update positions
455  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
456  pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
457 
458  # start new ref_indices/ref_distances from original values
459  funky_ref_indices = [np.copy(a) for a in ref_indices]
460  funky_ref_distances = [np.copy(a) for a in ref_distances]
461 
462  # The only distances from the binding site towards the ligand
463  # we care about are the ones from the symmetric atoms to
464  # correctly compute scorer._ResolveSymmetries.
465  # We collect them while updating distances from added mdl
466  # contacts
467  for idx in symmetric_atoms:
468  sym_idx_collector[idx] = list()
469  sym_dist_collector[idx] = list()
470 
471  # add data from added mdl contacts cache
472  added_penalty = 0
473  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
474  added_penalty += p_cache[mdl_i]
475  cache = s_cache[mdl_i, trg_i]
476  full_trg_i = ligand_start_idx + trg_i
477  funky_ref_indices[full_trg_i] = \
478  np.append(funky_ref_indices[full_trg_i], cache[0])
479  funky_ref_distances[full_trg_i] = \
480  np.append(funky_ref_distances[full_trg_i], cache[1])
481  for idx, d in zip(cache[2], cache[3]):
482  sym_idx_collector[idx].append(full_trg_i)
483  sym_dist_collector[idx].append(d)
484 
485  for idx in symmetric_atoms:
486  funky_ref_indices[idx] = \
487  np.append(funky_ref_indices[idx],
488  np.asarray(sym_idx_collector[idx],
489  dtype=np.int64))
490  funky_ref_distances[idx] = \
491  np.append(funky_ref_distances[idx],
492  np.asarray(sym_dist_collector[idx],
493  dtype=np.float64))
494 
495  # we can pass funky_ref_indices/funky_ref_distances as
496  # sym_ref_indices/sym_ref_distances in
497  # scorer._ResolveSymmetries as we only have distances of the bs
498  # to the ligand and ligand atoms are "non-symmetric"
499  scorer._ResolveSymmetries(pos, self.lddt_pli_thresholdslddt_pli_thresholds,
500  lddt_symmetries,
501  funky_ref_indices,
502  funky_ref_distances)
503 
504  N = sum([len(funky_ref_indices[i]) for i in ligand_at_indices])
505  N += added_penalty
506 
507  # collect number of expected contacts which can be mapped
508  if len(unmapped_chains) > 0:
509  N += self._lddt_pli_unmapped_chain_penalty_lddt_pli_unmapped_chain_penalty(unmapped_chains,
510  non_mapped_cache,
511  mdl_bs,
512  mdl_ligand_res,
513  mdl_sym)
514 
515  conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
516  self.lddt_pli_thresholdslddt_pli_thresholds,
517  funky_ref_indices,
518  funky_ref_distances),
519  axis=0)
520  score = None
521  if N > 0:
522  score = np.mean(conserved/N)
523 
524  if score is not None and score > best_score:
525  best_score = score
526  best_result = {"lddt_pli": score,
527  "lddt_pli_n_contacts": N}
528 
529  # fill misc info to result object
530  best_result["target_ligand"] = target_ligand
531  best_result["model_ligand"] = model_ligand
532  best_result["bs_ref_res"] = trg_residues
533  best_result["bs_mdl_res"] = mdl_residues
534 
535  return best_result
536 
537 
538  def _compute_lddt_pli_classic(self, symmetries, target_ligand,
539  model_ligand):
540 
541 
544 
545  max_r = None
546  if self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius:
547  max_r = self.lddt_pli_binding_site_radiuslddt_pli_binding_site_radius
548 
549  trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
550  trg_ligand_res, scorer, chem_groups = \
551  self._lddt_pli_get_trg_data_lddt_pli_get_trg_data(target_ligand, max_r = max_r)
552 
553  # Copy to make sure that we don't change anything on underlying
554  # references
555  # This is not strictly necessary in the current implementation but
556  # hey, maybe it avoids hard to debug errors when someone changes things
557  ref_indices = [a.copy() for a in scorer.ref_indices_ic]
558  ref_distances = [a.copy() for a in scorer.ref_distances_ic]
559 
560  # no matter what mapping/symmetries, the number of expected
561  # contacts stays the same
562  ligand_start_idx = scorer.chain_start_indices[-1]
563  ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
564  n_exp = sum([len(ref_indices[i]) for i in ligand_at_indices])
565 
566  mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
567  chem_mapping = self._lddt_pli_get_mdl_data_lddt_pli_get_mdl_data(model_ligand)
568 
569  if n_exp == 0:
570  # no contacts... nothing to compute...
571  return {"lddt_pli": None,
572  "lddt_pli_n_contacts": 0,
573  "target_ligand": target_ligand,
574  "model_ligand": model_ligand,
575  "bs_ref_res": trg_residues,
576  "bs_mdl_res": mdl_residues}
577 
578  # Distance hacking... remove any interchain distance except the ones
579  # with the ligand
580  for at_idx in range(ligand_start_idx):
581  mask = ref_indices[at_idx] >= ligand_start_idx
582  ref_indices[at_idx] = ref_indices[at_idx][mask]
583  ref_distances[at_idx] = ref_distances[at_idx][mask]
584 
585 
588 
589  # ref_mdl_alns refers to full chain mapper trg and mdl structures
590  # => need to adapt mdl sequence that only contain residues in contact
591  # with ligand
592  cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns_lddt_pli_cut_ref_mdl_alns(chem_groups,
593  chem_mapping,
594  mdl_bs, trg_bs)
595 
596 
599 
600  best_score = -1.0
601 
602  # dummy alignment for ligand chains which is needed as input later on
603  l_aln = seq.CreateAlignment()
604  l_aln.AddSequence(seq.CreateSequence(trg_ligand_chain.name,
605  trg_ligand_res.GetOneLetterCode()))
606  l_aln.AddSequence(seq.CreateSequence(mdl_ligand_chain.name,
607  mdl_ligand_res.GetOneLetterCode()))
608 
609  mdl_ligand_pos = np.zeros((model_ligand.GetAtomCount(), 3))
610  for a_idx, a in enumerate(model_ligand.atoms):
611  p = a.GetPos()
612  mdl_ligand_pos[a_idx, 0] = p[0]
613  mdl_ligand_pos[a_idx, 1] = p[1]
614  mdl_ligand_pos[a_idx, 2] = p[2]
615 
616  for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping):
617 
618  lddt_chain_mapping = dict()
619  lddt_alns = dict()
620  for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
621  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
622  # some mdl chains can be None
623  if mdl_ch is not None:
624  lddt_chain_mapping[mdl_ch] = ref_ch
625  lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
626 
627  # add ligand to lddt_chain_mapping/lddt_alns
628  lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
629  lddt_alns[mdl_ligand_chain.name] = l_aln
630 
631  # already process model, positions will be manually hacked for each
632  # symmetry - small overhead for variables that are thrown away here
633  pos, _, _, _, _, _, lddt_symmetries = \
634  scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
635  residue_mapping = lddt_alns,
636  thresholds = self.lddt_pli_thresholdslddt_pli_thresholds,
637  check_resnames = False)
638 
639  for (trg_sym, mdl_sym) in symmetries:
640  for mdl_i, trg_i in zip(mdl_sym, trg_sym):
641  pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
642  # we can pass ref_indices/ref_distances as
643  # sym_ref_indices/sym_ref_distances in
644  # scorer._ResolveSymmetries as we only have distances of the bs
645  # to the ligand and ligand atoms are "non-symmetric"
646  scorer._ResolveSymmetries(pos, self.lddt_pli_thresholdslddt_pli_thresholds,
647  lddt_symmetries,
648  ref_indices,
649  ref_distances)
650  # compute number of conserved distances for ligand atoms
651  conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
652  self.lddt_pli_thresholdslddt_pli_thresholds,
653  ref_indices,
654  ref_distances), axis=0)
655  score = np.mean(conserved/n_exp)
656 
657  if score > best_score:
658  best_score = score
659 
660  # fill misc info to result object
661  best_result = {"lddt_pli": best_score,
662  "lddt_pli_n_contacts": n_exp,
663  "target_ligand": target_ligand,
664  "model_ligand": model_ligand,
665  "bs_ref_res": trg_residues,
666  "bs_mdl_res": mdl_residues}
667 
668  return best_result
669 
670  def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains,
671  non_mapped_cache,
672  mdl_bs,
673  mdl_ligand_res,
674  mdl_sym):
675 
676  n_exp = 0
677  for ch_tuple in unmapped_chains:
678  if ch_tuple not in non_mapped_cache:
679  # for each ligand atom, we count the number of mappable atoms
680  # within lddt_pli_radius
681  counts = dict()
682  # the select statement also excludes the ligand in mdl_bs
683  # as it resides in a separate chain
684  mdl_cname = ch_tuple[1]
685  query = "cname=" + mol.QueryQuoteName(mdl_cname)
686  mdl_bs_ch = mdl_bs.Select(query)
687  for a in mdl_ligand_res.atoms:
688  close_atoms = \
689  mdl_bs_ch.FindWithin(a.GetPos(), self.lddt_pli_radiuslddt_pli_radius)
690  N = 0
691  for close_a in close_atoms:
692  at_key = (close_a.GetResidue().GetNumber(),
693  close_a.GetName())
694  if at_key in self._mappable_atoms_mappable_atoms[ch_tuple]:
695  N += 1
696  counts[a.hash_code] = N
697 
698  # fill cache
699  non_mapped_cache[ch_tuple] = counts
700 
701  # add number of mdl contacts which can be mapped to target
702  # as non-fulfilled contacts
703  counts = non_mapped_cache[ch_tuple]
704  lig_hash_codes = [a.hash_code for a in mdl_ligand_res.atoms]
705  for i in mdl_sym:
706  n_exp += counts[lig_hash_codes[i]]
707 
708  return n_exp
709 
710 
711  def _lddt_pli_get_mdl_data(self, model_ligand):
712  if model_ligand not in self._lddt_pli_model_data_lddt_pli_model_data:
713 
714  mdl = self._chain_mapping_mdl_chain_mapping_mdl
715 
716  mdl_residues = set()
717  for at in model_ligand.atoms:
718  close_atoms = mdl.FindWithin(at.GetPos(), self.lddt_pli_radiuslddt_pli_radius)
719  for close_at in close_atoms:
720  mdl_residues.add(close_at.GetResidue())
721 
722  max_r = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds)
723  for r in mdl.residues:
724  r.SetIntProp("bs", 0)
725  for at in model_ligand.atoms:
726  close_atoms = mdl.FindWithin(at.GetPos(), max_r)
727  for close_at in close_atoms:
728  close_at.GetResidue().SetIntProp("bs", 1)
729 
730  mdl_bs = mol.CreateEntityFromView(mdl.Select("grbs:0=1"), True)
731  mdl_chains = set([ch.name for ch in mdl_bs.chains])
732 
733  mdl_editor = mdl_bs.EditXCS(mol.BUFFERED_EDIT)
734  mdl_ligand_chain = None
735  for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
736  try:
737  # I'm pretty sure, one of these chain names is not there...
738  mdl_ligand_chain = mdl_editor.InsertChain(cname)
739  break
740  except:
741  pass
742  if mdl_ligand_chain is None:
743  raise RuntimeError("Fuck this, I'm out...")
744  mdl_ligand_res = mdl_editor.AppendResidue(mdl_ligand_chain,
745  model_ligand,
746  deep=True)
747  mdl_editor.RenameResidue(mdl_ligand_res, "LIG")
748  mdl_editor.SetResidueNumber(mdl_ligand_res, mol.ResNum(1))
749 
750  chem_mapping = list()
751  for m in self._chem_mapping_chem_mapping:
752  chem_mapping.append([x for x in m if x in mdl_chains])
753 
754  self._lddt_pli_model_data_lddt_pli_model_data[model_ligand] = (mdl_residues,
755  mdl_bs,
756  mdl_chains,
757  mdl_ligand_chain,
758  mdl_ligand_res,
759  chem_mapping)
760 
761  return self._lddt_pli_model_data_lddt_pli_model_data[model_ligand]
762 
763 
764  def _lddt_pli_get_trg_data(self, target_ligand, max_r = None):
765  if target_ligand not in self._lddt_pli_target_data_lddt_pli_target_data:
766 
767  trg = self._chain_mapper_chain_mapper.target
768 
769  if max_r is None:
770  max_r = self.lddt_pli_radiuslddt_pli_radius + max(self.lddt_pli_thresholdslddt_pli_thresholds)
771 
772  trg_residues = set()
773  for at in target_ligand.atoms:
774  close_atoms = trg.FindWithin(at.GetPos(), max_r)
775  for close_at in close_atoms:
776  trg_residues.add(close_at.GetResidue())
777 
778  for r in trg.residues:
779  r.SetIntProp("bs", 0)
780 
781  for r in trg_residues:
782  r.SetIntProp("bs", 1)
783 
784  trg_bs = mol.CreateEntityFromView(trg.Select("grbs:0=1"), True)
785  trg_chains = set([ch.name for ch in trg_bs.chains])
786 
787  trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT)
788  trg_ligand_chain = None
789  for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
790  try:
791  # I'm pretty sure, one of these chain names is not there yet
792  trg_ligand_chain = trg_editor.InsertChain(cname)
793  break
794  except:
795  pass
796  if trg_ligand_chain is None:
797  raise RuntimeError("Fuck this, I'm out...")
798 
799  trg_ligand_res = trg_editor.AppendResidue(trg_ligand_chain,
800  target_ligand,
801  deep=True)
802  trg_editor.RenameResidue(trg_ligand_res, "LIG")
803  trg_editor.SetResidueNumber(trg_ligand_res, mol.ResNum(1))
804 
805  compound_name = trg_ligand_res.name
806  compound = lddt.CustomCompound.FromResidue(trg_ligand_res)
807  custom_compounds = {compound_name: compound}
808 
809  scorer = lddt.lDDTScorer(trg_bs,
810  custom_compounds = custom_compounds,
811  inclusion_radius = self.lddt_pli_radiuslddt_pli_radius)
812 
813  chem_groups = list()
814  for g in self._chain_mapper_chain_mapper.chem_groups:
815  chem_groups.append([x for x in g if x in trg_chains])
816 
817  self._lddt_pli_target_data_lddt_pli_target_data[target_ligand] = (trg_residues,
818  trg_bs,
819  trg_chains,
820  trg_ligand_chain,
821  trg_ligand_res,
822  scorer,
823  chem_groups)
824 
825  return self._lddt_pli_target_data_lddt_pli_target_data[target_ligand]
826 
827 
828  def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs,
829  ref_bs):
830  cut_ref_mdl_alns = dict()
831  for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping):
832  for ref_ch in ref_chem_group:
833 
834  ref_bs_chain = ref_bs.FindChain(ref_ch)
835  query = "cname=" + mol.QueryQuoteName(ref_ch)
836  ref_view = self._chain_mapper_chain_mapper.target.Select(query)
837 
838  for mdl_ch in mdl_chem_group:
839  aln = self._ref_mdl_alns_ref_mdl_alns[(ref_ch, mdl_ch)]
840 
841  aln.AttachView(0, ref_view)
842 
843  mdl_bs_chain = mdl_bs.FindChain(mdl_ch)
844  query = "cname=" + mol.QueryQuoteName(mdl_ch)
845  aln.AttachView(1, self._chain_mapping_mdl_chain_mapping_mdl.Select(query))
846 
847  cut_mdl_seq = ['-'] * aln.GetLength()
848  cut_ref_seq = ['-'] * aln.GetLength()
849  for i, col in enumerate(aln):
850 
851  # check ref residue
852  r = col.GetResidue(0)
853  if r.IsValid():
854  bs_r = ref_bs_chain.FindResidue(r.GetNumber())
855  if bs_r.IsValid():
856  cut_ref_seq[i] = col[0]
857 
858  # check mdl residue
859  r = col.GetResidue(1)
860  if r.IsValid():
861  bs_r = mdl_bs_chain.FindResidue(r.GetNumber())
862  if bs_r.IsValid():
863  cut_mdl_seq[i] = col[1]
864 
865  cut_ref_seq = ''.join(cut_ref_seq)
866  cut_mdl_seq = ''.join(cut_mdl_seq)
867  cut_aln = seq.CreateAlignment()
868  cut_aln.AddSequence(seq.CreateSequence(ref_ch, cut_ref_seq))
869  cut_aln.AddSequence(seq.CreateSequence(mdl_ch, cut_mdl_seq))
870  cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln
871  return cut_ref_mdl_alns
872 
873  @property
874  def _mappable_atoms(self):
875  """ Stores mappable atoms given a chain mapping
876 
877  Store for each ref_ch,mdl_ch pair all mdl atoms that can be
878  mapped. Don't store mappable atoms as hashes but rather as tuple
879  (mdl_r.GetNumber(), mdl_a.GetName()). Reason for that is that one might
880  operate on Copied EntityHandle objects without corresponding hashes.
881  Given a tuple defining c_pair: (ref_cname, mdl_cname), one
882  can check if a certain atom is mappable by evaluating:
883  if (mdl_r.GetNumber(), mdl_a.GetName()) in self._mappable_atoms(c_pair)
884  """
885  if self.__mappable_atoms__mappable_atoms is None:
886  self.__mappable_atoms__mappable_atoms = dict()
887  for (ref_cname, mdl_cname), aln in self._ref_mdl_alns_ref_mdl_alns.items():
888  self._mappable_atoms_mappable_atoms[(ref_cname, mdl_cname)] = set()
889  ref_query = f"cname={mol.QueryQuoteName(ref_cname)}"
890  mdl_query = f"cname={mol.QueryQuoteName(mdl_cname)}"
891  ref_ch = self._chain_mapper_chain_mapper.target.Select(ref_query)
892  mdl_ch = self._chain_mapping_mdl_chain_mapping_mdl.Select(mdl_query)
893  aln.AttachView(0, ref_ch)
894  aln.AttachView(1, mdl_ch)
895  for col in aln:
896  ref_r = col.GetResidue(0)
897  mdl_r = col.GetResidue(1)
898  if ref_r.IsValid() and mdl_r.IsValid():
899  for mdl_a in mdl_r.atoms:
900  if ref_r.FindAtom(mdl_a.name).IsValid():
901  c_key = (ref_cname, mdl_cname)
902  at_key = (mdl_r.GetNumber(), mdl_a.name)
903  self.__mappable_atoms__mappable_atoms[c_key].add(at_key)
904 
905  return self.__mappable_atoms__mappable_atoms
906 
907  @property
908  def _chem_mapping(self):
909  if self.__chem_mapping__chem_mapping is None:
910  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
911  self.__chain_mapping_mdl__chain_mapping_mdl = \
912  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
913  return self.__chem_mapping__chem_mapping
914 
915  @property
916  def _chem_group_alns(self):
917  if self.__chem_group_alns__chem_group_alns is None:
918  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
919  self.__chain_mapping_mdl__chain_mapping_mdl = \
920  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
921  return self.__chem_group_alns__chem_group_alns
922 
923  @property
924  def _ref_mdl_alns(self):
925  if self.__ref_mdl_alns__ref_mdl_alns is None:
926  self.__ref_mdl_alns__ref_mdl_alns = \
927  chain_mapping._GetRefMdlAlns(self._chain_mapper_chain_mapper.chem_groups,
928  self._chain_mapper_chain_mapper.chem_group_alignments,
929  self._chem_mapping_chem_mapping,
930  self._chem_group_alns_chem_group_alns)
931  return self.__ref_mdl_alns__ref_mdl_alns
932 
933  @property
934  def _chain_mapping_mdl(self):
935  if self.__chain_mapping_mdl__chain_mapping_mdl is None:
936  with ligand_scoring_base._SinkVerbosityLevel():
937  self.__chem_mapping__chem_mapping, self.__chem_group_alns__chem_group_alns, \
938  self.__chain_mapping_mdl__chain_mapping_mdl = \
939  self._chain_mapper_chain_mapper.GetChemMapping(self.modelmodel)
940  return self.__chain_mapping_mdl__chain_mapping_mdl
941 
942 # specify public interface
943 __all__ = ('LDDTPLIScorer',)
def _lddt_pli_get_trg_data(self, target_ligand, max_r=None)
def _compute_lddt_pli_classic(self, symmetries, target_ligand, model_ligand)
def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains, non_mapped_cache, mdl_bs, mdl_ligand_res, mdl_sym)
def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs, ref_bs)
def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand, model_ligand)
def __init__(self, model, target, model_ligands, target_ligands, resnum_alignments=False, rename_ligand_chain=False, substructure_match=False, coverage_delta=0.2, max_symmetries=1e4, lddt_pli_radius=6.0, add_mdl_contacts=True, lddt_pli_thresholds=[0.5, 1.0, 2.0, 4.0], lddt_pli_binding_site_radius=None)
Real DLLEXPORT_OST_GEOM Distance(const Line2 &l, const Vec2 &v)