OpenStructure
Loading...
Searching...
No Matches
lddt.py
Go to the documentation of this file.
1import itertools
2import numpy as np
3
4from ost import mol
5from ost import conop
6
7# use cdist of scipy, fallback to (slower) numpy implementation if scipy is not
8# available
9try:
10 from scipy.spatial.distance import cdist
11except:
12 def cdist(p1, p2):
13 x2 = np.sum(p1**2, axis=1) # (m)
14 y2 = np.sum(p2**2, axis=1) # (n)
15 xy = np.matmul(p1, p2.T) # (m, n)
16 x2 = x2.reshape(-1, 1)
17 return np.sqrt(x2 - 2*xy + y2) # (m, n)
18
19def blockwise_cdist(A, B, block_size=1000):
20 """ Memory efficient cdist implementation that performs blockwise operations
21
22 scipy cdist uses 64 bit floats (double) which can scratch at the upper
23 memory end for most machines when number of positions become larger.
24 E.g. ~4000 residues might for example have 35000 atom positions. That's
25 Almost 10GB to hold all pairwise distances in 64bit floats. This function
26 calls cdist blockwise and stores the results in a 32bit float matrix.
27
28 This function is adapted from chatgpt output
29 """
30 A = A.astype(np.float32)
31 B = B.astype(np.float32)
32 M, N = A.shape[0], B.shape[0]
33 D = np.empty((M, N), dtype=np.float32) # Output in float32 to save memory
34 for i in range(0, M, block_size):
35 A_block = A[i:i+block_size]
36 D[i:i+block_size, :] = cdist(A_block, B).astype(np.float32)
37 return D
38
40 """ Defines atoms for custom compounds
41
42 LDDT requires the reference atoms of a compound which are typically
43 extracted from a :class:`ost.conop.CompoundLib`. This lightweight
44 container allows to handle arbitrary compounds which are not
45 necessarily in the compound library.
46
47 :param atom_names: Names of atoms of custom compound
48 :type atom_names: :class:`list` of :class:`str`
49 """
50 def __init__(self, atom_names):
51 self.atom_names = atom_names
52
53 @staticmethod
54 def FromResidue(res):
55 """ Construct custom compound from residue
56
57 :param res: Residue from which reference atom names are extracted,
58 hydrogen/deuterium atoms are filtered out
59 :type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle`
60 :returns: :class:`CustomCompound`
61 """
62 at_names = [a.name for a in res.atoms if a.element not in ["H", "D"]]
63 if len(at_names) != len(set(at_names)):
64 raise RuntimeError("Duplicate atoms detected in CustomCompound")
65 compound = CustomCompound(at_names)
66 return compound
67
69 """Container for symmetric compounds
70
71 LDDT considers symmetries and selects the one resulting in the highest
72 possible score.
73
74 A symmetry is defined as a renaming operation on one or more atoms that
75 leads to a chemically equivalent residue. Example would be OD1 and OD2 in
76 ASP => renaming OD1 to OD2 and vice versa gives a chemically equivalent
77 residue.
78
79 Use :func:`AddSymmetricCompound` to define a symmetry which can then
80 directly be accessed through the *symmetric_compounds* member.
81 """
82 def __init__(self):
83 self.symmetric_compounds = dict()
84
85 def AddSymmetricCompound(self, name, symmetric_atoms):
86 """Adds symmetry for compound with *name*
87
88 :param name: Name of compound with symmetry
89 :type name: :class:`str`
90 :param symmetric_atoms: Pairs of atom names that define renaming
91 operation, i.e. after applying all switches
92 defined in the tuples, the resulting residue
93 should be chemically equivalent. Atom names
94 must refer to the PDB component dictionary.
95 :type symmetric_atoms: :class:`list` of :class:`tuple`
96 """
97 for pair in symmetric_atoms:
98 if len(pair) != 2:
99 raise RuntimeError("Expect pairs when defining symmetries")
100 self.symmetric_compounds[name] = symmetric_atoms
101
102
104 """Constructs and returns :class:`SymmetrySettings` object for natural amino
105 acids
106 """
107 symmetry_settings = SymmetrySettings()
108
109 # ASP
110 symmetry_settings.AddSymmetricCompound("ASP", [("OD1", "OD2")])
111
112 # GLU
113 symmetry_settings.AddSymmetricCompound("GLU", [("OE1", "OE2")])
114
115 # LEU
116 symmetry_settings.AddSymmetricCompound("LEU", [("CD1", "CD2")])
117
118 # VAL
119 symmetry_settings.AddSymmetricCompound("VAL", [("CG1", "CG2")])
120
121 # ARG
122 symmetry_settings.AddSymmetricCompound("ARG", [("NH1", "NH2")])
123
124 # PHE
125 symmetry_settings.AddSymmetricCompound(
126 "PHE", [("CD1", "CD2"), ("CE1", "CE2")]
127 )
128
129 # TYR
130 symmetry_settings.AddSymmetricCompound(
131 "TYR", [("CD1", "CD2"), ("CE1", "CE2")]
132 )
133
134 # nucleotides
135 nuc_names = ["A", "C", "G", "U", "DA", "DC", "DG", "DT"]
136 for nuc_name in nuc_names:
137 symmetry_settings.AddSymmetricCompound(
138 nuc_name, [("OP1","OP2")]
139 )
140
141 return symmetry_settings
142
143
145 """LDDT scorer object for a specific target
146
147 Sets up everything to score models of that target. LDDT (local distance
148 difference test) is defined as fraction of pairwise distances which exhibit
149 a difference < threshold when considering target and model. In case of
150 multiple thresholds, the average is returned. See
151
152 V. Mariani, M. Biasini, A. Barbato, T. Schwede, lDDT : A local
153 superposition-free score for comparing protein structures and models using
154 distance difference tests, Bioinformatics, 2013
155
156 :param target: The target
157 :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
158 :param compound_lib: Compound library from which a compound for each residue
159 is extracted based on its name. Uses
160 :func:`ost.conop.GetDefaultLib` if not given, raises
161 if this returns no valid compound library. Atoms
162 defined in the compound are searched in the residue and
163 build the reference for scoring. If the residue has
164 atoms with names ["A", "B", "C"] but the corresponding
165 compound only has ["A", "B"], "A" and "B" are
166 considered for scoring. If the residue has atoms
167 ["A", "B"] but the compound has ["A", "B", "C"], "C" is
168 considered missing and does not influence scoring, even
169 if present in the model.
170 :param custom_compounds: Custom compounds defining reference atoms. If
171 given, *custom_compounds* take precedent over
172 *compound_lib*.
173 :type custom_compounds: :class:`dict` with residue names (:class:`str`) as
174 key and :class:`CustomCompound` as value.
175 :type compound_lib: :class:`ost.conop.CompoundLib`
176 :param inclusion_radius: All pairwise distances < *inclusion_radius* are
177 considered for scoring
178 :type inclusion_radius: :class:`float`
179 :param sequence_separation: Only pairwise distances between atoms of
180 residues which are further apart than this
181 threshold are considered. Residue distance is
182 based on resnum. The default (0) considers all
183 pairwise distances except intra-residue
184 distances.
185 :type sequence_separation: :class:`int`
186 :param symmetry_settings: Define residues exhibiting internal symmetry, uses
187 :func:`GetDefaultSymmetrySettings` if not given.
188 :type symmetry_settings: :class:`SymmetrySettings`
189 :param seqres_mapping: Mapping of model residues at the scoring stage
190 happens with residue numbers defining their location
191 in a reference sequence (SEQRES) using one based
192 indexing. If the residue numbers in *target* don't
193 correspond to that SEQRES, you can specify the
194 mapping manually. You can provide a dictionary to
195 specify a reference sequence (SEQRES) for one or more
196 chain(s). Key: chain name, value: alignment
197 (seq1: SEQRES, seq2: sequence of residues in chain).
198 Example: The residues in a chain with name "A" have
199 sequence "YEAH" and residue numbers [42,43,44,45].
200 You can provide an alignment with seq1 "``HELLYEAH``"
201 and seq2 "``----YEAH``". "Y" gets assigned residue
202 number 5, "E" gets assigned 6 and so on no matter
203 what the original residue numbers were.
204 :type seqres_mapping: :class:`dict` (key: :class:`str`, value:
205 :class:`ost.seq.AlignmentHandle`)
206 :param bb_only: Only consider atoms with name "CA" in case of amino acids and
207 "C3'" for Nucleotides. this invalidates *compound_lib*.
208 Raises if any residue in *target* is not
209 `r.chem_class.IsPeptideLinking()` or
210 `r.chem_class.IsNucleotideLinking()`
211 :type bb_only: :class:`bool`
212 :raises: :class:`RuntimeError` if *target* contains compound which is not in
213 *compound_lib*, :class:`RuntimeError` if *symmetry_settings*
214 specifies symmetric atoms that are not present in the according
215 compound in *compound_lib*, :class:`RuntimeError` if
216 *seqres_mapping* is not provided and *target* contains residue
217 numbers with insertion codes or the residue numbers for each chain
218 are not monotonically increasing, :class:`RuntimeError` if
219 *seqres_mapping* is provided but an alignment is invalid
220 (seq1 contains gaps, mismatch in seq1/seq2, seq2 does not match
221 residues in corresponding chains).
222 """
224 self,
225 target,
226 compound_lib=None,
227 custom_compounds=None,
228 inclusion_radius=15,
229 sequence_separation=0,
230 symmetry_settings=None,
231 seqres_mapping=dict(),
232 bb_only=False
233 ):
234
235 self.target = target
236 self.inclusion_radius = inclusion_radius
237 self.sequence_separation = sequence_separation
238 if compound_lib is None:
239 compound_lib = conop.GetDefaultLib()
240 if compound_lib is None:
241 raise RuntimeError("No compound_lib given and conop.GetDefaultLib "
242 "returns no valid compound library")
243 self.compound_lib = compound_lib
244 self.custom_compounds = custom_compounds
245 if symmetry_settings is None:
247 else:
248 self.symmetry_settings = symmetry_settings
249
250 # whether to only consider atoms with name "CA" (amino acids) or C3'
251 # (nucleotides), invalidates *compound_lib*
252 self.bb_only=bb_only
253
254 # names of heavy atoms of each unique compound present in *target* as
255 # extracted from *compound_lib*, e.g.
256 # self.compound_anames["GLY"] = ["N", "CA", "C", "O"]
257 self.compound_anames = dict()
258
259 # stores symmetry information for those compounds as defined in
260 # *symmetry_settings*
262
263 # list of len(target.chains) containing all chain names in *target*
264 self.chain_names = list()
265
266 # list of len(target.residues) containing all compound names in *target*
267 self.compound_names = list()
268
269 # list of len(target.residues) defining start pos in internal reference
270 # positions for each residue
271 self.res_start_indices = list()
272
273 # list of len(target.residues) defining residue numbers in target
274 self.res_resnums = list()
275
276 # list of len(target.chains) defining start pos in internal reference
277 # positions for each chain
279
280 # list of len(target.chains) defining start pos in self.compound_names
281 # for each chain
283
284 # maps residues in *target* to indices in
285 # self.compound_names/self.res_start_indices. A residue gets identified
286 # by a tuple (first element: chain name, second element: residue number,
287 # residue number is either the actual residue number in *target* or
288 # given by *seqres_mapping*)
289 self.res_mapper = dict()
290
291 # number of atoms as specified in compounds. not all are necessarily
292 # covered by structure
293 self.n_atoms = None
294
295 # stores an index for each AtomHandle in *target*
296 # (atom hashcode => index)
297 self.atom_indices = dict()
298
299 # store indices of all atoms that have symmetry properties
300 self.symmetric_atoms = set()
301
302 # the actual target positions in a numpy array of shape (self.n_atoms,3)
303 self.positions = None
304
305 # setup members defined above
307 self.symmetry_settings, seqres_mapping, self.bb_only)
308
309 # distance related members are lazily computed as they're affected
310 # by different flavours of LDDT (e.g. LDDT including inter-chain
311 # contacts or not etc.)
312
313 # stores for each atom the other atoms within inclusion_radius
314 self._ref_indices = None
315 # the corresponding distances
316 self._ref_distances = None
317
318 # The following lists will be sparsely populated. We keep for each
319 # symmetry related atom the distances towards all atoms which are NOT
320 # affected by symmetry. So we can evaluate two symmetric versions
321 # against the fixed stuff later on and select the better scoring one.
324
325 # exactly the same as above but without interchain contacts
326 # => single-chain (sc)
327 self._ref_indices_sc = None
331
332 # exactly the same as above but without intrachain contacts
333 # => inter-chain (ic)
334 self._ref_indices_ic = None
338
339 # input parameter checking
341
342 @property
343 def ref_indices(self):
344 if self._ref_indices is None:
345 self._ref_indices, self._ref_distances = \
346 lDDTScorer._SetupDistances(self.target, self.n_atoms,
347 self.atom_indices,
348 self.inclusion_radius)
349 return self._ref_indices
350
351 @property
352 def ref_distances(self):
353 if self._ref_distances is None:
354 self._ref_indices, self._ref_distances = \
355 lDDTScorer._SetupDistances(self.target, self.n_atoms,
356 self.atom_indices,
357 self.inclusion_radius)
358 return self._ref_distances
359
360 @property
362 if self._sym_ref_indices is None:
364 lDDTScorer._NonSymDistances(self.n_atoms, self.symmetric_atoms,
366 return self._sym_ref_indices
367
368 @property
370 if self._sym_ref_distances is None:
372 lDDTScorer._NonSymDistances(self.n_atoms, self.symmetric_atoms,
374 return self._sym_ref_distances
375
376 @property
377 def ref_indices_sc(self):
378 if self._ref_indices_sc is None:
380 lDDTScorer._SetupDistancesSC(self.n_atoms,
384 return self._ref_indices_sc
385
386 @property
388 if self._ref_distances_sc is None:
390 lDDTScorer._SetupDistancesSC(self.n_atoms,
394 return self._ref_distances_sc
395
396 @property
398 if self._sym_ref_indices_sc is None:
400 lDDTScorer._NonSymDistances(self.n_atoms,
401 self.symmetric_atoms,
404 return self._sym_ref_indices_sc
405
406 @property
408 if self._sym_ref_distances_sc is None:
410 lDDTScorer._NonSymDistances(self.n_atoms,
411 self.symmetric_atoms,
414 return self._sym_ref_distances_sc
415
416 @property
417 def ref_indices_ic(self):
418 if self._ref_indices_ic is None:
420 lDDTScorer._SetupDistancesIC(self.n_atoms,
424 return self._ref_indices_ic
425
426 @property
428 if self._ref_distances_ic is None:
430 lDDTScorer._SetupDistancesIC(self.n_atoms,
434 return self._ref_distances_ic
435
436 @property
438 if self._sym_ref_indices_ic is None:
440 lDDTScorer._NonSymDistances(self.n_atoms,
441 self.symmetric_atoms,
444 return self._sym_ref_indices_ic
445
446 @property
448 if self._sym_ref_distances_ic is None:
450 lDDTScorer._NonSymDistances(self.n_atoms,
451 self.symmetric_atoms,
454 return self._sym_ref_distances_ic
455
456 def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
457 local_lddt_prop=None, local_contact_prop=None,
458 chain_mapping=None, no_interchain=False,
459 no_intrachain=False, penalize_extra_chains=False,
460 residue_mapping=None, return_dist_test=False,
461 check_resnames=True, add_mdl_contacts=False,
462 interaction_data=None, set_atom_props=False):
463 """Computes LDDT of *model* - globally and per-residue
464
465 :param model: Model to be scored - models are preferably scored upon
466 performing stereo-chemistry checks in order to punish for
467 non-sensical irregularities. This must be done separately
468 as a pre-processing step. Target contacts that are not
469 covered by *model* are considered not conserved, thus
470 decreasing LDDT score. This also includes missing model
471 chains or model chains for which no mapping is provided in
472 *chain_mapping*.
473 :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
474 :param thresholds: Thresholds of distance differences to be considered
475 as correct - see docs in constructor for more info.
476 default: [0.5, 1.0, 2.0, 4.0]
477 :type thresholds: :class:`list` of :class:`floats`
478 :param local_lddt_prop: If set, per-residue scores will be assigned as
479 generic float property of that name
480 :type local_lddt_prop: :class:`str`
481 :param local_contact_prop: If set, number of expected contacts as well
482 as number of conserved contacts will be
483 assigned as generic int property.
484 Excected contacts will be set as
485 <local_contact_prop>_exp, conserved contacts
486 as <local_contact_prop>_cons. Values
487 are summed over all thresholds.
488 :type local_contact_prop: :class:`str`
489 :param chain_mapping: Mapping of model chains (key) onto target chains
490 (value). This is required if target or model have
491 more than one chain.
492 :type chain_mapping: :class:`dict` with :class:`str` as keys/values
493 :param no_interchain: Whether to exclude interchain contacts
494 :type no_interchain: :class:`bool`
495 :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
496 consider interface related contacts)
497 :type no_intrachain: :class:`bool`
498 :param penalize_extra_chains: Whether to include a fixed penalty for
499 additional chains in the model that are
500 not mapped to the target. ONLY AFFECTS
501 RETURNED GLOBAL SCORE. In detail: adds the
502 number of intra-chain contacts of each
503 extra chain to the expected contacts, thus
504 adding a penalty.
505 :type penalize_extra_chains: :class:`bool`
506 :param residue_mapping: By default, residue mapping is based on residue
507 numbers. That means, a model chain and the
508 respective target chain map to the same
509 underlying reference sequence (SEQRES).
510 Alternatively, you can specify one or
511 several alignment(s) between model and target
512 chains by providing a dictionary. key: Name
513 of chain in model (respective target chain is
514 extracted from *chain_mapping*),
515 value: Alignment with first sequence
516 corresponding to target chain and second
517 sequence to model chain. There is NO reference
518 sequence involved, so the two sequences MUST
519 exactly match the actual residues observed in
520 the respective target/model chains (ATOMSEQ).
521 :type residue_mapping: :class:`dict` with key: :class:`str`,
522 value: :class:`ost.seq.AlignmentHandle`
523 :param return_dist_test: Whether to additionally return the underlying
524 per-residue data for the distance difference
525 test. Adds five objects to the return tuple.
526 First: Number of total contacts summed over all
527 thresholds
528 Second: Number of conserved contacts summed
529 over all thresholds
530 Third: list with length of scored residues.
531 Contains indices referring to model.residues.
532 Fourth: numpy array of size
533 len(scored_residues) containing the number of
534 total contacts,
535 Fifth: numpy matrix of shape
536 (len(scored_residues), len(thresholds))
537 specifying how many for each threshold are
538 conserved.
539 :param check_resnames: On by default. Enforces residue name matches
540 between mapped model and target residues.
541 :type check_resnames: :class:`bool`
542 :param add_mdl_contacts: Adds model contacts - Only using contacts that
543 are within a certain distance threshold in the
544 target does not penalize for added model
545 contacts. If set to True, this flag will also
546 consider target contacts that are within the
547 specified distance threshold in the model but
548 not necessarily in the target. No contact will
549 be added if the respective atom pair is not
550 resolved in the target.
551 :type add_mdl_contacts: :class:`bool`
552 :param interaction_data: Pro param - don't use
553 :type interaction_data: :class:`tuple`
554 :param set_atom_props: If True, sets generic properties on a per atom
555 level if *local_lddt_prop*/*local_contact_prop*
556 are set as well.
557 In other words: this is the only way you can
558 get per-atom LDDT values.
559 :type set_atom_props: :class:`bool`
560
561 :returns: global and per-residue LDDT scores as a tuple -
562 first element is global LDDT score (None if *target* has no
563 contacts) and second element a list of per-residue scores with
564 length len(*model*.residues). None is assigned to residues that
565 are not covered by target. If a residue is covered but has no
566 contacts in *target*, 0.0 is assigned.
567 """
568 if chain_mapping is None:
569 if len(self.chain_names) > 1 or len(model.chains) > 1:
570 raise NotImplementedError("Must provide chain mapping if "
571 "target or model have > 1 chains.")
572 chain_mapping = {model.chains[0].GetName(): self.chain_names[0]}
573 else:
574 # check whether chains specified in mapping exist
575 for model_chain, target_chain in chain_mapping.items():
576 if target_chain not in self.chain_names:
577 raise RuntimeError(f"Target chain specified in "
578 f"chain_mapping ({target_chain}) does "
579 f"not exist. Target has chains: "
580 f"{self.chain_names}")
581 ch = model.FindChain(model_chain)
582 if not ch.IsValid():
583 raise RuntimeError(f"Model chain specified in "
584 f"chain_mapping ({model_chain}) does "
585 f"not exist. Model has chains: "
586 f"{[c.GetName() for c in model.chains]}")
587
588 # data objects defining model data - see _ProcessModel for rough
589 # description
590 pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
591 res_indices, ref_res_indices, symmetries = \
592 self._ProcessModel(model, chain_mapping,
593 residue_mapping = residue_mapping,
594 nirvana_dist = self.inclusion_radius + max(thresholds),
595 check_resnames = check_resnames)
596
597 if no_interchain and no_intrachain:
598 raise RuntimeError("no_interchain and no_intrachain flags are "
599 "mutually exclusive")
600
601 sym_ref_indices = None
602 sym_ref_distances = None
603 ref_indices = None
604 ref_distances = None
605
606 if interaction_data is None:
607 if no_interchain:
608 sym_ref_indices = self.sym_ref_indices_sc
609 sym_ref_distances = self.sym_ref_distances_sc
610 ref_indices = self.ref_indices_scref_indices_sc
611 ref_distances = self.ref_distances_scref_distances_sc
612 elif no_intrachain:
613 sym_ref_indices = self.sym_ref_indices_ic
614 sym_ref_distances = self.sym_ref_distances_ic
615 ref_indices = self.ref_indices_icref_indices_ic
616 ref_distances = self.ref_distances_icref_distances_ic
617 else:
618 sym_ref_indices = self.sym_ref_indices
619 sym_ref_distances = self.sym_ref_distances
620 ref_indices = self.ref_indicesref_indices
621 ref_distances = self.ref_distancesref_distances
622
623 if add_mdl_contacts:
624 ref_indices, ref_distances = \
625 self._AddMdlContacts(model, res_atom_indices, res_atom_hashes,
626 ref_indices, ref_distances,
627 no_interchain, no_intrachain)
628 # recompute symmetry related indices/distances
629 sym_ref_indices, sym_ref_distances = \
630 lDDTScorer._NonSymDistances(self.n_atoms, self.symmetric_atoms,
631 ref_indices, ref_distances)
632 else:
633 sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \
634 interaction_data
635
636 self._ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
637 sym_ref_distances)
638
639 atom_indices = list(itertools.chain.from_iterable(res_atom_indices))
640
641 per_atom_exp = np.asarray([self._GetNExp(i, ref_indices)
642 for i in atom_indices], dtype=np.int32)
643 per_res_exp = np.asarray([self._GetNExp(res_ref_atom_indices[idx],
644 ref_indices) for idx in range(len(res_indices))], dtype=np.int32)
645
646 per_atom_conserved = self._EvalAtoms(pos, atom_indices, thresholds,
647 ref_indices, ref_distances)
648 per_res_conserved = np.zeros((len(res_atom_indices), len(thresholds)),
649 dtype=np.int32)
650 start_idx = 0
651 for r_idx in range(len(res_atom_indices)):
652 end_idx = start_idx + len(res_atom_indices[r_idx])
653 per_res_conserved[r_idx] = np.sum(per_atom_conserved[start_idx:end_idx,:],
654 axis=0)
655 start_idx = end_idx
656
657 n_thresh = len(thresholds)
658
659 # do per-residue scores
660 per_res_lDDT = [None] * model.GetResidueCount()
661 for idx in range(len(res_indices)):
662 n_exp = n_thresh * per_res_exp[idx]
663 if n_exp > 0:
664 score = np.sum(per_res_conserved[idx,:]) / n_exp
665 per_res_lDDT[res_indices[idx]] = score
666 else:
667 per_res_lDDT[res_indices[idx]] = 0.0
668
669 # do full model score
670 n_distances = sum([len(x) for x in ref_indices])
671 if penalize_extra_chains:
672 n_distances += self._GetExtraModelChainPenalty(model, chain_mapping)
673
674 lDDT_tot = int(n_thresh * n_distances)
675 lDDT_cons = int(np.sum(per_res_conserved))
676 lDDT = None
677 if lDDT_tot > 0:
678 lDDT = float(lDDT_cons) / lDDT_tot
679
680 # set properties if necessary
681 if local_lddt_prop:
682 residues = model.residues
683 for idx in res_indices:
684 residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx])
685
686 if local_contact_prop:
687 residues = model.residues
688 exp_prop = local_contact_prop + "_exp"
689 conserved_prop = local_contact_prop + "_cons"
690
691 for i, r_idx in enumerate(res_indices):
692 residues[r_idx].SetIntProp(exp_prop,
693 n_thresh * int(per_res_exp[i]))
694 residues[r_idx].SetIntProp(conserved_prop,
695 int(np.sum(per_res_conserved[i,:])))
696
697 if set_atom_props and (local_lddt_prop or local_contact_prop):
698 atom_list = list()
699 residues = model.residues
700 for i, indices in enumerate(res_atom_indices):
701 r = residues[res_indices[i]]
702 r_idx = ref_res_indices[i]
703 res_start_idx = self.res_start_indices[r_idx]
704 anames = self.compound_anames[self.compound_names[r_idx]]
705 for a_i in indices:
706 a = r.FindAtom(anames[a_i - res_start_idx])
707 assert(a.IsValid())
708 atom_list.append(a)
709
710 summed_per_atom_conserved = per_atom_conserved.sum(axis=1)
711 if local_lddt_prop:
712 # the only place where actually need to compute per-atom LDDT
713 # scores
714 for a_idx in range(len(atom_list)):
715 if per_atom_exp[a_idx] != 0:
716 tmp = summed_per_atom_conserved[a_idx] / per_atom_exp[a_idx]
717 tmp = tmp / n_thresh
718 atom_list[a_idx].SetFloatProp(local_lddt_prop, tmp)
719
720 if local_contact_prop:
721 conserved_prop = local_contact_prop + "_cons"
722 exp_prop = local_contact_prop + "_exp"
723 for a_idx in range(len(atom_list)):
724 # do number of conserved contacts
725 tmp = summed_per_atom_conserved[a_idx]
726 atom_list[a_idx].SetIntProp(conserved_prop, tmp)
727 # do number of expected contacts
728 tmp = per_atom_exp[a_idx] * n_thresh
729 atom_list[a_idx].SetIntProp(exp_prop, tmp)
730
731 if return_dist_test:
732 return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
733 per_res_exp, per_res_conserved
734 else:
735 return lDDT, per_res_lDDT
736
737 def DRMSD(self, model, dist_cap = 5,
738 chain_mapping=None, no_interchain=False,
739 no_intrachain=False, residue_mapping=None,
740 check_resnames=True, add_mdl_contacts=False,
741 interaction_data=None):
742 """ DRMSD of *model* - globally and per-residue
743
744 Very similar to LDDT as we operate on distance differences for all
745 interatomic distances within the same inclusion radius as in LDDT.
746 DRMSD is the distance rmsd, i.e. the RMSD of distance differences.
747 Distance differences are capped at *dist_cap* which is also the default
748 value for missing distances.
749
750 :param model: Model to be scored - models are preferably scored upon
751 performing stereo-chemistry checks in order to punish for
752 non-sensical irregularities. This must be done separately
753 as a pre-processing step. Target contacts that are not
754 covered by *model* are considered not conserved, thus
755 increasing DRMSD score. This also includes missing model
756 chains or model chains for which no mapping is provided in
757 *chain_mapping*.
758 :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
759 :param dist_cap: Cap for distance differences.
760 :type dist_cap: :class:`float`
761 :param chain_mapping: Mapping of model chains (key) onto target chains
762 (value). This is required if target or model have
763 more than one chain.
764 :type chain_mapping: :class:`dict` with :class:`str` as keys/values
765 :param no_interchain: Whether to exclude interchain contacts
766 :type no_interchain: :class:`bool`
767 :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
768 consider interface related contacts)
769 :type no_intrachain: :class:`bool`
770 :param residue_mapping: By default, residue mapping is based on residue
771 numbers. That means, a model chain and the
772 respective target chain map to the same
773 underlying reference sequence (SEQRES).
774 Alternatively, you can specify one or
775 several alignment(s) between model and target
776 chains by providing a dictionary. key: Name
777 of chain in model (respective target chain is
778 extracted from *chain_mapping*),
779 value: Alignment with first sequence
780 corresponding to target chain and second
781 sequence to model chain. There is NO reference
782 sequence involved, so the two sequences MUST
783 exactly match the actual residues observed in
784 the respective target/model chains (ATOMSEQ).
785 :type residue_mapping: :class:`dict` with key: :class:`str`,
786 value: :class:`ost.seq.AlignmentHandle`
787 :param check_resnames: On by default. Enforces residue name matches
788 between mapped model and target residues.
789 :type check_resnames: :class:`bool`
790 :param add_mdl_contacts: Adds model contacts - Only using contacts that
791 are within a certain distance threshold in the
792 target does not penalize for added model
793 contacts. If set to True, this flag will also
794 consider target contacts that are within the
795 specified distance threshold in the model but
796 not necessarily in the target. No contact will
797 be added if the respective atom pair is not
798 resolved in the target.
799 :type add_mdl_contacts: :class:`bool`
800 :param interaction_data: Pro param - don't use
801 :type interaction_data: :class:`tuple`
802
803 :returns: global and per-residue DRMSD scores as a tuple -
804 first element is global DRMSD score (None if *target* has no
805 contacts) and second element a list of per-residue scores with
806 length len(*model*.residues). None is assigned to residues that
807 are not covered by target. If a residue is covered but has no
808 contacts in *target*, None is assigned.
809 """
810 if chain_mapping is None:
811 if len(self.chain_names) > 1 or len(model.chains) > 1:
812 raise NotImplementedError("Must provide chain mapping if "
813 "target or model have > 1 chains.")
814 chain_mapping = {model.chains[0].GetName(): self.chain_names[0]}
815 else:
816 # check whether chains specified in mapping exist
817 for model_chain, target_chain in chain_mapping.items():
818 if target_chain not in self.chain_names:
819 raise RuntimeError(f"Target chain specified in "
820 f"chain_mapping ({target_chain}) does "
821 f"not exist. Target has chains: "
822 f"{self.chain_names}")
823 ch = model.FindChain(model_chain)
824 if not ch.IsValid():
825 raise RuntimeError(f"Model chain specified in "
826 f"chain_mapping ({model_chain}) does "
827 f"not exist. Model has chains: "
828 f"{[c.GetName() for c in model.chains]}")
829
830 # data objects defining model data - see _ProcessModel for rough
831 # description
832 pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
833 res_indices, ref_res_indices, symmetries = \
834 self._ProcessModel(model, chain_mapping,
835 residue_mapping = residue_mapping,
836 nirvana_dist = self.inclusion_radius + dist_cap,
837 check_resnames = check_resnames)
838
839 if no_interchain and no_intrachain:
840 raise RuntimeError("no_interchain and no_intrachain flags are "
841 "mutually exclusive")
842
843 sym_ref_indices = None
844 sym_ref_distances = None
845 ref_indices = None
846 ref_distances = None
847
848 if interaction_data is None:
849 if no_interchain:
850 sym_ref_indices = self.sym_ref_indices_sc
851 sym_ref_distances = self.sym_ref_distances_sc
852 ref_indices = self.ref_indices_scref_indices_sc
853 ref_distances = self.ref_distances_scref_distances_sc
854 elif no_intrachain:
855 sym_ref_indices = self.sym_ref_indices_ic
856 sym_ref_distances = self.sym_ref_distances_ic
857 ref_indices = self.ref_indices_icref_indices_ic
858 ref_distances = self.ref_distances_icref_distances_ic
859 else:
860 sym_ref_indices = self.sym_ref_indices
861 sym_ref_distances = self.sym_ref_distances
862 ref_indices = self.ref_indicesref_indices
863 ref_distances = self.ref_distancesref_distances
864
865 if add_mdl_contacts:
866 ref_indices, ref_distances = \
867 self._AddMdlContacts(model, res_atom_indices, res_atom_hashes,
868 ref_indices, ref_distances,
869 no_interchain, no_intrachain)
870 # recompute symmetry related indices/distances
871 sym_ref_indices, sym_ref_distances = \
872 lDDTScorer._NonSymDistances(self.n_atoms, self.symmetric_atoms,
873 ref_indices, ref_distances)
874 else:
875 sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \
876 interaction_data
877
878 self._ResolveSymmetriesSSD(pos, dist_cap, symmetries, sym_ref_indices,
879 sym_ref_distances)
880
881 atom_indices = list(itertools.chain.from_iterable(res_atom_indices))
882
883 per_atom_exp = np.asarray([self._GetNExp(i, ref_indices)
884 for i in atom_indices], dtype=np.int32)
885 per_res_exp = np.asarray([self._GetNExp(res_ref_atom_indices[idx],
886 ref_indices) for idx in range(len(res_indices))], dtype=np.int32)
887 per_atom_ssd = self._EvalAtomsSSD(pos, atom_indices, dist_cap,
888 ref_indices, ref_distances)
889
890 # do per residue scores
891 start_idx = 0
892 per_res_drmsd = [None] * model.GetResidueCount()
893 for r_idx in range(len(res_atom_indices)):
894 end_idx = start_idx + len(res_atom_indices[r_idx])
895 n_tot = per_res_exp[r_idx]
896 if n_tot > 0:
897 ssd = np.sum(per_atom_ssd[start_idx:end_idx])
898 # add penalties from distances involving atoms that are not
899 # present in the model
900 n_missing = n_tot - np.sum(per_atom_exp[start_idx:end_idx])
901 ssd += n_missing*dist_cap*dist_cap
902 per_res_drmsd[res_indices[r_idx]] = np.sqrt(ssd/n_tot)
903 start_idx = end_idx
904
905 # do full model score
906 drmsd = None
907 n_tot = sum([len(x) for x in ref_indices])
908 if n_tot > 0:
909 ssd = np.sum(per_atom_ssd)
910 # add penalties from distances involving atoms that are not
911 # present in the model
912 n_missing = n_tot - np.sum(per_atom_exp)
913 ssd += (dist_cap*dist_cap*n_missing)
914 drmsd = np.sqrt(ssd/n_tot)
915
916 return drmsd, per_res_drmsd
917
918 def GetNChainContacts(self, target_chain, no_interchain=False):
919 """Returns number of contacts expected for a certain chain in *target*
920
921 :param target_chain: Chain in *target* for which you want the number
922 of expected contacts
923 :type target_chain: :class:`str`
924 :param no_interchain: Whether to exclude interchain contacts
925 :type no_interchain: :class:`bool`
926 :raises: :class:`RuntimeError` if specified chain doesnt exist
927 """
928 if target_chain not in self.chain_names:
929 raise RuntimeError(f"Specified chain name ({target_chain}) not in "
930 f"target")
931 ch_idx = self.chain_names.index(target_chain)
932 s = self.chain_start_indices[ch_idx]
933 e = self.n_atoms
934 if ch_idx + 1 < len(self.chain_names):
935 e = self.chain_start_indices[ch_idx+1]
936 if no_interchain:
937 return self._GetNExp(list(range(s, e)), self.ref_indices_scref_indices_sc)
938 else:
939 return self._GetNExp(list(range(s, e)), self.ref_indicesref_indices)
940
941 def _ProcessModel(self, model, chain_mapping, residue_mapping = None,
942 nirvana_dist = 100,
943 check_resnames = True):
944 """ Helper that generates data structures from model
945 """
946
947 # initialize positions with values far in nirvana. If a position is not
948 # set, it should be far away from any position in model.
949 max_pos = model.bounds.GetMax()
950 max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
951 max_coordinate += 42 * nirvana_dist
952 pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate
953
954 # for each scored residue in model a list of indices describing the
955 # atoms from the reference that should be there
956 res_ref_atom_indices = list()
957
958 # for each scored residue in model a list of indices of atoms that are
959 # actually there
960 res_atom_indices = list()
961
962 # and the respective hash codes
963 # this is required if add_mdl_contacts is set to True
964 res_atom_hashes = list()
965
966 # indices of the scored residues
967 res_indices = list()
968
969 # respective residue indices in reference
970 ref_res_indices = list()
971
972 # Will contain one element per symmetry group
973 symmetries = list()
974
975 current_model_res_idx = -1
976 for ch in model.chains:
977 model_ch_name = ch.GetName()
978 if model_ch_name not in chain_mapping:
979 current_model_res_idx += len(ch.residues)
980 continue # additional model chain which is not mapped
981 target_ch_name = chain_mapping[model_ch_name]
982
983 rnums = self._GetChainRNums(ch, residue_mapping, model_ch_name,
984 target_ch_name)
985
986 for r, rnum in zip(ch.residues, rnums):
987 current_model_res_idx += 1
988 res_mapper_key = (target_ch_name, rnum)
989 if res_mapper_key not in self.res_mapper:
990 continue
991 r_idx = self.res_mapper[res_mapper_key]
992 if check_resnames and r.name != self.compound_names[r_idx]:
993 raise RuntimeError(
994 f"Residue name mismatch for {r}, "
995 f" expect {self.compound_names[r_idx]}"
996 )
997 res_start_idx = self.res_start_indices[r_idx]
998 rname = self.compound_names[r_idx]
999 anames = self.compound_anames[rname]
1000 atoms = [r.FindAtom(aname) for aname in anames]
1001 res_ref_atom_indices.append(
1002 list(range(res_start_idx, res_start_idx + len(anames)))
1003 )
1004 res_atom_indices.append(list())
1005 res_atom_hashes.append(list())
1006 res_indices.append(current_model_res_idx)
1007 ref_res_indices.append(r_idx)
1008 for a_idx, a in enumerate(atoms):
1009 if a.IsValid():
1010 p = a.GetPos()
1011 pos[res_start_idx + a_idx][0] = p[0]
1012 pos[res_start_idx + a_idx][1] = p[1]
1013 pos[res_start_idx + a_idx][2] = p[2]
1014 res_atom_indices[-1].append(res_start_idx + a_idx)
1015 res_atom_hashes[-1].append(a.handle.GetHashCode())
1016 if rname in self.compound_symmetric_atoms:
1017 sym_indices = list()
1018 for sym_tuple in self.compound_symmetric_atoms[rname]:
1019 a_one = atoms[sym_tuple[0]]
1020 a_two = atoms[sym_tuple[1]]
1021 if a_one.IsValid() and a_two.IsValid():
1022 sym_indices.append(
1023 (
1024 res_start_idx + sym_tuple[0],
1025 res_start_idx + sym_tuple[1],
1026 )
1027 )
1028 if len(sym_indices) > 0:
1029 symmetries.append(sym_indices)
1030
1031 return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes,
1032 res_indices, ref_res_indices, symmetries)
1033
1034
1035 def _GetExtraModelChainPenalty(self, model, chain_mapping):
1036 """Counts n distances in extra model chains to be added as penalty
1037 """
1038 penalty = 0
1039 for chain in model.chains:
1040 ch_name = chain.GetName()
1041 if ch_name not in chain_mapping:
1042 sm = self.symmetry_settings
1043 mdl_sel = model.Select(f"cname={mol.QueryQuoteName(ch_name)}")
1044 dummy_scorer = lDDTScorer(mdl_sel, self.compound_lib,
1045 symmetry_settings = sm,
1046 inclusion_radius = self.inclusion_radius,
1047 bb_only = self.bb_only)
1048 penalty += sum([len(x) for x in dummy_scorer.ref_indices])
1049 return penalty
1050
1051 def _GetChainRNums(self, ch, residue_mapping, model_ch_name,
1052 target_ch_name):
1053 """Map residues in model chain to target residues
1054
1055 There are two options: one is simply using residue numbers,
1056 the other is a custom mapping as given in *residue_mapping*
1057 """
1058 if residue_mapping and model_ch_name in residue_mapping:
1059 # extract residue numbers from target chain
1060 ch_idx = self.chain_names.index(target_ch_name)
1061 start_idx = self.chain_res_start_indices[ch_idx]
1062 if ch_idx < len(self.chain_names) - 1:
1063 end_idx = self.chain_res_start_indices[ch_idx+1]
1064 else:
1065 end_idx = len(self.compound_names)
1066 target_rnums = self.res_resnums[start_idx:end_idx]
1067 # get sequences from alignment and do consistency checks
1068 target_seq = residue_mapping[model_ch_name].GetSequence(0)
1069 model_seq = residue_mapping[model_ch_name].GetSequence(1)
1070 if len(target_seq.GetGaplessString()) != len(target_rnums):
1071 raise RuntimeError(f"Try to perform residue mapping for "
1072 f"model chain {model_ch_name} which "
1073 f"maps to {target_ch_name} in target. "
1074 f"Target sequence in alignment suggests "
1075 f"{len(target_seq.GetGaplessString())} "
1076 f"residues but {len(target_rnums)} are "
1077 f"expected.")
1078 if len(model_seq.GetGaplessString()) != len(ch.residues):
1079 raise RuntimeError(f"Try to perform residue mapping for "
1080 f"model chain {model_ch_name} which "
1081 f"maps to {target_ch_name} in target. "
1082 f"Model sequence in alignment suggests "
1083 f"{len(model_seq.GetGaplessString())} "
1084 f"residues but {len(ch.residues)} are "
1085 f"expected.")
1086 rnums = list()
1087 target_idx = -1
1088 for col in residue_mapping[model_ch_name]:
1089 if col[0] != '-':
1090 target_idx += 1
1091 # handle match
1092 if col[0] != '-' and col[1] != '-':
1093 rnums.append(target_rnums[target_idx])
1094 # insertion in model adds None to rnum
1095 if col[0] == '-' and col[1] != '-':
1096 rnums.append(None)
1097 else:
1098 rnums = [r.GetNumber() for r in ch.residues]
1099
1100 return rnums
1101
1102
1103 def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings,
1104 seqres_mapping, bb_only):
1105 """Sets target related lDDTScorer members defined in constructor
1106
1107 No distance related members - see _SetupDistances
1108 """
1109 residue_numbers = self._GetTargetResidueNumbers(self.target,
1110 seqres_mapping)
1111 current_idx = 0
1112 positions = list()
1113 for chain in self.target.chains:
1114 ch_name = chain.GetName()
1115 self.chain_names.append(ch_name)
1116 self.chain_start_indices.append(current_idx)
1117 self.chain_res_start_indices.append(len(self.compound_names))
1118 for r, rnum in zip(chain.residues, residue_numbers[ch_name]):
1119 if r.name not in self.compound_anames:
1120 # sets compound info in self.compound_anames and
1121 # self.compound_symmetric_atoms
1122 self._SetupCompound(r, compound_lib, custom_compounds,
1123 symmetry_settings, bb_only)
1124
1125 self.res_start_indices.append(current_idx)
1126 self.res_mapper[(ch_name, rnum)] = len(self.compound_names)
1127 self.compound_names.append(r.name)
1128 self.res_resnums.append(rnum)
1129
1130 atoms = [r.FindAtom(an) for an in self.compound_anames[r.name]]
1131 for a in atoms:
1132 if a.IsValid():
1133 self.atom_indices[a.handle.GetHashCode()] = current_idx
1134 p = a.GetPos()
1135 positions.append(np.asarray([p[0], p[1], p[2]],
1136 dtype=np.float32))
1137 else:
1138 positions.append(np.zeros(3, dtype=np.float32))
1139 current_idx += 1
1140
1141 if r.name in self.compound_symmetric_atoms:
1142 for sym_tuple in self.compound_symmetric_atoms[r.name]:
1143 for a_idx in sym_tuple:
1144 a = atoms[a_idx]
1145 if a.IsValid():
1146 hashcode = a.handle.GetHashCode()
1147 self.symmetric_atoms.add(
1148 self.atom_indices[hashcode]
1149 )
1150 self.positions = np.vstack(positions)
1151 self.n_atoms = current_idx
1152
1153 def _GetTargetResidueNumbers(self, target, seqres_mapping):
1154 """Returns residue numbers for each chain in target as dict
1155
1156 They're either directly extracted from the raw residue number
1157 from the structure or from user provided alignments
1158 """
1159 residue_numbers = dict()
1160 for ch in target.chains:
1161 ch_name = ch.GetName()
1162 rnums = list()
1163 if ch_name in seqres_mapping:
1164 seqres = seqres_mapping[ch_name].GetSequence(0).GetString()
1165 atomseq = seqres_mapping[ch_name].GetSequence(1).GetString()
1166 # SEQRES must not contain gaps
1167 if "-" in seqres:
1168 raise RuntimeError(
1169 "SEQRES in seqres_mapping must not " "contain gaps"
1170 )
1171 atomseq_from_chain = [r.one_letter_code for r in ch.residues]
1172 if atomseq.replace("-", "") != atomseq_from_chain:
1173 raise RuntimeError(
1174 "ATOMSEQ in seqres_mapping must match "
1175 "raw sequence extracted from chain "
1176 "residues"
1177 )
1178 rnum = 0
1179 for seqres_olc, atomseq_olc in zip(seqres, atomseq):
1180 if seqres_olc != "-":
1181 rnum += 1
1182 if atomseq_olc != "-":
1183 if seqres_olc != atomseq_olc:
1184 raise RuntimeError(
1185 f"Residue with number {rnum} in "
1186 f"chain {ch_name} has SEQRES "
1187 f"ATOMSEQ mismatch"
1188 )
1189 rnums.append(mol.ResNum(rnum))
1190 else:
1191 rnums = [r.GetNumber() for r in ch.residues]
1192 assert len(rnums) == len(ch.residues)
1193 residue_numbers[ch_name] = rnums
1194 return residue_numbers
1195
1196 def _SetupCompound(self, r, compound_lib, custom_compounds,
1197 symmetry_settings, bb_only):
1198 """fill self.compound_anames/self.compound_symmetric_atoms
1199 """
1200 if bb_only:
1201 # throw away compound_lib info
1202 if r.chem_class.IsPeptideLinking():
1203 self.compound_anames[r.name] = ["CA"]
1204 elif r.chem_class.IsNucleotideLinking():
1205 self.compound_anames[r.name] = ["C3'"]
1206 else:
1207 raise RuntimeError(f"Only support amino acids and nucleotides "
1208 f"if bb_only is True, failed with {str(r)}")
1209 self.compound_symmetric_atoms[r.name] = list()
1210 else:
1211 atom_names = list()
1212 symmetric_atoms = list()
1213 if custom_compounds is not None and r.GetName() in custom_compounds:
1214 atom_names = list(custom_compounds[r.GetName()].atom_names)
1215 else:
1216 compound = compound_lib.FindCompound(r.name)
1217 if compound is None:
1218 raise RuntimeError(f"no entry for {r} in compound_lib")
1219 for atom_spec in compound.GetAtomSpecs():
1220 if atom_spec.element not in ["H", "D"]:
1221 atom_names.append(atom_spec.name)
1222 if r.name in symmetry_settings.symmetric_compounds:
1223 for pair in symmetry_settings.symmetric_compounds[r.name]:
1224 try:
1225 a = atom_names.index(pair[0])
1226 b = atom_names.index(pair[1])
1227 except:
1228 msg = f"Could not find symmetric atoms "
1229 msg += f"({pair[0]}, {pair[1]}) for {r.name} "
1230 msg += f"as specified in SymmetrySettings in "
1231 msg += f"compound from component dictionary. "
1232 msg += f"Atoms in compound: {atom_names}"
1233 raise RuntimeError(msg)
1234 symmetric_atoms.append((a, b))
1235 self.compound_anames[r.name] = atom_names
1236 if len(symmetric_atoms) > 0:
1237 self.compound_symmetric_atoms[r.name] = symmetric_atoms
1238
1239 def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes,
1240 ref_indices, ref_distances, no_interchain,
1241 no_intrachain):
1242
1243 # buildup an index map for mdl atoms that are also present in target
1244 in_target = np.zeros(self.n_atoms, dtype=bool)
1245 for i in self.atom_indices.values():
1246 in_target[i] = True
1247 mdl_atom_indices = dict()
1248 for at_indices, at_hashes in zip(res_atom_indices, res_atom_hashes):
1249 for i, h in zip(at_indices, at_hashes):
1250 if in_target[i]:
1251 mdl_atom_indices[h] = i
1252
1253 # get contacts for mdl - the contacts are only from atom pairs that
1254 # are also present in target, as we only provide the respective
1255 # hashes in mdl_atom_indices
1256 mdl_ref_indices, mdl_ref_distances = \
1257 lDDTScorer._SetupDistances(model, self.n_atoms, mdl_atom_indices,
1258 self.inclusion_radius)
1259 if no_interchain:
1260 mdl_ref_indices, mdl_ref_distances = \
1261 lDDTScorer._SetupDistancesSC(self.n_atoms,
1263 mdl_ref_indices,
1264 mdl_ref_distances)
1265
1266 if no_intrachain:
1267 mdl_ref_indices, mdl_ref_distances = \
1268 lDDTScorer._SetupDistancesIC(self.n_atoms,
1270 mdl_ref_indices,
1271 mdl_ref_distances)
1272
1273 # update ref_indices/ref_distances => add mdl contacts
1274 for i in range(self.n_atoms):
1275 mask = np.isin(mdl_ref_indices[i], ref_indices[i],
1276 assume_unique=True, invert=True)
1277 if np.sum(mask) > 0:
1278 added_mdl_indices = mdl_ref_indices[i][mask]
1279 ref_indices[i] = np.append(ref_indices[i],
1280 added_mdl_indices)
1281
1282 # distances need to be recomputed from ref positions
1283 tmp = self.positions.take(added_mdl_indices, axis=0)
1284 np.subtract(tmp, self.positions[i][None, :], out=tmp)
1285 np.square(tmp, out=tmp)
1286 tmp = tmp.sum(axis=1)
1287 np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1288 ref_distances[i] = np.append(ref_distances[i], tmp)
1289
1290 return (ref_indices, ref_distances)
1291
1292
1293
1294 @staticmethod
1295 def _SetupDistances(structure, n_atoms, atom_index_mapping,
1296 inclusion_radius):
1297
1298 """Compute distance related members of lDDTScorer
1299
1300 Brute force all vs all distance computation kills LDDT for large
1301 complexes. Instead of building some KD tree data structure, we make use
1302 of expected spatial proximity of atoms in the same chain. Distances are
1303 computed as follows:
1304
1305 - process each chain individually
1306 - perform crude collision detection
1307 - process potentially interacting chain pairs
1308 - concatenate distances from all processing steps
1309 """
1310 ref_indices = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)]
1311 ref_distances = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)]
1312
1313 indices = [list() for _ in range(n_atoms)]
1314 distances = [list() for _ in range(n_atoms)]
1315 per_chain_pos = list()
1316 per_chain_indices = list()
1317
1318 # Process individual chains
1319 for ch in structure.chains:
1320 pos_list = list()
1321 atom_indices = list()
1322 mask_start = list()
1323 mask_end = list()
1324 r_start_idx = 0
1325 for r_idx, r in enumerate(ch.residues):
1326 n_valid_atoms = 0
1327 for a in r.atoms:
1328 hash_code = a.handle.GetHashCode()
1329 if hash_code in atom_index_mapping:
1330 p = a.GetPos()
1331 pos_list.append(np.asarray([p[0], p[1], p[2]], dtype=np.float32))
1332 atom_indices.append(atom_index_mapping[hash_code])
1333 n_valid_atoms += 1
1334 mask_start.extend([r_start_idx] * n_valid_atoms)
1335 mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms)
1336 r_start_idx += n_valid_atoms
1337
1338 if len(pos_list) == 0:
1339 # nothing to do...
1340 continue
1341
1342 pos = np.vstack(pos_list)
1343 atom_indices = np.asarray(atom_indices, dtype=np.int32)
1344
1345 if atom_indices.shape[0] > 20000:
1346 dists = blockwise_cdist(pos, pos)
1347 else:
1348 dists = cdist(pos, pos)
1349
1350 # apply masks
1351 far_away = 2 * inclusion_radius
1352 for idx in range(atom_indices.shape[0]):
1353 dists[idx, range(mask_start[idx], mask_end[idx])] = far_away
1354
1355 # fish out and store close atoms within inclusion radius
1356 within_mask = dists < inclusion_radius
1357 for idx in range(atom_indices.shape[0]):
1358 indices_to_append = atom_indices[within_mask[idx,:]]
1359 if indices_to_append.shape[0] > 0:
1360 full_at_idx = atom_indices[idx]
1361 indices[full_at_idx].append(indices_to_append)
1362 distances[full_at_idx].append(dists[idx, within_mask[idx,:]])
1363
1364 dists = None
1365
1366 per_chain_pos.append(pos)
1367 per_chain_indices.append(atom_indices)
1368
1369 # perform crude collision detection
1370 min_pos = [p.min(0) for p in per_chain_pos]
1371 max_pos = [p.max(0) for p in per_chain_pos]
1372 chain_pairs = list()
1373 for idx_one in range(len(per_chain_pos)):
1374 for idx_two in range(idx_one + 1, len(per_chain_pos)):
1375 if np.max(min_pos[idx_one] - max_pos[idx_two]) > inclusion_radius:
1376 continue
1377 if np.max(min_pos[idx_two] - max_pos[idx_one]) > inclusion_radius:
1378 continue
1379 chain_pairs.append((idx_one, idx_two))
1380
1381 # process potentially interacting chains
1382 for pair in chain_pairs:
1383 if per_chain_pos[pair[0]].shape[0] > 20000 or per_chain_pos[pair[1]].shape[0] > 20000:
1384 dists = blockwise_cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1385 else:
1386 dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1387 within = dists <= inclusion_radius
1388
1389 # process pair[0]
1390 tmp = within.sum(axis=1)
1391 for idx in range(tmp.shape[0]):
1392 if tmp[idx] > 0:
1393 # even though not being a strict requirement, we perform an
1394 # insertion here such that the indices for each atom will be
1395 # sorted after the hstack operation
1396 at_idx = per_chain_indices[pair[0]][idx]
1397 indices_to_insert = per_chain_indices[pair[1]][within[idx,:]]
1398 distances_to_insert = dists[idx, within[idx, :]]
1399 insertion_idx = len(indices[at_idx])
1400 for i in range(insertion_idx):
1401 if indices_to_insert[0] > indices[at_idx][i][0]:
1402 insertion_idx = i
1403 break
1404 indices[at_idx].insert(insertion_idx, indices_to_insert)
1405 distances[at_idx].insert(insertion_idx, distances_to_insert)
1406
1407 # process pair[1]
1408 tmp = within.sum(axis=0)
1409 for idx in range(tmp.shape[0]):
1410 if tmp[idx] > 0:
1411 # even though not being a strict requirement, we perform an
1412 # insertion here such that the indices for each atom will be
1413 # sorted after the hstack operation
1414 at_idx = per_chain_indices[pair[1]][idx]
1415 indices_to_insert = per_chain_indices[pair[0]][within[:, idx]]
1416 distances_to_insert = dists[within[:, idx], idx]
1417 insertion_idx = len(indices[at_idx])
1418 for i in range(insertion_idx):
1419 if indices_to_insert[0] > indices[at_idx][i][0]:
1420 insertion_idx = i
1421 break
1422 indices[at_idx].insert(insertion_idx, indices_to_insert)
1423 distances[at_idx].insert(insertion_idx, distances_to_insert)
1424
1425 dists = None
1426
1427 # concatenate distances from all processing steps
1428 for at_idx in range(n_atoms):
1429 if len(indices[at_idx]) > 0:
1430 ref_indices[at_idx] = np.hstack(indices[at_idx])
1431 ref_distances[at_idx] = np.hstack(distances[at_idx])
1432
1433 return (ref_indices, ref_distances)
1434
1435 @staticmethod
1436 def _SetupDistancesSC(n_atoms, chain_start_indices,
1437 ref_indices, ref_distances):
1438 """Select subset of contacts only covering intra-chain contacts
1439 """
1440 # init
1441 ref_indices_sc = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)]
1442 ref_distances_sc = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)]
1443
1444 n_chains = len(chain_start_indices)
1445 for ch_idx in range(n_chains):
1446 chain_s = chain_start_indices[ch_idx]
1447 chain_e = n_atoms
1448 if ch_idx + 1 < n_chains:
1449 chain_e = chain_start_indices[ch_idx+1]
1450 for i in range(chain_s, chain_e):
1451 if len(ref_indices[i]) > 0:
1452 intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
1453 ref_indices[i]<chain_e))[0]
1454 ref_indices_sc[i] = ref_indices[i][intra_idx]
1455 ref_distances_sc[i] = ref_distances[i][intra_idx]
1456
1457 return (ref_indices_sc, ref_distances_sc)
1458
1459 @staticmethod
1460 def _SetupDistancesIC(n_atoms, chain_start_indices,
1461 ref_indices, ref_distances):
1462 """Select subset of contacts only covering inter-chain contacts
1463 """
1464 # init
1465 ref_indices_ic = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)]
1466 ref_distances_ic = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)]
1467
1468 n_chains = len(chain_start_indices)
1469 for ch_idx in range(n_chains):
1470 chain_s = chain_start_indices[ch_idx]
1471 chain_e = n_atoms
1472 if ch_idx + 1 < n_chains:
1473 chain_e = chain_start_indices[ch_idx+1]
1474 for i in range(chain_s, chain_e):
1475 if len(ref_indices[i]) > 0:
1476 inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s,
1477 ref_indices[i]>=chain_e))[0]
1478 ref_indices_ic[i] = ref_indices[i][inter_idx]
1479 ref_distances_ic[i] = ref_distances[i][inter_idx]
1480
1481 return (ref_indices_ic, ref_distances_ic)
1482
1483 @staticmethod
1484 def _NonSymDistances(n_atoms, symmetric_atoms, ref_indices, ref_distances):
1485 """Transfer indices/distances of non-symmetric atoms and return
1486 """
1487
1488 sym_ref_indices = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)]
1489 sym_ref_distances = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)]
1490
1491 for idx in symmetric_atoms:
1492 indices = list()
1493 distances = list()
1494 for i, d in zip(ref_indices[idx], ref_distances[idx]):
1495 if i not in symmetric_atoms:
1496 indices.append(i)
1497 distances.append(d)
1498 sym_ref_indices[idx] = indices
1499 sym_ref_distances[idx] = np.asarray(distances)
1500
1501 return (sym_ref_indices, sym_ref_distances)
1502
1503 def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances):
1504 """Computes number of distance differences within given thresholds
1505
1506 returns np.array with len(thresholds) elements
1507 """
1508 a_p = pos[atom_idx, :]
1509 tmp = pos.take(ref_indices[atom_idx], axis=0)
1510 np.subtract(tmp, a_p[None, :], out=tmp)
1511 np.square(tmp, out=tmp)
1512 tmp = tmp.sum(axis=1)
1513 np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1514 np.subtract(ref_distances[atom_idx], tmp, out=tmp)
1515 np.absolute(tmp, out=tmp) # absolute dist diffs
1516 return np.asarray([(tmp <= thresh).sum() for thresh in thresholds],
1517 dtype=np.int32)
1518
1520 self, pos, atom_indices, thresholds, ref_indices, ref_distances
1521 ):
1522 """Calls _EvalAtom for several atoms and sums up the computed number
1523 of distance differences within given thresholds
1524
1525 returns numpy matrix of shape (n_atoms, len(threshold))
1526 """
1527 conserved = np.zeros((len(atom_indices), len(thresholds)),
1528 dtype=np.int32)
1529 for a_idx, a in enumerate(atom_indices):
1530 conserved[a_idx, :] = self._EvalAtom(pos, a, thresholds,
1531 ref_indices, ref_distances)
1532 return conserved
1533
1534 def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices,
1535 ref_distances):
1536 """Calls _EvalAtoms for a bunch of residues
1537
1538 residues are defined in *res_atom_indices* as lists of atom indices
1539 returns numpy matrix of shape (n_residues, len(thresholds)).
1540 """
1541 conserved = np.zeros((len(res_atom_indices), len(thresholds)),
1542 dtype=np.int32)
1543 for rai_idx, rai in enumerate(res_atom_indices):
1544 conserved[rai_idx,:] = np.sum(self._EvalAtoms(pos, rai, thresholds,
1545 ref_indices, ref_distances), axis=0)
1546 return conserved
1547
1549 if self.sequence_separation != 0:
1550 raise NotImplementedError("Congratulations! You're the first one "
1551 "requesting a non-default "
1552 "sequence_separation in the new and "
1553 "awesome LDDT implementation. A crate of "
1554 "beer for Gabriel and he'll implement "
1555 "it.")
1556
1557 def _GetNExp(self, atom_idx, ref_indices):
1558 """Returns number of close atoms around one or several atoms
1559 """
1560 if isinstance(atom_idx, int):
1561 return len(ref_indices[atom_idx])
1562 elif isinstance(atom_idx, list):
1563 return sum([len(ref_indices[idx]) for idx in atom_idx])
1564 else:
1565 raise RuntimeError("invalid input type")
1566
1567 def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices,
1568 sym_ref_distances):
1569 """Swaps symmetric positions in-place in order to maximize LDDT scores
1570 towards non-symmetric atoms.
1571 """
1572 for sym in symmetries:
1573
1574 atom_indices = list()
1575 for sym_tuple in sym:
1576 atom_indices += [sym_tuple[0], sym_tuple[1]]
1577 tot = self._GetNExp(atom_indices, sym_ref_indices)
1578
1579 if tot == 0:
1580 continue # nothing to do
1581
1582 # score as is
1583 sym_one_conserved = self._EvalAtoms(
1584 pos,
1585 atom_indices,
1586 thresholds,
1587 sym_ref_indices,
1588 sym_ref_distances,
1589 )
1590
1591 # switch positions and score again
1592 for pair in sym:
1593 pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1594
1595 sym_two_conserved = self._EvalAtoms(
1596 pos,
1597 atom_indices,
1598 thresholds,
1599 sym_ref_indices,
1600 sym_ref_distances,
1601 )
1602
1603 sym_one_score = np.sum(sym_one_conserved) / (len(thresholds) * tot)
1604 sym_two_score = np.sum(sym_two_conserved) / (len(thresholds) * tot)
1605
1606 if sym_one_score >= sym_two_score:
1607 # switch back, initial positions were better or equal
1608 # for the equal case: we still switch back to reproduce the old
1609 # LDDT behaviour
1610 for pair in sym:
1611 pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1612
1613 def _EvalAtomSSD(self, pos, atom_idx, dist_cap, ref_indices, ref_distances):
1614 """ Computes summed squared distances
1615
1616 distances are capped at dist_cap
1617 """
1618 a_p = pos[atom_idx, :]
1619 tmp = pos.take(ref_indices[atom_idx], axis=0)
1620 np.subtract(tmp, a_p[None, :], out=tmp)
1621 np.square(tmp, out=tmp)
1622 tmp = tmp.sum(axis=1)
1623 np.sqrt(tmp, out=tmp) # distances against all relevant atoms
1624 np.subtract(ref_distances[atom_idx], tmp, out=tmp) # distance difference
1625 np.square(tmp, out=tmp) # squared distance difference
1626 squared_dist_cap = dist_cap*dist_cap
1627 tmp[tmp > squared_dist_cap] = squared_dist_cap
1628 return tmp.sum()
1629
1631 self, pos, atom_indices, dist_cap, ref_indices, ref_distances
1632 ):
1633 """Calls _EvalAtomSSD for several atoms
1634 """
1635 return np.asarray([self._EvalAtomSSD(pos, a, dist_cap, ref_indices,
1636 ref_distances) for a in atom_indices],
1637 dtype=np.float32)
1638
1639 def _ResolveSymmetriesSSD(self, pos, dist_cap, symmetries, sym_ref_indices,
1640 sym_ref_distances):
1641 """Swaps symmetric positions in-place in order to maximize summed
1642 squared distances towards non-symmetric atoms.
1643 """
1644 for sym in symmetries:
1645
1646 atom_indices = list()
1647 for sym_tuple in sym:
1648 atom_indices += [sym_tuple[0], sym_tuple[1]]
1649 tot = self._GetNExp(atom_indices, sym_ref_indices)
1650
1651 if tot == 0:
1652 continue # nothing to do
1653
1654 # score as is
1655 sym_one_ssd = self._EvalAtomsSSD(
1656 pos,
1657 atom_indices,
1658 dist_cap,
1659 sym_ref_indices,
1660 sym_ref_distances,
1661 )
1662
1663 # switch positions and score again
1664 for pair in sym:
1665 pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1666
1667 sym_two_ssd = self._EvalAtomsSSD(
1668 pos,
1669 atom_indices,
1670 dist_cap,
1671 sym_ref_indices,
1672 sym_ref_distances,
1673 )
1674
1675 sym_one_score = np.sum(sym_one_ssd)
1676 sym_two_score = np.sum(sym_two_ssd)
1677
1678 if sym_one_score < sym_two_score:
1679 # switch back, initial positions were better
1680 for pair in sym:
1681 pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
__init__(self, atom_names)
Definition lddt.py:50
AddSymmetricCompound(self, name, symmetric_atoms)
Definition lddt.py:85
DRMSD(self, model, dist_cap=5, chain_mapping=None, no_interchain=False, no_intrachain=False, residue_mapping=None, check_resnames=True, add_mdl_contacts=False, interaction_data=None)
Definition lddt.py:741
_AddMdlContacts(self, model, res_atom_indices, res_atom_hashes, ref_indices, ref_distances, no_interchain, no_intrachain)
Definition lddt.py:1241
_ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances)
Definition lddt.py:1568
GetNChainContacts(self, target_chain, no_interchain=False)
Definition lddt.py:918
_ProcessModel(self, model, chain_mapping, residue_mapping=None, nirvana_dist=100, check_resnames=True)
Definition lddt.py:943
_EvalAtomSSD(self, pos, atom_idx, dist_cap, ref_indices, ref_distances)
Definition lddt.py:1613
_SetupEnv(self, compound_lib, custom_compounds, symmetry_settings, seqres_mapping, bb_only)
Definition lddt.py:1104
lDDT(self, model, thresholds=[0.5, 1.0, 2.0, 4.0], local_lddt_prop=None, local_contact_prop=None, chain_mapping=None, no_interchain=False, no_intrachain=False, penalize_extra_chains=False, residue_mapping=None, return_dist_test=False, check_resnames=True, add_mdl_contacts=False, interaction_data=None, set_atom_props=False)
Definition lddt.py:462
_EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances)
Definition lddt.py:1503
_GetChainRNums(self, ch, residue_mapping, model_ch_name, target_ch_name)
Definition lddt.py:1052
__init__(self, target, compound_lib=None, custom_compounds=None, inclusion_radius=15, sequence_separation=0, symmetry_settings=None, seqres_mapping=dict(), bb_only=False)
Definition lddt.py:233
_SetupDistances(structure, n_atoms, atom_index_mapping, inclusion_radius)
Definition lddt.py:1296
_GetExtraModelChainPenalty(self, model, chain_mapping)
Definition lddt.py:1035
_SetupDistancesIC(n_atoms, chain_start_indices, ref_indices, ref_distances)
Definition lddt.py:1461
_EvalAtoms(self, pos, atom_indices, thresholds, ref_indices, ref_distances)
Definition lddt.py:1521
_GetTargetResidueNumbers(self, target, seqres_mapping)
Definition lddt.py:1153
_EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices, ref_distances)
Definition lddt.py:1535
_NonSymDistances(n_atoms, symmetric_atoms, ref_indices, ref_distances)
Definition lddt.py:1484
_SetupCompound(self, r, compound_lib, custom_compounds, symmetry_settings, bb_only)
Definition lddt.py:1197
_ResolveSymmetriesSSD(self, pos, dist_cap, symmetries, sym_ref_indices, sym_ref_distances)
Definition lddt.py:1640
_GetNExp(self, atom_idx, ref_indices)
Definition lddt.py:1557
_EvalAtomsSSD(self, pos, atom_indices, dist_cap, ref_indices, ref_distances)
Definition lddt.py:1632
_SetupDistancesSC(n_atoms, chain_start_indices, ref_indices, ref_distances)
Definition lddt.py:1437
blockwise_cdist(A, B, block_size=1000)
Definition lddt.py:19
GetDefaultSymmetrySettings()
Definition lddt.py:103