OpenStructure
lddt.py
Go to the documentation of this file.
1 import numpy as np
2 
3 from ost import mol
4 from ost import conop
5 
6 # use cdist of scipy, fallback to (slower) numpy implementation if scipy is not
7 # available
8 try:
9  from scipy.spatial.distance import cdist
10 except:
11  def cdist(p1, p2):
12  x2 = np.sum(p1**2, axis=1) # (m)
13  y2 = np.sum(p2**2, axis=1) # (n)
14  xy = np.matmul(p1, p2.T) # (m, n)
15  x2 = x2.reshape(-1, 1)
16  return np.sqrt(x2 - 2*xy + y2) # (m, n)
17 
19  """ Defines atoms for custom compounds
20 
21  lDDT requires the reference atoms of a compound which are typically
22  extracted from a :class:`ost.conop.CompoundLib`. This lightweight
23  container allows to handle arbitrary compounds which are not
24  necessarily in the compound library.
25 
26  :param atom_names: Names of atoms of custom compound
27  :type atom_names: :class:`list` of :class:`str`
28  """
29  def __init__(self, atom_names):
30  self.atom_namesatom_names = atom_names
31 
32  @staticmethod
33  def FromResidue(res):
34  """ Construct custom compound from residue
35 
36  :param res: Residue from which reference atom names are extracted,
37  hydrogen/deuterium atoms are filtered out
38  :type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle`
39  :returns: :class:`CustomCompound`
40  """
41  at_names = [a.name for a in res.atoms if a.element not in ["H", "D"]]
42  if len(at_names) != len(set(at_names)):
43  raise RuntimeError("Duplicate atoms detected in CustomCompound")
44  compound = CustomCompound(at_names)
45  return compound
46 
48  """Container for symmetric compounds
49 
50  lDDT considers symmetries and selects the one resulting in the highest
51  possible score.
52 
53  A symmetry is defined as a renaming operation on one or more atoms that
54  leads to a chemically equivalent residue. Example would be OD1 and OD2 in
55  ASP => renaming OD1 to OD2 and vice versa gives a chemically equivalent
56  residue.
57 
58  Use :func:`AddSymmetricCompound` to define a symmetry which can then
59  directly be accessed through the *symmetric_compounds* member.
60  """
61  def __init__(self):
62  self.symmetric_compoundssymmetric_compounds = dict()
63 
64  def AddSymmetricCompound(self, name, symmetric_atoms):
65  """Adds symmetry for compound with *name*
66 
67  :param name: Name of compound with symmetry
68  :type name: :class:`str`
69  :param symmetric_atoms: Pairs of atom names that define renaming
70  operation, i.e. after applying all switches
71  defined in the tuples, the resulting residue
72  should be chemically equivalent. Atom names
73  must refer to the PDB component dictionary.
74  :type symmetric_atoms: :class:`list` of :class:`tuple`
75  """
76  for pair in symmetric_atoms:
77  if len(pair) != 2:
78  raise RuntimeError("Expect pairs when defining symmetries")
79  self.symmetric_compoundssymmetric_compounds[name] = symmetric_atoms
80 
81 
83  """Constructs and returns :class:`SymmetrySettings` object for natural amino
84  acids
85  """
86  symmetry_settings = SymmetrySettings()
87 
88  # ASP
89  symmetry_settings.AddSymmetricCompound("ASP", [("OD1", "OD2")])
90 
91  # GLU
92  symmetry_settings.AddSymmetricCompound("GLU", [("OE1", "OE2")])
93 
94  # LEU
95  symmetry_settings.AddSymmetricCompound("LEU", [("CD1", "CD2")])
96 
97  # VAL
98  symmetry_settings.AddSymmetricCompound("VAL", [("CG1", "CG2")])
99 
100  # ARG
101  symmetry_settings.AddSymmetricCompound("ARG", [("NH1", "NH2")])
102 
103  # PHE
104  symmetry_settings.AddSymmetricCompound(
105  "PHE", [("CD1", "CD2"), ("CE1", "CE2")]
106  )
107 
108  # TYR
109  symmetry_settings.AddSymmetricCompound(
110  "TYR", [("CD1", "CD2"), ("CE1", "CE2")]
111  )
112 
113  # nucleotides
114  nuc_names = ["A", "C", "G", "U", "DA", "DC", "DG", "DT"]
115  for nuc_name in nuc_names:
116  symmetry_settings.AddSymmetricCompound(
117  nuc_name, [("OP1","OP2")]
118  )
119 
120  return symmetry_settings
121 
122 
124  """lDDT scorer object for a specific target
125 
126  Sets up everything to score models of that target. lDDT (local distance
127  difference test) is defined as fraction of pairwise distances which exhibit
128  a difference < threshold when considering target and model. In case of
129  multiple thresholds, the average is returned. See
130 
131  V. Mariani, M. Biasini, A. Barbato, T. Schwede, lDDT : A local
132  superposition-free score for comparing protein structures and models using
133  distance difference tests, Bioinformatics, 2013
134 
135  :param target: The target
136  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
137  :param compound_lib: Compound library from which a compound for each residue
138  is extracted based on its name. Uses
139  :func:`ost.conop.GetDefaultLib` if not given, raises
140  if this returns no valid compound library. Atoms
141  defined in the compound are searched in the residue and
142  build the reference for scoring. If the residue has
143  atoms with names ["A", "B", "C"] but the corresponding
144  compound only has ["A", "B"], "A" and "B" are
145  considered for scoring. If the residue has atoms
146  ["A", "B"] but the compound has ["A", "B", "C"], "C" is
147  considered missing and does not influence scoring, even
148  if present in the model.
149  :param custom_compounds: Custom compounds defining reference atoms. If
150  given, *custom_compounds* take precedent over
151  *compound_lib*.
152  :type custom_compounds: :class:`dict` with residue names (:class:`str`) as
153  key and :class:`CustomCompound` as value.
154  :type compound_lib: :class:`ost.conop.CompoundLib`
155  :param inclusion_radius: All pairwise distances < *inclusion_radius* are
156  considered for scoring
157  :type inclusion_radius: :class:`float`
158  :param sequence_separation: Only pairwise distances between atoms of
159  residues which are further apart than this
160  threshold are considered. Residue distance is
161  based on resnum. The default (0) considers all
162  pairwise distances except intra-residue
163  distances.
164  :type sequence_separation: :class:`int`
165  :param symmetry_settings: Define residues exhibiting internal symmetry, uses
166  :func:`GetDefaultSymmetrySettings` if not given.
167  :type symmetry_settings: :class:`SymmetrySettings`
168  :param seqres_mapping: Mapping of model residues at the scoring stage
169  happens with residue numbers defining their location
170  in a reference sequence (SEQRES) using one based
171  indexing. If the residue numbers in *target* don't
172  correspond to that SEQRES, you can specify the
173  mapping manually. You can provide a dictionary to
174  specify a reference sequence (SEQRES) for one or more
175  chain(s). Key: chain name, value: alignment
176  (seq1: SEQRES, seq2: sequence of residues in chain).
177  Example: The residues in a chain with name "A" have
178  sequence "YEAH" and residue numbers [42,43,44,45].
179  You can provide an alignment with seq1 "``HELLYEAH``"
180  and seq2 "``----YEAH``". "Y" gets assigned residue
181  number 5, "E" gets assigned 6 and so on no matter
182  what the original residue numbers were.
183  :type seqres_mapping: :class:`dict` (key: :class:`str`, value:
184  :class:`ost.seq.AlignmentHandle`)
185  :param bb_only: Only consider atoms with name "CA" in case of amino acids and
186  "C3'" for Nucleotides. this invalidates *compound_lib*.
187  Raises if any residue in *target* is not
188  `r.chem_class.IsPeptideLinking()` or
189  `r.chem_class.IsNucleotideLinking()`
190  :type bb_only: :class:`bool`
191  :raises: :class:`RuntimeError` if *target* contains compound which is not in
192  *compound_lib*, :class:`RuntimeError` if *symmetry_settings*
193  specifies symmetric atoms that are not present in the according
194  compound in *compound_lib*, :class:`RuntimeError` if
195  *seqres_mapping* is not provided and *target* contains residue
196  numbers with insertion codes or the residue numbers for each chain
197  are not monotonically increasing, :class:`RuntimeError` if
198  *seqres_mapping* is provided but an alignment is invalid
199  (seq1 contains gaps, mismatch in seq1/seq2, seq2 does not match
200  residues in corresponding chains).
201  """
202  def __init__(
203  self,
204  target,
205  compound_lib=None,
206  custom_compounds=None,
207  inclusion_radius=15,
208  sequence_separation=0,
209  symmetry_settings=None,
210  seqres_mapping=dict(),
211  bb_only=False
212  ):
213 
214  self.targettarget = target
215  self.inclusion_radiusinclusion_radius = inclusion_radius
216  self.sequence_separationsequence_separation = sequence_separation
217  if compound_lib is None:
218  compound_lib = conop.GetDefaultLib()
219  if compound_lib is None:
220  raise RuntimeError("No compound_lib given and conop.GetDefaultLib "
221  "returns no valid compound library")
222  self.compound_libcompound_lib = compound_lib
223  self.custom_compoundscustom_compounds = custom_compounds
224  if symmetry_settings is None:
226  else:
227  self.symmetry_settingssymmetry_settings = symmetry_settings
228 
229  # whether to only consider atoms with name "CA" (amino acids) or C3'
230  # (nucleotides), invalidates *compound_lib*
231  self.bb_onlybb_only=bb_only
232 
233  # names of heavy atoms of each unique compound present in *target* as
234  # extracted from *compound_lib*, e.g.
235  # self.compound_anames["GLY"] = ["N", "CA", "C", "O"]
236  self.compound_anamescompound_anames = dict()
237 
238  # stores symmetry information for those compounds as defined in
239  # *symmetry_settings*
240  self.compound_symmetric_atomscompound_symmetric_atoms = dict()
241 
242  # list of len(target.chains) containing all chain names in *target*
243  self.chain_nameschain_names = list()
244 
245  # list of len(target.residues) containing all compound names in *target*
246  self.compound_namescompound_names = list()
247 
248  # list of len(target.residues) defining start pos in internal reference
249  # positions for each residue
250  self.res_start_indicesres_start_indices = list()
251 
252  # list of len(target.residues) defining residue numbers in target
253  self.res_resnumsres_resnums = list()
254 
255  # list of len(target.chains) defining start pos in internal reference
256  # positions for each chain
257  self.chain_start_indiceschain_start_indices = list()
258 
259  # list of len(target.chains) defining start pos in self.compound_names
260  # for each chain
261  self.chain_res_start_indiceschain_res_start_indices = list()
262 
263  # maps residues in *target* to indices in
264  # self.compound_names/self.res_start_indices. A residue gets identified
265  # by a tuple (first element: chain name, second element: residue number,
266  # residue number is either the actual residue number in *target* or
267  # given by *seqres_mapping*)
268  self.res_mapperres_mapper = dict()
269 
270  # number of atoms as specified in compounds. not all are necessarily
271  # covered by structure
272  self.n_atomsn_atoms = None
273 
274  # stores an index for each AtomHandle in *target*
275  # (atom hashcode => index)
276  self.atom_indicesatom_indices = dict()
277 
278  # store indices of all atoms that have symmetry properties
279  self.symmetric_atomssymmetric_atoms = set()
280 
281  # the actual target positions in a numpy array of shape (self.n_atoms,3)
282  self.positionspositions = None
283 
284  # setup members defined above
285  self._SetupEnv_SetupEnv(self.compound_libcompound_lib, self.custom_compoundscustom_compounds,
286  self.symmetry_settingssymmetry_settings, seqres_mapping, self.bb_onlybb_only)
287 
288  # distance related members are lazily computed as they're affected
289  # by different flavours of lDDT (e.g. lDDT including inter-chain
290  # contacts or not etc.)
291 
292  # stores for each atom the other atoms within inclusion_radius
293  self._ref_indices_ref_indices = None
294  # the corresponding distances
295  self._ref_distances_ref_distances = None
296 
297  # The following lists will be sparsely populated. We keep for each
298  # symmetry related atom the distances towards all atoms which are NOT
299  # affected by symmetry. So we can evaluate two symmetric versions
300  # against the fixed stuff later on and select the better scoring one.
301  self._sym_ref_indices_sym_ref_indices = None
302  self._sym_ref_distances_sym_ref_distances = None
303 
304  # exactly the same as above but without interchain contacts
305  # => single-chain (sc)
306  self._ref_indices_sc_ref_indices_sc = None
307  self._ref_distances_sc_ref_distances_sc = None
308  self._sym_ref_indices_sc_sym_ref_indices_sc = None
309  self._sym_ref_distances_sc_sym_ref_distances_sc = None
310 
311  # exactly the same as above but without intrachain contacts
312  # => inter-chain (ic)
313  self._ref_indices_ic_ref_indices_ic = None
314  self._ref_distances_ic_ref_distances_ic = None
315  self._sym_ref_indices_ic_sym_ref_indices_ic = None
316  self._sym_ref_distances_ic_sym_ref_distances_ic = None
317 
318  # input parameter checking
319  self._ProcessSequenceSeparation_ProcessSequenceSeparation()
320 
321  @property
322  def ref_indices(self):
323  if self._ref_indices_ref_indices is None:
324  self._ref_indices_ref_indices, self._ref_distances_ref_distances = \
325  lDDTScorer._SetupDistances(self.targettarget, self.n_atomsn_atoms,
326  self.atom_indicesatom_indices,
327  self.inclusion_radiusinclusion_radius)
328  return self._ref_indices_ref_indices
329 
330  @property
331  def ref_distances(self):
332  if self._ref_distances_ref_distances is None:
333  self._ref_indices_ref_indices, self._ref_distances_ref_distances = \
334  lDDTScorer._SetupDistances(self.targettarget, self.n_atomsn_atoms,
335  self.atom_indicesatom_indices,
336  self.inclusion_radiusinclusion_radius)
337  return self._ref_distances_ref_distances
338 
339  @property
340  def sym_ref_indices(self):
341  if self._sym_ref_indices_sym_ref_indices is None:
342  self._sym_ref_indices_sym_ref_indices, self._sym_ref_distances_sym_ref_distances = \
343  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
344  self.ref_indicesref_indices, self.ref_distancesref_distances)
345  return self._sym_ref_indices_sym_ref_indices
346 
347  @property
348  def sym_ref_distances(self):
349  if self._sym_ref_distances_sym_ref_distances is None:
350  self._sym_ref_indices_sym_ref_indices, self._sym_ref_distances_sym_ref_distances = \
351  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
352  self.ref_indicesref_indices, self.ref_distancesref_distances)
353  return self._sym_ref_distances_sym_ref_distances
354 
355  @property
356  def ref_indices_sc(self):
357  if self._ref_indices_sc_ref_indices_sc is None:
358  self._ref_indices_sc_ref_indices_sc, self._ref_distances_sc_ref_distances_sc = \
359  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
360  self.chain_start_indiceschain_start_indices,
361  self.ref_indicesref_indices,
362  self.ref_distancesref_distances)
363  return self._ref_indices_sc_ref_indices_sc
364 
365  @property
366  def ref_distances_sc(self):
367  if self._ref_distances_sc_ref_distances_sc is None:
368  self._ref_indices_sc_ref_indices_sc, self._ref_distances_sc_ref_distances_sc = \
369  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
370  self.chain_start_indiceschain_start_indices,
371  self.ref_indicesref_indices,
372  self.ref_distancesref_distances)
373  return self._ref_distances_sc_ref_distances_sc
374 
375  @property
377  if self._sym_ref_indices_sc_sym_ref_indices_sc is None:
378  self._sym_ref_indices_sc_sym_ref_indices_sc, self._sym_ref_distances_sc_sym_ref_distances_sc = \
379  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
380  self.symmetric_atomssymmetric_atoms,
381  self.ref_indices_scref_indices_sc,
382  self.ref_distances_scref_distances_sc)
383  return self._sym_ref_indices_sc_sym_ref_indices_sc
384 
385  @property
387  if self._sym_ref_distances_sc_sym_ref_distances_sc is None:
388  self._sym_ref_indices_sc_sym_ref_indices_sc, self._sym_ref_distances_sc_sym_ref_distances_sc = \
389  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
390  self.symmetric_atomssymmetric_atoms,
391  self.ref_indices_scref_indices_sc,
392  self.ref_distances_scref_distances_sc)
393  return self._sym_ref_distances_sc_sym_ref_distances_sc
394 
395  @property
396  def ref_indices_ic(self):
397  if self._ref_indices_ic_ref_indices_ic is None:
398  self._ref_indices_ic_ref_indices_ic, self._ref_distances_ic_ref_distances_ic = \
399  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
400  self.chain_start_indiceschain_start_indices,
401  self.ref_indicesref_indices,
402  self.ref_distancesref_distances)
403  return self._ref_indices_ic_ref_indices_ic
404 
405  @property
406  def ref_distances_ic(self):
407  if self._ref_distances_ic_ref_distances_ic is None:
408  self._ref_indices_ic_ref_indices_ic, self._ref_distances_ic_ref_distances_ic = \
409  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
410  self.chain_start_indiceschain_start_indices,
411  self.ref_indicesref_indices,
412  self.ref_distancesref_distances)
413  return self._ref_distances_ic_ref_distances_ic
414 
415  @property
417  if self._sym_ref_indices_ic_sym_ref_indices_ic is None:
418  self._sym_ref_indices_ic_sym_ref_indices_ic, self._sym_ref_distances_ic_sym_ref_distances_ic = \
419  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
420  self.symmetric_atomssymmetric_atoms,
421  self.ref_indices_icref_indices_ic,
422  self.ref_distances_icref_distances_ic)
423  return self._sym_ref_indices_ic_sym_ref_indices_ic
424 
425  @property
427  if self._sym_ref_distances_ic_sym_ref_distances_ic is None:
428  self._sym_ref_indices_ic_sym_ref_indices_ic, self._sym_ref_distances_ic_sym_ref_distances_ic = \
429  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
430  self.symmetric_atomssymmetric_atoms,
431  self.ref_indices_icref_indices_ic,
432  self.ref_distances_icref_distances_ic)
433  return self._sym_ref_distances_ic_sym_ref_distances_ic
434 
435  def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
436  local_lddt_prop=None, local_contact_prop=None,
437  chain_mapping=None, no_interchain=False,
438  no_intrachain=False, penalize_extra_chains=False,
439  residue_mapping=None, return_dist_test=False,
440  check_resnames=True, add_mdl_contacts=False,
441  interaction_data=None):
442  """Computes lDDT of *model* - globally and per-residue
443 
444  :param model: Model to be scored - models are preferably scored upon
445  performing stereo-chemistry checks in order to punish for
446  non-sensical irregularities. This must be done separately
447  as a pre-processing step. Target contacts that are not
448  covered by *model* are considered not conserved, thus
449  decreasing lDDT score. This also includes missing model
450  chains or model chains for which no mapping is provided in
451  *chain_mapping*.
452  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
453  :param thresholds: Thresholds of distance differences to be considered
454  as correct - see docs in constructor for more info.
455  default: [0.5, 1.0, 2.0, 4.0]
456  :type thresholds: :class:`list` of :class:`floats`
457  :param local_lddt_prop: If set, per-residue scores will be assigned as
458  generic float property of that name
459  :type local_lddt_prop: :class:`str`
460  :param local_contact_prop: If set, number of expected contacts as well
461  as number of conserved contacts will be
462  assigned as generic int property.
463  Excected contacts will be set as
464  <local_contact_prop>_exp, conserved contacts
465  as <local_contact_prop>_cons. Values
466  are summed over all thresholds.
467  :type local_contact_prop: :class:`str`
468  :param chain_mapping: Mapping of model chains (key) onto target chains
469  (value). This is required if target or model have
470  more than one chain.
471  :type chain_mapping: :class:`dict` with :class:`str` as keys/values
472  :param no_interchain: Whether to exclude interchain contacts
473  :type no_interchain: :class:`bool`
474  :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
475  consider interface related contacts)
476  :type no_intrachain: :class:`bool`
477  :param penalize_extra_chains: Whether to include a fixed penalty for
478  additional chains in the model that are
479  not mapped to the target. ONLY AFFECTS
480  RETURNED GLOBAL SCORE. In detail: adds the
481  number of intra-chain contacts of each
482  extra chain to the expected contacts, thus
483  adding a penalty.
484  :type penalize_extra_chains: :class:`bool`
485  :param residue_mapping: By default, residue mapping is based on residue
486  numbers. That means, a model chain and the
487  respective target chain map to the same
488  underlying reference sequence (SEQRES).
489  Alternatively, you can specify one or
490  several alignment(s) between model and target
491  chains by providing a dictionary. key: Name
492  of chain in model (respective target chain is
493  extracted from *chain_mapping*),
494  value: Alignment with first sequence
495  corresponding to target chain and second
496  sequence to model chain. There is NO reference
497  sequence involved, so the two sequences MUST
498  exactly match the actual residues observed in
499  the respective target/model chains (ATOMSEQ).
500  :type residue_mapping: :class:`dict` with key: :class:`str`,
501  value: :class:`ost.seq.AlignmentHandle`
502  :param return_dist_test: Whether to additionally return the underlying
503  per-residue data for the distance difference
504  test. Adds five objects to the return tuple.
505  First: Number of total contacts summed over all
506  thresholds
507  Second: Number of conserved contacts summed
508  over all thresholds
509  Third: list with length of scored residues.
510  Contains indices referring to model.residues.
511  Fourth: numpy array of size
512  len(scored_residues) containing the number of
513  total contacts,
514  Fifth: numpy matrix of shape
515  (len(scored_residues), len(thresholds))
516  specifying how many for each threshold are
517  conserved.
518  :param check_resnames: On by default. Enforces residue name matches
519  between mapped model and target residues.
520  :type check_resnames: :class:`bool`
521  :param add_mdl_contacts: Adds model contacts - Only using contacts that
522  are within a certain distance threshold in the
523  target does not penalize for added model
524  contacts. If set to True, this flag will also
525  consider target contacts that are within the
526  specified distance threshold in the model but
527  not necessarily in the target. No contact will
528  be added if the respective atom pair is not
529  resolved in the target.
530  :type add_mdl_contacts: :class:`bool`
531  :param interaction_data: Pro param - don't use
532  :type interaction_data: :class:`tuple`
533 
534  :returns: global and per-residue lDDT scores as a tuple -
535  first element is global lDDT score (None if *target* has no
536  contacts) and second element a list of per-residue scores with
537  length len(*model*.residues). None is assigned to residues that
538  are not covered by target. If a residue is covered but has no
539  contacts in *target*, 0.0 is assigned.
540  """
541  if chain_mapping is None:
542  if len(self.chain_nameschain_names) > 1 or len(model.chains) > 1:
543  raise NotImplementedError("Must provide chain mapping if "
544  "target or model have > 1 chains.")
545  chain_mapping = {model.chains[0].GetName(): self.chain_nameschain_names[0]}
546  else:
547  # check whether chains specified in mapping exist
548  for model_chain, target_chain in chain_mapping.items():
549  if target_chain not in self.chain_nameschain_names:
550  raise RuntimeError(f"Target chain specified in "
551  f"chain_mapping ({target_chain}) does "
552  f"not exist. Target has chains: "
553  f"{self.chain_names}")
554  ch = model.FindChain(model_chain)
555  if not ch.IsValid():
556  raise RuntimeError(f"Model chain specified in "
557  f"chain_mapping ({model_chain}) does "
558  f"not exist. Model has chains: "
559  f"{[c.GetName() for c in model.chains]}")
560 
561  # data objects defining model data - see _ProcessModel for rough
562  # description
563  pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
564  res_indices, ref_res_indices, symmetries = \
565  self._ProcessModel_ProcessModel(model, chain_mapping,
566  residue_mapping = residue_mapping,
567  thresholds = thresholds,
568  check_resnames = check_resnames)
569 
570  if no_interchain and no_intrachain:
571  raise RuntimeError("no_interchain and no_intrachain flags are "
572  "mutually exclusive")
573 
574 
575  sym_ref_indices = None
576  sym_ref_distances = None
577  ref_indices = None
578  ref_distances = None
579 
580  if interaction_data is None:
581  if no_interchain:
582  sym_ref_indices = self.sym_ref_indices_scsym_ref_indices_sc
583  sym_ref_distances = self.sym_ref_distances_scsym_ref_distances_sc
584  ref_indices = self.ref_indices_scref_indices_sc
585  ref_distances = self.ref_distances_scref_distances_sc
586  elif no_intrachain:
587  sym_ref_indices = self.sym_ref_indices_icsym_ref_indices_ic
588  sym_ref_distances = self.sym_ref_distances_icsym_ref_distances_ic
589  ref_indices = self.ref_indices_icref_indices_ic
590  ref_distances = self.ref_distances_icref_distances_ic
591  else:
592  sym_ref_indices = self.sym_ref_indicessym_ref_indices
593  sym_ref_distances = self.sym_ref_distancessym_ref_distances
594  ref_indices = self.ref_indicesref_indices
595  ref_distances = self.ref_distancesref_distances
596 
597  if add_mdl_contacts:
598  ref_indices, ref_distances = \
599  self._AddMdlContacts_AddMdlContacts(model, res_atom_indices, res_atom_hashes,
600  ref_indices, ref_distances,
601  no_interchain, no_intrachain)
602  # recompute symmetry related indices/distances
603  sym_ref_indices, sym_ref_distances = \
604  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
605  ref_indices, ref_distances)
606  else:
607  sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \
608  interaction_data
609 
610  self._ResolveSymmetries_ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
611  sym_ref_distances)
612 
613  per_res_exp = np.asarray([self._GetNExp_GetNExp(res_ref_atom_indices[idx],
614  ref_indices) for idx in range(len(res_indices))], dtype=np.int32)
615  per_res_conserved = self._EvalResidues_EvalResidues(pos, thresholds,
616  res_atom_indices,
617  ref_indices, ref_distances)
618 
619  n_thresh = len(thresholds)
620 
621  # do per-residue scores
622  per_res_lDDT = [None] * len(model.residues)
623  for idx in range(len(res_indices)):
624  n_exp = n_thresh * per_res_exp[idx]
625  if n_exp > 0:
626  score = np.sum(per_res_conserved[idx,:]) / n_exp
627  per_res_lDDT[res_indices[idx]] = score
628  else:
629  per_res_lDDT[res_indices[idx]] = 0.0
630 
631  # do full model score
632  n_distances = sum([len(x) for x in ref_indices])
633  if penalize_extra_chains:
634  n_distances += self._GetExtraModelChainPenalty_GetExtraModelChainPenalty(model, chain_mapping)
635 
636  lDDT_tot = int(n_thresh * n_distances)
637  lDDT_cons = int(np.sum(per_res_conserved))
638  lDDT = None
639  if lDDT_tot > 0:
640  lDDT = float(lDDT_cons) / lDDT_tot
641 
642  # set properties if necessary
643  if local_lddt_prop:
644  residues = model.residues
645  for idx in res_indices:
646  residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx])
647 
648  if local_contact_prop:
649  residues = model.residues
650  exp_prop = local_contact_prop + "_exp"
651  conserved_prop = local_contact_prop + "_cons"
652 
653  for i, r_idx in enumerate(res_indices):
654  residues[r_idx].SetIntProp(exp_prop,
655  n_thresh * int(per_res_exp[i]))
656  residues[r_idx].SetIntProp(conserved_prop,
657  int(np.sum(per_res_conserved[i,:])))
658 
659  if return_dist_test:
660  return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
661  per_res_exp, per_res_conserved
662  else:
663  return lDDT, per_res_lDDT
664 
665  def GetNChainContacts(self, target_chain, no_interchain=False):
666  """Returns number of contacts expected for a certain chain in *target*
667 
668  :param target_chain: Chain in *target* for which you want the number
669  of expected contacts
670  :type target_chain: :class:`str`
671  :param no_interchain: Whether to exclude interchain contacts
672  :type no_interchain: :class:`bool`
673  :raises: :class:`RuntimeError` if specified chain doesnt exist
674  """
675  if target_chain not in self.chain_nameschain_names:
676  raise RuntimeError(f"Specified chain name ({target_chain}) not in "
677  f"target")
678  ch_idx = self.chain_nameschain_names.index(target_chain)
679  s = self.chain_start_indiceschain_start_indices[ch_idx]
680  e = self.n_atomsn_atoms
681  if ch_idx + 1 < len(self.chain_nameschain_names):
682  e = self.chain_start_indiceschain_start_indices[ch_idx+1]
683  if no_interchain:
684  return self._GetNExp_GetNExp(list(range(s, e)), self.ref_indices_scref_indices_sc)
685  else:
686  return self._GetNExp_GetNExp(list(range(s, e)), self.ref_indicesref_indices)
687 
688  def _ProcessModel(self, model, chain_mapping, residue_mapping = None,
689  thresholds = [0.5, 1.0, 2.0, 4.0],
690  check_resnames = True):
691  """ Helper that generates data structures from model
692  """
693 
694  # initialize positions with values far in nirvana. If a position is not
695  # set, it should be far away from any position in model.
696  max_pos = model.bounds.GetMax()
697  max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
698  max_coordinate += 42 * max(thresholds)
699  pos = np.ones((self.n_atomsn_atoms, 3), dtype=np.float32) * max_coordinate
700 
701  # for each scored residue in model a list of indices describing the
702  # atoms from the reference that should be there
703  res_ref_atom_indices = list()
704 
705  # for each scored residue in model a list of indices of atoms that are
706  # actually there
707  res_atom_indices = list()
708 
709  # and the respective hash codes
710  # this is required if add_mdl_contacts is set to True
711  res_atom_hashes = list()
712 
713  # indices of the scored residues
714  res_indices = list()
715 
716  # respective residue indices in reference
717  ref_res_indices = list()
718 
719  # Will contain one element per symmetry group
720  symmetries = list()
721 
722  current_model_res_idx = -1
723  for ch in model.chains:
724  model_ch_name = ch.GetName()
725  if model_ch_name not in chain_mapping:
726  current_model_res_idx += len(ch.residues)
727  continue # additional model chain which is not mapped
728  target_ch_name = chain_mapping[model_ch_name]
729 
730  rnums = self._GetChainRNums_GetChainRNums(ch, residue_mapping, model_ch_name,
731  target_ch_name)
732 
733  for r, rnum in zip(ch.residues, rnums):
734  current_model_res_idx += 1
735  res_mapper_key = (target_ch_name, rnum)
736  if res_mapper_key not in self.res_mapperres_mapper:
737  continue
738  r_idx = self.res_mapperres_mapper[res_mapper_key]
739  if check_resnames and r.name != self.compound_namescompound_names[r_idx]:
740  raise RuntimeError(
741  f"Residue name mismatch for {r}, "
742  f" expect {self.compound_names[r_idx]}"
743  )
744  res_start_idx = self.res_start_indicesres_start_indices[r_idx]
745  rname = self.compound_namescompound_names[r_idx]
746  anames = self.compound_anamescompound_anames[rname]
747  atoms = [r.FindAtom(aname) for aname in anames]
748  res_ref_atom_indices.append(
749  list(range(res_start_idx, res_start_idx + len(anames)))
750  )
751  res_atom_indices.append(list())
752  res_atom_hashes.append(list())
753  res_indices.append(current_model_res_idx)
754  ref_res_indices.append(r_idx)
755  for a_idx, a in enumerate(atoms):
756  if a.IsValid():
757  p = a.GetPos()
758  pos[res_start_idx + a_idx][0] = p[0]
759  pos[res_start_idx + a_idx][1] = p[1]
760  pos[res_start_idx + a_idx][2] = p[2]
761  res_atom_indices[-1].append(res_start_idx + a_idx)
762  res_atom_hashes[-1].append(a.handle.GetHashCode())
763  if rname in self.compound_symmetric_atomscompound_symmetric_atoms:
764  sym_indices = list()
765  for sym_tuple in self.compound_symmetric_atomscompound_symmetric_atoms[rname]:
766  a_one = atoms[sym_tuple[0]]
767  a_two = atoms[sym_tuple[1]]
768  if a_one.IsValid() and a_two.IsValid():
769  sym_indices.append(
770  (
771  res_start_idx + sym_tuple[0],
772  res_start_idx + sym_tuple[1],
773  )
774  )
775  if len(sym_indices) > 0:
776  symmetries.append(sym_indices)
777 
778  return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes,
779  res_indices, ref_res_indices, symmetries)
780 
781 
782  def _GetExtraModelChainPenalty(self, model, chain_mapping):
783  """Counts n distances in extra model chains to be added as penalty
784  """
785  penalty = 0
786  for chain in model.chains:
787  ch_name = chain.GetName()
788  if ch_name not in chain_mapping:
789  sm = self.symmetry_settingssymmetry_settings
790  mdl_sel = model.Select(f"cname={mol.QueryQuoteName(ch_name)}")
791  dummy_scorer = lDDTScorer(mdl_sel, self.compound_libcompound_lib,
792  symmetry_settings = sm,
793  inclusion_radius = self.inclusion_radiusinclusion_radius,
794  bb_only = self.bb_onlybb_only)
795  penalty += sum([len(x) for x in dummy_scorer.ref_indices])
796  return penalty
797 
798  def _GetChainRNums(self, ch, residue_mapping, model_ch_name,
799  target_ch_name):
800  """Map residues in model chain to target residues
801 
802  There are two options: one is simply using residue numbers,
803  the other is a custom mapping as given in *residue_mapping*
804  """
805  if residue_mapping and model_ch_name in residue_mapping:
806  # extract residue numbers from target chain
807  ch_idx = self.chain_nameschain_names.index(target_ch_name)
808  start_idx = self.chain_res_start_indiceschain_res_start_indices[ch_idx]
809  if ch_idx < len(self.chain_nameschain_names) - 1:
810  end_idx = self.chain_res_start_indiceschain_res_start_indices[ch_idx+1]
811  else:
812  end_idx = len(self.compound_namescompound_names)
813  target_rnums = self.res_resnumsres_resnums[start_idx:end_idx]
814  # get sequences from alignment and do consistency checks
815  target_seq = residue_mapping[model_ch_name].GetSequence(0)
816  model_seq = residue_mapping[model_ch_name].GetSequence(1)
817  if len(target_seq.GetGaplessString()) != len(target_rnums):
818  raise RuntimeError(f"Try to perform residue mapping for "
819  f"model chain {model_ch_name} which "
820  f"maps to {target_ch_name} in target. "
821  f"Target sequence in alignment suggests "
822  f"{len(target_seq.GetGaplessString())} "
823  f"residues but {len(target_rnums)} are "
824  f"expected.")
825  if len(model_seq.GetGaplessString()) != len(ch.residues):
826  raise RuntimeError(f"Try to perform residue mapping for "
827  f"model chain {model_ch_name} which "
828  f"maps to {target_ch_name} in target. "
829  f"Model sequence in alignment suggests "
830  f"{len(model_seq.GetGaplessString())} "
831  f"residues but {len(ch.residues)} are "
832  f"expected.")
833  rnums = list()
834  target_idx = -1
835  for col in residue_mapping[model_ch_name]:
836  if col[0] != '-':
837  target_idx += 1
838  # handle match
839  if col[0] != '-' and col[1] != '-':
840  rnums.append(target_rnums[target_idx])
841  # insertion in model adds None to rnum
842  if col[0] == '-' and col[1] != '-':
843  rnums.append(None)
844  else:
845  rnums = [r.GetNumber() for r in ch.residues]
846 
847  return rnums
848 
849 
850  def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings,
851  seqres_mapping, bb_only):
852  """Sets target related lDDTScorer members defined in constructor
853 
854  No distance related members - see _SetupDistances
855  """
856  residue_numbers = self._GetTargetResidueNumbers_GetTargetResidueNumbers(self.targettarget,
857  seqres_mapping)
858  current_idx = 0
859  positions = list()
860  for chain in self.targettarget.chains:
861  ch_name = chain.GetName()
862  self.chain_nameschain_names.append(ch_name)
863  self.chain_start_indiceschain_start_indices.append(current_idx)
864  self.chain_res_start_indiceschain_res_start_indices.append(len(self.compound_namescompound_names))
865  for r, rnum in zip(chain.residues, residue_numbers[ch_name]):
866  if r.name not in self.compound_anamescompound_anames:
867  # sets compound info in self.compound_anames and
868  # self.compound_symmetric_atoms
869  self._SetupCompound_SetupCompound(r, compound_lib, custom_compounds,
870  symmetry_settings, bb_only)
871 
872  self.res_start_indicesres_start_indices.append(current_idx)
873  self.res_mapperres_mapper[(ch_name, rnum)] = len(self.compound_namescompound_names)
874  self.compound_namescompound_names.append(r.name)
875  self.res_resnumsres_resnums.append(rnum)
876 
877  atoms = [r.FindAtom(an) for an in self.compound_anamescompound_anames[r.name]]
878  for a in atoms:
879  if a.IsValid():
880  self.atom_indicesatom_indices[a.handle.GetHashCode()] = current_idx
881  p = a.GetPos()
882  positions.append(np.asarray([p[0], p[1], p[2]],
883  dtype=np.float32))
884  else:
885  positions.append(np.zeros(3, dtype=np.float32))
886  current_idx += 1
887 
888  if r.name in self.compound_symmetric_atomscompound_symmetric_atoms:
889  for sym_tuple in self.compound_symmetric_atomscompound_symmetric_atoms[r.name]:
890  for a_idx in sym_tuple:
891  a = atoms[a_idx]
892  if a.IsValid():
893  hashcode = a.handle.GetHashCode()
894  self.symmetric_atomssymmetric_atoms.add(
895  self.atom_indicesatom_indices[hashcode]
896  )
897  self.positionspositions = np.vstack(positions)
898  self.n_atomsn_atoms = current_idx
899 
900  def _GetTargetResidueNumbers(self, target, seqres_mapping):
901  """Returns residue numbers for each chain in target as dict
902 
903  They're either directly extracted from the raw residue number
904  from the structure or from user provided alignments
905  """
906  residue_numbers = dict()
907  for ch in target.chains:
908  ch_name = ch.GetName()
909  rnums = list()
910  if ch_name in seqres_mapping:
911  seqres = seqres_mapping[ch_name].GetSequence(0).GetString()
912  atomseq = seqres_mapping[ch_name].GetSequence(1).GetString()
913  # SEQRES must not contain gaps
914  if "-" in seqres:
915  raise RuntimeError(
916  "SEQRES in seqres_mapping must not " "contain gaps"
917  )
918  atomseq_from_chain = [r.one_letter_code for r in ch.residues]
919  if atomseq.replace("-", "") != atomseq_from_chain:
920  raise RuntimeError(
921  "ATOMSEQ in seqres_mapping must match "
922  "raw sequence extracted from chain "
923  "residues"
924  )
925  rnum = 0
926  for seqres_olc, atomseq_olc in zip(seqres, atomseq):
927  if seqres_olc != "-":
928  rnum += 1
929  if atomseq_olc != "-":
930  if seqres_olc != atomseq_olc:
931  raise RuntimeError(
932  f"Residue with number {rnum} in "
933  f"chain {ch_name} has SEQRES "
934  f"ATOMSEQ mismatch"
935  )
936  rnums.append(mol.ResNum(rnum))
937  else:
938  rnums = [r.GetNumber() for r in ch.residues]
939  assert len(rnums) == len(ch.residues)
940  residue_numbers[ch_name] = rnums
941  return residue_numbers
942 
943  def _SetupCompound(self, r, compound_lib, custom_compounds,
944  symmetry_settings, bb_only):
945  """fill self.compound_anames/self.compound_symmetric_atoms
946  """
947  if bb_only:
948  # throw away compound_lib info
949  if r.chem_class.IsPeptideLinking():
950  self.compound_anamescompound_anames[r.name] = ["CA"]
951  elif r.chem_class.IsNucleotideLinking():
952  self.compound_anamescompound_anames[r.name] = ["C3'"]
953  else:
954  raise RuntimeError(f"Only support amino acids and nucleotides "
955  f"if bb_only is True, failed with {str(r)}")
956  self.compound_symmetric_atomscompound_symmetric_atoms[r.name] = list()
957  else:
958  atom_names = list()
959  symmetric_atoms = list()
960  if custom_compounds is not None and r.GetName() in custom_compounds:
961  atom_names = list(custom_compounds[r.GetName()].atom_names)
962  else:
963  compound = compound_lib.FindCompound(r.name)
964  if compound is None:
965  raise RuntimeError(f"no entry for {r} in compound_lib")
966  for atom_spec in compound.GetAtomSpecs():
967  if atom_spec.element not in ["H", "D"]:
968  atom_names.append(atom_spec.name)
969  if r.name in symmetry_settings.symmetric_compounds:
970  for pair in symmetry_settings.symmetric_compounds[r.name]:
971  try:
972  a = atom_names.index(pair[0])
973  b = atom_names.index(pair[1])
974  except:
975  msg = f"Could not find symmetric atoms "
976  msg += f"({pair[0]}, {pair[1]}) for {r.name} "
977  msg += f"as specified in SymmetrySettings in "
978  msg += f"compound from component dictionary. "
979  msg += f"Atoms in compound: {atom_names}"
980  raise RuntimeError(msg)
981  symmetric_atoms.append((a, b))
982  self.compound_anamescompound_anames[r.name] = atom_names
983  if len(symmetric_atoms) > 0:
984  self.compound_symmetric_atomscompound_symmetric_atoms[r.name] = symmetric_atoms
985 
986  def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes,
987  ref_indices, ref_distances, no_interchain,
988  no_intrachain):
989 
990  # buildup an index map for mdl atoms that are also present in target
991  in_target = np.zeros(self.n_atomsn_atoms, dtype=bool)
992  for i in self.atom_indicesatom_indices.values():
993  in_target[i] = True
994  mdl_atom_indices = dict()
995  for at_indices, at_hashes in zip(res_atom_indices, res_atom_hashes):
996  for i, h in zip(at_indices, at_hashes):
997  if in_target[i]:
998  mdl_atom_indices[h] = i
999 
1000  # get contacts for mdl - the contacts are only from atom pairs that
1001  # are also present in target, as we only provide the respective
1002  # hashes in mdl_atom_indices
1003  mdl_ref_indices, mdl_ref_distances = \
1004  lDDTScorer._SetupDistances(model, self.n_atomsn_atoms, mdl_atom_indices,
1005  self.inclusion_radiusinclusion_radius)
1006  if no_interchain:
1007  mdl_ref_indices, mdl_ref_distances = \
1008  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
1009  self.chain_start_indiceschain_start_indices,
1010  mdl_ref_indices,
1011  mdl_ref_distances)
1012 
1013  if no_intrachain:
1014  mdl_ref_indices, mdl_ref_distances = \
1015  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
1016  self.chain_start_indiceschain_start_indices,
1017  mdl_ref_indices,
1018  mdl_ref_distances)
1019 
1020  # update ref_indices/ref_distances => add mdl contacts
1021  for i in range(self.n_atomsn_atoms):
1022  mask = np.isin(mdl_ref_indices[i], ref_indices[i],
1023  assume_unique=True, invert=True)
1024  if np.sum(mask) > 0:
1025  added_mdl_indices = mdl_ref_indices[i][mask]
1026  ref_indices[i] = np.append(ref_indices[i],
1027  added_mdl_indices)
1028 
1029  # distances need to be recomputed from ref positions
1030  tmp = self.positionspositions.take(added_mdl_indices, axis=0)
1031  np.subtract(tmp, self.positionspositions[i][None, :], out=tmp)
1032  np.square(tmp, out=tmp)
1033  tmp = tmp.sum(axis=1)
1034  np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1035  ref_distances[i] = np.append(ref_distances[i], tmp)
1036 
1037  return (ref_indices, ref_distances)
1038 
1039 
1040 
1041  @staticmethod
1042  def _SetupDistances(structure, n_atoms, atom_index_mapping,
1043  inclusion_radius):
1044 
1045  """Compute distance related members of lDDTScorer
1046 
1047  Brute force all vs all distance computation kills lDDT for large
1048  complexes. Instead of building some KD tree data structure, we make use
1049  of expected spatial proximity of atoms in the same chain. Distances are
1050  computed as follows:
1051 
1052  - process each chain individually
1053  - perform crude collision detection
1054  - process potentially interacting chain pairs
1055  - concatenate distances from all processing steps
1056  """
1057  ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1058  ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1059 
1060  indices = [list() for _ in range(n_atoms)]
1061  distances = [list() for _ in range(n_atoms)]
1062  per_chain_pos = list()
1063  per_chain_indices = list()
1064 
1065  # Process individual chains
1066  for ch in structure.chains:
1067  pos_list = list()
1068  atom_indices = list()
1069  mask_start = list()
1070  mask_end = list()
1071  r_start_idx = 0
1072  for r_idx, r in enumerate(ch.residues):
1073  n_valid_atoms = 0
1074  for a in r.atoms:
1075  hash_code = a.handle.GetHashCode()
1076  if hash_code in atom_index_mapping:
1077  p = a.GetPos()
1078  pos_list.append(np.asarray([p[0], p[1], p[2]]))
1079  atom_indices.append(atom_index_mapping[hash_code])
1080  n_valid_atoms += 1
1081  mask_start.extend([r_start_idx] * n_valid_atoms)
1082  mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms)
1083  r_start_idx += n_valid_atoms
1084 
1085  if len(pos_list) == 0:
1086  # nothing to do...
1087  continue
1088 
1089  pos = np.vstack(pos_list)
1090  atom_indices = np.asarray(atom_indices)
1091  dists = cdist(pos, pos)
1092 
1093  # apply masks
1094  far_away = 2 * inclusion_radius
1095  for idx in range(atom_indices.shape[0]):
1096  dists[idx, range(mask_start[idx], mask_end[idx])] = far_away
1097 
1098  # fish out and store close atoms within inclusion radius
1099  within_mask = dists < inclusion_radius
1100  for idx in range(atom_indices.shape[0]):
1101  indices_to_append = atom_indices[within_mask[idx,:]]
1102  if indices_to_append.shape[0] > 0:
1103  full_at_idx = atom_indices[idx]
1104  indices[full_at_idx].append(indices_to_append)
1105  distances[full_at_idx].append(dists[idx, within_mask[idx,:]])
1106 
1107  per_chain_pos.append(pos)
1108  per_chain_indices.append(atom_indices)
1109 
1110  # perform crude collision detection
1111  min_pos = [p.min(0) for p in per_chain_pos]
1112  max_pos = [p.max(0) for p in per_chain_pos]
1113  chain_pairs = list()
1114  for idx_one in range(len(per_chain_pos)):
1115  for idx_two in range(idx_one + 1, len(per_chain_pos)):
1116  if np.max(min_pos[idx_one] - max_pos[idx_two]) > inclusion_radius:
1117  continue
1118  if np.max(min_pos[idx_two] - max_pos[idx_one]) > inclusion_radius:
1119  continue
1120  chain_pairs.append((idx_one, idx_two))
1121 
1122  # process potentially interacting chains
1123  for pair in chain_pairs:
1124  dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1125  within = dists <= inclusion_radius
1126 
1127  # process pair[0]
1128  tmp = within.sum(axis=1)
1129  for idx in range(tmp.shape[0]):
1130  if tmp[idx] > 0:
1131  # even though not being a strict requirement, we perform an
1132  # insertion here such that the indices for each atom will be
1133  # sorted after the hstack operation
1134  at_idx = per_chain_indices[pair[0]][idx]
1135  indices_to_insert = per_chain_indices[pair[1]][within[idx,:]]
1136  distances_to_insert = dists[idx, within[idx, :]]
1137  insertion_idx = len(indices[at_idx])
1138  for i in range(insertion_idx):
1139  if indices_to_insert[0] > indices[at_idx][i][0]:
1140  insertion_idx = i
1141  break
1142  indices[at_idx].insert(insertion_idx, indices_to_insert)
1143  distances[at_idx].insert(insertion_idx, distances_to_insert)
1144 
1145  # process pair[1]
1146  tmp = within.sum(axis=0)
1147  for idx in range(tmp.shape[0]):
1148  if tmp[idx] > 0:
1149  # even though not being a strict requirement, we perform an
1150  # insertion here such that the indices for each atom will be
1151  # sorted after the hstack operation
1152  at_idx = per_chain_indices[pair[1]][idx]
1153  indices_to_insert = per_chain_indices[pair[0]][within[:, idx]]
1154  distances_to_insert = dists[within[:, idx], idx]
1155  insertion_idx = len(indices[at_idx])
1156  for i in range(insertion_idx):
1157  if indices_to_insert[0] > indices[at_idx][i][0]:
1158  insertion_idx = i
1159  break
1160  indices[at_idx].insert(insertion_idx, indices_to_insert)
1161  distances[at_idx].insert(insertion_idx, distances_to_insert)
1162 
1163  # concatenate distances from all processing steps
1164  for at_idx in range(n_atoms):
1165  if len(indices[at_idx]) > 0:
1166  ref_indices[at_idx] = np.hstack(indices[at_idx])
1167  ref_distances[at_idx] = np.hstack(distances[at_idx])
1168 
1169  return (ref_indices, ref_distances)
1170 
1171  @staticmethod
1172  def _SetupDistancesSC(n_atoms, chain_start_indices,
1173  ref_indices, ref_distances):
1174  """Select subset of contacts only covering intra-chain contacts
1175  """
1176  # init
1177  ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1178  ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1179 
1180  n_chains = len(chain_start_indices)
1181  for ch_idx in range(n_chains):
1182  chain_s = chain_start_indices[ch_idx]
1183  chain_e = n_atoms
1184  if ch_idx + 1 < n_chains:
1185  chain_e = chain_start_indices[ch_idx+1]
1186  for i in range(chain_s, chain_e):
1187  if len(ref_indices[i]) > 0:
1188  intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
1189  ref_indices[i]<chain_e))[0]
1190  ref_indices_sc[i] = ref_indices[i][intra_idx]
1191  ref_distances_sc[i] = ref_distances[i][intra_idx]
1192 
1193  return (ref_indices_sc, ref_distances_sc)
1194 
1195  @staticmethod
1196  def _SetupDistancesIC(n_atoms, chain_start_indices,
1197  ref_indices, ref_distances):
1198  """Select subset of contacts only covering inter-chain contacts
1199  """
1200  # init
1201  ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1202  ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1203 
1204  n_chains = len(chain_start_indices)
1205  for ch_idx in range(n_chains):
1206  chain_s = chain_start_indices[ch_idx]
1207  chain_e = n_atoms
1208  if ch_idx + 1 < n_chains:
1209  chain_e = chain_start_indices[ch_idx+1]
1210  for i in range(chain_s, chain_e):
1211  if len(ref_indices[i]) > 0:
1212  inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s,
1213  ref_indices[i]>=chain_e))[0]
1214  ref_indices_ic[i] = ref_indices[i][inter_idx]
1215  ref_distances_ic[i] = ref_distances[i][inter_idx]
1216 
1217  return (ref_indices_ic, ref_distances_ic)
1218 
1219  @staticmethod
1220  def _NonSymDistances(n_atoms, symmetric_atoms, ref_indices, ref_distances):
1221  """Transfer indices/distances of non-symmetric atoms and return
1222  """
1223 
1224  sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
1225  sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
1226 
1227  for idx in symmetric_atoms:
1228  indices = list()
1229  distances = list()
1230  for i, d in zip(ref_indices[idx], ref_distances[idx]):
1231  if i not in symmetric_atoms:
1232  indices.append(i)
1233  distances.append(d)
1234  sym_ref_indices[idx] = indices
1235  sym_ref_distances[idx] = np.asarray(distances)
1236 
1237  return (sym_ref_indices, sym_ref_distances)
1238 
1239  def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances):
1240  """Computes number of distance differences within given thresholds
1241 
1242  returns np.array with len(thresholds) elements
1243  """
1244  a_p = pos[atom_idx, :]
1245  tmp = pos.take(ref_indices[atom_idx], axis=0)
1246  np.subtract(tmp, a_p[None, :], out=tmp)
1247  np.square(tmp, out=tmp)
1248  tmp = tmp.sum(axis=1)
1249  np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1250  np.subtract(ref_distances[atom_idx], tmp, out=tmp)
1251  np.absolute(tmp, out=tmp) # absolute dist diffs
1252  return np.asarray([(tmp <= thresh).sum() for thresh in thresholds],
1253  dtype=np.int32)
1254 
1255  def _EvalAtoms(
1256  self, pos, atom_indices, thresholds, ref_indices, ref_distances
1257  ):
1258  """Calls _EvalAtom for several atoms and sums up the computed number
1259  of distance differences within given thresholds
1260 
1261  returns numpy matrix of shape (n_atoms, len(threshold))
1262  """
1263  conserved = np.zeros((len(atom_indices), len(thresholds)),
1264  dtype=np.int32)
1265  for a_idx, a in enumerate(atom_indices):
1266  conserved[a_idx, :] = self._EvalAtom_EvalAtom(pos, a, thresholds,
1267  ref_indices, ref_distances)
1268  return conserved
1269 
1270  def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices,
1271  ref_distances):
1272  """Calls _EvalAtoms for a bunch of residues
1273 
1274  residues are defined in *res_atom_indices* as lists of atom indices
1275  returns numpy matrix of shape (n_residues, len(thresholds)).
1276  """
1277  conserved = np.zeros((len(res_atom_indices), len(thresholds)),
1278  dtype=np.int32)
1279  for rai_idx, rai in enumerate(res_atom_indices):
1280  conserved[rai_idx,:] = np.sum(self._EvalAtoms_EvalAtoms(pos, rai, thresholds,
1281  ref_indices, ref_distances), axis=0)
1282  return conserved
1283 
1284  def _ProcessSequenceSeparation(self):
1285  if self.sequence_separationsequence_separation != 0:
1286  raise NotImplementedError("Congratulations! You're the first one "
1287  "requesting a non-default "
1288  "sequence_separation in the new and "
1289  "awesome lDDT implementation. A crate of "
1290  "beer for Gabriel and he'll implement "
1291  "it.")
1292 
1293  def _GetNExp(self, atom_idx, ref_indices):
1294  """Returns number of close atoms around one or several atoms
1295  """
1296  if isinstance(atom_idx, int):
1297  return len(ref_indices[atom_idx])
1298  elif isinstance(atom_idx, list):
1299  return sum([len(ref_indices[idx]) for idx in atom_idx])
1300  else:
1301  raise RuntimeError("invalid input type")
1302 
1303  def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices,
1304  sym_ref_distances):
1305  """Swaps symmetric positions in-place in order to maximize lDDT scores
1306  towards non-symmetric atoms.
1307  """
1308  for sym in symmetries:
1309 
1310  atom_indices = list()
1311  for sym_tuple in sym:
1312  atom_indices += [sym_tuple[0], sym_tuple[1]]
1313  tot = self._GetNExp_GetNExp(atom_indices, sym_ref_indices)
1314 
1315  if tot == 0:
1316  continue # nothing to do
1317 
1318  # score as is
1319  sym_one_conserved = self._EvalAtoms_EvalAtoms(
1320  pos,
1321  atom_indices,
1322  thresholds,
1323  sym_ref_indices,
1324  sym_ref_distances,
1325  )
1326 
1327  # switch positions and score again
1328  for pair in sym:
1329  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1330 
1331  sym_two_conserved = self._EvalAtoms_EvalAtoms(
1332  pos,
1333  atom_indices,
1334  thresholds,
1335  sym_ref_indices,
1336  sym_ref_distances,
1337  )
1338 
1339  sym_one_score = np.sum(sym_one_conserved) / (len(thresholds) * tot)
1340  sym_two_score = np.sum(sym_two_conserved) / (len(thresholds) * tot)
1341 
1342  if sym_one_score >= sym_two_score:
1343  # switch back, initial positions were better or equal
1344  # for the equal case: we still switch back to reproduce the old
1345  # lDDT behaviour
1346  for pair in sym:
1347  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
def __init__(self, atom_names)
Definition: lddt.py:29
def AddSymmetricCompound(self, name, symmetric_atoms)
Definition: lddt.py:64
def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices, ref_distances)
Definition: lddt.py:1271
def _SetupCompound(self, r, compound_lib, custom_compounds, symmetry_settings, bb_only)
Definition: lddt.py:944
def _ProcessModel(self, model, chain_mapping, residue_mapping=None, thresholds=[0.5, 1.0, 2.0, 4.0], check_resnames=True)
Definition: lddt.py:690
def lDDT(self, model, thresholds=[0.5, 1.0, 2.0, 4.0], local_lddt_prop=None, local_contact_prop=None, chain_mapping=None, no_interchain=False, no_intrachain=False, penalize_extra_chains=False, residue_mapping=None, return_dist_test=False, check_resnames=True, add_mdl_contacts=False, interaction_data=None)
Definition: lddt.py:441
def _GetChainRNums(self, ch, residue_mapping, model_ch_name, target_ch_name)
Definition: lddt.py:799
def _ProcessSequenceSeparation(self)
Definition: lddt.py:1284
def sym_ref_distances(self)
Definition: lddt.py:348
def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances)
Definition: lddt.py:1304
def ref_distances_ic(self)
Definition: lddt.py:406
def _GetTargetResidueNumbers(self, target, seqres_mapping)
Definition: lddt.py:900
def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances)
Definition: lddt.py:1239
def sym_ref_distances_ic(self)
Definition: lddt.py:426
def sym_ref_distances_sc(self)
Definition: lddt.py:386
def sym_ref_indices_ic(self)
Definition: lddt.py:416
def GetNChainContacts(self, target_chain, no_interchain=False)
Definition: lddt.py:665
def sym_ref_indices_sc(self)
Definition: lddt.py:376
def ref_distances_sc(self)
Definition: lddt.py:366
def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes, ref_indices, ref_distances, no_interchain, no_intrachain)
Definition: lddt.py:988
def _GetNExp(self, atom_idx, ref_indices)
Definition: lddt.py:1293
def __init__(self, target, compound_lib=None, custom_compounds=None, inclusion_radius=15, sequence_separation=0, symmetry_settings=None, seqres_mapping=dict(), bb_only=False)
Definition: lddt.py:212
def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings, seqres_mapping, bb_only)
Definition: lddt.py:851
def _GetExtraModelChainPenalty(self, model, chain_mapping)
Definition: lddt.py:782
def _EvalAtoms(self, pos, atom_indices, thresholds, ref_indices, ref_distances)
Definition: lddt.py:1257
def GetDefaultSymmetrySettings()
Definition: lddt.py:82
def cdist(p1, p2)
Definition: lddt.py:11