OpenStructure
 All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
lddt.py
Go to the documentation of this file.
1 import numpy as np
2 
3 from ost import mol
4 from ost import conop
5 
6 
8  """ Defines atoms for custom compounds
9 
10  lDDT requires the reference atoms of a compound which are typically
11  extracted from a :class:`ost.conop.CompoundLib`. This lightweight
12  container allows to handle arbitrary compounds which are not
13  necessarily in the compound library.
14 
15  :param atom_names: Names of atoms of custom compound
16  :type atom_names: :class:`list` of :class:`str`
17  """
18  def __init__(self, atom_names):
19  self.atom_names = atom_names
20 
21  @staticmethod
22  def FromResidue(res):
23  """ Construct custom compound from residue
24 
25  :param res: Residue from which reference atom names are extracted,
26  hydrogen/deuterium atoms are filtered out
27  :type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle`
28  :returns: :class:`CustomCompound`
29  """
30  at_names = [a.name for a in res.atoms if a.element not in ["H", "D"]]
31  if len(at_names) != len(set(at_names)):
32  raise RuntimeError("Duplicate atoms detected in CustomCompound")
33  compound = CustomCompound(at_names)
34  return compound
35 
37  """Container for symmetric compounds
38 
39  lDDT considers symmetries and selects the one resulting in the highest
40  possible score.
41 
42  A symmetry is defined as a renaming operation on one or more atoms that
43  leads to a chemically equivalent residue. Example would be OD1 and OD2 in
44  ASP => renaming OD1 to OD2 and vice versa gives a chemically equivalent
45  residue.
46 
47  Use :func:`AddSymmetricCompound` to define a symmetry which can then
48  directly be accessed through the *symmetric_compounds* member.
49  """
50  def __init__(self):
51  self.symmetric_compounds = dict()
52 
53  def AddSymmetricCompound(self, name, symmetric_atoms):
54  """Adds symmetry for compound with *name*
55 
56  :param name: Name of compound with symmetry
57  :type name: :class:`str`
58  :param symmetric_atoms: Pairs of atom names that define renaming
59  operation, i.e. after applying all switches
60  defined in the tuples, the resulting residue
61  should be chemically equivalent. Atom names
62  must refer to the PDB component dictionary.
63  :type symmetric_atoms: :class:`list` of :class:`tuple`
64  """
65  for pair in symmetric_atoms:
66  if len(pair) != 2:
67  raise RuntimeError("Expect pairs when defining symmetries")
68  self.symmetric_compounds[name] = symmetric_atoms
69 
70 
72  """Constructs and returns :class:`SymmetrySettings` object for natural amino
73  acids
74  """
75  symmetry_settings = SymmetrySettings()
76 
77  # ASP
78  symmetry_settings.AddSymmetricCompound("ASP", [("OD1", "OD2")])
79 
80  # GLU
81  symmetry_settings.AddSymmetricCompound("GLU", [("OE1", "OE2")])
82 
83  # LEU
84  symmetry_settings.AddSymmetricCompound("LEU", [("CD1", "CD2")])
85 
86  # VAL
87  symmetry_settings.AddSymmetricCompound("VAL", [("CG1", "CG2")])
88 
89  # ARG
90  symmetry_settings.AddSymmetricCompound("ARG", [("NH1", "NH2")])
91 
92  # PHE
93  symmetry_settings.AddSymmetricCompound(
94  "PHE", [("CD1", "CD2"), ("CE1", "CE2")]
95  )
96 
97  # TYR
98  symmetry_settings.AddSymmetricCompound(
99  "TYR", [("CD1", "CD2"), ("CE1", "CE2")]
100  )
101 
102  return symmetry_settings
103 
104 
106  """lDDT scorer object for a specific target
107 
108  Sets up everything to score models of that target. lDDT (local distance
109  difference test) is defined as fraction of pairwise distances which exhibit
110  a difference < threshold when considering target and model. In case of
111  multiple thresholds, the average is returned. See
112 
113  V. Mariani, M. Biasini, A. Barbato, T. Schwede, lDDT : A local
114  superposition-free score for comparing protein structures and models using
115  distance difference tests, Bioinformatics, 2013
116 
117  :param target: The target
118  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
119  :param compound_lib: Compound library from which a compound for each residue
120  is extracted based on its name. Uses
121  :func:`ost.conop.GetDefaultLib` if not given, raises
122  if this returns no valid compound library. Atoms
123  defined in the compound are searched in the residue and
124  build the reference for scoring. If the residue has
125  atoms with names ["A", "B", "C"] but the corresponding
126  compound only has ["A", "B"], "A" and "B" are
127  considered for scoring. If the residue has atoms
128  ["A", "B"] but the compound has ["A", "B", "C"], "C" is
129  considered missing and does not influence scoring, even
130  if present in the model.
131  :param custom_compounds: Custom compounds defining reference atoms. If
132  given, *custom_compounds* take precedent over
133  *compound_lib*.
134  :type custom_compounds: :class:`dict` with residue names (:class:`str`) as
135  key and :class:`CustomCompound` as value.
136  :type compound_lib: :class:`ost.conop.CompoundLib`
137  :param inclusion_radius: All pairwise distances < *inclusion_radius* are
138  considered for scoring
139  :type inclusion_radius: :class:`float`
140  :param sequence_separation: Only pairwise distances between atoms of
141  residues which are further apart than this
142  threshold are considered. Residue distance is
143  based on resnum. The default (0) considers all
144  pairwise distances except intra-residue
145  distances.
146  :type sequence_separation: :class:`int`
147  :param symmetry_settings: Define residues exhibiting internal symmetry, uses
148  :func:`GetDefaultSymmetrySettings` if not given.
149  :type symmetry_settings: :class:`SymmetrySettings`
150  :param seqres_mapping: Mapping of model residues at the scoring stage
151  happens with residue numbers defining their location
152  in a reference sequence (SEQRES) using one based
153  indexing. If the residue numbers in *target* don't
154  correspond to that SEQRES, you can specify the
155  mapping manually. You can provide a dictionary to
156  specify a reference sequence (SEQRES) for one or more
157  chain(s). Key: chain name, value: alignment
158  (seq1: SEQRES, seq2: sequence of residues in chain).
159  Example: The residues in a chain with name "A" have
160  sequence "YEAH" and residue numbers [42,43,44,45].
161  You can provide an alignment with seq1 "``HELLYEAH``"
162  and seq2 "``----YEAH``". "Y" gets assigned residue
163  number 5, "E" gets assigned 6 and so on no matter
164  what the original residue numbers were.
165  :type seqres_mapping: :class:`dict` (key: :class:`str`, value:
166  :class:`ost.seq.AlignmentHandle`)
167  :param bb_only: Only consider atoms with name "CA" in case of amino acids and
168  "C3'" for Nucleotides. this invalidates *compound_lib*.
169  Raises if any residue in *target* is not
170  `r.chem_class.IsPeptideLinking()` or
171  `r.chem_class.IsNucleotideLinking()`
172  :type bb_only: :class:`bool`
173  :raises: :class:`RuntimeError` if *target* contains compound which is not in
174  *compound_lib*, :class:`RuntimeError` if *symmetry_settings*
175  specifies symmetric atoms that are not present in the according
176  compound in *compound_lib*, :class:`RuntimeError` if
177  *seqres_mapping* is not provided and *target* contains residue
178  numbers with insertion codes or the residue numbers for each chain
179  are not monotonically increasing, :class:`RuntimeError` if
180  *seqres_mapping* is provided but an alignment is invalid
181  (seq1 contains gaps, mismatch in seq1/seq2, seq2 does not match
182  residues in corresponding chains).
183  """
184  def __init__(
185  self,
186  target,
187  compound_lib=None,
188  custom_compounds=None,
189  inclusion_radius=15,
190  sequence_separation=0,
191  symmetry_settings=None,
192  seqres_mapping=dict(),
193  bb_only=False
194  ):
195 
196  self.target = target
197  self.inclusion_radius = inclusion_radius
198  self.sequence_separation = sequence_separation
199  if compound_lib is None:
200  compound_lib = conop.GetDefaultLib()
201  if compound_lib is None:
202  raise RuntimeError("No compound_lib given and conop.GetDefaultLib "
203  "returns no valid compound library")
204  self.compound_lib = compound_lib
205  self.custom_compounds = custom_compounds
206  if symmetry_settings is None:
208  else:
209  self.symmetry_settings = symmetry_settings
210 
211  # whether to only consider atoms with name "CA" (amino acids) or C3'
212  # (nucleotides), invalidates *compound_lib*
213  self.bb_only=bb_only
214 
215  # names of heavy atoms of each unique compound present in *target* as
216  # extracted from *compound_lib*, e.g.
217  # self.compound_anames["GLY"] = ["N", "CA", "C", "O"]
218  self.compound_anames = dict()
219 
220  # stores symmetry information for those compounds as defined in
221  # *symmetry_settings*
223 
224  # list of len(target.chains) containing all chain names in *target*
225  self.chain_names = list()
226 
227  # list of len(target.residues) containing all compound names in *target*
228  self.compound_names = list()
229 
230  # list of len(target.residues) defining start pos in internal reference
231  # positions for each residue
232  self.res_start_indices = list()
233 
234  # list of len(target.residues) defining residue numbers in target
235  self.res_resnums = list()
236 
237  # list of len(target.chains) defining start pos in internal reference
238  # positions for each chain
239  self.chain_start_indices = list()
240 
241  # list of len(target.chains) defining start pos in self.compound_names
242  # for each chain
244 
245  # maps residues in *target* to indices in
246  # self.compound_names/self.res_start_indices. A residue gets identified
247  # by a tuple (first element: chain name, second element: residue number,
248  # residue number is either the actual residue number in *target* or
249  # given by *seqres_mapping*)
250  self.res_mapper = dict()
251 
252  # number of atoms as specified in compounds. not all are necessarily
253  # covered by structure
254  self.n_atoms = None
255 
256  # stores an index for each AtomHandle in *target*
257  # (atom hashcode => index)
258  self.atom_indices = dict()
259 
260  # store indices of all atoms that have symmetry properties
261  self.symmetric_atoms = set()
262 
263  # setup members defined above
264  self._SetupEnv(self.compound_lib, self.custom_compounds,
265  self.symmetry_settings, seqres_mapping, self.bb_only)
266 
267  # distance related members are lazily computed as they're affected
268  # by different flavours of lDDT (e.g. lDDT including inter-chain
269  # contacts or not etc.)
270 
271  # stores for each atom the other atoms within inclusion_radius
272  self._ref_indices = None
273  # the corresponding distances
274  self._ref_distances = None
275 
276  # The following lists will be sparsely populated. We keep for each
277  # symmetry related atom the distances towards all atoms which are NOT
278  # affected by symmetry. So we can evaluate two symmetric versions
279  # against the fixed stuff later on and select the better scoring one.
280  self._sym_ref_indices = None
281  self._sym_ref_distances = None
282 
283  # total number of distances
284  self._n_distances = None
285 
286  # exactly the same as above but without interchain contacts
287  # => single-chain (sc)
288  self._ref_indices_sc = None
289  self._ref_distances_sc = None
290  self._sym_ref_indices_sc = None
291  self._sym_ref_distances_sc = None
292  self._n_distances_sc = None
293 
294  # exactly the same as above but without intrachain contacts
295  # => inter-chain (ic)
296  self._ref_indices_ic = None
297  self._ref_distances_ic = None
298  self._sym_ref_indices_ic = None
299  self._sym_ref_distances_ic = None
300  self._n_distances_ic = None
301 
302  # input parameter checking
304 
305  @property
306  def ref_indices(self):
307  if self._ref_indices is None:
308  self._SetupDistances()
309  return self._ref_indices
310 
311  @property
312  def ref_distances(self):
313  if self._ref_distances is None:
314  self._SetupDistances()
315  return self._ref_distances
316 
317  @property
318  def sym_ref_indices(self):
319  if self._sym_ref_indices is None:
320  self._SetupDistances()
321  return self._sym_ref_indices
322 
323  @property
324  def sym_ref_distances(self):
325  if self._sym_ref_distances is None:
326  self._SetupDistances()
327  return self._sym_ref_distances
328 
329  @property
330  def n_distances(self):
331  if self._n_distances is None:
332  self._n_distances = sum([len(x) for x in self.ref_indices])
333  return self._n_distances
334 
335  @property
336  def ref_indices_sc(self):
337  if self._ref_indices_sc is None:
338  self._SetupDistancesSC()
339  return self._ref_indices_sc
340 
341  @property
342  def ref_distances_sc(self):
343  if self._ref_distances_sc is None:
344  self._SetupDistancesSC()
345  return self._ref_distances_sc
346 
347  @property
349  if self._sym_ref_indices_sc is None:
350  self._SetupDistancesSC()
351  return self._sym_ref_indices_sc
352 
353  @property
355  if self._sym_ref_distances_sc is None:
356  self._SetupDistancesSC()
357  return self._sym_ref_distances_sc
358 
359  @property
360  def n_distances_sc(self):
361  if self._n_distances_sc is None:
362  self._n_distances_sc = sum([len(x) for x in self.ref_indices_sc])
363  return self._n_distances_sc
364 
365  @property
366  def ref_indices_ic(self):
367  if self._ref_indices_ic is None:
368  self._SetupDistancesIC()
369  return self._ref_indices_ic
370 
371  @property
372  def ref_distances_ic(self):
373  if self._ref_distances_ic is None:
374  self._SetupDistancesIC()
375  return self._ref_distances_ic
376 
377  @property
379  if self._sym_ref_indices_ic is None:
380  self._SetupDistancesIC()
381  return self._sym_ref_indices_ic
382 
383  @property
385  if self._sym_ref_distances_ic is None:
386  self._SetupDistancesIC()
387  return self._sym_ref_distances_ic
388 
389  @property
390  def n_distances_ic(self):
391  if self._n_distances_ic is None:
392  self._n_distances_ic = sum([len(x) for x in self.ref_indices_ic])
393  return self._n_distances_ic
394 
395  def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
396  local_lddt_prop=None, local_contact_prop=None,
397  chain_mapping=None, no_interchain=False,
398  no_intrachain=False, penalize_extra_chains=False,
399  residue_mapping=None, return_dist_test=False,
400  check_resnames=True):
401  """Computes lDDT of *model* - globally and per-residue
402 
403  :param model: Model to be scored - models are preferably scored upon
404  performing stereo-chemistry checks in order to punish for
405  non-sensical irregularities. This must be done separately
406  as a pre-processing step. Target contacts that are not
407  covered by *model* are considered not conserved, thus
408  decreasing lDDT score. This also includes missing model
409  chains or model chains for which no mapping is provided in
410  *chain_mapping*.
411  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
412  :param thresholds: Thresholds of distance differences to be considered
413  as correct - see docs in constructor for more info.
414  default: [0.5, 1.0, 2.0, 4.0]
415  :type thresholds: :class:`list` of :class:`floats`
416  :param local_lddt_prop: If set, per-residue scores will be assigned as
417  generic float property of that name
418  :type local_lddt_prop: :class:`str`
419  :param local_contact_prop: If set, number of expected contacts as well
420  as number of conserved contacts will be
421  assigned as generic int property.
422  Excected contacts will be set as
423  <local_contact_prop>_exp, conserved contacts
424  as <local_contact_prop>_cons. Values
425  are summed over all thresholds.
426  :type local_contact_prop: :class:`str`
427  :param chain_mapping: Mapping of model chains (key) onto target chains
428  (value). This is required if target or model have
429  more than one chain.
430  :type chain_mapping: :class:`dict` with :class:`str` as keys/values
431  :param no_interchain: Whether to exclude interchain contacts
432  :type no_interchain: :class:`bool`
433  :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
434  consider interface related contacts)
435  :type no_intrachain: :class:`bool`
436  :param penalize_extra_chains: Whether to include a fixed penalty for
437  additional chains in the model that are
438  not mapped to the target. ONLY AFFECTS
439  RETURNED GLOBAL SCORE. In detail: adds the
440  number of intra-chain contacts of each
441  extra chain to the expected contacts, thus
442  adding a penalty.
443  :param penalize_extra_chains: :class:`bool`
444  :param residue_mapping: By default, residue mapping is based on residue
445  numbers. That means, a model chain and the
446  respective target chain map to the same
447  underlying reference sequence (SEQRES).
448  Alternatively, you can specify one or
449  several alignment(s) between model and target
450  chains by providing a dictionary. key: Name
451  of chain in model (respective target chain is
452  extracted from *chain_mapping*),
453  value: Alignment with first sequence
454  corresponding to target chain and second
455  sequence to model chain. There is NO reference
456  sequence involved, so the two sequences MUST
457  exactly match the actual residues observed in
458  the respective target/model chains (ATOMSEQ).
459  :type residue_mapping: :class:`dict` with key: :class:`str`,
460  value: :class:`ost.seq.AlignmentHandle`
461  :param return_dist_test: Whether to additionally return the underlying
462  per-residue data for the distance difference
463  test. Adds five objects to the return tuple.
464  First: Number of total contacts summed over all
465  thresholds
466  Second: Number of conserved contacts summed
467  over all thresholds
468  Third: list with length of scored residues.
469  Contains indices referring to model.residues.
470  Fourth: numpy array of size
471  len(scored_residues) containing the number of
472  total contacts,
473  Fifth: numpy matrix of shape
474  (len(scored_residues), len(thresholds))
475  specifying how many for each threshold are
476  conserved.
477  :param check_resnames: On by default. Enforces residue name matches
478  between mapped model and target residues.
479  :type check_resnames: :class:`bool`
480  :returns: global and per-residue lDDT scores as a tuple -
481  first element is global lDDT score and second element
482  a list of per-residue scores with length len(*model*.residues)
483  None is assigned to residues that are not covered by target
484  """
485  if chain_mapping is None:
486  if len(self.chain_names) > 1 or len(model.chains) > 1:
487  raise NotImplementedError("Must provide chain mapping if "
488  "target or model have > 1 chains.")
489  chain_mapping = {model.chains[0].GetName(): self.chain_names[0]}
490  else:
491  # check whether chains specified in mapping exist
492  for model_chain, target_chain in chain_mapping.items():
493  if target_chain not in self.chain_names:
494  raise RuntimeError(f"Target chain specified in "
495  f"chain_mapping ({target_chain}) does "
496  f"not exist. Target has chains: "
497  f"{self.chain_names}")
498  ch = model.FindChain(model_chain)
499  if not ch.IsValid():
500  raise RuntimeError(f"Model chain specified in "
501  f"chain_mapping ({model_chain}) does "
502  f"not exist. Model has chains: "
503  f"{[c.GetName() for c in model.chains]}")
504 
505  # initialize positions with values far in nirvana. If a position is not
506  # set, it should be far away from any position in model.
507  max_pos = model.bounds.GetMax()
508  max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
509  max_coordinate += 42 * max(thresholds)
510  pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate
511 
512  # for each scored residue in model a list of indices describing the
513  # atoms from the reference that should be there
514  res_ref_atom_indices = list()
515 
516  # for each scored residue in model a list of indices of atoms that are
517  # actually there
518  res_atom_indices = list()
519 
520  # indices of the scored residues
521  res_indices = list()
522 
523  # Will contain one element per symmetry group
524  symmetries = list()
525 
526  current_model_res_idx = -1
527  for ch in model.chains:
528  model_ch_name = ch.GetName()
529  if model_ch_name not in chain_mapping:
530  current_model_res_idx += len(ch.residues)
531  continue # additional model chain which is not mapped
532  target_ch_name = chain_mapping[model_ch_name]
533 
534  rnums = self._GetChainRNums(ch, residue_mapping, model_ch_name,
535  target_ch_name)
536 
537  for r, rnum in zip(ch.residues, rnums):
538  current_model_res_idx += 1
539  res_mapper_key = (target_ch_name, rnum)
540  if res_mapper_key not in self.res_mapper:
541  continue
542  r_idx = self.res_mapper[res_mapper_key]
543  if check_resnames and r.name != self.compound_names[r_idx]:
544  raise RuntimeError(
545  f"Residue name mismatch for {r}, "
546  f" expect {self.compound_names[r_idx]}"
547  )
548  res_start_idx = self.res_start_indices[r_idx]
549  rname = self.compound_names[r_idx]
550  anames = self.compound_anames[rname]
551  atoms = [r.FindAtom(aname) for aname in anames]
552  res_ref_atom_indices.append(
553  list(range(res_start_idx, res_start_idx + len(anames)))
554  )
555  res_atom_indices.append(list())
556  res_indices.append(current_model_res_idx)
557  for a_idx, a in enumerate(atoms):
558  if a.IsValid():
559  p = a.GetPos()
560  pos[res_start_idx + a_idx][0] = p[0]
561  pos[res_start_idx + a_idx][1] = p[1]
562  pos[res_start_idx + a_idx][2] = p[2]
563  res_atom_indices[-1].append(res_start_idx + a_idx)
564  if rname in self.compound_symmetric_atoms:
565  sym_indices = list()
566  for sym_tuple in self.compound_symmetric_atoms[rname]:
567  a_one = atoms[sym_tuple[0]]
568  a_two = atoms[sym_tuple[1]]
569  if a_one.IsValid() and a_two.IsValid():
570  sym_indices.append(
571  (
572  res_start_idx + sym_tuple[0],
573  res_start_idx + sym_tuple[1],
574  )
575  )
576  if len(sym_indices) > 0:
577  symmetries.append(sym_indices)
578 
579  if no_interchain and no_intrachain:
580  raise RuntimeError("on_interchain and no_intrachain flags are "
581  "mutually exclusive")
582 
583  if no_interchain:
584  sym_ref_indices = self.sym_ref_indices_sc
585  sym_ref_distances = self.sym_ref_distances_sc
586  ref_indices = self.ref_indices_sc
587  ref_distances = self.ref_distances_sc
588  n_distances = self.n_distances_sc
589  elif no_intrachain:
590  sym_ref_indices = self.sym_ref_indices_ic
591  sym_ref_distances = self.sym_ref_distances_ic
592  ref_indices = self.ref_indices_ic
593  ref_distances = self.ref_distances_ic
594  n_distances = self.n_distances_ic
595  else:
596  sym_ref_indices = self.sym_ref_indices
597  sym_ref_distances = self.sym_ref_distances
598  ref_indices = self.ref_indices
599  ref_distances = self.ref_distances
600  n_distances = self.n_distances
601 
602  self._ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
603  sym_ref_distances)
604 
605  per_res_exp = np.asarray([self._GetNExp(res_ref_atom_indices[idx],
606  ref_indices) for idx in range(len(res_indices))], dtype=np.int32)
607  per_res_conserved = self._EvalResidues(pos, thresholds,
608  res_atom_indices,
609  ref_indices, ref_distances)
610 
611  n_thresh = len(thresholds)
612 
613  # do per-residue scores
614  per_res_lDDT = [None] * len(model.residues)
615  for idx in range(len(res_indices)):
616  n_exp = n_thresh * per_res_exp[idx]
617  if n_exp > 0:
618  score = np.sum(per_res_conserved[idx,:]) / n_exp
619  per_res_lDDT[res_indices[idx]] = score
620  else:
621  per_res_lDDT[res_indices[idx]] = 0.0
622 
623 
624  # do full model score
625  if penalize_extra_chains:
626  n_distances += self._GetExtraModelChainPenalty(model, chain_mapping)
627 
628  lDDT_tot = int(n_thresh * n_distances)
629  lDDT_cons = int(np.sum(per_res_conserved))
630  lDDT = None
631  if lDDT_tot > 0:
632  lDDT = float(lDDT_cons) / lDDT_tot
633 
634  # set properties if necessary
635  if local_lddt_prop:
636  residues = model.residues
637  for idx in res_indices:
638  residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx])
639 
640  if local_contact_prop:
641  residues = model.residues
642  exp_prop = local_contact_prop + "_exp"
643  conserved_prop = local_contact_prop + "_cons"
644 
645  for i, r_idx in enumerate(res_indices):
646  residues[r_idx].SetIntProp(exp_prop,
647  n_thresh * int(per_res_exp[i]))
648  residues[r_idx].SetIntProp(conserved_prop,
649  int(np.sum(per_res_conserved[i,:])))
650 
651  if return_dist_test:
652  return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
653  per_res_exp, per_res_conserved
654  else:
655  return lDDT, per_res_lDDT
656 
657  def GetNChainContacts(self, target_chain, no_interchain=False):
658  """Returns number of contacts expected for a certain chain in *target*
659 
660  :param target_chain: Chain in *target* for which you want the number
661  of expected contacts
662  :type target_chain: :class:`str`
663  :param no_interchain: Whether to exclude interchain contacts
664  :type no_interchain: :class:`bool`
665  :raises: :class:`RuntimeError` if specified chain doesnt exist
666  """
667  if target_chain not in self.chain_names:
668  raise RuntimeError(f"Specified chain name ({target_chain}) not in "
669  f"target")
670  ch_idx = self.chain_names.index(target_chain)
671  s = self.chain_start_indices[ch_idx]
672  e = self.n_atoms
673  if ch_idx + 1 < len(self.chain_names):
674  e = self.chain_start_indices[ch_idx+1]
675  if no_interchain:
676  return self._GetNExp(list(range(s, e)), self.ref_indices_sc)
677  else:
678  return self._GetNExp(list(range(s, e)), self.ref_indices)
679 
680 
681  def _GetExtraModelChainPenalty(self, model, chain_mapping):
682  """Counts n distances in extra model chains to be added as penalty
683  """
684  penalty = 0
685  for chain in model.chains:
686  ch_name = chain.GetName()
687  if ch_name not in chain_mapping:
688  sm = self.symmetry_settings
689  mdl_sel = model.Select(f"cname={mol.QueryQuoteName(ch_name)}")
690  dummy_scorer = lDDTScorer(mdl_sel, self.compound_lib,
691  symmetry_settings = sm,
692  inclusion_radius = self.inclusion_radius,
693  bb_only = self.bb_only)
694  penalty += dummy_scorer.n_distances
695  return penalty
696 
697  def _GetChainRNums(self, ch, residue_mapping, model_ch_name,
698  target_ch_name):
699  """Map residues in model chain to target residues
700 
701  There are two options: one is simply using residue numbers,
702  the other is a custom mapping as given in *residue_mapping*
703  """
704  if residue_mapping and model_ch_name in residue_mapping:
705  # extract residue numbers from target chain
706  ch_idx = self.chain_names.index(target_ch_name)
707  start_idx = self.chain_res_start_indices[ch_idx]
708  if ch_idx < len(self.chain_names) - 1:
709  end_idx = self.chain_res_start_indices[ch_idx+1]
710  else:
711  end_idx = len(self.compound_names)
712  target_rnums = self.res_resnums[start_idx:end_idx]
713  # get sequences from alignment and do consistency checks
714  target_seq = residue_mapping[model_ch_name].GetSequence(0)
715  model_seq = residue_mapping[model_ch_name].GetSequence(1)
716  if len(target_seq.GetGaplessString()) != len(target_rnums):
717  raise RuntimeError(f"Try to perform residue mapping for "
718  f"model chain {model_ch_name} which "
719  f"maps to {target_ch_name} in target. "
720  f"Target sequence in alignment suggests "
721  f"{len(target_seq.GetGaplessString())} "
722  f"residues but {len(target_rnums)} are "
723  f"expected.")
724  if len(model_seq.GetGaplessString()) != len(ch.residues):
725  raise RuntimeError(f"Try to perform residue mapping for "
726  f"model chain {model_ch_name} which "
727  f"maps to {target_ch_name} in target. "
728  f"Model sequence in alignment suggests "
729  f"{len(model_seq.GetGaplessString())} "
730  f"residues but {len(ch.residues)} are "
731  f"expected.")
732  rnums = list()
733  target_idx = -1
734  for col in residue_mapping[model_ch_name]:
735  if col[0] != '-':
736  target_idx += 1
737  # handle match
738  if col[0] != '-' and col[1] != '-':
739  rnums.append(target_rnums[target_idx])
740  # insertion in model adds None to rnum
741  if col[0] == '-' and col[1] != '-':
742  rnums.append(None)
743  else:
744  rnums = [r.GetNumber() for r in ch.residues]
745 
746  return rnums
747 
748 
749  def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings,
750  seqres_mapping, bb_only):
751  """Sets target related lDDTScorer members defined in constructor
752 
753  No distance related members - see _SetupDistances
754  """
755  residue_numbers = self._GetTargetResidueNumbers(self.target,
756  seqres_mapping)
757  current_idx = 0
758  for chain in self.target.chains:
759  ch_name = chain.GetName()
760  self.chain_names.append(ch_name)
761  self.chain_start_indices.append(current_idx)
762  self.chain_res_start_indices.append(len(self.compound_names))
763  for r, rnum in zip(chain.residues, residue_numbers[ch_name]):
764  if r.name not in self.compound_anames:
765  # sets compound info in self.compound_anames and
766  # self.compound_symmetric_atoms
767  self._SetupCompound(r, compound_lib, custom_compounds,
768  symmetry_settings, bb_only)
769 
770  self.res_start_indices.append(current_idx)
771  self.res_mapper[(ch_name, rnum)] = len(self.compound_names)
772  self.compound_names.append(r.name)
773  self.res_resnums.append(rnum)
774 
775  atoms = [r.FindAtom(an) for an in self.compound_anames[r.name]]
776  for a in atoms:
777  if a.IsValid():
778  self.atom_indices[a.handle.GetHashCode()] = current_idx
779  current_idx += 1
780 
781  if r.name in self.compound_symmetric_atoms:
782  for sym_tuple in self.compound_symmetric_atoms[r.name]:
783  for a_idx in sym_tuple:
784  a = atoms[a_idx]
785  if a.IsValid():
786  hashcode = a.handle.GetHashCode()
787  self.symmetric_atoms.add(
788  self.atom_indices[hashcode]
789  )
790  self.n_atoms = current_idx
791 
792 
793  def _GetTargetResidueNumbers(self, target, seqres_mapping):
794  """Returns residue numbers for each chain in target as dict
795 
796  They're either directly extracted from the raw residue number
797  from the structure or from user provided alignments
798  """
799  residue_numbers = dict()
800  for ch in target.chains:
801  ch_name = ch.GetName()
802  rnums = list()
803  if ch_name in seqres_mapping:
804  seqres = seqres_mapping[ch_name].GetSequence(0).GetString()
805  atomseq = seqres_mapping[ch_name].GetSequence(1).GetString()
806  # SEQRES must not contain gaps
807  if "-" in seqres:
808  raise RuntimeError(
809  "SEQRES in seqres_mapping must not " "contain gaps"
810  )
811  atomseq_from_chain = [r.one_letter_code for r in ch.residues]
812  if atomseq.replace("-", "") != atomseq_from_chain:
813  raise RuntimeError(
814  "ATOMSEQ in seqres_mapping must match "
815  "raw sequence extracted from chain "
816  "residues"
817  )
818  rnum = 0
819  for seqres_olc, atomseq_olc in zip(seqres, atomseq):
820  if seqres_olc != "-":
821  rnum += 1
822  if atomseq_olc != "-":
823  if seqres_olc != atomseq_olc:
824  raise RuntimeError(
825  f"Residue with number {rnum} in "
826  f"chain {ch_name} has SEQRES "
827  f"ATOMSEQ mismatch"
828  )
829  rnums.append(mol.ResNum(rnum))
830  else:
831  rnums = [r.GetNumber() for r in ch.residues]
832  assert len(rnums) == len(ch.residues)
833  residue_numbers[ch_name] = rnums
834  return residue_numbers
835 
836  def _SetupCompound(self, r, compound_lib, custom_compounds,
837  symmetry_settings, bb_only):
838  """fill self.compound_anames/self.compound_symmetric_atoms
839  """
840  if bb_only:
841  # throw away compound_lib info
842  if r.chem_class.IsPeptideLinking():
843  self.compound_anames[r.name] = ["CA"]
844  elif r.chem_class.IsNucleotideLinking():
845  self.compound_anames[r.name] = ["C3'"]
846  else:
847  raise RuntimeError(f"Only support amino acids and nucleotides "
848  f"if bb_only is True, failed with {str(r)}")
849  self.compound_symmetric_atoms[r.name] = list()
850  else:
851  atom_names = list()
852  symmetric_atoms = list()
853  if custom_compounds is not None and r.GetName() in custom_compounds:
854  atom_names = list(custom_compounds[r.GetName()].atom_names)
855  else:
856  compound = compound_lib.FindCompound(r.name)
857  if compound is None:
858  raise RuntimeError(f"no entry for {r} in compound_lib")
859  for atom_spec in compound.GetAtomSpecs():
860  if atom_spec.element not in ["H", "D"]:
861  atom_names.append(atom_spec.name)
862  if r.name in symmetry_settings.symmetric_compounds:
863  for pair in symmetry_settings.symmetric_compounds[r.name]:
864  try:
865  a = atom_names.index(pair[0])
866  b = atom_names.index(pair[1])
867  except:
868  msg = f"Could not find symmetric atoms "
869  msg += f"({pair[0]}, {pair[1]}) for {r.name} "
870  msg += f"as specified in SymmetrySettings in "
871  msg += f"compound from component dictionary. "
872  msg += f"Atoms in compound: {atom_names}"
873  raise RuntimeError(msg)
874  symmetric_atoms.append((a, b))
875  self.compound_anames[r.name] = atom_names
876  if len(symmetric_atoms) > 0:
877  self.compound_symmetric_atoms[r.name] = symmetric_atoms
878 
879  def _SetupDistances(self):
880  """Compute distance related members of lDDTScorer
881  """
882  # init
883  self._ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
884  self._ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
885  self._sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
886  self._sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
887 
888  # initialize positions with values far in nirvana. If a position is not
889  # set, it should be far away from any position in target (or at least
890  # more than inclusion_radius).
891  max_pos = self.target.bounds.GetMax()
892  max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
893  max_coordinate += 2 * self.inclusion_radius
894 
895  pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate
896  atom_indices = list()
897  mask_start = list()
898  mask_end = list()
899 
900  for r_idx, r in enumerate(self.target.residues):
901  r_start_idx = self.res_start_indices[r_idx]
902  r_n_atoms = len(self.compound_anames[r.name])
903  r_end_idx = r_start_idx + r_n_atoms
904  for a in r.atoms:
905  if a.handle.GetHashCode() in self.atom_indices:
906  idx = self.atom_indices[a.handle.GetHashCode()]
907  p = a.GetPos()
908  pos[idx][0] = p[0]
909  pos[idx][1] = p[1]
910  pos[idx][2] = p[2]
911  atom_indices.append(idx)
912  mask_start.append(r_start_idx)
913  mask_end.append(r_end_idx)
914 
915  indices, distances = self._CloseStuff(pos, self.inclusion_radius,
916  atom_indices, mask_start,
917  mask_end)
918 
919  for i in range(len(atom_indices)):
920  self._ref_indices[atom_indices[i]] = indices[i]
921  self._ref_distances[atom_indices[i]] = distances[i]
923  self._sym_ref_indices,
924  self._sym_ref_distances)
925 
926  def _SetupDistancesSC(self):
927  """Select subset of contacts only covering intra-chain contacts
928  """
929  # init
930  self._ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
931  self._ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
932  self._sym_ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
933  self._sym_ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
934 
935  # start from overall contacts
936  ref_indices = self.ref_indices
937  ref_distances = self.ref_distances
938  sym_ref_indices = self.sym_ref_indices
939  sym_ref_distances = self.sym_ref_distances
940 
941  n_chains = len(self.chain_start_indices)
942  for ch_idx, ch in enumerate(self.target.chains):
943  chain_s = self.chain_start_indices[ch_idx]
944  chain_e = self.n_atoms
945  if ch_idx + 1 < n_chains:
946  chain_e = self.chain_start_indices[ch_idx+1]
947  for i in range(chain_s, chain_e):
948  if len(ref_indices[i]) > 0:
949  intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
950  ref_indices[i]<chain_e))[0]
951  self._ref_indices_sc[i] = ref_indices[i][intra_idx]
952  self._ref_distances_sc[i] = ref_distances[i][intra_idx]
953 
955  self._sym_ref_indices_sc,
957 
958  def _SetupDistancesIC(self):
959  """Select subset of contacts only covering inter-chain contacts
960  """
961  # init
962  self._ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
963  self._ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
964  self._sym_ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
965  self._sym_ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
966 
967  # start from overall contacts
968  ref_indices = self.ref_indices
969  ref_distances = self.ref_distances
970  sym_ref_indices = self.sym_ref_indices
971  sym_ref_distances = self.sym_ref_distances
972 
973  n_chains = len(self.chain_start_indices)
974  for ch_idx, ch in enumerate(self.target.chains):
975  chain_s = self.chain_start_indices[ch_idx]
976  chain_e = self.n_atoms
977  if ch_idx + 1 < n_chains:
978  chain_e = self.chain_start_indices[ch_idx+1]
979  for i in range(chain_s, chain_e):
980  if len(ref_indices[i]) > 0:
981  inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s,
982  ref_indices[i]>=chain_e))[0]
983  self._ref_indices_ic[i] = ref_indices[i][inter_idx]
984  self._ref_distances_ic[i] = ref_distances[i][inter_idx]
985 
987  self._sym_ref_indices_ic,
989 
990  def _CloseStuff(self, pos, inclusion_radius, indices, mask_start, mask_end):
991  """returns close stuff for positions specified by indices
992  """
993  # TODO: this function does brute force distance computation which has
994  # quadratic complexity...
995  close_indices = list()
996  distances = list()
997  # work with squared_inclusion_radius (sir) to save some square roots
998  sir = inclusion_radius ** 2
999  for idx, ms, me in zip(indices, mask_start, mask_end):
1000  p = pos[idx, :]
1001  tmp = pos - p[None, :]
1002  np.square(tmp, out=tmp)
1003  tmp = tmp.sum(axis=1)
1004  # mask out atoms of own residue => put them far away
1005  tmp[range(ms, me)] = 2 * sir
1006  close_indices.append(np.nonzero(tmp <= sir)[0])
1007  distances.append(np.sqrt(tmp[close_indices[-1]]))
1008  return (close_indices, distances)
1009 
1010  def _NonSymDistances(self, ref_indices, ref_distances,
1011  sym_ref_indices, sym_ref_distances):
1012  """Transfer indices/distances of non-symmetric atoms in place
1013  """
1014  for idx in self.symmetric_atoms:
1015  indices = list()
1016  distances = list()
1017  for i, d in zip(ref_indices[idx], ref_distances[idx]):
1018  if i not in self.symmetric_atoms:
1019  indices.append(i)
1020  distances.append(d)
1021  sym_ref_indices[idx] = indices
1022  sym_ref_distances[idx] = np.asarray(distances)
1023 
1024  def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances):
1025  """Computes number of distance differences within given thresholds
1026 
1027  returns np.array with len(thresholds) elements
1028  """
1029  a_p = pos[atom_idx, :]
1030  tmp = pos.take(ref_indices[atom_idx], axis=0)
1031  np.subtract(tmp, a_p[None, :], out=tmp)
1032  np.square(tmp, out=tmp)
1033  tmp = tmp.sum(axis=1)
1034  np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1035  np.subtract(ref_distances[atom_idx], tmp, out=tmp)
1036  np.absolute(tmp, out=tmp) # absolute dist diffs
1037  return np.asarray([(tmp <= thresh).sum() for thresh in thresholds],
1038  dtype=np.int32)
1039 
1040  def _EvalAtoms(
1041  self, pos, atom_indices, thresholds, ref_indices, ref_distances
1042  ):
1043  """Calls _EvalAtom for several atoms and sums up the computed number
1044  of distance differences within given thresholds
1045 
1046  returns numpy matrix of shape (n_atoms, len(threshold))
1047  """
1048  conserved = np.zeros((len(atom_indices), len(thresholds)),
1049  dtype=np.int32)
1050  for a_idx, a in enumerate(atom_indices):
1051  conserved[a_idx, :] = self._EvalAtom(pos, a, thresholds,
1052  ref_indices, ref_distances)
1053  return conserved
1054 
1055  def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices,
1056  ref_distances):
1057  """Calls _EvalAtoms for a bunch of residues
1058 
1059  residues are defined in *res_atom_indices* as lists of atom indices
1060  returns numpy matrix of shape (n_residues, len(thresholds)).
1061  """
1062  conserved = np.zeros((len(res_atom_indices), len(thresholds)),
1063  dtype=np.int32)
1064  for rai_idx, rai in enumerate(res_atom_indices):
1065  conserved[rai_idx,:] = np.sum(self._EvalAtoms(pos, rai, thresholds,
1066  ref_indices, ref_distances), axis=0)
1067  return conserved
1068 
1069  def _ProcessSequenceSeparation(self):
1070  if self.sequence_separation != 0:
1071  raise NotImplementedError("Congratulations! You're the first one "
1072  "requesting a non-default "
1073  "sequence_separation in the new and "
1074  "awesome lDDT implementation. A crate of "
1075  "beer for Gabriel and he'll implement "
1076  "it.")
1077 
1078  def _GetNExp(self, atom_idx, ref_indices):
1079  """Returns number of close atoms around one or several atoms
1080  """
1081  if isinstance(atom_idx, int):
1082  return len(ref_indices[atom_idx])
1083  elif isinstance(atom_idx, list):
1084  return sum([len(ref_indices[idx]) for idx in atom_idx])
1085  else:
1086  raise RuntimeError("invalid input type")
1087 
1088  def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices,
1089  sym_ref_distances):
1090  """Swaps symmetric positions in-place in order to maximize lDDT scores
1091  towards non-symmetric atoms.
1092  """
1093  for sym in symmetries:
1094 
1095  atom_indices = list()
1096  for sym_tuple in sym:
1097  atom_indices += [sym_tuple[0], sym_tuple[1]]
1098  tot = self._GetNExp(atom_indices, sym_ref_indices)
1099 
1100  if tot == 0:
1101  continue # nothing to do
1102 
1103  # score as is
1104  sym_one_conserved = self._EvalAtoms(
1105  pos,
1106  atom_indices,
1107  thresholds,
1108  sym_ref_indices,
1109  sym_ref_distances,
1110  )
1111 
1112  # switch positions and score again
1113  for pair in sym:
1114  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1115 
1116  sym_two_conserved = self._EvalAtoms(
1117  pos,
1118  atom_indices,
1119  thresholds,
1120  sym_ref_indices,
1121  sym_ref_distances,
1122  )
1123 
1124  sym_one_score = np.sum(sym_one_conserved) / (len(thresholds) * tot)
1125  sym_two_score = np.sum(sym_two_conserved) / (len(thresholds) * tot)
1126 
1127  if sym_one_score >= sym_two_score:
1128  # switch back, initial positions were better or equal
1129  # for the equal case: we still switch back to reproduce the old
1130  # lDDT behaviour
1131  for pair in sym:
1132  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
def GetDefaultSymmetrySettings
Definition: lddt.py:71