OpenStructure
lddt.py
Go to the documentation of this file.
1 import itertools
2 import numpy as np
3 
4 from ost import mol
5 from ost import conop
6 
7 # use cdist of scipy, fallback to (slower) numpy implementation if scipy is not
8 # available
9 try:
10  from scipy.spatial.distance import cdist
11 except:
12  def cdist(p1, p2):
13  x2 = np.sum(p1**2, axis=1) # (m)
14  y2 = np.sum(p2**2, axis=1) # (n)
15  xy = np.matmul(p1, p2.T) # (m, n)
16  x2 = x2.reshape(-1, 1)
17  return np.sqrt(x2 - 2*xy + y2) # (m, n)
18 
19 def blockwise_cdist(A, B, block_size=1000):
20  """ Memory efficient cdist implementation that performs blockwise operations
21 
22  scipy cdist uses 64 bit floats (double) which can scratch at the upper
23  memory end for most machines when number of positions become larger.
24  E.g. ~4000 residues might for example have 35000 atom positions. That's
25  Almost 10GB to hold all pairwise distances in 64bit floats. This function
26  calls cdist blockwise and stores the results in a 32bit float matrix.
27 
28  This function is adapted from chatgpt output
29  """
30  A = A.astype(np.float32)
31  B = B.astype(np.float32)
32  M, N = A.shape[0], B.shape[0]
33  D = np.empty((M, N), dtype=np.float32) # Output in float32 to save memory
34  for i in range(0, M, block_size):
35  A_block = A[i:i+block_size]
36  D[i:i+block_size, :] = cdist(A_block, B).astype(np.float32)
37  return D
38 
40  """ Defines atoms for custom compounds
41 
42  LDDT requires the reference atoms of a compound which are typically
43  extracted from a :class:`ost.conop.CompoundLib`. This lightweight
44  container allows to handle arbitrary compounds which are not
45  necessarily in the compound library.
46 
47  :param atom_names: Names of atoms of custom compound
48  :type atom_names: :class:`list` of :class:`str`
49  """
50  def __init__(self, atom_names):
51  self.atom_namesatom_names = atom_names
52 
53  @staticmethod
54  def FromResidue(res):
55  """ Construct custom compound from residue
56 
57  :param res: Residue from which reference atom names are extracted,
58  hydrogen/deuterium atoms are filtered out
59  :type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle`
60  :returns: :class:`CustomCompound`
61  """
62  at_names = [a.name for a in res.atoms if a.element not in ["H", "D"]]
63  if len(at_names) != len(set(at_names)):
64  raise RuntimeError("Duplicate atoms detected in CustomCompound")
65  compound = CustomCompound(at_names)
66  return compound
67 
69  """Container for symmetric compounds
70 
71  LDDT considers symmetries and selects the one resulting in the highest
72  possible score.
73 
74  A symmetry is defined as a renaming operation on one or more atoms that
75  leads to a chemically equivalent residue. Example would be OD1 and OD2 in
76  ASP => renaming OD1 to OD2 and vice versa gives a chemically equivalent
77  residue.
78 
79  Use :func:`AddSymmetricCompound` to define a symmetry which can then
80  directly be accessed through the *symmetric_compounds* member.
81  """
82  def __init__(self):
83  self.symmetric_compoundssymmetric_compounds = dict()
84 
85  def AddSymmetricCompound(self, name, symmetric_atoms):
86  """Adds symmetry for compound with *name*
87 
88  :param name: Name of compound with symmetry
89  :type name: :class:`str`
90  :param symmetric_atoms: Pairs of atom names that define renaming
91  operation, i.e. after applying all switches
92  defined in the tuples, the resulting residue
93  should be chemically equivalent. Atom names
94  must refer to the PDB component dictionary.
95  :type symmetric_atoms: :class:`list` of :class:`tuple`
96  """
97  for pair in symmetric_atoms:
98  if len(pair) != 2:
99  raise RuntimeError("Expect pairs when defining symmetries")
100  self.symmetric_compoundssymmetric_compounds[name] = symmetric_atoms
101 
102 
104  """Constructs and returns :class:`SymmetrySettings` object for natural amino
105  acids
106  """
107  symmetry_settings = SymmetrySettings()
108 
109  # ASP
110  symmetry_settings.AddSymmetricCompound("ASP", [("OD1", "OD2")])
111 
112  # GLU
113  symmetry_settings.AddSymmetricCompound("GLU", [("OE1", "OE2")])
114 
115  # LEU
116  symmetry_settings.AddSymmetricCompound("LEU", [("CD1", "CD2")])
117 
118  # VAL
119  symmetry_settings.AddSymmetricCompound("VAL", [("CG1", "CG2")])
120 
121  # ARG
122  symmetry_settings.AddSymmetricCompound("ARG", [("NH1", "NH2")])
123 
124  # PHE
125  symmetry_settings.AddSymmetricCompound(
126  "PHE", [("CD1", "CD2"), ("CE1", "CE2")]
127  )
128 
129  # TYR
130  symmetry_settings.AddSymmetricCompound(
131  "TYR", [("CD1", "CD2"), ("CE1", "CE2")]
132  )
133 
134  # nucleotides
135  nuc_names = ["A", "C", "G", "U", "DA", "DC", "DG", "DT"]
136  for nuc_name in nuc_names:
137  symmetry_settings.AddSymmetricCompound(
138  nuc_name, [("OP1","OP2")]
139  )
140 
141  return symmetry_settings
142 
143 
145  """LDDT scorer object for a specific target
146 
147  Sets up everything to score models of that target. LDDT (local distance
148  difference test) is defined as fraction of pairwise distances which exhibit
149  a difference < threshold when considering target and model. In case of
150  multiple thresholds, the average is returned. See
151 
152  V. Mariani, M. Biasini, A. Barbato, T. Schwede, lDDT : A local
153  superposition-free score for comparing protein structures and models using
154  distance difference tests, Bioinformatics, 2013
155 
156  :param target: The target
157  :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
158  :param compound_lib: Compound library from which a compound for each residue
159  is extracted based on its name. Uses
160  :func:`ost.conop.GetDefaultLib` if not given, raises
161  if this returns no valid compound library. Atoms
162  defined in the compound are searched in the residue and
163  build the reference for scoring. If the residue has
164  atoms with names ["A", "B", "C"] but the corresponding
165  compound only has ["A", "B"], "A" and "B" are
166  considered for scoring. If the residue has atoms
167  ["A", "B"] but the compound has ["A", "B", "C"], "C" is
168  considered missing and does not influence scoring, even
169  if present in the model.
170  :param custom_compounds: Custom compounds defining reference atoms. If
171  given, *custom_compounds* take precedent over
172  *compound_lib*.
173  :type custom_compounds: :class:`dict` with residue names (:class:`str`) as
174  key and :class:`CustomCompound` as value.
175  :type compound_lib: :class:`ost.conop.CompoundLib`
176  :param inclusion_radius: All pairwise distances < *inclusion_radius* are
177  considered for scoring
178  :type inclusion_radius: :class:`float`
179  :param sequence_separation: Only pairwise distances between atoms of
180  residues which are further apart than this
181  threshold are considered. Residue distance is
182  based on resnum. The default (0) considers all
183  pairwise distances except intra-residue
184  distances.
185  :type sequence_separation: :class:`int`
186  :param symmetry_settings: Define residues exhibiting internal symmetry, uses
187  :func:`GetDefaultSymmetrySettings` if not given.
188  :type symmetry_settings: :class:`SymmetrySettings`
189  :param seqres_mapping: Mapping of model residues at the scoring stage
190  happens with residue numbers defining their location
191  in a reference sequence (SEQRES) using one based
192  indexing. If the residue numbers in *target* don't
193  correspond to that SEQRES, you can specify the
194  mapping manually. You can provide a dictionary to
195  specify a reference sequence (SEQRES) for one or more
196  chain(s). Key: chain name, value: alignment
197  (seq1: SEQRES, seq2: sequence of residues in chain).
198  Example: The residues in a chain with name "A" have
199  sequence "YEAH" and residue numbers [42,43,44,45].
200  You can provide an alignment with seq1 "``HELLYEAH``"
201  and seq2 "``----YEAH``". "Y" gets assigned residue
202  number 5, "E" gets assigned 6 and so on no matter
203  what the original residue numbers were.
204  :type seqres_mapping: :class:`dict` (key: :class:`str`, value:
205  :class:`ost.seq.AlignmentHandle`)
206  :param bb_only: Only consider atoms with name "CA" in case of amino acids and
207  "C3'" for Nucleotides. this invalidates *compound_lib*.
208  Raises if any residue in *target* is not
209  `r.chem_class.IsPeptideLinking()` or
210  `r.chem_class.IsNucleotideLinking()`
211  :type bb_only: :class:`bool`
212  :raises: :class:`RuntimeError` if *target* contains compound which is not in
213  *compound_lib*, :class:`RuntimeError` if *symmetry_settings*
214  specifies symmetric atoms that are not present in the according
215  compound in *compound_lib*, :class:`RuntimeError` if
216  *seqres_mapping* is not provided and *target* contains residue
217  numbers with insertion codes or the residue numbers for each chain
218  are not monotonically increasing, :class:`RuntimeError` if
219  *seqres_mapping* is provided but an alignment is invalid
220  (seq1 contains gaps, mismatch in seq1/seq2, seq2 does not match
221  residues in corresponding chains).
222  """
223  def __init__(
224  self,
225  target,
226  compound_lib=None,
227  custom_compounds=None,
228  inclusion_radius=15,
229  sequence_separation=0,
230  symmetry_settings=None,
231  seqres_mapping=dict(),
232  bb_only=False
233  ):
234 
235  self.targettarget = target
236  self.inclusion_radiusinclusion_radius = inclusion_radius
237  self.sequence_separationsequence_separation = sequence_separation
238  if compound_lib is None:
239  compound_lib = conop.GetDefaultLib()
240  if compound_lib is None:
241  raise RuntimeError("No compound_lib given and conop.GetDefaultLib "
242  "returns no valid compound library")
243  self.compound_libcompound_lib = compound_lib
244  self.custom_compoundscustom_compounds = custom_compounds
245  if symmetry_settings is None:
247  else:
248  self.symmetry_settingssymmetry_settings = symmetry_settings
249 
250  # whether to only consider atoms with name "CA" (amino acids) or C3'
251  # (nucleotides), invalidates *compound_lib*
252  self.bb_onlybb_only=bb_only
253 
254  # names of heavy atoms of each unique compound present in *target* as
255  # extracted from *compound_lib*, e.g.
256  # self.compound_anames["GLY"] = ["N", "CA", "C", "O"]
257  self.compound_anamescompound_anames = dict()
258 
259  # stores symmetry information for those compounds as defined in
260  # *symmetry_settings*
261  self.compound_symmetric_atomscompound_symmetric_atoms = dict()
262 
263  # list of len(target.chains) containing all chain names in *target*
264  self.chain_nameschain_names = list()
265 
266  # list of len(target.residues) containing all compound names in *target*
267  self.compound_namescompound_names = list()
268 
269  # list of len(target.residues) defining start pos in internal reference
270  # positions for each residue
271  self.res_start_indicesres_start_indices = list()
272 
273  # list of len(target.residues) defining residue numbers in target
274  self.res_resnumsres_resnums = list()
275 
276  # list of len(target.chains) defining start pos in internal reference
277  # positions for each chain
278  self.chain_start_indiceschain_start_indices = list()
279 
280  # list of len(target.chains) defining start pos in self.compound_names
281  # for each chain
282  self.chain_res_start_indiceschain_res_start_indices = list()
283 
284  # maps residues in *target* to indices in
285  # self.compound_names/self.res_start_indices. A residue gets identified
286  # by a tuple (first element: chain name, second element: residue number,
287  # residue number is either the actual residue number in *target* or
288  # given by *seqres_mapping*)
289  self.res_mapperres_mapper = dict()
290 
291  # number of atoms as specified in compounds. not all are necessarily
292  # covered by structure
293  self.n_atomsn_atoms = None
294 
295  # stores an index for each AtomHandle in *target*
296  # (atom hashcode => index)
297  self.atom_indicesatom_indices = dict()
298 
299  # store indices of all atoms that have symmetry properties
300  self.symmetric_atomssymmetric_atoms = set()
301 
302  # the actual target positions in a numpy array of shape (self.n_atoms,3)
303  self.positionspositions = None
304 
305  # setup members defined above
306  self._SetupEnv_SetupEnv(self.compound_libcompound_lib, self.custom_compoundscustom_compounds,
307  self.symmetry_settingssymmetry_settings, seqres_mapping, self.bb_onlybb_only)
308 
309  # distance related members are lazily computed as they're affected
310  # by different flavours of LDDT (e.g. LDDT including inter-chain
311  # contacts or not etc.)
312 
313  # stores for each atom the other atoms within inclusion_radius
314  self._ref_indices_ref_indices = None
315  # the corresponding distances
316  self._ref_distances_ref_distances = None
317 
318  # The following lists will be sparsely populated. We keep for each
319  # symmetry related atom the distances towards all atoms which are NOT
320  # affected by symmetry. So we can evaluate two symmetric versions
321  # against the fixed stuff later on and select the better scoring one.
322  self._sym_ref_indices_sym_ref_indices = None
323  self._sym_ref_distances_sym_ref_distances = None
324 
325  # exactly the same as above but without interchain contacts
326  # => single-chain (sc)
327  self._ref_indices_sc_ref_indices_sc = None
328  self._ref_distances_sc_ref_distances_sc = None
329  self._sym_ref_indices_sc_sym_ref_indices_sc = None
330  self._sym_ref_distances_sc_sym_ref_distances_sc = None
331 
332  # exactly the same as above but without intrachain contacts
333  # => inter-chain (ic)
334  self._ref_indices_ic_ref_indices_ic = None
335  self._ref_distances_ic_ref_distances_ic = None
336  self._sym_ref_indices_ic_sym_ref_indices_ic = None
337  self._sym_ref_distances_ic_sym_ref_distances_ic = None
338 
339  # input parameter checking
340  self._ProcessSequenceSeparation_ProcessSequenceSeparation()
341 
342  @property
343  def ref_indices(self):
344  if self._ref_indices_ref_indices is None:
345  self._ref_indices_ref_indices, self._ref_distances_ref_distances = \
346  lDDTScorer._SetupDistances(self.targettarget, self.n_atomsn_atoms,
347  self.atom_indicesatom_indices,
348  self.inclusion_radiusinclusion_radius)
349  return self._ref_indices_ref_indices
350 
351  @property
352  def ref_distances(self):
353  if self._ref_distances_ref_distances is None:
354  self._ref_indices_ref_indices, self._ref_distances_ref_distances = \
355  lDDTScorer._SetupDistances(self.targettarget, self.n_atomsn_atoms,
356  self.atom_indicesatom_indices,
357  self.inclusion_radiusinclusion_radius)
358  return self._ref_distances_ref_distances
359 
360  @property
361  def sym_ref_indices(self):
362  if self._sym_ref_indices_sym_ref_indices is None:
363  self._sym_ref_indices_sym_ref_indices, self._sym_ref_distances_sym_ref_distances = \
364  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
365  self.ref_indicesref_indices, self.ref_distancesref_distances)
366  return self._sym_ref_indices_sym_ref_indices
367 
368  @property
369  def sym_ref_distances(self):
370  if self._sym_ref_distances_sym_ref_distances is None:
371  self._sym_ref_indices_sym_ref_indices, self._sym_ref_distances_sym_ref_distances = \
372  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
373  self.ref_indicesref_indices, self.ref_distancesref_distances)
374  return self._sym_ref_distances_sym_ref_distances
375 
376  @property
377  def ref_indices_sc(self):
378  if self._ref_indices_sc_ref_indices_sc is None:
379  self._ref_indices_sc_ref_indices_sc, self._ref_distances_sc_ref_distances_sc = \
380  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
381  self.chain_start_indiceschain_start_indices,
382  self.ref_indicesref_indices,
383  self.ref_distancesref_distances)
384  return self._ref_indices_sc_ref_indices_sc
385 
386  @property
387  def ref_distances_sc(self):
388  if self._ref_distances_sc_ref_distances_sc is None:
389  self._ref_indices_sc_ref_indices_sc, self._ref_distances_sc_ref_distances_sc = \
390  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
391  self.chain_start_indiceschain_start_indices,
392  self.ref_indicesref_indices,
393  self.ref_distancesref_distances)
394  return self._ref_distances_sc_ref_distances_sc
395 
396  @property
398  if self._sym_ref_indices_sc_sym_ref_indices_sc is None:
399  self._sym_ref_indices_sc_sym_ref_indices_sc, self._sym_ref_distances_sc_sym_ref_distances_sc = \
400  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
401  self.symmetric_atomssymmetric_atoms,
402  self.ref_indices_scref_indices_sc,
403  self.ref_distances_scref_distances_sc)
404  return self._sym_ref_indices_sc_sym_ref_indices_sc
405 
406  @property
408  if self._sym_ref_distances_sc_sym_ref_distances_sc is None:
409  self._sym_ref_indices_sc_sym_ref_indices_sc, self._sym_ref_distances_sc_sym_ref_distances_sc = \
410  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
411  self.symmetric_atomssymmetric_atoms,
412  self.ref_indices_scref_indices_sc,
413  self.ref_distances_scref_distances_sc)
414  return self._sym_ref_distances_sc_sym_ref_distances_sc
415 
416  @property
417  def ref_indices_ic(self):
418  if self._ref_indices_ic_ref_indices_ic is None:
419  self._ref_indices_ic_ref_indices_ic, self._ref_distances_ic_ref_distances_ic = \
420  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
421  self.chain_start_indiceschain_start_indices,
422  self.ref_indicesref_indices,
423  self.ref_distancesref_distances)
424  return self._ref_indices_ic_ref_indices_ic
425 
426  @property
427  def ref_distances_ic(self):
428  if self._ref_distances_ic_ref_distances_ic is None:
429  self._ref_indices_ic_ref_indices_ic, self._ref_distances_ic_ref_distances_ic = \
430  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
431  self.chain_start_indiceschain_start_indices,
432  self.ref_indicesref_indices,
433  self.ref_distancesref_distances)
434  return self._ref_distances_ic_ref_distances_ic
435 
436  @property
438  if self._sym_ref_indices_ic_sym_ref_indices_ic is None:
439  self._sym_ref_indices_ic_sym_ref_indices_ic, self._sym_ref_distances_ic_sym_ref_distances_ic = \
440  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
441  self.symmetric_atomssymmetric_atoms,
442  self.ref_indices_icref_indices_ic,
443  self.ref_distances_icref_distances_ic)
444  return self._sym_ref_indices_ic_sym_ref_indices_ic
445 
446  @property
448  if self._sym_ref_distances_ic_sym_ref_distances_ic is None:
449  self._sym_ref_indices_ic_sym_ref_indices_ic, self._sym_ref_distances_ic_sym_ref_distances_ic = \
450  lDDTScorer._NonSymDistances(self.n_atomsn_atoms,
451  self.symmetric_atomssymmetric_atoms,
452  self.ref_indices_icref_indices_ic,
453  self.ref_distances_icref_distances_ic)
454  return self._sym_ref_distances_ic_sym_ref_distances_ic
455 
456  def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
457  local_lddt_prop=None, local_contact_prop=None,
458  chain_mapping=None, no_interchain=False,
459  no_intrachain=False, penalize_extra_chains=False,
460  residue_mapping=None, return_dist_test=False,
461  check_resnames=True, add_mdl_contacts=False,
462  interaction_data=None, set_atom_props=False):
463  """Computes LDDT of *model* - globally and per-residue
464 
465  :param model: Model to be scored - models are preferably scored upon
466  performing stereo-chemistry checks in order to punish for
467  non-sensical irregularities. This must be done separately
468  as a pre-processing step. Target contacts that are not
469  covered by *model* are considered not conserved, thus
470  decreasing LDDT score. This also includes missing model
471  chains or model chains for which no mapping is provided in
472  *chain_mapping*.
473  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
474  :param thresholds: Thresholds of distance differences to be considered
475  as correct - see docs in constructor for more info.
476  default: [0.5, 1.0, 2.0, 4.0]
477  :type thresholds: :class:`list` of :class:`floats`
478  :param local_lddt_prop: If set, per-residue scores will be assigned as
479  generic float property of that name
480  :type local_lddt_prop: :class:`str`
481  :param local_contact_prop: If set, number of expected contacts as well
482  as number of conserved contacts will be
483  assigned as generic int property.
484  Excected contacts will be set as
485  <local_contact_prop>_exp, conserved contacts
486  as <local_contact_prop>_cons. Values
487  are summed over all thresholds.
488  :type local_contact_prop: :class:`str`
489  :param chain_mapping: Mapping of model chains (key) onto target chains
490  (value). This is required if target or model have
491  more than one chain.
492  :type chain_mapping: :class:`dict` with :class:`str` as keys/values
493  :param no_interchain: Whether to exclude interchain contacts
494  :type no_interchain: :class:`bool`
495  :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
496  consider interface related contacts)
497  :type no_intrachain: :class:`bool`
498  :param penalize_extra_chains: Whether to include a fixed penalty for
499  additional chains in the model that are
500  not mapped to the target. ONLY AFFECTS
501  RETURNED GLOBAL SCORE. In detail: adds the
502  number of intra-chain contacts of each
503  extra chain to the expected contacts, thus
504  adding a penalty.
505  :type penalize_extra_chains: :class:`bool`
506  :param residue_mapping: By default, residue mapping is based on residue
507  numbers. That means, a model chain and the
508  respective target chain map to the same
509  underlying reference sequence (SEQRES).
510  Alternatively, you can specify one or
511  several alignment(s) between model and target
512  chains by providing a dictionary. key: Name
513  of chain in model (respective target chain is
514  extracted from *chain_mapping*),
515  value: Alignment with first sequence
516  corresponding to target chain and second
517  sequence to model chain. There is NO reference
518  sequence involved, so the two sequences MUST
519  exactly match the actual residues observed in
520  the respective target/model chains (ATOMSEQ).
521  :type residue_mapping: :class:`dict` with key: :class:`str`,
522  value: :class:`ost.seq.AlignmentHandle`
523  :param return_dist_test: Whether to additionally return the underlying
524  per-residue data for the distance difference
525  test. Adds five objects to the return tuple.
526  First: Number of total contacts summed over all
527  thresholds
528  Second: Number of conserved contacts summed
529  over all thresholds
530  Third: list with length of scored residues.
531  Contains indices referring to model.residues.
532  Fourth: numpy array of size
533  len(scored_residues) containing the number of
534  total contacts,
535  Fifth: numpy matrix of shape
536  (len(scored_residues), len(thresholds))
537  specifying how many for each threshold are
538  conserved.
539  :param check_resnames: On by default. Enforces residue name matches
540  between mapped model and target residues.
541  :type check_resnames: :class:`bool`
542  :param add_mdl_contacts: Adds model contacts - Only using contacts that
543  are within a certain distance threshold in the
544  target does not penalize for added model
545  contacts. If set to True, this flag will also
546  consider target contacts that are within the
547  specified distance threshold in the model but
548  not necessarily in the target. No contact will
549  be added if the respective atom pair is not
550  resolved in the target.
551  :type add_mdl_contacts: :class:`bool`
552  :param interaction_data: Pro param - don't use
553  :type interaction_data: :class:`tuple`
554  :param set_atom_props: If True, sets generic properties on a per atom
555  level if *local_lddt_prop*/*local_contact_prop*
556  are set as well.
557  In other words: this is the only way you can
558  get per-atom LDDT values.
559  :type set_atom_props: :class:`bool`
560 
561  :returns: global and per-residue LDDT scores as a tuple -
562  first element is global LDDT score (None if *target* has no
563  contacts) and second element a list of per-residue scores with
564  length len(*model*.residues). None is assigned to residues that
565  are not covered by target. If a residue is covered but has no
566  contacts in *target*, 0.0 is assigned.
567  """
568  if chain_mapping is None:
569  if len(self.chain_nameschain_names) > 1 or len(model.chains) > 1:
570  raise NotImplementedError("Must provide chain mapping if "
571  "target or model have > 1 chains.")
572  chain_mapping = {model.chains[0].GetName(): self.chain_nameschain_names[0]}
573  else:
574  # check whether chains specified in mapping exist
575  for model_chain, target_chain in chain_mapping.items():
576  if target_chain not in self.chain_nameschain_names:
577  raise RuntimeError(f"Target chain specified in "
578  f"chain_mapping ({target_chain}) does "
579  f"not exist. Target has chains: "
580  f"{self.chain_names}")
581  ch = model.FindChain(model_chain)
582  if not ch.IsValid():
583  raise RuntimeError(f"Model chain specified in "
584  f"chain_mapping ({model_chain}) does "
585  f"not exist. Model has chains: "
586  f"{[c.GetName() for c in model.chains]}")
587 
588  # data objects defining model data - see _ProcessModel for rough
589  # description
590  pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
591  res_indices, ref_res_indices, symmetries = \
592  self._ProcessModel_ProcessModel(model, chain_mapping,
593  residue_mapping = residue_mapping,
594  nirvana_dist = self.inclusion_radiusinclusion_radius + max(thresholds),
595  check_resnames = check_resnames)
596 
597  if no_interchain and no_intrachain:
598  raise RuntimeError("no_interchain and no_intrachain flags are "
599  "mutually exclusive")
600 
601  sym_ref_indices = None
602  sym_ref_distances = None
603  ref_indices = None
604  ref_distances = None
605 
606  if interaction_data is None:
607  if no_interchain:
608  sym_ref_indices = self.sym_ref_indices_scsym_ref_indices_sc
609  sym_ref_distances = self.sym_ref_distances_scsym_ref_distances_sc
610  ref_indices = self.ref_indices_scref_indices_sc
611  ref_distances = self.ref_distances_scref_distances_sc
612  elif no_intrachain:
613  sym_ref_indices = self.sym_ref_indices_icsym_ref_indices_ic
614  sym_ref_distances = self.sym_ref_distances_icsym_ref_distances_ic
615  ref_indices = self.ref_indices_icref_indices_ic
616  ref_distances = self.ref_distances_icref_distances_ic
617  else:
618  sym_ref_indices = self.sym_ref_indicessym_ref_indices
619  sym_ref_distances = self.sym_ref_distancessym_ref_distances
620  ref_indices = self.ref_indicesref_indices
621  ref_distances = self.ref_distancesref_distances
622 
623  if add_mdl_contacts:
624  ref_indices, ref_distances = \
625  self._AddMdlContacts_AddMdlContacts(model, res_atom_indices, res_atom_hashes,
626  ref_indices, ref_distances,
627  no_interchain, no_intrachain)
628  # recompute symmetry related indices/distances
629  sym_ref_indices, sym_ref_distances = \
630  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
631  ref_indices, ref_distances)
632  else:
633  sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \
634  interaction_data
635 
636  self._ResolveSymmetries_ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
637  sym_ref_distances)
638 
639  atom_indices = list(itertools.chain.from_iterable(res_atom_indices))
640 
641  per_atom_exp = np.asarray([self._GetNExp_GetNExp(i, ref_indices)
642  for i in atom_indices], dtype=np.int32)
643  per_res_exp = np.asarray([self._GetNExp_GetNExp(res_ref_atom_indices[idx],
644  ref_indices) for idx in range(len(res_indices))], dtype=np.int32)
645 
646  per_atom_conserved = self._EvalAtoms_EvalAtoms(pos, atom_indices, thresholds,
647  ref_indices, ref_distances)
648  per_res_conserved = np.zeros((len(res_atom_indices), len(thresholds)),
649  dtype=np.int32)
650  start_idx = 0
651  for r_idx in range(len(res_atom_indices)):
652  end_idx = start_idx + len(res_atom_indices[r_idx])
653  per_res_conserved[r_idx] = np.sum(per_atom_conserved[start_idx:end_idx,:],
654  axis=0)
655  start_idx = end_idx
656 
657  n_thresh = len(thresholds)
658 
659  # do per-residue scores
660  per_res_lDDT = [None] * model.GetResidueCount()
661  for idx in range(len(res_indices)):
662  n_exp = n_thresh * per_res_exp[idx]
663  if n_exp > 0:
664  score = np.sum(per_res_conserved[idx,:]) / n_exp
665  per_res_lDDT[res_indices[idx]] = score
666  else:
667  per_res_lDDT[res_indices[idx]] = 0.0
668 
669  # do full model score
670  n_distances = sum([len(x) for x in ref_indices])
671  if penalize_extra_chains:
672  n_distances += self._GetExtraModelChainPenalty_GetExtraModelChainPenalty(model, chain_mapping)
673 
674  lDDT_tot = int(n_thresh * n_distances)
675  lDDT_cons = int(np.sum(per_res_conserved))
676  lDDT = None
677  if lDDT_tot > 0:
678  lDDT = float(lDDT_cons) / lDDT_tot
679 
680  # set properties if necessary
681  if local_lddt_prop:
682  residues = model.residues
683  for idx in res_indices:
684  residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx])
685 
686  if local_contact_prop:
687  residues = model.residues
688  exp_prop = local_contact_prop + "_exp"
689  conserved_prop = local_contact_prop + "_cons"
690 
691  for i, r_idx in enumerate(res_indices):
692  residues[r_idx].SetIntProp(exp_prop,
693  n_thresh * int(per_res_exp[i]))
694  residues[r_idx].SetIntProp(conserved_prop,
695  int(np.sum(per_res_conserved[i,:])))
696 
697  if set_atom_props and (local_lddt_prop or local_contact_prop):
698  atom_list = list()
699  residues = model.residues
700  for i, indices in enumerate(res_atom_indices):
701  r = residues[res_indices[i]]
702  r_idx = ref_res_indices[i]
703  res_start_idx = self.res_start_indicesres_start_indices[r_idx]
704  anames = self.compound_anamescompound_anames[self.compound_namescompound_names[r_idx]]
705  for a_i in indices:
706  a = r.FindAtom(anames[a_i - res_start_idx])
707  assert(a.IsValid())
708  atom_list.append(a)
709 
710  summed_per_atom_conserved = per_atom_conserved.sum(axis=1)
711  if local_lddt_prop:
712  # the only place where actually need to compute per-atom LDDT
713  # scores
714  for a_idx in range(len(atom_list)):
715  if per_atom_exp[a_idx] != 0:
716  tmp = summed_per_atom_conserved[a_idx] / per_atom_exp[a_idx]
717  tmp = tmp / n_thresh
718  atom_list[a_idx].SetFloatProp(local_lddt_prop, tmp)
719 
720  if local_contact_prop:
721  conserved_prop = local_contact_prop + "_cons"
722  exp_prop = local_contact_prop + "_exp"
723  for a_idx in range(len(atom_list)):
724  # do number of conserved contacts
725  tmp = summed_per_atom_conserved[a_idx]
726  atom_list[a_idx].SetIntProp(conserved_prop, tmp)
727  # do number of expected contacts
728  tmp = per_atom_exp[a_idx] * n_thresh
729  atom_list[a_idx].SetIntProp(exp_prop, tmp)
730 
731  if return_dist_test:
732  return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
733  per_res_exp, per_res_conserved
734  else:
735  return lDDT, per_res_lDDT
736 
737  def DRMSD(self, model, dist_cap = 5,
738  chain_mapping=None, no_interchain=False,
739  no_intrachain=False, residue_mapping=None,
740  check_resnames=True, add_mdl_contacts=False,
741  interaction_data=None):
742  """ DRMSD of *model* - globally and per-residue
743 
744  Very similar to LDDT as we operate on distance differences for all
745  interatomic distances within the same inclusion radius as in LDDT.
746  DRMSD is the distance rmsd, i.e. the RMSD of distance differences.
747  Distance differences are capped at *dist_cap* which is also the default
748  value for missing distances.
749 
750  :param model: Model to be scored - models are preferably scored upon
751  performing stereo-chemistry checks in order to punish for
752  non-sensical irregularities. This must be done separately
753  as a pre-processing step. Target contacts that are not
754  covered by *model* are considered not conserved, thus
755  increasing DRMSD score. This also includes missing model
756  chains or model chains for which no mapping is provided in
757  *chain_mapping*.
758  :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
759  :param dist_cap: Cap for distance differences.
760  :type dist_cap: :class:`float`
761  :param chain_mapping: Mapping of model chains (key) onto target chains
762  (value). This is required if target or model have
763  more than one chain.
764  :type chain_mapping: :class:`dict` with :class:`str` as keys/values
765  :param no_interchain: Whether to exclude interchain contacts
766  :type no_interchain: :class:`bool`
767  :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
768  consider interface related contacts)
769  :type no_intrachain: :class:`bool`
770  :param residue_mapping: By default, residue mapping is based on residue
771  numbers. That means, a model chain and the
772  respective target chain map to the same
773  underlying reference sequence (SEQRES).
774  Alternatively, you can specify one or
775  several alignment(s) between model and target
776  chains by providing a dictionary. key: Name
777  of chain in model (respective target chain is
778  extracted from *chain_mapping*),
779  value: Alignment with first sequence
780  corresponding to target chain and second
781  sequence to model chain. There is NO reference
782  sequence involved, so the two sequences MUST
783  exactly match the actual residues observed in
784  the respective target/model chains (ATOMSEQ).
785  :type residue_mapping: :class:`dict` with key: :class:`str`,
786  value: :class:`ost.seq.AlignmentHandle`
787  :param check_resnames: On by default. Enforces residue name matches
788  between mapped model and target residues.
789  :type check_resnames: :class:`bool`
790  :param add_mdl_contacts: Adds model contacts - Only using contacts that
791  are within a certain distance threshold in the
792  target does not penalize for added model
793  contacts. If set to True, this flag will also
794  consider target contacts that are within the
795  specified distance threshold in the model but
796  not necessarily in the target. No contact will
797  be added if the respective atom pair is not
798  resolved in the target.
799  :type add_mdl_contacts: :class:`bool`
800  :param interaction_data: Pro param - don't use
801  :type interaction_data: :class:`tuple`
802 
803  :returns: global and per-residue DRMSD scores as a tuple -
804  first element is global DRMSD score (None if *target* has no
805  contacts) and second element a list of per-residue scores with
806  length len(*model*.residues). None is assigned to residues that
807  are not covered by target. If a residue is covered but has no
808  contacts in *target*, None is assigned.
809  """
810  if chain_mapping is None:
811  if len(self.chain_nameschain_names) > 1 or len(model.chains) > 1:
812  raise NotImplementedError("Must provide chain mapping if "
813  "target or model have > 1 chains.")
814  chain_mapping = {model.chains[0].GetName(): self.chain_nameschain_names[0]}
815  else:
816  # check whether chains specified in mapping exist
817  for model_chain, target_chain in chain_mapping.items():
818  if target_chain not in self.chain_nameschain_names:
819  raise RuntimeError(f"Target chain specified in "
820  f"chain_mapping ({target_chain}) does "
821  f"not exist. Target has chains: "
822  f"{self.chain_names}")
823  ch = model.FindChain(model_chain)
824  if not ch.IsValid():
825  raise RuntimeError(f"Model chain specified in "
826  f"chain_mapping ({model_chain}) does "
827  f"not exist. Model has chains: "
828  f"{[c.GetName() for c in model.chains]}")
829 
830  # data objects defining model data - see _ProcessModel for rough
831  # description
832  pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
833  res_indices, ref_res_indices, symmetries = \
834  self._ProcessModel_ProcessModel(model, chain_mapping,
835  residue_mapping = residue_mapping,
836  nirvana_dist = self.inclusion_radiusinclusion_radius + dist_cap,
837  check_resnames = check_resnames)
838 
839  if no_interchain and no_intrachain:
840  raise RuntimeError("no_interchain and no_intrachain flags are "
841  "mutually exclusive")
842 
843  sym_ref_indices = None
844  sym_ref_distances = None
845  ref_indices = None
846  ref_distances = None
847 
848  if interaction_data is None:
849  if no_interchain:
850  sym_ref_indices = self.sym_ref_indices_scsym_ref_indices_sc
851  sym_ref_distances = self.sym_ref_distances_scsym_ref_distances_sc
852  ref_indices = self.ref_indices_scref_indices_sc
853  ref_distances = self.ref_distances_scref_distances_sc
854  elif no_intrachain:
855  sym_ref_indices = self.sym_ref_indices_icsym_ref_indices_ic
856  sym_ref_distances = self.sym_ref_distances_icsym_ref_distances_ic
857  ref_indices = self.ref_indices_icref_indices_ic
858  ref_distances = self.ref_distances_icref_distances_ic
859  else:
860  sym_ref_indices = self.sym_ref_indicessym_ref_indices
861  sym_ref_distances = self.sym_ref_distancessym_ref_distances
862  ref_indices = self.ref_indicesref_indices
863  ref_distances = self.ref_distancesref_distances
864 
865  if add_mdl_contacts:
866  ref_indices, ref_distances = \
867  self._AddMdlContacts_AddMdlContacts(model, res_atom_indices, res_atom_hashes,
868  ref_indices, ref_distances,
869  no_interchain, no_intrachain)
870  # recompute symmetry related indices/distances
871  sym_ref_indices, sym_ref_distances = \
872  lDDTScorer._NonSymDistances(self.n_atomsn_atoms, self.symmetric_atomssymmetric_atoms,
873  ref_indices, ref_distances)
874  else:
875  sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \
876  interaction_data
877 
878  self._ResolveSymmetriesSSD_ResolveSymmetriesSSD(pos, dist_cap, symmetries, sym_ref_indices,
879  sym_ref_distances)
880 
881  atom_indices = list(itertools.chain.from_iterable(res_atom_indices))
882 
883  per_atom_exp = np.asarray([self._GetNExp_GetNExp(i, ref_indices)
884  for i in atom_indices], dtype=np.int32)
885  per_res_exp = np.asarray([self._GetNExp_GetNExp(res_ref_atom_indices[idx],
886  ref_indices) for idx in range(len(res_indices))], dtype=np.int32)
887  per_atom_ssd = self._EvalAtomsSSD_EvalAtomsSSD(pos, atom_indices, dist_cap,
888  ref_indices, ref_distances)
889 
890  # do per residue scores
891  start_idx = 0
892  per_res_drmsd = [None] * model.GetResidueCount()
893  for r_idx in range(len(res_atom_indices)):
894  end_idx = start_idx + len(res_atom_indices[r_idx])
895  n_tot = per_res_exp[r_idx]
896  if n_tot > 0:
897  ssd = np.sum(per_atom_ssd[start_idx:end_idx])
898  # add penalties from distances involving atoms that are not
899  # present in the model
900  n_missing = n_tot - np.sum(per_atom_exp[start_idx:end_idx])
901  ssd += n_missing*dist_cap*dist_cap
902  per_res_drmsd[res_indices[r_idx]] = np.sqrt(ssd/n_tot)
903  start_idx = end_idx
904 
905  # do full model score
906  drmsd = None
907  n_tot = sum([len(x) for x in ref_indices])
908  if n_tot > 0:
909  ssd = np.sum(per_atom_ssd)
910  # add penalties from distances involving atoms that are not
911  # present in the model
912  n_missing = n_tot - np.sum(per_atom_exp)
913  ssd += (dist_cap*dist_cap*n_missing)
914  drmsd = np.sqrt(ssd/n_tot)
915 
916  return drmsd, per_res_drmsd
917 
918  def GetNChainContacts(self, target_chain, no_interchain=False):
919  """Returns number of contacts expected for a certain chain in *target*
920 
921  :param target_chain: Chain in *target* for which you want the number
922  of expected contacts
923  :type target_chain: :class:`str`
924  :param no_interchain: Whether to exclude interchain contacts
925  :type no_interchain: :class:`bool`
926  :raises: :class:`RuntimeError` if specified chain doesnt exist
927  """
928  if target_chain not in self.chain_nameschain_names:
929  raise RuntimeError(f"Specified chain name ({target_chain}) not in "
930  f"target")
931  ch_idx = self.chain_nameschain_names.index(target_chain)
932  s = self.chain_start_indiceschain_start_indices[ch_idx]
933  e = self.n_atomsn_atoms
934  if ch_idx + 1 < len(self.chain_nameschain_names):
935  e = self.chain_start_indiceschain_start_indices[ch_idx+1]
936  if no_interchain:
937  return self._GetNExp_GetNExp(list(range(s, e)), self.ref_indices_scref_indices_sc)
938  else:
939  return self._GetNExp_GetNExp(list(range(s, e)), self.ref_indicesref_indices)
940 
941  def _ProcessModel(self, model, chain_mapping, residue_mapping = None,
942  nirvana_dist = 100,
943  check_resnames = True):
944  """ Helper that generates data structures from model
945  """
946 
947  # initialize positions with values far in nirvana. If a position is not
948  # set, it should be far away from any position in model.
949  max_pos = model.bounds.GetMax()
950  max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
951  max_coordinate += 42 * nirvana_dist
952  pos = np.ones((self.n_atomsn_atoms, 3), dtype=np.float32) * max_coordinate
953 
954  # for each scored residue in model a list of indices describing the
955  # atoms from the reference that should be there
956  res_ref_atom_indices = list()
957 
958  # for each scored residue in model a list of indices of atoms that are
959  # actually there
960  res_atom_indices = list()
961 
962  # and the respective hash codes
963  # this is required if add_mdl_contacts is set to True
964  res_atom_hashes = list()
965 
966  # indices of the scored residues
967  res_indices = list()
968 
969  # respective residue indices in reference
970  ref_res_indices = list()
971 
972  # Will contain one element per symmetry group
973  symmetries = list()
974 
975  current_model_res_idx = -1
976  for ch in model.chains:
977  model_ch_name = ch.GetName()
978  if model_ch_name not in chain_mapping:
979  current_model_res_idx += len(ch.residues)
980  continue # additional model chain which is not mapped
981  target_ch_name = chain_mapping[model_ch_name]
982 
983  rnums = self._GetChainRNums_GetChainRNums(ch, residue_mapping, model_ch_name,
984  target_ch_name)
985 
986  for r, rnum in zip(ch.residues, rnums):
987  current_model_res_idx += 1
988  res_mapper_key = (target_ch_name, rnum)
989  if res_mapper_key not in self.res_mapperres_mapper:
990  continue
991  r_idx = self.res_mapperres_mapper[res_mapper_key]
992  if check_resnames and r.name != self.compound_namescompound_names[r_idx]:
993  raise RuntimeError(
994  f"Residue name mismatch for {r}, "
995  f" expect {self.compound_names[r_idx]}"
996  )
997  res_start_idx = self.res_start_indicesres_start_indices[r_idx]
998  rname = self.compound_namescompound_names[r_idx]
999  anames = self.compound_anamescompound_anames[rname]
1000  atoms = [r.FindAtom(aname) for aname in anames]
1001  res_ref_atom_indices.append(
1002  list(range(res_start_idx, res_start_idx + len(anames)))
1003  )
1004  res_atom_indices.append(list())
1005  res_atom_hashes.append(list())
1006  res_indices.append(current_model_res_idx)
1007  ref_res_indices.append(r_idx)
1008  for a_idx, a in enumerate(atoms):
1009  if a.IsValid():
1010  p = a.GetPos()
1011  pos[res_start_idx + a_idx][0] = p[0]
1012  pos[res_start_idx + a_idx][1] = p[1]
1013  pos[res_start_idx + a_idx][2] = p[2]
1014  res_atom_indices[-1].append(res_start_idx + a_idx)
1015  res_atom_hashes[-1].append(a.handle.GetHashCode())
1016  if rname in self.compound_symmetric_atomscompound_symmetric_atoms:
1017  sym_indices = list()
1018  for sym_tuple in self.compound_symmetric_atomscompound_symmetric_atoms[rname]:
1019  a_one = atoms[sym_tuple[0]]
1020  a_two = atoms[sym_tuple[1]]
1021  if a_one.IsValid() and a_two.IsValid():
1022  sym_indices.append(
1023  (
1024  res_start_idx + sym_tuple[0],
1025  res_start_idx + sym_tuple[1],
1026  )
1027  )
1028  if len(sym_indices) > 0:
1029  symmetries.append(sym_indices)
1030 
1031  return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes,
1032  res_indices, ref_res_indices, symmetries)
1033 
1034 
1035  def _GetExtraModelChainPenalty(self, model, chain_mapping):
1036  """Counts n distances in extra model chains to be added as penalty
1037  """
1038  penalty = 0
1039  for chain in model.chains:
1040  ch_name = chain.GetName()
1041  if ch_name not in chain_mapping:
1042  sm = self.symmetry_settingssymmetry_settings
1043  mdl_sel = model.Select(f"cname={mol.QueryQuoteName(ch_name)}")
1044  dummy_scorer = lDDTScorer(mdl_sel, self.compound_libcompound_lib,
1045  symmetry_settings = sm,
1046  inclusion_radius = self.inclusion_radiusinclusion_radius,
1047  bb_only = self.bb_onlybb_only)
1048  penalty += sum([len(x) for x in dummy_scorer.ref_indices])
1049  return penalty
1050 
1051  def _GetChainRNums(self, ch, residue_mapping, model_ch_name,
1052  target_ch_name):
1053  """Map residues in model chain to target residues
1054 
1055  There are two options: one is simply using residue numbers,
1056  the other is a custom mapping as given in *residue_mapping*
1057  """
1058  if residue_mapping and model_ch_name in residue_mapping:
1059  # extract residue numbers from target chain
1060  ch_idx = self.chain_nameschain_names.index(target_ch_name)
1061  start_idx = self.chain_res_start_indiceschain_res_start_indices[ch_idx]
1062  if ch_idx < len(self.chain_nameschain_names) - 1:
1063  end_idx = self.chain_res_start_indiceschain_res_start_indices[ch_idx+1]
1064  else:
1065  end_idx = len(self.compound_namescompound_names)
1066  target_rnums = self.res_resnumsres_resnums[start_idx:end_idx]
1067  # get sequences from alignment and do consistency checks
1068  target_seq = residue_mapping[model_ch_name].GetSequence(0)
1069  model_seq = residue_mapping[model_ch_name].GetSequence(1)
1070  if len(target_seq.GetGaplessString()) != len(target_rnums):
1071  raise RuntimeError(f"Try to perform residue mapping for "
1072  f"model chain {model_ch_name} which "
1073  f"maps to {target_ch_name} in target. "
1074  f"Target sequence in alignment suggests "
1075  f"{len(target_seq.GetGaplessString())} "
1076  f"residues but {len(target_rnums)} are "
1077  f"expected.")
1078  if len(model_seq.GetGaplessString()) != len(ch.residues):
1079  raise RuntimeError(f"Try to perform residue mapping for "
1080  f"model chain {model_ch_name} which "
1081  f"maps to {target_ch_name} in target. "
1082  f"Model sequence in alignment suggests "
1083  f"{len(model_seq.GetGaplessString())} "
1084  f"residues but {len(ch.residues)} are "
1085  f"expected.")
1086  rnums = list()
1087  target_idx = -1
1088  for col in residue_mapping[model_ch_name]:
1089  if col[0] != '-':
1090  target_idx += 1
1091  # handle match
1092  if col[0] != '-' and col[1] != '-':
1093  rnums.append(target_rnums[target_idx])
1094  # insertion in model adds None to rnum
1095  if col[0] == '-' and col[1] != '-':
1096  rnums.append(None)
1097  else:
1098  rnums = [r.GetNumber() for r in ch.residues]
1099 
1100  return rnums
1101 
1102 
1103  def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings,
1104  seqres_mapping, bb_only):
1105  """Sets target related lDDTScorer members defined in constructor
1106 
1107  No distance related members - see _SetupDistances
1108  """
1109  residue_numbers = self._GetTargetResidueNumbers_GetTargetResidueNumbers(self.targettarget,
1110  seqres_mapping)
1111  current_idx = 0
1112  positions = list()
1113  for chain in self.targettarget.chains:
1114  ch_name = chain.GetName()
1115  self.chain_nameschain_names.append(ch_name)
1116  self.chain_start_indiceschain_start_indices.append(current_idx)
1117  self.chain_res_start_indiceschain_res_start_indices.append(len(self.compound_namescompound_names))
1118  for r, rnum in zip(chain.residues, residue_numbers[ch_name]):
1119  if r.name not in self.compound_anamescompound_anames:
1120  # sets compound info in self.compound_anames and
1121  # self.compound_symmetric_atoms
1122  self._SetupCompound_SetupCompound(r, compound_lib, custom_compounds,
1123  symmetry_settings, bb_only)
1124 
1125  self.res_start_indicesres_start_indices.append(current_idx)
1126  self.res_mapperres_mapper[(ch_name, rnum)] = len(self.compound_namescompound_names)
1127  self.compound_namescompound_names.append(r.name)
1128  self.res_resnumsres_resnums.append(rnum)
1129 
1130  atoms = [r.FindAtom(an) for an in self.compound_anamescompound_anames[r.name]]
1131  for a in atoms:
1132  if a.IsValid():
1133  self.atom_indicesatom_indices[a.handle.GetHashCode()] = current_idx
1134  p = a.GetPos()
1135  positions.append(np.asarray([p[0], p[1], p[2]],
1136  dtype=np.float32))
1137  else:
1138  positions.append(np.zeros(3, dtype=np.float32))
1139  current_idx += 1
1140 
1141  if r.name in self.compound_symmetric_atomscompound_symmetric_atoms:
1142  for sym_tuple in self.compound_symmetric_atomscompound_symmetric_atoms[r.name]:
1143  for a_idx in sym_tuple:
1144  a = atoms[a_idx]
1145  if a.IsValid():
1146  hashcode = a.handle.GetHashCode()
1147  self.symmetric_atomssymmetric_atoms.add(
1148  self.atom_indicesatom_indices[hashcode]
1149  )
1150  self.positionspositions = np.vstack(positions)
1151  self.n_atomsn_atoms = current_idx
1152 
1153  def _GetTargetResidueNumbers(self, target, seqres_mapping):
1154  """Returns residue numbers for each chain in target as dict
1155 
1156  They're either directly extracted from the raw residue number
1157  from the structure or from user provided alignments
1158  """
1159  residue_numbers = dict()
1160  for ch in target.chains:
1161  ch_name = ch.GetName()
1162  rnums = list()
1163  if ch_name in seqres_mapping:
1164  seqres = seqres_mapping[ch_name].GetSequence(0).GetString()
1165  atomseq = seqres_mapping[ch_name].GetSequence(1).GetString()
1166  # SEQRES must not contain gaps
1167  if "-" in seqres:
1168  raise RuntimeError(
1169  "SEQRES in seqres_mapping must not " "contain gaps"
1170  )
1171  atomseq_from_chain = [r.one_letter_code for r in ch.residues]
1172  if atomseq.replace("-", "") != atomseq_from_chain:
1173  raise RuntimeError(
1174  "ATOMSEQ in seqres_mapping must match "
1175  "raw sequence extracted from chain "
1176  "residues"
1177  )
1178  rnum = 0
1179  for seqres_olc, atomseq_olc in zip(seqres, atomseq):
1180  if seqres_olc != "-":
1181  rnum += 1
1182  if atomseq_olc != "-":
1183  if seqres_olc != atomseq_olc:
1184  raise RuntimeError(
1185  f"Residue with number {rnum} in "
1186  f"chain {ch_name} has SEQRES "
1187  f"ATOMSEQ mismatch"
1188  )
1189  rnums.append(mol.ResNum(rnum))
1190  else:
1191  rnums = [r.GetNumber() for r in ch.residues]
1192  assert len(rnums) == len(ch.residues)
1193  residue_numbers[ch_name] = rnums
1194  return residue_numbers
1195 
1196  def _SetupCompound(self, r, compound_lib, custom_compounds,
1197  symmetry_settings, bb_only):
1198  """fill self.compound_anames/self.compound_symmetric_atoms
1199  """
1200  if bb_only:
1201  # throw away compound_lib info
1202  if r.chem_class.IsPeptideLinking():
1203  self.compound_anamescompound_anames[r.name] = ["CA"]
1204  elif r.chem_class.IsNucleotideLinking():
1205  self.compound_anamescompound_anames[r.name] = ["C3'"]
1206  else:
1207  raise RuntimeError(f"Only support amino acids and nucleotides "
1208  f"if bb_only is True, failed with {str(r)}")
1209  self.compound_symmetric_atomscompound_symmetric_atoms[r.name] = list()
1210  else:
1211  atom_names = list()
1212  symmetric_atoms = list()
1213  if custom_compounds is not None and r.GetName() in custom_compounds:
1214  atom_names = list(custom_compounds[r.GetName()].atom_names)
1215  else:
1216  compound = compound_lib.FindCompound(r.name)
1217  if compound is None:
1218  raise RuntimeError(f"no entry for {r} in compound_lib")
1219  for atom_spec in compound.GetAtomSpecs():
1220  if atom_spec.element not in ["H", "D"]:
1221  atom_names.append(atom_spec.name)
1222  if r.name in symmetry_settings.symmetric_compounds:
1223  for pair in symmetry_settings.symmetric_compounds[r.name]:
1224  try:
1225  a = atom_names.index(pair[0])
1226  b = atom_names.index(pair[1])
1227  except:
1228  msg = f"Could not find symmetric atoms "
1229  msg += f"({pair[0]}, {pair[1]}) for {r.name} "
1230  msg += f"as specified in SymmetrySettings in "
1231  msg += f"compound from component dictionary. "
1232  msg += f"Atoms in compound: {atom_names}"
1233  raise RuntimeError(msg)
1234  symmetric_atoms.append((a, b))
1235  self.compound_anamescompound_anames[r.name] = atom_names
1236  if len(symmetric_atoms) > 0:
1237  self.compound_symmetric_atomscompound_symmetric_atoms[r.name] = symmetric_atoms
1238 
1239  def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes,
1240  ref_indices, ref_distances, no_interchain,
1241  no_intrachain):
1242 
1243  # buildup an index map for mdl atoms that are also present in target
1244  in_target = np.zeros(self.n_atomsn_atoms, dtype=bool)
1245  for i in self.atom_indicesatom_indices.values():
1246  in_target[i] = True
1247  mdl_atom_indices = dict()
1248  for at_indices, at_hashes in zip(res_atom_indices, res_atom_hashes):
1249  for i, h in zip(at_indices, at_hashes):
1250  if in_target[i]:
1251  mdl_atom_indices[h] = i
1252 
1253  # get contacts for mdl - the contacts are only from atom pairs that
1254  # are also present in target, as we only provide the respective
1255  # hashes in mdl_atom_indices
1256  mdl_ref_indices, mdl_ref_distances = \
1257  lDDTScorer._SetupDistances(model, self.n_atomsn_atoms, mdl_atom_indices,
1258  self.inclusion_radiusinclusion_radius)
1259  if no_interchain:
1260  mdl_ref_indices, mdl_ref_distances = \
1261  lDDTScorer._SetupDistancesSC(self.n_atomsn_atoms,
1262  self.chain_start_indiceschain_start_indices,
1263  mdl_ref_indices,
1264  mdl_ref_distances)
1265 
1266  if no_intrachain:
1267  mdl_ref_indices, mdl_ref_distances = \
1268  lDDTScorer._SetupDistancesIC(self.n_atomsn_atoms,
1269  self.chain_start_indiceschain_start_indices,
1270  mdl_ref_indices,
1271  mdl_ref_distances)
1272 
1273  # update ref_indices/ref_distances => add mdl contacts
1274  for i in range(self.n_atomsn_atoms):
1275  mask = np.isin(mdl_ref_indices[i], ref_indices[i],
1276  assume_unique=True, invert=True)
1277  if np.sum(mask) > 0:
1278  added_mdl_indices = mdl_ref_indices[i][mask]
1279  ref_indices[i] = np.append(ref_indices[i],
1280  added_mdl_indices)
1281 
1282  # distances need to be recomputed from ref positions
1283  tmp = self.positionspositions.take(added_mdl_indices, axis=0)
1284  np.subtract(tmp, self.positionspositions[i][None, :], out=tmp)
1285  np.square(tmp, out=tmp)
1286  tmp = tmp.sum(axis=1)
1287  np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1288  ref_distances[i] = np.append(ref_distances[i], tmp)
1289 
1290  return (ref_indices, ref_distances)
1291 
1292 
1293 
1294  @staticmethod
1295  def _SetupDistances(structure, n_atoms, atom_index_mapping,
1296  inclusion_radius):
1297 
1298  """Compute distance related members of lDDTScorer
1299 
1300  Brute force all vs all distance computation kills LDDT for large
1301  complexes. Instead of building some KD tree data structure, we make use
1302  of expected spatial proximity of atoms in the same chain. Distances are
1303  computed as follows:
1304 
1305  - process each chain individually
1306  - perform crude collision detection
1307  - process potentially interacting chain pairs
1308  - concatenate distances from all processing steps
1309  """
1310  ref_indices = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)]
1311  ref_distances = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)]
1312 
1313  indices = [list() for _ in range(n_atoms)]
1314  distances = [list() for _ in range(n_atoms)]
1315  per_chain_pos = list()
1316  per_chain_indices = list()
1317 
1318  # Process individual chains
1319  for ch in structure.chains:
1320  pos_list = list()
1321  atom_indices = list()
1322  mask_start = list()
1323  mask_end = list()
1324  r_start_idx = 0
1325  for r_idx, r in enumerate(ch.residues):
1326  n_valid_atoms = 0
1327  for a in r.atoms:
1328  hash_code = a.handle.GetHashCode()
1329  if hash_code in atom_index_mapping:
1330  p = a.GetPos()
1331  pos_list.append(np.asarray([p[0], p[1], p[2]], dtype=np.float32))
1332  atom_indices.append(atom_index_mapping[hash_code])
1333  n_valid_atoms += 1
1334  mask_start.extend([r_start_idx] * n_valid_atoms)
1335  mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms)
1336  r_start_idx += n_valid_atoms
1337 
1338  if len(pos_list) == 0:
1339  # nothing to do...
1340  continue
1341 
1342  pos = np.vstack(pos_list)
1343  atom_indices = np.asarray(atom_indices, dtype=np.int32)
1344 
1345  if atom_indices.shape[0] > 20000:
1346  dists = blockwise_cdist(pos, pos)
1347  else:
1348  dists = cdist(pos, pos)
1349 
1350  # apply masks
1351  far_away = 2 * inclusion_radius
1352  for idx in range(atom_indices.shape[0]):
1353  dists[idx, range(mask_start[idx], mask_end[idx])] = far_away
1354 
1355  # fish out and store close atoms within inclusion radius
1356  within_mask = dists < inclusion_radius
1357  for idx in range(atom_indices.shape[0]):
1358  indices_to_append = atom_indices[within_mask[idx,:]]
1359  if indices_to_append.shape[0] > 0:
1360  full_at_idx = atom_indices[idx]
1361  indices[full_at_idx].append(indices_to_append)
1362  distances[full_at_idx].append(dists[idx, within_mask[idx,:]])
1363 
1364  dists = None
1365 
1366  per_chain_pos.append(pos)
1367  per_chain_indices.append(atom_indices)
1368 
1369  # perform crude collision detection
1370  min_pos = [p.min(0) for p in per_chain_pos]
1371  max_pos = [p.max(0) for p in per_chain_pos]
1372  chain_pairs = list()
1373  for idx_one in range(len(per_chain_pos)):
1374  for idx_two in range(idx_one + 1, len(per_chain_pos)):
1375  if np.max(min_pos[idx_one] - max_pos[idx_two]) > inclusion_radius:
1376  continue
1377  if np.max(min_pos[idx_two] - max_pos[idx_one]) > inclusion_radius:
1378  continue
1379  chain_pairs.append((idx_one, idx_two))
1380 
1381  # process potentially interacting chains
1382  for pair in chain_pairs:
1383  if per_chain_pos[pair[0]].shape[0] > 20000 or per_chain_pos[pair[1]].shape[0] > 20000:
1384  dists = blockwise_cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1385  else:
1386  dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1387  within = dists <= inclusion_radius
1388 
1389  # process pair[0]
1390  tmp = within.sum(axis=1)
1391  for idx in range(tmp.shape[0]):
1392  if tmp[idx] > 0:
1393  # even though not being a strict requirement, we perform an
1394  # insertion here such that the indices for each atom will be
1395  # sorted after the hstack operation
1396  at_idx = per_chain_indices[pair[0]][idx]
1397  indices_to_insert = per_chain_indices[pair[1]][within[idx,:]]
1398  distances_to_insert = dists[idx, within[idx, :]]
1399  insertion_idx = len(indices[at_idx])
1400  for i in range(insertion_idx):
1401  if indices_to_insert[0] > indices[at_idx][i][0]:
1402  insertion_idx = i
1403  break
1404  indices[at_idx].insert(insertion_idx, indices_to_insert)
1405  distances[at_idx].insert(insertion_idx, distances_to_insert)
1406 
1407  # process pair[1]
1408  tmp = within.sum(axis=0)
1409  for idx in range(tmp.shape[0]):
1410  if tmp[idx] > 0:
1411  # even though not being a strict requirement, we perform an
1412  # insertion here such that the indices for each atom will be
1413  # sorted after the hstack operation
1414  at_idx = per_chain_indices[pair[1]][idx]
1415  indices_to_insert = per_chain_indices[pair[0]][within[:, idx]]
1416  distances_to_insert = dists[within[:, idx], idx]
1417  insertion_idx = len(indices[at_idx])
1418  for i in range(insertion_idx):
1419  if indices_to_insert[0] > indices[at_idx][i][0]:
1420  insertion_idx = i
1421  break
1422  indices[at_idx].insert(insertion_idx, indices_to_insert)
1423  distances[at_idx].insert(insertion_idx, distances_to_insert)
1424 
1425  dists = None
1426 
1427  # concatenate distances from all processing steps
1428  for at_idx in range(n_atoms):
1429  if len(indices[at_idx]) > 0:
1430  ref_indices[at_idx] = np.hstack(indices[at_idx])
1431  ref_distances[at_idx] = np.hstack(distances[at_idx])
1432 
1433  return (ref_indices, ref_distances)
1434 
1435  @staticmethod
1436  def _SetupDistancesSC(n_atoms, chain_start_indices,
1437  ref_indices, ref_distances):
1438  """Select subset of contacts only covering intra-chain contacts
1439  """
1440  # init
1441  ref_indices_sc = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)]
1442  ref_distances_sc = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)]
1443 
1444  n_chains = len(chain_start_indices)
1445  for ch_idx in range(n_chains):
1446  chain_s = chain_start_indices[ch_idx]
1447  chain_e = n_atoms
1448  if ch_idx + 1 < n_chains:
1449  chain_e = chain_start_indices[ch_idx+1]
1450  for i in range(chain_s, chain_e):
1451  if len(ref_indices[i]) > 0:
1452  intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
1453  ref_indices[i]<chain_e))[0]
1454  ref_indices_sc[i] = ref_indices[i][intra_idx]
1455  ref_distances_sc[i] = ref_distances[i][intra_idx]
1456 
1457  return (ref_indices_sc, ref_distances_sc)
1458 
1459  @staticmethod
1460  def _SetupDistancesIC(n_atoms, chain_start_indices,
1461  ref_indices, ref_distances):
1462  """Select subset of contacts only covering inter-chain contacts
1463  """
1464  # init
1465  ref_indices_ic = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)]
1466  ref_distances_ic = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)]
1467 
1468  n_chains = len(chain_start_indices)
1469  for ch_idx in range(n_chains):
1470  chain_s = chain_start_indices[ch_idx]
1471  chain_e = n_atoms
1472  if ch_idx + 1 < n_chains:
1473  chain_e = chain_start_indices[ch_idx+1]
1474  for i in range(chain_s, chain_e):
1475  if len(ref_indices[i]) > 0:
1476  inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s,
1477  ref_indices[i]>=chain_e))[0]
1478  ref_indices_ic[i] = ref_indices[i][inter_idx]
1479  ref_distances_ic[i] = ref_distances[i][inter_idx]
1480 
1481  return (ref_indices_ic, ref_distances_ic)
1482 
1483  @staticmethod
1484  def _NonSymDistances(n_atoms, symmetric_atoms, ref_indices, ref_distances):
1485  """Transfer indices/distances of non-symmetric atoms and return
1486  """
1487 
1488  sym_ref_indices = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)]
1489  sym_ref_distances = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)]
1490 
1491  for idx in symmetric_atoms:
1492  indices = list()
1493  distances = list()
1494  for i, d in zip(ref_indices[idx], ref_distances[idx]):
1495  if i not in symmetric_atoms:
1496  indices.append(i)
1497  distances.append(d)
1498  sym_ref_indices[idx] = indices
1499  sym_ref_distances[idx] = np.asarray(distances)
1500 
1501  return (sym_ref_indices, sym_ref_distances)
1502 
1503  def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances):
1504  """Computes number of distance differences within given thresholds
1505 
1506  returns np.array with len(thresholds) elements
1507  """
1508  a_p = pos[atom_idx, :]
1509  tmp = pos.take(ref_indices[atom_idx], axis=0)
1510  np.subtract(tmp, a_p[None, :], out=tmp)
1511  np.square(tmp, out=tmp)
1512  tmp = tmp.sum(axis=1)
1513  np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1514  np.subtract(ref_distances[atom_idx], tmp, out=tmp)
1515  np.absolute(tmp, out=tmp) # absolute dist diffs
1516  return np.asarray([(tmp <= thresh).sum() for thresh in thresholds],
1517  dtype=np.int32)
1518 
1519  def _EvalAtoms(
1520  self, pos, atom_indices, thresholds, ref_indices, ref_distances
1521  ):
1522  """Calls _EvalAtom for several atoms and sums up the computed number
1523  of distance differences within given thresholds
1524 
1525  returns numpy matrix of shape (n_atoms, len(threshold))
1526  """
1527  conserved = np.zeros((len(atom_indices), len(thresholds)),
1528  dtype=np.int32)
1529  for a_idx, a in enumerate(atom_indices):
1530  conserved[a_idx, :] = self._EvalAtom_EvalAtom(pos, a, thresholds,
1531  ref_indices, ref_distances)
1532  return conserved
1533 
1534  def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices,
1535  ref_distances):
1536  """Calls _EvalAtoms for a bunch of residues
1537 
1538  residues are defined in *res_atom_indices* as lists of atom indices
1539  returns numpy matrix of shape (n_residues, len(thresholds)).
1540  """
1541  conserved = np.zeros((len(res_atom_indices), len(thresholds)),
1542  dtype=np.int32)
1543  for rai_idx, rai in enumerate(res_atom_indices):
1544  conserved[rai_idx,:] = np.sum(self._EvalAtoms_EvalAtoms(pos, rai, thresholds,
1545  ref_indices, ref_distances), axis=0)
1546  return conserved
1547 
1548  def _ProcessSequenceSeparation(self):
1549  if self.sequence_separationsequence_separation != 0:
1550  raise NotImplementedError("Congratulations! You're the first one "
1551  "requesting a non-default "
1552  "sequence_separation in the new and "
1553  "awesome LDDT implementation. A crate of "
1554  "beer for Gabriel and he'll implement "
1555  "it.")
1556 
1557  def _GetNExp(self, atom_idx, ref_indices):
1558  """Returns number of close atoms around one or several atoms
1559  """
1560  if isinstance(atom_idx, int):
1561  return len(ref_indices[atom_idx])
1562  elif isinstance(atom_idx, list):
1563  return sum([len(ref_indices[idx]) for idx in atom_idx])
1564  else:
1565  raise RuntimeError("invalid input type")
1566 
1567  def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices,
1568  sym_ref_distances):
1569  """Swaps symmetric positions in-place in order to maximize LDDT scores
1570  towards non-symmetric atoms.
1571  """
1572  for sym in symmetries:
1573 
1574  atom_indices = list()
1575  for sym_tuple in sym:
1576  atom_indices += [sym_tuple[0], sym_tuple[1]]
1577  tot = self._GetNExp_GetNExp(atom_indices, sym_ref_indices)
1578 
1579  if tot == 0:
1580  continue # nothing to do
1581 
1582  # score as is
1583  sym_one_conserved = self._EvalAtoms_EvalAtoms(
1584  pos,
1585  atom_indices,
1586  thresholds,
1587  sym_ref_indices,
1588  sym_ref_distances,
1589  )
1590 
1591  # switch positions and score again
1592  for pair in sym:
1593  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1594 
1595  sym_two_conserved = self._EvalAtoms_EvalAtoms(
1596  pos,
1597  atom_indices,
1598  thresholds,
1599  sym_ref_indices,
1600  sym_ref_distances,
1601  )
1602 
1603  sym_one_score = np.sum(sym_one_conserved) / (len(thresholds) * tot)
1604  sym_two_score = np.sum(sym_two_conserved) / (len(thresholds) * tot)
1605 
1606  if sym_one_score >= sym_two_score:
1607  # switch back, initial positions were better or equal
1608  # for the equal case: we still switch back to reproduce the old
1609  # LDDT behaviour
1610  for pair in sym:
1611  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1612 
1613  def _EvalAtomSSD(self, pos, atom_idx, dist_cap, ref_indices, ref_distances):
1614  """ Computes summed squared distances
1615 
1616  distances are capped at dist_cap
1617  """
1618  a_p = pos[atom_idx, :]
1619  tmp = pos.take(ref_indices[atom_idx], axis=0)
1620  np.subtract(tmp, a_p[None, :], out=tmp)
1621  np.square(tmp, out=tmp)
1622  tmp = tmp.sum(axis=1)
1623  np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1624  np.subtract(ref_distances[atom_idx], tmp, out=tmp) # distance difference
1625  np.square(tmp, out=tmp) # squared distance difference
1626  squared_dist_cap = dist_cap*dist_cap
1627  tmp[tmp > squared_dist_cap] = squared_dist_cap
1628  return tmp.sum()
1629 
1630  def _EvalAtomsSSD(
1631  self, pos, atom_indices, dist_cap, ref_indices, ref_distances
1632  ):
1633  """Calls _EvalAtomSSD for several atoms
1634  """
1635  return np.asarray([self._EvalAtomSSD_EvalAtomSSD(pos, a, dist_cap, ref_indices,
1636  ref_distances) for a in atom_indices],
1637  dtype=np.float32)
1638 
1639  def _ResolveSymmetriesSSD(self, pos, dist_cap, symmetries, sym_ref_indices,
1640  sym_ref_distances):
1641  """Swaps symmetric positions in-place in order to maximize summed
1642  squared distances towards non-symmetric atoms.
1643  """
1644  for sym in symmetries:
1645 
1646  atom_indices = list()
1647  for sym_tuple in sym:
1648  atom_indices += [sym_tuple[0], sym_tuple[1]]
1649  tot = self._GetNExp_GetNExp(atom_indices, sym_ref_indices)
1650 
1651  if tot == 0:
1652  continue # nothing to do
1653 
1654  # score as is
1655  sym_one_ssd = self._EvalAtomsSSD_EvalAtomsSSD(
1656  pos,
1657  atom_indices,
1658  dist_cap,
1659  sym_ref_indices,
1660  sym_ref_distances,
1661  )
1662 
1663  # switch positions and score again
1664  for pair in sym:
1665  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1666 
1667  sym_two_ssd = self._EvalAtomsSSD_EvalAtomsSSD(
1668  pos,
1669  atom_indices,
1670  dist_cap,
1671  sym_ref_indices,
1672  sym_ref_distances,
1673  )
1674 
1675  sym_one_score = np.sum(sym_one_ssd)
1676  sym_two_score = np.sum(sym_two_ssd)
1677 
1678  if sym_one_score < sym_two_score:
1679  # switch back, initial positions were better
1680  for pair in sym:
1681  pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
def __init__(self, atom_names)
Definition: lddt.py:50
def AddSymmetricCompound(self, name, symmetric_atoms)
Definition: lddt.py:85
def _SetupCompound(self, r, compound_lib, custom_compounds, symmetry_settings, bb_only)
Definition: lddt.py:1197
def lDDT(self, model, thresholds=[0.5, 1.0, 2.0, 4.0], local_lddt_prop=None, local_contact_prop=None, chain_mapping=None, no_interchain=False, no_intrachain=False, penalize_extra_chains=False, residue_mapping=None, return_dist_test=False, check_resnames=True, add_mdl_contacts=False, interaction_data=None, set_atom_props=False)
Definition: lddt.py:462
def _ResolveSymmetriesSSD(self, pos, dist_cap, symmetries, sym_ref_indices, sym_ref_distances)
Definition: lddt.py:1640
def _GetChainRNums(self, ch, residue_mapping, model_ch_name, target_ch_name)
Definition: lddt.py:1052
def _ProcessSequenceSeparation(self)
Definition: lddt.py:1548
def sym_ref_distances(self)
Definition: lddt.py:369
def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances)
Definition: lddt.py:1568
def ref_distances_ic(self)
Definition: lddt.py:427
def _GetTargetResidueNumbers(self, target, seqres_mapping)
Definition: lddt.py:1153
def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances)
Definition: lddt.py:1503
def sym_ref_distances_ic(self)
Definition: lddt.py:447
def sym_ref_distances_sc(self)
Definition: lddt.py:407
def sym_ref_indices_ic(self)
Definition: lddt.py:437
def GetNChainContacts(self, target_chain, no_interchain=False)
Definition: lddt.py:918
def sym_ref_indices_sc(self)
Definition: lddt.py:397
def ref_distances_sc(self)
Definition: lddt.py:387
def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes, ref_indices, ref_distances, no_interchain, no_intrachain)
Definition: lddt.py:1241
def DRMSD(self, model, dist_cap=5, chain_mapping=None, no_interchain=False, no_intrachain=False, residue_mapping=None, check_resnames=True, add_mdl_contacts=False, interaction_data=None)
Definition: lddt.py:741
def _GetNExp(self, atom_idx, ref_indices)
Definition: lddt.py:1557
def __init__(self, target, compound_lib=None, custom_compounds=None, inclusion_radius=15, sequence_separation=0, symmetry_settings=None, seqres_mapping=dict(), bb_only=False)
Definition: lddt.py:233
def _EvalAtomsSSD(self, pos, atom_indices, dist_cap, ref_indices, ref_distances)
Definition: lddt.py:1632
def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings, seqres_mapping, bb_only)
Definition: lddt.py:1104
def _EvalAtomSSD(self, pos, atom_idx, dist_cap, ref_indices, ref_distances)
Definition: lddt.py:1613
def _ProcessModel(self, model, chain_mapping, residue_mapping=None, nirvana_dist=100, check_resnames=True)
Definition: lddt.py:943
def _GetExtraModelChainPenalty(self, model, chain_mapping)
Definition: lddt.py:1035
def _EvalAtoms(self, pos, atom_indices, thresholds, ref_indices, ref_distances)
Definition: lddt.py:1521
def GetDefaultSymmetrySettings()
Definition: lddt.py:103
def blockwise_cdist(A, B, block_size=1000)
Definition: lddt.py:19
def cdist(p1, p2)
Definition: lddt.py:12