OpenStructure
lddt.py
Go to the documentation of this file.
1 import itertools
2 import numpy as np
3 
4 from ost import mol
5 from ost import conop
6 
7 # use cdist of scipy, fallback to (slower) numpy implementation if scipy is not
8 # available
9 try:
10  from scipy.spatial.distance import cdist
11 except:
12  def cdist(p1, p2):
13  x2 = np.sum(p1**2, axis=1) # (m)
14  y2 = np.sum(p2**2, axis=1) # (n)
15  xy = np.matmul(p1, p2.T) # (m, n)
16  x2 = x2.reshape(-1, 1)
17  return np.sqrt(x2 - 2*xy + y2) # (m, n)
18 
20  """ Defines atoms for custom compounds
21 
22  lDDT requires the reference atoms of a compound which are typically
23  extracted from a :class:`ost.conop.CompoundLib`. This lightweight
24  container allows to handle arbitrary compounds which are not
25  necessarily in the compound library.
26 
27  :param atom_names: Names of atoms of custom compound
28  :type atom_names: :class:`list` of :class:`str`
29  """
30  def __init__(self, atom_names):
31  self.atom_namesatom_names = atom_names
32 
33  @staticmethod
34  def FromResidue(res):
35  """ Construct custom compound from residue
36 
37  :param res: Residue from which reference atom names are extracted,
38  hydrogen/deuterium atoms are filtered out
39  :type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle`
40  :returns: :class:`CustomCompound`
41  """
42  at_names = [a.name for a in res.atoms if a.element not in ["H", "D"]]
43  if len(at_names) != len(set(at_names)):
44  raise RuntimeError("Duplicate atoms detected in CustomCompound")
45  compound = CustomCompound(at_names)
46  return compound
47 
49  """Container for symmetric compounds
50 
51  lDDT considers symmetries and selects the one resulting in the highest
52  possible score.
53 
54  A symmetry is defined as a renaming operation on one or more atoms that
55  leads to a chemically equivalent residue. Example would be OD1 and OD2 in
56  ASP => renaming OD1 to OD2 and vice versa gives a chemically equivalent
57  residue.
58 
59  Use :func:`AddSymmetricCompound` to define a symmetry which can then
60  directly be accessed through the *symmetric_compounds* member.
61  """
62  def __init__(self):
63  self.symmetric_compoundssymmetric_compounds = dict()
64 
65  def AddSymmetricCompound(self, name, symmetric_atoms):
66  """Adds symmetry for compound with *name*
67 
68  :param name: Name of compound with symmetry
69  :type name: :class:`str`
70  :param symmetric_atoms: Pairs of atom names that define renaming
71  operation, i.e. after applying all switches
72  defined in the tuples, the resulting residue
73  should be chemically equivalent. Atom names
74  must refer to the PDB component dictionary.
75  :type symmetric_atoms: :class:`list` of :class:`tuple`
76  """
77  for pair in symmetric_atoms:
78  if len(pair) != 2:
79  raise RuntimeError("Expect pairs when defining symmetries")
80  self.symmetric_compoundssymmetric_compounds[name] = symmetric_atoms
81 
82 
84  """Constructs and returns :class:`SymmetrySettings` object for natural amino
85  acids
86  """
87  symmetry_settings = SymmetrySettings()
88 
89  # ASP
90  symmetry_settings.AddSymmetricCompound("ASP", [("OD1", "OD2")])
91 
92  # GLU
93  symmetry_settings.AddSymmetricCompound("GLU", [("OE1", "OE2")])
94 
95  # LEU
96  symmetry_settings.AddSymmetricCompound("LEU", [("CD1", "CD2")])
97 
98  # VAL
99  symmetry_settings.AddSymmetricCompound("VAL", [("CG1", "CG2")])
100 
101  # ARG
102  symmetry_settings.AddSymmetricCompound("ARG", [("NH1", "NH2")])
103 
104  # PHE
105  symmetry_settings.AddSymmetricCompound(
106  "PHE", [("CD1", "CD2"), ("CE1", "CE2")]
107  )
108 
109  # TYR
110  symmetry_settings.AddSymmetricCompound(
111  "TYR", [("CD1", "CD2"), ("CE1", "CE2")]
112  )
113 
114  # nucleotides
115  nuc_names = ["A", "C", "G", "U", "DA", "DC", "DG", "DT"]
116  for nuc_name in nuc_names:
117  symmetry_settings.AddSymmetricCompound(
118  nuc_name, [("OP1","OP2")]
119  )
120 
121  return symmetry_settings
122 
123 
125  """lDDT scorer object for a specific target
126 
127  Sets up everything to score models of that target. lDDT (local distance
128  difference test) is defined as fraction of pairwise distances which exhibit
129  a difference < threshold when considering target and model. In case of
130  multiple thresholds, the average is returned. See
131 
132  V. Mariani, M. Biasini, A. Barbato, T. Schwede, lDDT : A local
133  superposition-free score for comparing protein structures and models using
134  distance difference tests, Bioinformatics, 2013
135 
136  :param target: The target
137  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
138  :param compound_lib: Compound library from which a compound for each residue
139  is extracted based on its name. Uses
140  :func:`ost.conop.GetDefaultLib` if not given, raises
141  if this returns no valid compound library. Atoms
142  defined in the compound are searched in the residue and
143  build the reference for scoring. If the residue has
144  atoms with names ["A", "B", "C"] but the corresponding
145  compound only has ["A", "B"], "A" and "B" are
146  considered for scoring. If the residue has atoms
147  ["A", "B"] but the compound has ["A", "B", "C"], "C" is
148  considered missing and does not influence scoring, even
149  if present in the model.
150  :param custom_compounds: Custom compounds defining reference atoms. If
151  given, *custom_compounds* take precedent over
152  *compound_lib*.
153  :type custom_compounds: :class:`dict` with residue names (:class:`str`) as
154  key and :class:`CustomCompound` as value.
155  :type compound_lib: :class:`ost.conop.CompoundLib`
156  :param inclusion_radius: All pairwise distances < *inclusion_radius* are
157  considered for scoring
158  :type inclusion_radius: :class:`float`
159  :param sequence_separation: Only pairwise distances between atoms of
160  residues which are further apart than this
161  threshold are considered. Residue distance is
162  based on resnum. The default (0) considers all
163  pairwise distances except intra-residue
164  distances.
165  :type sequence_separation: :class:`int`
166  :param symmetry_settings: Define residues exhibiting internal symmetry, uses
167  :func:`GetDefaultSymmetrySettings` if not given.
168  :type symmetry_settings: :class:`SymmetrySettings`
169  :param seqres_mapping: Mapping of model residues at the scoring stage
170  happens with residue numbers defining their location
171  in a reference sequence (SEQRES) using one based
172  indexing. If the residue numbers in *target* don't
173  correspond to that SEQRES, you can specify the
174  mapping manually. You can provide a dictionary to
175  specify a reference sequence (SEQRES) for one or more
176  chain(s). Key: chain name, value: alignment
177  (seq1: SEQRES, seq2: sequence of residues in chain).
178  Example: The residues in a chain with name "A" have
179  sequence "YEAH" and residue numbers [42,43,44,45].
180  You can provide an alignment with seq1 "``HELLYEAH``"
181  and seq2 "``----YEAH``". "Y" gets assigned residue
182  number 5, "E" gets assigned 6 and so on no matter
183  what the original residue numbers were.
184  :type seqres_mapping: :class:`dict` (key: :class:`str`, value:
185  :class:`ost.seq.AlignmentHandle`)
186  :param bb_only: Only consider atoms with name "CA" in case of amino acids and
187  "C3'" for Nucleotides. this invalidates *compound_lib*.
188  Raises if any residue in *target* is not
189  `r.chem_class.IsPeptideLinking()` or
190  `r.chem_class.IsNucleotideLinking()`
191  :type bb_only: :class:`bool`
192  :raises: :class:`RuntimeError` if *target* contains compound which is not in
193  *compound_lib*, :class:`RuntimeError` if *symmetry_settings*
194  specifies symmetric atoms that are not present in the according
195  compound in *compound_lib*, :class:`RuntimeError` if
196  *seqres_mapping* is not provided and *target* contains residue
197  numbers with insertion codes or the residue numbers for each chain
198  are not monotonically increasing, :class:`RuntimeError` if
199  *seqres_mapping* is provided but an alignment is invalid
200  (seq1 contains gaps, mismatch in seq1/seq2, seq2 does not match
201  residues in corresponding chains).
202  """
203  def __init__(
204  self,
205  target,
206  compound_lib=None,
207  custom_compounds=None,
208  inclusion_radius=15,
209  sequence_separation=0,
210  symmetry_settings=None,
211  seqres_mapping=dict(),
212  bb_only=False
213  ):
214 
215  self.targettarget = target
216  self.inclusion_radiusinclusion_radius = inclusion_radius
217  self.sequence_separationsequence_separation = sequence_separation
218  if compound_lib is None:
219  compound_lib = conop.GetDefaultLib()
220  if compound_lib is None:
221  raise RuntimeError("No compound_lib given and conop.GetDefaultLib "
222  "returns no valid compound library")
223  self.compound_libcompound_lib = compound_lib
224  self.custom_compoundscustom_compounds = custom_compounds
225  if symmetry_settings is None:
227  else:
228  self.symmetry_settingssymmetry_settings = symmetry_settings
229 
230  # whether to only consider atoms with name "CA" (amino acids) or C3'
231  # (nucleotides), invalidates *compound_lib*
232  self.bb_onlybb_only=bb_only
233 
234  # names of heavy atoms of each unique compound present in *target* as
235  # extracted from *compound_lib*, e.g.
236  # self.compound_anames["GLY"] = ["N", "CA", "C", "O"]
237  self.compound_anamescompound_anames = dict()
238 
239  # stores symmetry information for those compounds as defined in
240  # *symmetry_settings*
241  self.compound_symmetric_atomscompound_symmetric_atoms = dict()
242 
243  # list of len(target.chains) containing all chain names in *target*
244  self.chain_nameschain_names = list()
245 
246  # list of len(target.residues) containing all compound names in *target*
247  self.compound_namescompound_names = list()
248 
249  # list of len(target.residues) defining start pos in internal reference
250  # positions for each residue
251  self.res_start_indicesres_start_indices = list()
252 
253  # list of len(target.residues) defining residue numbers in target
254  self.res_resnumsres_resnums = list()
255 
256  # list of len(target.chains) defining start pos in internal reference
257  # positions for each chain
258  self.chain_start_indiceschain_start_indices = list()
259 
260  # list of len(target.chains) defining start pos in self.compound_names
261  # for each chain
262  self.chain_res_start_indiceschain_res_start_indices = list()
263 
264  # maps residues in *target* to indices in
265  # self.compound_names/self.res_start_indices. A residue gets identified
266  # by a tuple (first element: chain name, second element: residue number,
267  # residue number is either the actual residue number in *target* or
268  # given by *seqres_mapping*)
269  self.res_mapperres_mapper = dict()
270 
271  # number of atoms as specified in compounds. not all are necessarily
272  # covered by structure
273  self.n_atomsn_atoms = None
274 
275  # stores an index for each AtomHandle in *target*
276  # (atom hashcode => index)
277  self.atom_indicesatom_indices = dict()
278 
279  # store indices of all atoms that have symmetry properties
280  self.symmetric_atomssymmetric_atoms = set()
281 
282  # the actual target positions in a numpy array of shape (self.n_atoms,3)
283  self.positionspositions = None
284 
285  # setup members defined above
286  self._SetupEnv_SetupEnv(self.compound_libcompound_lib, self.custom_compoundscustom_compounds,
287  self.symmetry_settingssymmetry_settings, seqres_mapping, self.bb_onlybb_only)
288 
289  # distance related members are lazily computed as they're affected
290  # by different flavours of lDDT (e.g. lDDT including inter-chain
291  # contacts or not etc.)
292 
293  # stores for each atom the other atoms within inclusion_radius
294  self._ref_indices_ref_indices = None
295  # the corresponding distances
296  self._ref_distances_ref_distances = None
297 
298  # The following lists will be sparsely populated. We keep for each
299  # symmetry related atom the distances towards all atoms which are NOT
300  # affected by symmetry. So we can evaluate two symmetric versions
301  # against the fixed stuff later on and select the better scoring one.
302  self._sym_ref_indices_sym_ref_indices = None
303  self._sym_ref_distances_sym_ref_distances = None
304 
305  # exactly the same as above but without interchain contacts
306  # => single-chain (sc)
307  self._ref_indices_sc_ref_indices_sc = None
308  self._ref_distances_sc_ref_distances_sc = None
309  self._sym_ref_indices_sc_sym_ref_indices_sc = None
310  self._sym_ref_distances_sc_sym_ref_distances_sc = None
311 
312  # exactly the same as above but without intrachain contacts
313  # => inter-chain (ic)
314  self._ref_indices_ic_ref_indices_ic = None
315  self._ref_distances_ic_ref_distances_ic = None
316  self._sym_ref_indices_ic_sym_ref_indices_ic = None
317  self._sym_ref_distances_ic_sym_ref_distances_ic = None
318 
319  # input parameter checking
320  self._ProcessSequenceSeparation_ProcessSequenceSeparation()
321 
322  @property
323  def ref_indices(self):
324  if self._ref_indices_ref_indices is None:
325  self._ref_indices_ref_indices, self._ref_distances_ref_distances = \
326  lDDTScorer._SetupDistances(self.targettarget, self.n_atomsn_atoms,
327  self.atom_indicesatom_indices,
328  self.inclusion_radiusinclusion_radius)
329  return self._ref_indices_ref_indices
330 
331  @property
332  def ref_distances(self):
333  if self._ref_distances_ref_distances is None:
334  self._ref_indices_ref_indices, self._ref_distances_ref_distances = \
335  lDDTScorer._SetupDistances(self.targettarget, self.n_atomsn_atoms,
336  self.atom_indicesatom_indices,
337  self.inclusion_radiusinclusion_radius)
338  return self._ref_distances_ref_distances
339 
340  @property
341  def sym_ref_indices(self):
342  if self._sym_ref_indices_sym_ref_indices is None:
343  self._sym_ref_indices_sym_ref_indices, self._sym_ref_distances_sym_ref_distances = \
344  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
345  self.ref_indicesref_indices, self.ref_distancesref_distances)
346  return self._sym_ref_indices_sym_ref_indices
347 
348  @property
349  def sym_ref_distances(self):
350  if self._sym_ref_distances_sym_ref_distances is None:
351  self._sym_ref_indices_sym_ref_indices, self._sym_ref_distances_sym_ref_distances = \
352  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
353  self.ref_indicesref_indices, self.ref_distancesref_distances)
354  return self._sym_ref_distances_sym_ref_distances
355 
356  @property
357  def ref_indices_sc(self):
358  if self._ref_indices_sc_ref_indices_sc is None:
359  self._ref_indices_sc_ref_indices_sc, self._ref_distances_sc_ref_distances_sc = \
360  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
361  self.chain_start_indiceschain_start_indices,
362  self.ref_indicesref_indices,
363  self.ref_distancesref_distances)
364  return self._ref_indices_sc_ref_indices_sc
365 
366  @property
367  def ref_distances_sc(self):
368  if self._ref_distances_sc_ref_distances_sc is None:
369  self._ref_indices_sc_ref_indices_sc, self._ref_distances_sc_ref_distances_sc = \
370  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
371  self.chain_start_indiceschain_start_indices,
372  self.ref_indicesref_indices,
373  self.ref_distancesref_distances)
374  return self._ref_distances_sc_ref_distances_sc
375 
376  @property
378  if self._sym_ref_indices_sc_sym_ref_indices_sc is None:
379  self._sym_ref_indices_sc_sym_ref_indices_sc, self._sym_ref_distances_sc_sym_ref_distances_sc = \
380  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
381  self.symmetric_atomssymmetric_atoms,
382  self.ref_indices_scref_indices_sc,
383  self.ref_distances_scref_distances_sc)
384  return self._sym_ref_indices_sc_sym_ref_indices_sc
385 
386  @property
388  if self._sym_ref_distances_sc_sym_ref_distances_sc is None:
389  self._sym_ref_indices_sc_sym_ref_indices_sc, self._sym_ref_distances_sc_sym_ref_distances_sc = \
390  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
391  self.symmetric_atomssymmetric_atoms,
392  self.ref_indices_scref_indices_sc,
393  self.ref_distances_scref_distances_sc)
394  return self._sym_ref_distances_sc_sym_ref_distances_sc
395 
396  @property
397  def ref_indices_ic(self):
398  if self._ref_indices_ic_ref_indices_ic is None:
399  self._ref_indices_ic_ref_indices_ic, self._ref_distances_ic_ref_distances_ic = \
400  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
401  self.chain_start_indiceschain_start_indices,
402  self.ref_indicesref_indices,
403  self.ref_distancesref_distances)
404  return self._ref_indices_ic_ref_indices_ic
405 
406  @property
407  def ref_distances_ic(self):
408  if self._ref_distances_ic_ref_distances_ic is None:
409  self._ref_indices_ic_ref_indices_ic, self._ref_distances_ic_ref_distances_ic = \
410  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
411  self.chain_start_indiceschain_start_indices,
412  self.ref_indicesref_indices,
413  self.ref_distancesref_distances)
414  return self._ref_distances_ic_ref_distances_ic
415 
416  @property
418  if self._sym_ref_indices_ic_sym_ref_indices_ic is None:
419  self._sym_ref_indices_ic_sym_ref_indices_ic, self._sym_ref_distances_ic_sym_ref_distances_ic = \
420  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
421  self.symmetric_atomssymmetric_atoms,
422  self.ref_indices_icref_indices_ic,
423  self.ref_distances_icref_distances_ic)
424  return self._sym_ref_indices_ic_sym_ref_indices_ic
425 
426  @property
428  if self._sym_ref_distances_ic_sym_ref_distances_ic is None:
429  self._sym_ref_indices_ic_sym_ref_indices_ic, self._sym_ref_distances_ic_sym_ref_distances_ic = \
430  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
431  self.symmetric_atomssymmetric_atoms,
432  self.ref_indices_icref_indices_ic,
433  self.ref_distances_icref_distances_ic)
434  return self._sym_ref_distances_ic_sym_ref_distances_ic
435 
436  def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
437  local_lddt_prop=None, local_contact_prop=None,
438  chain_mapping=None, no_interchain=False,
439  no_intrachain=False, penalize_extra_chains=False,
440  residue_mapping=None, return_dist_test=False,
441  check_resnames=True, add_mdl_contacts=False,
442  interaction_data=None, set_atom_props=False):
443  """Computes lDDT of *model* - globally and per-residue
444 
445  :param model: Model to be scored - models are preferably scored upon
446  performing stereo-chemistry checks in order to punish for
447  non-sensical irregularities. This must be done separately
448  as a pre-processing step. Target contacts that are not
449  covered by *model* are considered not conserved, thus
450  decreasing lDDT score. This also includes missing model
451  chains or model chains for which no mapping is provided in
452  *chain_mapping*.
453  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
454  :param thresholds: Thresholds of distance differences to be considered
455  as correct - see docs in constructor for more info.
456  default: [0.5, 1.0, 2.0, 4.0]
457  :type thresholds: :class:`list` of :class:`floats`
458  :param local_lddt_prop: If set, per-residue scores will be assigned as
459  generic float property of that name
460  :type local_lddt_prop: :class:`str`
461  :param local_contact_prop: If set, number of expected contacts as well
462  as number of conserved contacts will be
463  assigned as generic int property.
464  Excected contacts will be set as
465  <local_contact_prop>_exp, conserved contacts
466  as <local_contact_prop>_cons. Values
467  are summed over all thresholds.
468  :type local_contact_prop: :class:`str`
469  :param chain_mapping: Mapping of model chains (key) onto target chains
470  (value). This is required if target or model have
471  more than one chain.
472  :type chain_mapping: :class:`dict` with :class:`str` as keys/values
473  :param no_interchain: Whether to exclude interchain contacts
474  :type no_interchain: :class:`bool`
475  :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
476  consider interface related contacts)
477  :type no_intrachain: :class:`bool`
478  :param penalize_extra_chains: Whether to include a fixed penalty for
479  additional chains in the model that are
480  not mapped to the target. ONLY AFFECTS
481  RETURNED GLOBAL SCORE. In detail: adds the
482  number of intra-chain contacts of each
483  extra chain to the expected contacts, thus
484  adding a penalty.
485  :type penalize_extra_chains: :class:`bool`
486  :param residue_mapping: By default, residue mapping is based on residue
487  numbers. That means, a model chain and the
488  respective target chain map to the same
489  underlying reference sequence (SEQRES).
490  Alternatively, you can specify one or
491  several alignment(s) between model and target
492  chains by providing a dictionary. key: Name
493  of chain in model (respective target chain is
494  extracted from *chain_mapping*),
495  value: Alignment with first sequence
496  corresponding to target chain and second
497  sequence to model chain. There is NO reference
498  sequence involved, so the two sequences MUST
499  exactly match the actual residues observed in
500  the respective target/model chains (ATOMSEQ).
501  :type residue_mapping: :class:`dict` with key: :class:`str`,
502  value: :class:`ost.seq.AlignmentHandle`
503  :param return_dist_test: Whether to additionally return the underlying
504  per-residue data for the distance difference
505  test. Adds five objects to the return tuple.
506  First: Number of total contacts summed over all
507  thresholds
508  Second: Number of conserved contacts summed
509  over all thresholds
510  Third: list with length of scored residues.
511  Contains indices referring to model.residues.
512  Fourth: numpy array of size
513  len(scored_residues) containing the number of
514  total contacts,
515  Fifth: numpy matrix of shape
516  (len(scored_residues), len(thresholds))
517  specifying how many for each threshold are
518  conserved.
519  :param check_resnames: On by default. Enforces residue name matches
520  between mapped model and target residues.
521  :type check_resnames: :class:`bool`
522  :param add_mdl_contacts: Adds model contacts - Only using contacts that
523  are within a certain distance threshold in the
524  target does not penalize for added model
525  contacts. If set to True, this flag will also
526  consider target contacts that are within the
527  specified distance threshold in the model but
528  not necessarily in the target. No contact will
529  be added if the respective atom pair is not
530  resolved in the target.
531  :type add_mdl_contacts: :class:`bool`
532  :param interaction_data: Pro param - don't use
533  :type interaction_data: :class:`tuple`
534  :param set_atom_props: If True, sets generic properties on a per atom
535  level if *local_lddt_prop*/*local_contact_prop*
536  are set as well.
537  In other words: this is the only way you can
538  get per-atom lDDT values.
539  :type set_atom_props: :class:`bool`
540 
541  :returns: global and per-residue lDDT scores as a tuple -
542  first element is global lDDT score (None if *target* has no
543  contacts) and second element a list of per-residue scores with
544  length len(*model*.residues). None is assigned to residues that
545  are not covered by target. If a residue is covered but has no
546  contacts in *target*, 0.0 is assigned.
547  """
548  if chain_mapping is None:
549  if len(self.chain_nameschain_names) > 1 or len(model.chains) > 1:
550  raise NotImplementedError("Must provide chain mapping if "
551  "target or model have > 1 chains.")
552  chain_mapping = {model.chains[0].GetName(): self.chain_nameschain_names[0]}
553  else:
554  # check whether chains specified in mapping exist
555  for model_chain, target_chain in chain_mapping.items():
556  if target_chain not in self.chain_nameschain_names:
557  raise RuntimeError(f"Target chain specified in "
558  f"chain_mapping ({target_chain}) does "
559  f"not exist. Target has chains: "
560  f"{self.chain_names}")
561  ch = model.FindChain(model_chain)
562  if not ch.IsValid():
563  raise RuntimeError(f"Model chain specified in "
564  f"chain_mapping ({model_chain}) does "
565  f"not exist. Model has chains: "
566  f"{[c.GetName() for c in model.chains]}")
567 
568  # data objects defining model data - see _ProcessModel for rough
569  # description
570  pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
571  res_indices, ref_res_indices, symmetries = \
572  self._ProcessModel_ProcessModel(model, chain_mapping,
573  residue_mapping = residue_mapping,
574  thresholds = thresholds,
575  check_resnames = check_resnames)
576 
577  if no_interchain and no_intrachain:
578  raise RuntimeError("no_interchain and no_intrachain flags are "
579  "mutually exclusive")
580 
581 
582  sym_ref_indices = None
583  sym_ref_distances = None
584  ref_indices = None
585  ref_distances = None
586 
587  if interaction_data is None:
588  if no_interchain:
589  sym_ref_indices = self.sym_ref_indices_scsym_ref_indices_sc
590  sym_ref_distances = self.sym_ref_distances_scsym_ref_distances_sc
591  ref_indices = self.ref_indices_scref_indices_sc
592  ref_distances = self.ref_distances_scref_distances_sc
593  elif no_intrachain:
594  sym_ref_indices = self.sym_ref_indices_icsym_ref_indices_ic
595  sym_ref_distances = self.sym_ref_distances_icsym_ref_distances_ic
596  ref_indices = self.ref_indices_icref_indices_ic
597  ref_distances = self.ref_distances_icref_distances_ic
598  else:
599  sym_ref_indices = self.sym_ref_indicessym_ref_indices
600  sym_ref_distances = self.sym_ref_distancessym_ref_distances
601  ref_indices = self.ref_indicesref_indices
602  ref_distances = self.ref_distancesref_distances
603 
604  if add_mdl_contacts:
605  ref_indices, ref_distances = \
606  self._AddMdlContacts_AddMdlContacts(model, res_atom_indices, res_atom_hashes,
607  ref_indices, ref_distances,
608  no_interchain, no_intrachain)
609  # recompute symmetry related indices/distances
610  sym_ref_indices, sym_ref_distances = \
611  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
612  ref_indices, ref_distances)
613  else:
614  sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \
615  interaction_data
616 
617  self._ResolveSymmetries_ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
618  sym_ref_distances)
619 
620  atom_indices = list(itertools.chain.from_iterable(res_atom_indices))
621 
622  per_atom_exp = np.asarray([self._GetNExp_GetNExp(i, ref_indices)
623  for i in atom_indices], dtype=np.int32)
624  per_res_exp = np.asarray([self._GetNExp_GetNExp(res_ref_atom_indices[idx],
625  ref_indices) for idx in range(len(res_indices))], dtype=np.int32)
626 
627  per_atom_conserved = self._EvalAtoms_EvalAtoms(pos, atom_indices, thresholds,
628  ref_indices, ref_distances)
629  per_res_conserved = np.zeros((len(res_atom_indices), len(thresholds)),
630  dtype=np.int32)
631  start_idx = 0
632  for r_idx in range(len(res_atom_indices)):
633  end_idx = start_idx + len(res_atom_indices[r_idx])
634  per_res_conserved[r_idx] = np.sum(per_atom_conserved[start_idx:end_idx,:],
635  axis=0)
636  start_idx = end_idx
637 
638  n_thresh = len(thresholds)
639 
640  # do per-residue scores
641  per_res_lDDT = [None] * model.GetResidueCount()
642  for idx in range(len(res_indices)):
643  n_exp = n_thresh * per_res_exp[idx]
644  if n_exp > 0:
645  score = np.sum(per_res_conserved[idx,:]) / n_exp
646  per_res_lDDT[res_indices[idx]] = score
647  else:
648  per_res_lDDT[res_indices[idx]] = 0.0
649 
650  # do full model score
651  n_distances = sum([len(x) for x in ref_indices])
652  if penalize_extra_chains:
653  n_distances += self._GetExtraModelChainPenalty_GetExtraModelChainPenalty(model, chain_mapping)
654 
655  lDDT_tot = int(n_thresh * n_distances)
656  lDDT_cons = int(np.sum(per_res_conserved))
657  lDDT = None
658  if lDDT_tot > 0:
659  lDDT = float(lDDT_cons) / lDDT_tot
660 
661  # set properties if necessary
662  if local_lddt_prop:
663  residues = model.residues
664  for idx in res_indices:
665  residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx])
666 
667  if local_contact_prop:
668  residues = model.residues
669  exp_prop = local_contact_prop + "_exp"
670  conserved_prop = local_contact_prop + "_cons"
671 
672  for i, r_idx in enumerate(res_indices):
673  residues[r_idx].SetIntProp(exp_prop,
674  n_thresh * int(per_res_exp[i]))
675  residues[r_idx].SetIntProp(conserved_prop,
676  int(np.sum(per_res_conserved[i,:])))
677 
678  if set_atom_props and (local_lddt_prop or local_contact_prop):
679  atom_list = list()
680  residues = model.residues
681  for i, indices in enumerate(res_atom_indices):
682  r = residues[res_indices[i]]
683  r_idx = ref_res_indices[i]
684  res_start_idx = self.res_start_indicesres_start_indices[r_idx]
685  anames = self.compound_anamescompound_anames[self.compound_namescompound_names[r_idx]]
686  for a_i in indices:
687  a = r.FindAtom(anames[a_i - res_start_idx])
688  assert(a.IsValid())
689  atom_list.append(a)
690 
691  summed_per_atom_conserved = per_atom_conserved.sum(axis=1)
692  if local_lddt_prop:
693  # the only place where actually need to compute per-atom lDDT
694  # scores
695  for a_idx in range(len(atom_list)):
696  if per_atom_exp[a_idx] != 0:
697  tmp = summed_per_atom_conserved[a_idx] / per_atom_exp[a_idx]
698  tmp = tmp / n_thresh
699  atom_list[a_idx].SetFloatProp(local_lddt_prop, tmp)
700 
701  if local_contact_prop:
702  conserved_prop = local_contact_prop + "_cons"
703  exp_prop = local_contact_prop + "_exp"
704  for a_idx in range(len(atom_list)):
705  # do number of conserved contacts
706  tmp = summed_per_atom_conserved[a_idx]
707  atom_list[a_idx].SetIntProp(conserved_prop, tmp)
708  # do number of expected contacts
709  tmp = per_atom_exp[a_idx] * n_thresh
710  atom_list[a_idx].SetIntProp(exp_prop, tmp)
711 
712  if return_dist_test:
713  return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
714  per_res_exp, per_res_conserved
715  else:
716  return lDDT, per_res_lDDT
717 
718  def GetNChainContacts(self, target_chain, no_interchain=False):
719  """Returns number of contacts expected for a certain chain in *target*
720 
721  :param target_chain: Chain in *target* for which you want the number
722  of expected contacts
723  :type target_chain: :class:`str`
724  :param no_interchain: Whether to exclude interchain contacts
725  :type no_interchain: :class:`bool`
726  :raises: :class:`RuntimeError` if specified chain doesnt exist
727  """
728  if target_chain not in self.chain_nameschain_names:
729  raise RuntimeError(f"Specified chain name ({target_chain}) not in "
730  f"target")
731  ch_idx = self.chain_nameschain_names.index(target_chain)
732  s = self.chain_start_indiceschain_start_indices[ch_idx]
733  e = self.n_atomsn_atoms
734  if ch_idx + 1 < len(self.chain_nameschain_names):
735  e = self.chain_start_indiceschain_start_indices[ch_idx+1]
736  if no_interchain:
737  return self._GetNExp_GetNExp(list(range(s, e)), self.ref_indices_scref_indices_sc)
738  else:
739  return self._GetNExp_GetNExp(list(range(s, e)), self.ref_indicesref_indices)
740 
741  def _ProcessModel(self, model, chain_mapping, residue_mapping = None,
742  thresholds = [0.5, 1.0, 2.0, 4.0],
743  check_resnames = True):
744  """ Helper that generates data structures from model
745  """
746 
747  # initialize positions with values far in nirvana. If a position is not
748  # set, it should be far away from any position in model.
749  max_pos = model.bounds.GetMax()
750  max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
751  max_coordinate += 42 * max(thresholds)
752  pos = np.ones((self.n_atomsn_atoms, 3), dtype=np.float32) * max_coordinate
753 
754  # for each scored residue in model a list of indices describing the
755  # atoms from the reference that should be there
756  res_ref_atom_indices = list()
757 
758  # for each scored residue in model a list of indices of atoms that are
759  # actually there
760  res_atom_indices = list()
761 
762  # and the respective hash codes
763  # this is required if add_mdl_contacts is set to True
764  res_atom_hashes = list()
765 
766  # indices of the scored residues
767  res_indices = list()
768 
769  # respective residue indices in reference
770  ref_res_indices = list()
771 
772  # Will contain one element per symmetry group
773  symmetries = list()
774 
775  current_model_res_idx = -1
776  for ch in model.chains:
777  model_ch_name = ch.GetName()
778  if model_ch_name not in chain_mapping:
779  current_model_res_idx += len(ch.residues)
780  continue # additional model chain which is not mapped
781  target_ch_name = chain_mapping[model_ch_name]
782 
783  rnums = self._GetChainRNums_GetChainRNums(ch, residue_mapping, model_ch_name,
784  target_ch_name)
785 
786  for r, rnum in zip(ch.residues, rnums):
787  current_model_res_idx += 1
788  res_mapper_key = (target_ch_name, rnum)
789  if res_mapper_key not in self.res_mapperres_mapper:
790  continue
791  r_idx = self.res_mapperres_mapper[res_mapper_key]
792  if check_resnames and r.name != self.compound_namescompound_names[r_idx]:
793  raise RuntimeError(
794  f"Residue name mismatch for {r}, "
795  f" expect {self.compound_names[r_idx]}"
796  )
797  res_start_idx = self.res_start_indicesres_start_indices[r_idx]
798  rname = self.compound_namescompound_names[r_idx]
799  anames = self.compound_anamescompound_anames[rname]
800  atoms = [r.FindAtom(aname) for aname in anames]
801  res_ref_atom_indices.append(
802  list(range(res_start_idx, res_start_idx + len(anames)))
803  )
804  res_atom_indices.append(list())
805  res_atom_hashes.append(list())
806  res_indices.append(current_model_res_idx)
807  ref_res_indices.append(r_idx)
808  for a_idx, a in enumerate(atoms):
809  if a.IsValid():
810  p = a.GetPos()
811  pos[res_start_idx + a_idx][0] = p[0]
812  pos[res_start_idx + a_idx][1] = p[1]
813  pos[res_start_idx + a_idx][2] = p[2]
814  res_atom_indices[-1].append(res_start_idx + a_idx)
815  res_atom_hashes[-1].append(a.handle.GetHashCode())
816  if rname in self.compound_symmetric_atomscompound_symmetric_atoms:
817  sym_indices = list()
818  for sym_tuple in self.compound_symmetric_atomscompound_symmetric_atoms[rname]:
819  a_one = atoms[sym_tuple[0]]
820  a_two = atoms[sym_tuple[1]]
821  if a_one.IsValid() and a_two.IsValid():
822  sym_indices.append(
823  (
824  res_start_idx + sym_tuple[0],
825  res_start_idx + sym_tuple[1],
826  )
827  )
828  if len(sym_indices) > 0:
829  symmetries.append(sym_indices)
830 
831  return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes,
832  res_indices, ref_res_indices, symmetries)
833 
834 
835  def _GetExtraModelChainPenalty(self, model, chain_mapping):
836  """Counts n distances in extra model chains to be added as penalty
837  """
838  penalty = 0
839  for chain in model.chains:
840  ch_name = chain.GetName()
841  if ch_name not in chain_mapping:
842  sm = self.symmetry_settingssymmetry_settings
843  mdl_sel = model.Select(f"cname={mol.QueryQuoteName(ch_name)}")
844  dummy_scorer = lDDTScorer(mdl_sel, self.compound_libcompound_lib,
845  symmetry_settings = sm,
846  inclusion_radius = self.inclusion_radiusinclusion_radius,
847  bb_only = self.bb_onlybb_only)
848  penalty += sum([len(x) for x in dummy_scorer.ref_indices])
849  return penalty
850 
851  def _GetChainRNums(self, ch, residue_mapping, model_ch_name,
852  target_ch_name):
853  """Map residues in model chain to target residues
854 
855  There are two options: one is simply using residue numbers,
856  the other is a custom mapping as given in *residue_mapping*
857  """
858  if residue_mapping and model_ch_name in residue_mapping:
859  # extract residue numbers from target chain
860  ch_idx = self.chain_nameschain_names.index(target_ch_name)
861  start_idx = self.chain_res_start_indiceschain_res_start_indices[ch_idx]
862  if ch_idx < len(self.chain_nameschain_names) - 1:
863  end_idx = self.chain_res_start_indiceschain_res_start_indices[ch_idx+1]
864  else:
865  end_idx = len(self.compound_namescompound_names)
866  target_rnums = self.res_resnumsres_resnums[start_idx:end_idx]
867  # get sequences from alignment and do consistency checks
868  target_seq = residue_mapping[model_ch_name].GetSequence(0)
869  model_seq = residue_mapping[model_ch_name].GetSequence(1)
870  if len(target_seq.GetGaplessString()) != len(target_rnums):
871  raise RuntimeError(f"Try to perform residue mapping for "
872  f"model chain {model_ch_name} which "
873  f"maps to {target_ch_name} in target. "
874  f"Target sequence in alignment suggests "
875  f"{len(target_seq.GetGaplessString())} "
876  f"residues but {len(target_rnums)} are "
877  f"expected.")
878  if len(model_seq.GetGaplessString()) != len(ch.residues):
879  raise RuntimeError(f"Try to perform residue mapping for "
880  f"model chain {model_ch_name} which "
881  f"maps to {target_ch_name} in target. "
882  f"Model sequence in alignment suggests "
883  f"{len(model_seq.GetGaplessString())} "
884  f"residues but {len(ch.residues)} are "
885  f"expected.")
886  rnums = list()
887  target_idx = -1
888  for col in residue_mapping[model_ch_name]:
889  if col[0] != '-':
890  target_idx += 1
891  # handle match
892  if col[0] != '-' and col[1] != '-':
893  rnums.append(target_rnums[target_idx])
894  # insertion in model adds None to rnum
895  if col[0] == '-' and col[1] != '-':
896  rnums.append(None)
897  else:
898  rnums = [r.GetNumber() for r in ch.residues]
899 
900  return rnums
901 
902 
903  def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings,
904  seqres_mapping, bb_only):
905  """Sets target related lDDTScorer members defined in constructor
906 
907  No distance related members - see _SetupDistances
908  """
909  residue_numbers = self._GetTargetResidueNumbers_GetTargetResidueNumbers(self.targettarget,
910  seqres_mapping)
911  current_idx = 0
912  positions = list()
913  for chain in self.targettarget.chains:
914  ch_name = chain.GetName()
915  self.chain_nameschain_names.append(ch_name)
916  self.chain_start_indiceschain_start_indices.append(current_idx)
917  self.chain_res_start_indiceschain_res_start_indices.append(len(self.compound_namescompound_names))
918  for r, rnum in zip(chain.residues, residue_numbers[ch_name]):
919  if r.name not in self.compound_anamescompound_anames:
920  # sets compound info in self.compound_anames and
921  # self.compound_symmetric_atoms
922  self._SetupCompound_SetupCompound(r, compound_lib, custom_compounds,
923  symmetry_settings, bb_only)
924 
925  self.res_start_indicesres_start_indices.append(current_idx)
926  self.res_mapperres_mapper[(ch_name, rnum)] = len(self.compound_namescompound_names)
927  self.compound_namescompound_names.append(r.name)
928  self.res_resnumsres_resnums.append(rnum)
929 
930  atoms = [r.FindAtom(an) for an in self.compound_anamescompound_anames[r.name]]
931  for a in atoms:
932  if a.IsValid():
933  self.atom_indicesatom_indices[a.handle.GetHashCode()] = current_idx
934  p = a.GetPos()
935  positions.append(np.asarray([p[0], p[1], p[2]],
936  dtype=np.float32))
937  else:
938  positions.append(np.zeros(3, dtype=np.float32))
939  current_idx += 1
940 
941  if r.name in self.compound_symmetric_atomscompound_symmetric_atoms:
942  for sym_tuple in self.compound_symmetric_atomscompound_symmetric_atoms[r.name]:
943  for a_idx in sym_tuple:
944  a = atoms[a_idx]
945  if a.IsValid():
946  hashcode = a.handle.GetHashCode()
947  self.symmetric_atomssymmetric_atoms.add(
948  self.atom_indicesatom_indices[hashcode]
949  )
950  self.positionspositions = np.vstack(positions)
951  self.n_atomsn_atoms = current_idx
952 
953  def _GetTargetResidueNumbers(self, target, seqres_mapping):
954  """Returns residue numbers for each chain in target as dict
955 
956  They're either directly extracted from the raw residue number
957  from the structure or from user provided alignments
958  """
959  residue_numbers = dict()
960  for ch in target.chains:
961  ch_name = ch.GetName()
962  rnums = list()
963  if ch_name in seqres_mapping:
964  seqres = seqres_mapping[ch_name].GetSequence(0).GetString()
965  atomseq = seqres_mapping[ch_name].GetSequence(1).GetString()
966  # SEQRES must not contain gaps
967  if "-" in seqres:
968  raise RuntimeError(
969  "SEQRES in seqres_mapping must not " "contain gaps"
970  )
971  atomseq_from_chain = [r.one_letter_code for r in ch.residues]
972  if atomseq.replace("-", "") != atomseq_from_chain:
973  raise RuntimeError(
974  "ATOMSEQ in seqres_mapping must match "
975  "raw sequence extracted from chain "
976  "residues"
977  )
978  rnum = 0
979  for seqres_olc, atomseq_olc in zip(seqres, atomseq):
980  if seqres_olc != "-":
981  rnum += 1
982  if atomseq_olc != "-":
983  if seqres_olc != atomseq_olc:
984  raise RuntimeError(
985  f"Residue with number {rnum} in "
986  f"chain {ch_name} has SEQRES "
987  f"ATOMSEQ mismatch"
988  )
989  rnums.append(mol.ResNum(rnum))
990  else:
991  rnums = [r.GetNumber() for r in ch.residues]
992  assert len(rnums) == len(ch.residues)
993  residue_numbers[ch_name] = rnums
994  return residue_numbers
995 
996  def _SetupCompound(self, r, compound_lib, custom_compounds,
997  symmetry_settings, bb_only):
998  """fill self.compound_anames/self.compound_symmetric_atoms
999  """
1000  if bb_only:
1001  # throw away compound_lib info
1002  if r.chem_class.IsPeptideLinking():
1003  self.compound_anamescompound_anames[r.name] = ["CA"]
1004  elif r.chem_class.IsNucleotideLinking():
1005  self.compound_anamescompound_anames[r.name] = ["C3'"]
1006  else:
1007  raise RuntimeError(f"Only support amino acids and nucleotides "
1008  f"if bb_only is True, failed with {str(r)}")
1009  self.compound_symmetric_atomscompound_symmetric_atoms[r.name] = list()
1010  else:
1011  atom_names = list()
1012  symmetric_atoms = list()
1013  if custom_compounds is not None and r.GetName() in custom_compounds:
1014  atom_names = list(custom_compounds[r.GetName()].atom_names)
1015  else:
1016  compound = compound_lib.FindCompound(r.name)
1017  if compound is None:
1018  raise RuntimeError(f"no entry for {r} in compound_lib")
1019  for atom_spec in compound.GetAtomSpecs():
1020  if atom_spec.element not in ["H", "D"]:
1021  atom_names.append(atom_spec.name)
1022  if r.name in symmetry_settings.symmetric_compounds:
1023  for pair in symmetry_settings.symmetric_compounds[r.name]:
1024  try:
1025  a = atom_names.index(pair[0])
1026  b = atom_names.index(pair[1])
1027  except:
1028  msg = f"Could not find symmetric atoms "
1029  msg += f"({pair[0]}, {pair[1]}) for {r.name} "
1030  msg += f"as specified in SymmetrySettings in "
1031  msg += f"compound from component dictionary. "
1032  msg += f"Atoms in compound: {atom_names}"
1033  raise RuntimeError(msg)
1034  symmetric_atoms.append((a, b))
1035  self.compound_anamescompound_anames[r.name] = atom_names
1036  if len(symmetric_atoms) > 0:
1037  self.compound_symmetric_atomscompound_symmetric_atoms[r.name] = symmetric_atoms
1038 
1039  def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes,
1040  ref_indices, ref_distances, no_interchain,
1041  no_intrachain):
1042 
1043  # buildup an index map for mdl atoms that are also present in target
1044  in_target = np.zeros(self.n_atomsn_atoms, dtype=bool)
1045  for i in self.atom_indicesatom_indices.values():
1046  in_target[i] = True
1047  mdl_atom_indices = dict()
1048  for at_indices, at_hashes in zip(res_atom_indices, res_atom_hashes):
1049  for i, h in zip(at_indices, at_hashes):
1050  if in_target[i]:
1051  mdl_atom_indices[h] = i
1052 
1053  # get contacts for mdl - the contacts are only from atom pairs that
1054  # are also present in target, as we only provide the respective
1055  # hashes in mdl_atom_indices
1056  mdl_ref_indices, mdl_ref_distances = \
1057  lDDTScorer._SetupDistances(model, self.n_atomsn_atoms, mdl_atom_indices,
1058  self.inclusion_radiusinclusion_radius)
1059  if no_interchain:
1060  mdl_ref_indices, mdl_ref_distances = \
1061  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
1062  self.chain_start_indiceschain_start_indices,
1063  mdl_ref_indices,
1064  mdl_ref_distances)
1065 
1066  if no_intrachain:
1067  mdl_ref_indices, mdl_ref_distances = \
1068  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
1069  self.chain_start_indiceschain_start_indices,
1070  mdl_ref_indices,
1071  mdl_ref_distances)
1072 
1073  # update ref_indices/ref_distances => add mdl contacts
1074  for i in range(self.n_atomsn_atoms):
1075  mask = np.isin(mdl_ref_indices[i], ref_indices[i],
1076  assume_unique=True, invert=True)
1077  if np.sum(mask) > 0:
1078  added_mdl_indices = mdl_ref_indices[i][mask]
1079  ref_indices[i] = np.append(ref_indices[i],
1080  added_mdl_indices)
1081 
1082  # distances need to be recomputed from ref positions
1083  tmp = self.positionspositions.take(added_mdl_indices, axis=0)
1084  np.subtract(tmp, self.positionspositions[i][None, :], out=tmp)
1085  np.square(tmp, out=tmp)
1086  tmp = tmp.sum(axis=1)
1087  np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1088  ref_distances[i] = np.append(ref_distances[i], tmp)
1089 
1090  return (ref_indices, ref_distances)
1091 
1092 
1093 
1094  @staticmethod
1095  def _SetupDistances(structure, n_atoms, atom_index_mapping,
1096  inclusion_radius):
1097 
1098  """Compute distance related members of lDDTScorer
1099 
1100  Brute force all vs all distance computation kills lDDT for large
1101  complexes. Instead of building some KD tree data structure, we make use
1102  of expected spatial proximity of atoms in the same chain. Distances are
1103  computed as follows:
1104 
1105  - process each chain individually
1106  - perform crude collision detection
1107  - process potentially interacting chain pairs
1108  - concatenate distances from all processing steps
1109  """
1110  ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1111  ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1112 
1113  indices = [list() for _ in range(n_atoms)]
1114  distances = [list() for _ in range(n_atoms)]
1115  per_chain_pos = list()
1116  per_chain_indices = list()
1117 
1118  # Process individual chains
1119  for ch in structure.chains:
1120  pos_list = list()
1121  atom_indices = list()
1122  mask_start = list()
1123  mask_end = list()
1124  r_start_idx = 0
1125  for r_idx, r in enumerate(ch.residues):
1126  n_valid_atoms = 0
1127  for a in r.atoms:
1128  hash_code = a.handle.GetHashCode()
1129  if hash_code in atom_index_mapping:
1130  p = a.GetPos()
1131  pos_list.append(np.asarray([p[0], p[1], p[2]]))
1132  atom_indices.append(atom_index_mapping[hash_code])
1133  n_valid_atoms += 1
1134  mask_start.extend([r_start_idx] * n_valid_atoms)
1135  mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms)
1136  r_start_idx += n_valid_atoms
1137 
1138  if len(pos_list) == 0:
1139  # nothing to do...
1140  continue
1141 
1142  pos = np.vstack(pos_list)
1143  atom_indices = np.asarray(atom_indices)
1144  dists = cdist(pos, pos)
1145 
1146  # apply masks
1147  far_away = 2 * inclusion_radius
1148  for idx in range(atom_indices.shape[0]):
1149  dists[idx, range(mask_start[idx], mask_end[idx])] = far_away
1150 
1151  # fish out and store close atoms within inclusion radius
1152  within_mask = dists < inclusion_radius
1153  for idx in range(atom_indices.shape[0]):
1154  indices_to_append = atom_indices[within_mask[idx,:]]
1155  if indices_to_append.shape[0] > 0:
1156  full_at_idx = atom_indices[idx]
1157  indices[full_at_idx].append(indices_to_append)
1158  distances[full_at_idx].append(dists[idx, within_mask[idx,:]])
1159 
1160  per_chain_pos.append(pos)
1161  per_chain_indices.append(atom_indices)
1162 
1163  # perform crude collision detection
1164  min_pos = [p.min(0) for p in per_chain_pos]
1165  max_pos = [p.max(0) for p in per_chain_pos]
1166  chain_pairs = list()
1167  for idx_one in range(len(per_chain_pos)):
1168  for idx_two in range(idx_one + 1, len(per_chain_pos)):
1169  if np.max(min_pos[idx_one] - max_pos[idx_two]) > inclusion_radius:
1170  continue
1171  if np.max(min_pos[idx_two] - max_pos[idx_one]) > inclusion_radius:
1172  continue
1173  chain_pairs.append((idx_one, idx_two))
1174 
1175  # process potentially interacting chains
1176  for pair in chain_pairs:
1177  dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1178  within = dists <= inclusion_radius
1179 
1180  # process pair[0]
1181  tmp = within.sum(axis=1)
1182  for idx in range(tmp.shape[0]):
1183  if tmp[idx] > 0:
1184  # even though not being a strict requirement, we perform an
1185  # insertion here such that the indices for each atom will be
1186  # sorted after the hstack operation
1187  at_idx = per_chain_indices[pair[0]][idx]
1188  indices_to_insert = per_chain_indices[pair[1]][within[idx,:]]
1189  distances_to_insert = dists[idx, within[idx, :]]
1190  insertion_idx = len(indices[at_idx])
1191  for i in range(insertion_idx):
1192  if indices_to_insert[0] > indices[at_idx][i][0]:
1193  insertion_idx = i
1194  break
1195  indices[at_idx].insert(insertion_idx, indices_to_insert)
1196  distances[at_idx].insert(insertion_idx, distances_to_insert)
1197 
1198  # process pair[1]
1199  tmp = within.sum(axis=0)
1200  for idx in range(tmp.shape[0]):
1201  if tmp[idx] > 0:
1202  # even though not being a strict requirement, we perform an
1203  # insertion here such that the indices for each atom will be
1204  # sorted after the hstack operation
1205  at_idx = per_chain_indices[pair[1]][idx]
1206  indices_to_insert = per_chain_indices[pair[0]][within[:, idx]]
1207  distances_to_insert = dists[within[:, idx], idx]
1208  insertion_idx = len(indices[at_idx])
1209  for i in range(insertion_idx):
1210  if indices_to_insert[0] > indices[at_idx][i][0]:
1211  insertion_idx = i
1212  break
1213  indices[at_idx].insert(insertion_idx, indices_to_insert)
1214  distances[at_idx].insert(insertion_idx, distances_to_insert)
1215 
1216  # concatenate distances from all processing steps
1217  for at_idx in range(n_atoms):
1218  if len(indices[at_idx]) > 0:
1219  ref_indices[at_idx] = np.hstack(indices[at_idx])
1220  ref_distances[at_idx] = np.hstack(distances[at_idx])
1221 
1222  return (ref_indices, ref_distances)
1223 
1224  @staticmethod
1225  def _SetupDistancesSC(n_atoms, chain_start_indices,
1226  ref_indices, ref_distances):
1227  """Select subset of contacts only covering intra-chain contacts
1228  """
1229  # init
1230  ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1231  ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1232 
1233  n_chains = len(chain_start_indices)
1234  for ch_idx in range(n_chains):
1235  chain_s = chain_start_indices[ch_idx]
1236  chain_e = n_atoms
1237  if ch_idx + 1 < n_chains:
1238  chain_e = chain_start_indices[ch_idx+1]
1239  for i in range(chain_s, chain_e):
1240  if len(ref_indices[i]) > 0:
1241  intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
1242  ref_indices[i]<chain_e))[0]
1243  ref_indices_sc[i] = ref_indices[i][intra_idx]
1244  ref_distances_sc[i] = ref_distances[i][intra_idx]
1245 
1246  return (ref_indices_sc, ref_distances_sc)
1247 
1248  @staticmethod
1249  def _SetupDistancesIC(n_atoms, chain_start_indices,
1250  ref_indices, ref_distances):
1251  """Select subset of contacts only covering inter-chain contacts
1252  """
1253  # init
1254  ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1255  ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1256 
1257  n_chains = len(chain_start_indices)
1258  for ch_idx in range(n_chains):
1259  chain_s = chain_start_indices[ch_idx]
1260  chain_e = n_atoms
1261  if ch_idx + 1 < n_chains:
1262  chain_e = chain_start_indices[ch_idx+1]
1263  for i in range(chain_s, chain_e):
1264  if len(ref_indices[i]) > 0:
1265  inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s,
1266  ref_indices[i]>=chain_e))[0]
1267  ref_indices_ic[i] = ref_indices[i][inter_idx]
1268  ref_distances_ic[i] = ref_distances[i][inter_idx]
1269 
1270  return (ref_indices_ic, ref_distances_ic)
1271 
1272  @staticmethod
1273  def _NonSymDistances(n_atoms, symmetric_atoms, ref_indices, ref_distances):
1274  """Transfer indices/distances of non-symmetric atoms and return
1275  """
1276 
1277  sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1278  sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1279 
1280  for idx in symmetric_atoms:
1281  indices = list()
1282  distances = list()
1283  for i, d in zip(ref_indices[idx], ref_distances[idx]):
1284  if i not in symmetric_atoms:
1285  indices.append(i)
1286  distances.append(d)
1287  sym_ref_indices[idx] = indices
1288  sym_ref_distances[idx] = np.asarray(distances)
1289 
1290  return (sym_ref_indices, sym_ref_distances)
1291 
1292  def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances):
1293  """Computes number of distance differences within given thresholds
1294 
1295  returns np.array with len(thresholds) elements
1296  """
1297  a_p = pos[atom_idx, :]
1298  tmp = pos.take(ref_indices[atom_idx], axis=0)
1299  np.subtract(tmp, a_p[None, :], out=tmp)
1300  np.square(tmp, out=tmp)
1301  tmp = tmp.sum(axis=1)
1302  np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1303  np.subtract(ref_distances[atom_idx], tmp, out=tmp)
1304  np.absolute(tmp, out=tmp) # absolute dist diffs
1305  return np.asarray([(tmp <= thresh).sum() for thresh in thresholds],
1306  dtype=np.int32)
1307 
1308  def _EvalAtoms(
1309  self, pos, atom_indices, thresholds, ref_indices, ref_distances
1310  ):
1311  """Calls _EvalAtom for several atoms and sums up the computed number
1312  of distance differences within given thresholds
1313 
1314  returns numpy matrix of shape (n_atoms, len(threshold))
1315  """
1316  conserved = np.zeros((len(atom_indices), len(thresholds)),
1317  dtype=np.int32)
1318  for a_idx, a in enumerate(atom_indices):
1319  conserved[a_idx, :] = self._EvalAtom_EvalAtom(pos, a, thresholds,
1320  ref_indices, ref_distances)
1321  return conserved
1322 
1323  def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices,
1324  ref_distances):
1325  """Calls _EvalAtoms for a bunch of residues
1326 
1327  residues are defined in *res_atom_indices* as lists of atom indices
1328  returns numpy matrix of shape (n_residues, len(thresholds)).
1329  """
1330  conserved = np.zeros((len(res_atom_indices), len(thresholds)),
1331  dtype=np.int32)
1332  for rai_idx, rai in enumerate(res_atom_indices):
1333  conserved[rai_idx,:] = np.sum(self._EvalAtoms_EvalAtoms(pos, rai, thresholds,
1334  ref_indices, ref_distances), axis=0)
1335  return conserved
1336 
1337  def _ProcessSequenceSeparation(self):
1338  if self.sequence_separationsequence_separation != 0:
1339  raise NotImplementedError("Congratulations! You're the first one "
1340  "requesting a non-default "
1341  "sequence_separation in the new and "
1342  "awesome lDDT implementation. A crate of "
1343  "beer for Gabriel and he'll implement "
1344  "it.")
1345 
1346  def _GetNExp(self, atom_idx, ref_indices):
1347  """Returns number of close atoms around one or several atoms
1348  """
1349  if isinstance(atom_idx, int):
1350  return len(ref_indices[atom_idx])
1351  elif isinstance(atom_idx, list):
1352  return sum([len(ref_indices[idx]) for idx in atom_idx])
1353  else:
1354  raise RuntimeError("invalid input type")
1355 
1356  def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices,
1357  sym_ref_distances):
1358  """Swaps symmetric positions in-place in order to maximize lDDT scores
1359  towards non-symmetric atoms.
1360  """
1361  for sym in symmetries:
1362 
1363  atom_indices = list()
1364  for sym_tuple in sym:
1365  atom_indices += [sym_tuple[0], sym_tuple[1]]
1366  tot = self._GetNExp_GetNExp(atom_indices, sym_ref_indices)
1367 
1368  if tot == 0:
1369  continue # nothing to do
1370 
1371  # score as is
1372  sym_one_conserved = self._EvalAtoms_EvalAtoms(
1373  pos,
1374  atom_indices,
1375  thresholds,
1376  sym_ref_indices,
1377  sym_ref_distances,
1378  )
1379 
1380  # switch positions and score again
1381  for pair in sym:
1382  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1383 
1384  sym_two_conserved = self._EvalAtoms_EvalAtoms(
1385  pos,
1386  atom_indices,
1387  thresholds,
1388  sym_ref_indices,
1389  sym_ref_distances,
1390  )
1391 
1392  sym_one_score = np.sum(sym_one_conserved) / (len(thresholds) * tot)
1393  sym_two_score = np.sum(sym_two_conserved) / (len(thresholds) * tot)
1394 
1395  if sym_one_score >= sym_two_score:
1396  # switch back, initial positions were better or equal
1397  # for the equal case: we still switch back to reproduce the old
1398  # lDDT behaviour
1399  for pair in sym:
1400  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
def __init__(self, atom_names)
Definition: lddt.py:30
def AddSymmetricCompound(self, name, symmetric_atoms)
Definition: lddt.py:65
def _SetupCompound(self, r, compound_lib, custom_compounds, symmetry_settings, bb_only)
Definition: lddt.py:997
def lDDT(self, model, thresholds=[0.5, 1.0, 2.0, 4.0], local_lddt_prop=None, local_contact_prop=None, chain_mapping=None, no_interchain=False, no_intrachain=False, penalize_extra_chains=False, residue_mapping=None, return_dist_test=False, check_resnames=True, add_mdl_contacts=False, interaction_data=None, set_atom_props=False)
Definition: lddt.py:442
def _ProcessModel(self, model, chain_mapping, residue_mapping=None, thresholds=[0.5, 1.0, 2.0, 4.0], check_resnames=True)
Definition: lddt.py:743
def _GetChainRNums(self, ch, residue_mapping, model_ch_name, target_ch_name)
Definition: lddt.py:852
def _ProcessSequenceSeparation(self)
Definition: lddt.py:1337
def sym_ref_distances(self)
Definition: lddt.py:349
def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances)
Definition: lddt.py:1357
def ref_distances_ic(self)
Definition: lddt.py:407
def _GetTargetResidueNumbers(self, target, seqres_mapping)
Definition: lddt.py:953
def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances)
Definition: lddt.py:1292
def sym_ref_distances_ic(self)
Definition: lddt.py:427
def sym_ref_distances_sc(self)
Definition: lddt.py:387
def sym_ref_indices_ic(self)
Definition: lddt.py:417
def GetNChainContacts(self, target_chain, no_interchain=False)
Definition: lddt.py:718
def sym_ref_indices_sc(self)
Definition: lddt.py:377
def ref_distances_sc(self)
Definition: lddt.py:367
def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes, ref_indices, ref_distances, no_interchain, no_intrachain)
Definition: lddt.py:1041
def _GetNExp(self, atom_idx, ref_indices)
Definition: lddt.py:1346
def __init__(self, target, compound_lib=None, custom_compounds=None, inclusion_radius=15, sequence_separation=0, symmetry_settings=None, seqres_mapping=dict(), bb_only=False)
Definition: lddt.py:213
def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings, seqres_mapping, bb_only)
Definition: lddt.py:904
def _GetExtraModelChainPenalty(self, model, chain_mapping)
Definition: lddt.py:835
def _EvalAtoms(self, pos, atom_indices, thresholds, ref_indices, ref_distances)
Definition: lddt.py:1310
def GetDefaultSymmetrySettings()
Definition: lddt.py:83
def cdist(p1, p2)
Definition: lddt.py:12