9 from scipy.spatial.distance
import cdist
12 x2 = np.sum(p1**2, axis=1)
13 y2 = np.sum(p2**2, axis=1)
14 xy = np.matmul(p1, p2.T)
15 x2 = x2.reshape(-1, 1)
16 return np.sqrt(x2 - 2*xy + y2)
19 """ Defines atoms for custom compounds
21 lDDT requires the reference atoms of a compound which are typically
22 extracted from a :class:`ost.conop.CompoundLib`. This lightweight
23 container allows to handle arbitrary compounds which are not
24 necessarily in the compound library.
26 :param atom_names: Names of atoms of custom compound
27 :type atom_names: :class:`list` of :class:`str`
34 """ Construct custom compound from residue
36 :param res: Residue from which reference atom names are extracted,
37 hydrogen/deuterium atoms are filtered out
38 :type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle`
39 :returns: :class:`CustomCompound`
41 at_names = [a.name
for a
in res.atoms
if a.element
not in [
"H",
"D"]]
42 if len(at_names) != len(set(at_names)):
43 raise RuntimeError(
"Duplicate atoms detected in CustomCompound")
48 """Container for symmetric compounds
50 lDDT considers symmetries and selects the one resulting in the highest
53 A symmetry is defined as a renaming operation on one or more atoms that
54 leads to a chemically equivalent residue. Example would be OD1 and OD2 in
55 ASP => renaming OD1 to OD2 and vice versa gives a chemically equivalent
58 Use :func:`AddSymmetricCompound` to define a symmetry which can then
59 directly be accessed through the *symmetric_compounds* member.
65 """Adds symmetry for compound with *name*
67 :param name: Name of compound with symmetry
68 :type name: :class:`str`
69 :param symmetric_atoms: Pairs of atom names that define renaming
70 operation, i.e. after applying all switches
71 defined in the tuples, the resulting residue
72 should be chemically equivalent. Atom names
73 must refer to the PDB component dictionary.
74 :type symmetric_atoms: :class:`list` of :class:`tuple`
76 for pair
in symmetric_atoms:
78 raise RuntimeError(
"Expect pairs when defining symmetries")
83 """Constructs and returns :class:`SymmetrySettings` object for natural amino
89 symmetry_settings.AddSymmetricCompound(
"ASP", [(
"OD1",
"OD2")])
92 symmetry_settings.AddSymmetricCompound(
"GLU", [(
"OE1",
"OE2")])
95 symmetry_settings.AddSymmetricCompound(
"LEU", [(
"CD1",
"CD2")])
98 symmetry_settings.AddSymmetricCompound(
"VAL", [(
"CG1",
"CG2")])
101 symmetry_settings.AddSymmetricCompound(
"ARG", [(
"NH1",
"NH2")])
104 symmetry_settings.AddSymmetricCompound(
105 "PHE", [(
"CD1",
"CD2"), (
"CE1",
"CE2")]
109 symmetry_settings.AddSymmetricCompound(
110 "TYR", [(
"CD1",
"CD2"), (
"CE1",
"CE2")]
114 nuc_names = [
"A",
"C",
"G",
"U",
"DA",
"DC",
"DG",
"DT"]
115 for nuc_name
in nuc_names:
116 symmetry_settings.AddSymmetricCompound(
117 nuc_name, [(
"OP1",
"OP2")]
120 return symmetry_settings
124 """lDDT scorer object for a specific target
126 Sets up everything to score models of that target. lDDT (local distance
127 difference test) is defined as fraction of pairwise distances which exhibit
128 a difference < threshold when considering target and model. In case of
129 multiple thresholds, the average is returned. See
131 V. Mariani, M. Biasini, A. Barbato, T. Schwede, lDDT : A local
132 superposition-free score for comparing protein structures and models using
133 distance difference tests, Bioinformatics, 2013
135 :param target: The target
136 :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
137 :param compound_lib: Compound library from which a compound for each residue
138 is extracted based on its name. Uses
139 :func:`ost.conop.GetDefaultLib` if not given, raises
140 if this returns no valid compound library. Atoms
141 defined in the compound are searched in the residue and
142 build the reference for scoring. If the residue has
143 atoms with names ["A", "B", "C"] but the corresponding
144 compound only has ["A", "B"], "A" and "B" are
145 considered for scoring. If the residue has atoms
146 ["A", "B"] but the compound has ["A", "B", "C"], "C" is
147 considered missing and does not influence scoring, even
148 if present in the model.
149 :param custom_compounds: Custom compounds defining reference atoms. If
150 given, *custom_compounds* take precedent over
152 :type custom_compounds: :class:`dict` with residue names (:class:`str`) as
153 key and :class:`CustomCompound` as value.
154 :type compound_lib: :class:`ost.conop.CompoundLib`
155 :param inclusion_radius: All pairwise distances < *inclusion_radius* are
156 considered for scoring
157 :type inclusion_radius: :class:`float`
158 :param sequence_separation: Only pairwise distances between atoms of
159 residues which are further apart than this
160 threshold are considered. Residue distance is
161 based on resnum. The default (0) considers all
162 pairwise distances except intra-residue
164 :type sequence_separation: :class:`int`
165 :param symmetry_settings: Define residues exhibiting internal symmetry, uses
166 :func:`GetDefaultSymmetrySettings` if not given.
167 :type symmetry_settings: :class:`SymmetrySettings`
168 :param seqres_mapping: Mapping of model residues at the scoring stage
169 happens with residue numbers defining their location
170 in a reference sequence (SEQRES) using one based
171 indexing. If the residue numbers in *target* don't
172 correspond to that SEQRES, you can specify the
173 mapping manually. You can provide a dictionary to
174 specify a reference sequence (SEQRES) for one or more
175 chain(s). Key: chain name, value: alignment
176 (seq1: SEQRES, seq2: sequence of residues in chain).
177 Example: The residues in a chain with name "A" have
178 sequence "YEAH" and residue numbers [42,43,44,45].
179 You can provide an alignment with seq1 "``HELLYEAH``"
180 and seq2 "``----YEAH``". "Y" gets assigned residue
181 number 5, "E" gets assigned 6 and so on no matter
182 what the original residue numbers were.
183 :type seqres_mapping: :class:`dict` (key: :class:`str`, value:
184 :class:`ost.seq.AlignmentHandle`)
185 :param bb_only: Only consider atoms with name "CA" in case of amino acids and
186 "C3'" for Nucleotides. this invalidates *compound_lib*.
187 Raises if any residue in *target* is not
188 `r.chem_class.IsPeptideLinking()` or
189 `r.chem_class.IsNucleotideLinking()`
190 :type bb_only: :class:`bool`
191 :raises: :class:`RuntimeError` if *target* contains compound which is not in
192 *compound_lib*, :class:`RuntimeError` if *symmetry_settings*
193 specifies symmetric atoms that are not present in the according
194 compound in *compound_lib*, :class:`RuntimeError` if
195 *seqres_mapping* is not provided and *target* contains residue
196 numbers with insertion codes or the residue numbers for each chain
197 are not monotonically increasing, :class:`RuntimeError` if
198 *seqres_mapping* is provided but an alignment is invalid
199 (seq1 contains gaps, mismatch in seq1/seq2, seq2 does not match
200 residues in corresponding chains).
206 custom_compounds=None,
208 sequence_separation=0,
209 symmetry_settings=None,
210 seqres_mapping=dict(),
217 if compound_lib
is None:
218 compound_lib = conop.GetDefaultLib()
219 if compound_lib
is None:
220 raise RuntimeError(
"No compound_lib given and conop.GetDefaultLib "
221 "returns no valid compound library")
224 if symmetry_settings
is None:
325 lDDTScorer._SetupDistances(self.
targettarget, self.
n_atomsn_atoms,
334 lDDTScorer._SetupDistances(self.
targettarget, self.
n_atomsn_atoms,
359 lDDTScorer._SetupDistancesSC(self.
n_atomsn_atoms,
369 lDDTScorer._SetupDistancesSC(self.
n_atomsn_atoms,
379 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
389 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
399 lDDTScorer._SetupDistancesIC(self.
n_atomsn_atoms,
409 lDDTScorer._SetupDistancesIC(self.
n_atomsn_atoms,
419 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
429 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
435 def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
436 local_lddt_prop=None, local_contact_prop=None,
437 chain_mapping=None, no_interchain=False,
438 no_intrachain=False, penalize_extra_chains=False,
439 residue_mapping=None, return_dist_test=False,
440 check_resnames=True, add_mdl_contacts=False,
441 interaction_data=None):
442 """Computes lDDT of *model* - globally and per-residue
444 :param model: Model to be scored - models are preferably scored upon
445 performing stereo-chemistry checks in order to punish for
446 non-sensical irregularities. This must be done separately
447 as a pre-processing step. Target contacts that are not
448 covered by *model* are considered not conserved, thus
449 decreasing lDDT score. This also includes missing model
450 chains or model chains for which no mapping is provided in
452 :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
453 :param thresholds: Thresholds of distance differences to be considered
454 as correct - see docs in constructor for more info.
455 default: [0.5, 1.0, 2.0, 4.0]
456 :type thresholds: :class:`list` of :class:`floats`
457 :param local_lddt_prop: If set, per-residue scores will be assigned as
458 generic float property of that name
459 :type local_lddt_prop: :class:`str`
460 :param local_contact_prop: If set, number of expected contacts as well
461 as number of conserved contacts will be
462 assigned as generic int property.
463 Excected contacts will be set as
464 <local_contact_prop>_exp, conserved contacts
465 as <local_contact_prop>_cons. Values
466 are summed over all thresholds.
467 :type local_contact_prop: :class:`str`
468 :param chain_mapping: Mapping of model chains (key) onto target chains
469 (value). This is required if target or model have
471 :type chain_mapping: :class:`dict` with :class:`str` as keys/values
472 :param no_interchain: Whether to exclude interchain contacts
473 :type no_interchain: :class:`bool`
474 :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
475 consider interface related contacts)
476 :type no_intrachain: :class:`bool`
477 :param penalize_extra_chains: Whether to include a fixed penalty for
478 additional chains in the model that are
479 not mapped to the target. ONLY AFFECTS
480 RETURNED GLOBAL SCORE. In detail: adds the
481 number of intra-chain contacts of each
482 extra chain to the expected contacts, thus
484 :type penalize_extra_chains: :class:`bool`
485 :param residue_mapping: By default, residue mapping is based on residue
486 numbers. That means, a model chain and the
487 respective target chain map to the same
488 underlying reference sequence (SEQRES).
489 Alternatively, you can specify one or
490 several alignment(s) between model and target
491 chains by providing a dictionary. key: Name
492 of chain in model (respective target chain is
493 extracted from *chain_mapping*),
494 value: Alignment with first sequence
495 corresponding to target chain and second
496 sequence to model chain. There is NO reference
497 sequence involved, so the two sequences MUST
498 exactly match the actual residues observed in
499 the respective target/model chains (ATOMSEQ).
500 :type residue_mapping: :class:`dict` with key: :class:`str`,
501 value: :class:`ost.seq.AlignmentHandle`
502 :param return_dist_test: Whether to additionally return the underlying
503 per-residue data for the distance difference
504 test. Adds five objects to the return tuple.
505 First: Number of total contacts summed over all
507 Second: Number of conserved contacts summed
509 Third: list with length of scored residues.
510 Contains indices referring to model.residues.
511 Fourth: numpy array of size
512 len(scored_residues) containing the number of
514 Fifth: numpy matrix of shape
515 (len(scored_residues), len(thresholds))
516 specifying how many for each threshold are
518 :param check_resnames: On by default. Enforces residue name matches
519 between mapped model and target residues.
520 :type check_resnames: :class:`bool`
521 :param add_mdl_contacts: Adds model contacts - Only using contacts that
522 are within a certain distance threshold in the
523 target does not penalize for added model
524 contacts. If set to True, this flag will also
525 consider target contacts that are within the
526 specified distance threshold in the model but
527 not necessarily in the target. No contact will
528 be added if the respective atom pair is not
529 resolved in the target.
530 :type add_mdl_contacts: :class:`bool`
531 :param interaction_data: Pro param - don't use
532 :type interaction_data: :class:`tuple`
534 :returns: global and per-residue lDDT scores as a tuple -
535 first element is global lDDT score (None if *target* has no
536 contacts) and second element a list of per-residue scores with
537 length len(*model*.residues). None is assigned to residues that
538 are not covered by target. If a residue is covered but has no
539 contacts in *target*, 0.0 is assigned.
541 if chain_mapping
is None:
542 if len(self.
chain_nameschain_names) > 1
or len(model.chains) > 1:
543 raise NotImplementedError(
"Must provide chain mapping if "
544 "target or model have > 1 chains.")
545 chain_mapping = {model.chains[0].GetName(): self.
chain_nameschain_names[0]}
548 for model_chain, target_chain
in chain_mapping.items():
549 if target_chain
not in self.
chain_nameschain_names:
550 raise RuntimeError(f
"Target chain specified in "
551 f
"chain_mapping ({target_chain}) does "
552 f
"not exist. Target has chains: "
553 f
"{self.chain_names}")
554 ch = model.FindChain(model_chain)
556 raise RuntimeError(f
"Model chain specified in "
557 f
"chain_mapping ({model_chain}) does "
558 f
"not exist. Model has chains: "
559 f
"{[c.GetName() for c in model.chains]}")
563 pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
564 res_indices, ref_res_indices, symmetries = \
566 residue_mapping = residue_mapping,
567 thresholds = thresholds,
568 check_resnames = check_resnames)
570 if no_interchain
and no_intrachain:
571 raise RuntimeError(
"no_interchain and no_intrachain flags are "
572 "mutually exclusive")
575 sym_ref_indices =
None
576 sym_ref_distances =
None
580 if interaction_data
is None:
598 ref_indices, ref_distances = \
599 self.
_AddMdlContacts_AddMdlContacts(model, res_atom_indices, res_atom_hashes,
600 ref_indices, ref_distances,
601 no_interchain, no_intrachain)
603 sym_ref_indices, sym_ref_distances = \
605 ref_indices, ref_distances)
607 sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \
610 self.
_ResolveSymmetries_ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
613 per_res_exp = np.asarray([self.
_GetNExp_GetNExp(res_ref_atom_indices[idx],
614 ref_indices)
for idx
in range(len(res_indices))], dtype=np.int32)
615 per_res_conserved = self.
_EvalResidues_EvalResidues(pos, thresholds,
617 ref_indices, ref_distances)
619 n_thresh = len(thresholds)
622 per_res_lDDT = [
None] * len(model.residues)
623 for idx
in range(len(res_indices)):
624 n_exp = n_thresh * per_res_exp[idx]
626 score = np.sum(per_res_conserved[idx,:]) / n_exp
627 per_res_lDDT[res_indices[idx]] = score
629 per_res_lDDT[res_indices[idx]] = 0.0
632 n_distances = sum([len(x)
for x
in ref_indices])
633 if penalize_extra_chains:
636 lDDT_tot = int(n_thresh * n_distances)
637 lDDT_cons = int(np.sum(per_res_conserved))
640 lDDT = float(lDDT_cons) / lDDT_tot
644 residues = model.residues
645 for idx
in res_indices:
646 residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx])
648 if local_contact_prop:
649 residues = model.residues
650 exp_prop = local_contact_prop +
"_exp"
651 conserved_prop = local_contact_prop +
"_cons"
653 for i, r_idx
in enumerate(res_indices):
654 residues[r_idx].SetIntProp(exp_prop,
655 n_thresh * int(per_res_exp[i]))
656 residues[r_idx].SetIntProp(conserved_prop,
657 int(np.sum(per_res_conserved[i,:])))
660 return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
661 per_res_exp, per_res_conserved
663 return lDDT, per_res_lDDT
666 """Returns number of contacts expected for a certain chain in *target*
668 :param target_chain: Chain in *target* for which you want the number
670 :type target_chain: :class:`str`
671 :param no_interchain: Whether to exclude interchain contacts
672 :type no_interchain: :class:`bool`
673 :raises: :class:`RuntimeError` if specified chain doesnt exist
675 if target_chain
not in self.
chain_nameschain_names:
676 raise RuntimeError(f
"Specified chain name ({target_chain}) not in "
678 ch_idx = self.
chain_nameschain_names.index(target_chain)
688 def _ProcessModel(self, model, chain_mapping, residue_mapping = None,
689 thresholds = [0.5, 1.0, 2.0, 4.0],
690 check_resnames = True):
691 """ Helper that generates data structures from model
696 max_pos = model.bounds.GetMax()
697 max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
698 max_coordinate += 42 * max(thresholds)
699 pos = np.ones((self.
n_atomsn_atoms, 3), dtype=np.float32) * max_coordinate
703 res_ref_atom_indices = list()
707 res_atom_indices = list()
711 res_atom_hashes = list()
717 ref_res_indices = list()
722 current_model_res_idx = -1
723 for ch
in model.chains:
724 model_ch_name = ch.GetName()
725 if model_ch_name
not in chain_mapping:
726 current_model_res_idx += len(ch.residues)
728 target_ch_name = chain_mapping[model_ch_name]
730 rnums = self.
_GetChainRNums_GetChainRNums(ch, residue_mapping, model_ch_name,
733 for r, rnum
in zip(ch.residues, rnums):
734 current_model_res_idx += 1
735 res_mapper_key = (target_ch_name, rnum)
736 if res_mapper_key
not in self.
res_mapperres_mapper:
738 r_idx = self.
res_mapperres_mapper[res_mapper_key]
739 if check_resnames
and r.name != self.
compound_namescompound_names[r_idx]:
741 f
"Residue name mismatch for {r}, "
742 f
" expect {self.compound_names[r_idx]}"
747 atoms = [r.FindAtom(aname)
for aname
in anames]
748 res_ref_atom_indices.append(
749 list(range(res_start_idx, res_start_idx + len(anames)))
751 res_atom_indices.append(list())
752 res_atom_hashes.append(list())
753 res_indices.append(current_model_res_idx)
754 ref_res_indices.append(r_idx)
755 for a_idx, a
in enumerate(atoms):
758 pos[res_start_idx + a_idx][0] = p[0]
759 pos[res_start_idx + a_idx][1] = p[1]
760 pos[res_start_idx + a_idx][2] = p[2]
761 res_atom_indices[-1].append(res_start_idx + a_idx)
762 res_atom_hashes[-1].append(a.handle.GetHashCode())
766 a_one = atoms[sym_tuple[0]]
767 a_two = atoms[sym_tuple[1]]
768 if a_one.IsValid()
and a_two.IsValid():
771 res_start_idx + sym_tuple[0],
772 res_start_idx + sym_tuple[1],
775 if len(sym_indices) > 0:
776 symmetries.append(sym_indices)
778 return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes,
779 res_indices, ref_res_indices, symmetries)
782 def _GetExtraModelChainPenalty(self, model, chain_mapping):
783 """Counts n distances in extra model chains to be added as penalty
786 for chain
in model.chains:
787 ch_name = chain.GetName()
788 if ch_name
not in chain_mapping:
790 mdl_sel = model.Select(f
"cname={mol.QueryQuoteName(ch_name)}")
792 symmetry_settings = sm,
795 penalty += sum([len(x)
for x
in dummy_scorer.ref_indices])
798 def _GetChainRNums(self, ch, residue_mapping, model_ch_name,
800 """Map residues in model chain to target residues
802 There are two options: one is simply using residue numbers,
803 the other is a custom mapping as given in *residue_mapping*
805 if residue_mapping
and model_ch_name
in residue_mapping:
807 ch_idx = self.
chain_nameschain_names.index(target_ch_name)
813 target_rnums = self.
res_resnumsres_resnums[start_idx:end_idx]
815 target_seq = residue_mapping[model_ch_name].GetSequence(0)
816 model_seq = residue_mapping[model_ch_name].GetSequence(1)
817 if len(target_seq.GetGaplessString()) != len(target_rnums):
818 raise RuntimeError(f
"Try to perform residue mapping for "
819 f
"model chain {model_ch_name} which "
820 f
"maps to {target_ch_name} in target. "
821 f
"Target sequence in alignment suggests "
822 f
"{len(target_seq.GetGaplessString())} "
823 f
"residues but {len(target_rnums)} are "
825 if len(model_seq.GetGaplessString()) != len(ch.residues):
826 raise RuntimeError(f
"Try to perform residue mapping for "
827 f
"model chain {model_ch_name} which "
828 f
"maps to {target_ch_name} in target. "
829 f
"Model sequence in alignment suggests "
830 f
"{len(model_seq.GetGaplessString())} "
831 f
"residues but {len(ch.residues)} are "
835 for col
in residue_mapping[model_ch_name]:
839 if col[0] !=
'-' and col[1] !=
'-':
840 rnums.append(target_rnums[target_idx])
842 if col[0] ==
'-' and col[1] !=
'-':
845 rnums = [r.GetNumber()
for r
in ch.residues]
850 def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings,
851 seqres_mapping, bb_only):
852 """Sets target related lDDTScorer members defined in constructor
854 No distance related members - see _SetupDistances
860 for chain
in self.
targettarget.chains:
861 ch_name = chain.GetName()
865 for r, rnum
in zip(chain.residues, residue_numbers[ch_name]):
869 self.
_SetupCompound_SetupCompound(r, compound_lib, custom_compounds,
870 symmetry_settings, bb_only)
877 atoms = [r.FindAtom(an)
for an
in self.
compound_anamescompound_anames[r.name]]
880 self.
atom_indicesatom_indices[a.handle.GetHashCode()] = current_idx
882 positions.append(np.asarray([p[0], p[1], p[2]],
885 positions.append(np.zeros(3, dtype=np.float32))
890 for a_idx
in sym_tuple:
893 hashcode = a.handle.GetHashCode()
897 self.
positionspositions = np.vstack(positions)
898 self.
n_atomsn_atoms = current_idx
900 def _GetTargetResidueNumbers(self, target, seqres_mapping):
901 """Returns residue numbers for each chain in target as dict
903 They're either directly extracted from the raw residue number
904 from the structure or from user provided alignments
906 residue_numbers = dict()
907 for ch
in target.chains:
908 ch_name = ch.GetName()
910 if ch_name
in seqres_mapping:
911 seqres = seqres_mapping[ch_name].GetSequence(0).GetString()
912 atomseq = seqres_mapping[ch_name].GetSequence(1).GetString()
916 "SEQRES in seqres_mapping must not " "contain gaps"
918 atomseq_from_chain = [r.one_letter_code
for r
in ch.residues]
919 if atomseq.replace(
"-",
"") != atomseq_from_chain:
921 "ATOMSEQ in seqres_mapping must match "
922 "raw sequence extracted from chain "
926 for seqres_olc, atomseq_olc
in zip(seqres, atomseq):
927 if seqres_olc !=
"-":
929 if atomseq_olc !=
"-":
930 if seqres_olc != atomseq_olc:
932 f
"Residue with number {rnum} in "
933 f
"chain {ch_name} has SEQRES "
938 rnums = [r.GetNumber()
for r
in ch.residues]
939 assert len(rnums) == len(ch.residues)
940 residue_numbers[ch_name] = rnums
941 return residue_numbers
943 def _SetupCompound(self, r, compound_lib, custom_compounds,
944 symmetry_settings, bb_only):
945 """fill self.compound_anames/self.compound_symmetric_atoms
949 if r.chem_class.IsPeptideLinking():
951 elif r.chem_class.IsNucleotideLinking():
954 raise RuntimeError(f
"Only support amino acids and nucleotides "
955 f
"if bb_only is True, failed with {str(r)}")
959 symmetric_atoms = list()
960 if custom_compounds
is not None and r.GetName()
in custom_compounds:
961 atom_names = list(custom_compounds[r.GetName()].atom_names)
963 compound = compound_lib.FindCompound(r.name)
965 raise RuntimeError(f
"no entry for {r} in compound_lib")
966 for atom_spec
in compound.GetAtomSpecs():
967 if atom_spec.element
not in [
"H",
"D"]:
968 atom_names.append(atom_spec.name)
969 if r.name
in symmetry_settings.symmetric_compounds:
970 for pair
in symmetry_settings.symmetric_compounds[r.name]:
972 a = atom_names.index(pair[0])
973 b = atom_names.index(pair[1])
975 msg = f
"Could not find symmetric atoms "
976 msg += f
"({pair[0]}, {pair[1]}) for {r.name} "
977 msg += f
"as specified in SymmetrySettings in "
978 msg += f
"compound from component dictionary. "
979 msg += f
"Atoms in compound: {atom_names}"
980 raise RuntimeError(msg)
981 symmetric_atoms.append((a, b))
983 if len(symmetric_atoms) > 0:
986 def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes,
987 ref_indices, ref_distances, no_interchain,
991 in_target = np.zeros(self.
n_atomsn_atoms, dtype=bool)
994 mdl_atom_indices = dict()
995 for at_indices, at_hashes
in zip(res_atom_indices, res_atom_hashes):
996 for i, h
in zip(at_indices, at_hashes):
998 mdl_atom_indices[h] = i
1003 mdl_ref_indices, mdl_ref_distances = \
1004 lDDTScorer._SetupDistances(model, self.
n_atomsn_atoms, mdl_atom_indices,
1007 mdl_ref_indices, mdl_ref_distances = \
1008 lDDTScorer._SetupDistancesSC(self.
n_atomsn_atoms,
1014 mdl_ref_indices, mdl_ref_distances = \
1015 lDDTScorer._SetupDistancesIC(self.
n_atomsn_atoms,
1021 for i
in range(self.
n_atomsn_atoms):
1022 mask = np.isin(mdl_ref_indices[i], ref_indices[i],
1023 assume_unique=
True, invert=
True)
1024 if np.sum(mask) > 0:
1025 added_mdl_indices = mdl_ref_indices[i][mask]
1026 ref_indices[i] = np.append(ref_indices[i],
1030 tmp = self.
positionspositions.take(added_mdl_indices, axis=0)
1031 np.subtract(tmp, self.
positionspositions[i][
None, :], out=tmp)
1032 np.square(tmp, out=tmp)
1033 tmp = tmp.sum(axis=1)
1034 np.sqrt(tmp, out=tmp)
1035 ref_distances[i] = np.append(ref_distances[i], tmp)
1037 return (ref_indices, ref_distances)
1042 def _SetupDistances(structure, n_atoms, atom_index_mapping,
1045 """Compute distance related members of lDDTScorer
1047 Brute force all vs all distance computation kills lDDT for large
1048 complexes. Instead of building some KD tree data structure, we make use
1049 of expected spatial proximity of atoms in the same chain. Distances are
1050 computed as follows:
1052 - process each chain individually
1053 - perform crude collision detection
1054 - process potentially interacting chain pairs
1055 - concatenate distances from all processing steps
1057 ref_indices = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1058 ref_distances = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1060 indices = [list()
for _
in range(n_atoms)]
1061 distances = [list()
for _
in range(n_atoms)]
1062 per_chain_pos = list()
1063 per_chain_indices = list()
1066 for ch
in structure.chains:
1068 atom_indices = list()
1072 for r_idx, r
in enumerate(ch.residues):
1075 hash_code = a.handle.GetHashCode()
1076 if hash_code
in atom_index_mapping:
1078 pos_list.append(np.asarray([p[0], p[1], p[2]]))
1079 atom_indices.append(atom_index_mapping[hash_code])
1081 mask_start.extend([r_start_idx] * n_valid_atoms)
1082 mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms)
1083 r_start_idx += n_valid_atoms
1085 if len(pos_list) == 0:
1089 pos = np.vstack(pos_list)
1090 atom_indices = np.asarray(atom_indices)
1091 dists =
cdist(pos, pos)
1094 far_away = 2 * inclusion_radius
1095 for idx
in range(atom_indices.shape[0]):
1096 dists[idx, range(mask_start[idx], mask_end[idx])] = far_away
1099 within_mask = dists < inclusion_radius
1100 for idx
in range(atom_indices.shape[0]):
1101 indices_to_append = atom_indices[within_mask[idx,:]]
1102 if indices_to_append.shape[0] > 0:
1103 full_at_idx = atom_indices[idx]
1104 indices[full_at_idx].append(indices_to_append)
1105 distances[full_at_idx].append(dists[idx, within_mask[idx,:]])
1107 per_chain_pos.append(pos)
1108 per_chain_indices.append(atom_indices)
1111 min_pos = [p.min(0)
for p
in per_chain_pos]
1112 max_pos = [p.max(0)
for p
in per_chain_pos]
1113 chain_pairs = list()
1114 for idx_one
in range(len(per_chain_pos)):
1115 for idx_two
in range(idx_one + 1, len(per_chain_pos)):
1116 if np.max(min_pos[idx_one] - max_pos[idx_two]) > inclusion_radius:
1118 if np.max(min_pos[idx_two] - max_pos[idx_one]) > inclusion_radius:
1120 chain_pairs.append((idx_one, idx_two))
1123 for pair
in chain_pairs:
1124 dists =
cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1125 within = dists <= inclusion_radius
1128 tmp = within.sum(axis=1)
1129 for idx
in range(tmp.shape[0]):
1134 at_idx = per_chain_indices[pair[0]][idx]
1135 indices_to_insert = per_chain_indices[pair[1]][within[idx,:]]
1136 distances_to_insert = dists[idx, within[idx, :]]
1137 insertion_idx = len(indices[at_idx])
1138 for i
in range(insertion_idx):
1139 if indices_to_insert[0] > indices[at_idx][i][0]:
1142 indices[at_idx].insert(insertion_idx, indices_to_insert)
1143 distances[at_idx].insert(insertion_idx, distances_to_insert)
1146 tmp = within.sum(axis=0)
1147 for idx
in range(tmp.shape[0]):
1152 at_idx = per_chain_indices[pair[1]][idx]
1153 indices_to_insert = per_chain_indices[pair[0]][within[:, idx]]
1154 distances_to_insert = dists[within[:, idx], idx]
1155 insertion_idx = len(indices[at_idx])
1156 for i
in range(insertion_idx):
1157 if indices_to_insert[0] > indices[at_idx][i][0]:
1160 indices[at_idx].insert(insertion_idx, indices_to_insert)
1161 distances[at_idx].insert(insertion_idx, distances_to_insert)
1164 for at_idx
in range(n_atoms):
1165 if len(indices[at_idx]) > 0:
1166 ref_indices[at_idx] = np.hstack(indices[at_idx])
1167 ref_distances[at_idx] = np.hstack(distances[at_idx])
1169 return (ref_indices, ref_distances)
1172 def _SetupDistancesSC(n_atoms, chain_start_indices,
1173 ref_indices, ref_distances):
1174 """Select subset of contacts only covering intra-chain contacts
1177 ref_indices_sc = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1178 ref_distances_sc = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1180 n_chains = len(chain_start_indices)
1181 for ch_idx
in range(n_chains):
1182 chain_s = chain_start_indices[ch_idx]
1184 if ch_idx + 1 < n_chains:
1185 chain_e = chain_start_indices[ch_idx+1]
1186 for i
in range(chain_s, chain_e):
1187 if len(ref_indices[i]) > 0:
1188 intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
1189 ref_indices[i]<chain_e))[0]
1190 ref_indices_sc[i] = ref_indices[i][intra_idx]
1191 ref_distances_sc[i] = ref_distances[i][intra_idx]
1193 return (ref_indices_sc, ref_distances_sc)
1196 def _SetupDistancesIC(n_atoms, chain_start_indices,
1197 ref_indices, ref_distances):
1198 """Select subset of contacts only covering inter-chain contacts
1201 ref_indices_ic = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1202 ref_distances_ic = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1204 n_chains = len(chain_start_indices)
1205 for ch_idx
in range(n_chains):
1206 chain_s = chain_start_indices[ch_idx]
1208 if ch_idx + 1 < n_chains:
1209 chain_e = chain_start_indices[ch_idx+1]
1210 for i
in range(chain_s, chain_e):
1211 if len(ref_indices[i]) > 0:
1212 inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s,
1213 ref_indices[i]>=chain_e))[0]
1214 ref_indices_ic[i] = ref_indices[i][inter_idx]
1215 ref_distances_ic[i] = ref_distances[i][inter_idx]
1217 return (ref_indices_ic, ref_distances_ic)
1220 def _NonSymDistances(n_atoms, symmetric_atoms, ref_indices, ref_distances):
1221 """Transfer indices/distances of non-symmetric atoms and return
1224 sym_ref_indices = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1225 sym_ref_distances = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1227 for idx
in symmetric_atoms:
1230 for i, d
in zip(ref_indices[idx], ref_distances[idx]):
1231 if i
not in symmetric_atoms:
1234 sym_ref_indices[idx] = indices
1235 sym_ref_distances[idx] = np.asarray(distances)
1237 return (sym_ref_indices, sym_ref_distances)
1239 def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances):
1240 """Computes number of distance differences within given thresholds
1242 returns np.array with len(thresholds) elements
1244 a_p = pos[atom_idx, :]
1245 tmp = pos.take(ref_indices[atom_idx], axis=0)
1246 np.subtract(tmp, a_p[
None, :], out=tmp)
1247 np.square(tmp, out=tmp)
1248 tmp = tmp.sum(axis=1)
1249 np.sqrt(tmp, out=tmp)
1250 np.subtract(ref_distances[atom_idx], tmp, out=tmp)
1251 np.absolute(tmp, out=tmp)
1252 return np.asarray([(tmp <= thresh).sum()
for thresh
in thresholds],
1256 self, pos, atom_indices, thresholds, ref_indices, ref_distances
1258 """Calls _EvalAtom for several atoms and sums up the computed number
1259 of distance differences within given thresholds
1261 returns numpy matrix of shape (n_atoms, len(threshold))
1263 conserved = np.zeros((len(atom_indices), len(thresholds)),
1265 for a_idx, a
in enumerate(atom_indices):
1266 conserved[a_idx, :] = self.
_EvalAtom_EvalAtom(pos, a, thresholds,
1267 ref_indices, ref_distances)
1270 def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices,
1272 """Calls _EvalAtoms for a bunch of residues
1274 residues are defined in *res_atom_indices* as lists of atom indices
1275 returns numpy matrix of shape (n_residues, len(thresholds)).
1277 conserved = np.zeros((len(res_atom_indices), len(thresholds)),
1279 for rai_idx, rai
in enumerate(res_atom_indices):
1280 conserved[rai_idx,:] = np.sum(self.
_EvalAtoms_EvalAtoms(pos, rai, thresholds,
1281 ref_indices, ref_distances), axis=0)
1284 def _ProcessSequenceSeparation(self):
1286 raise NotImplementedError(
"Congratulations! You're the first one "
1287 "requesting a non-default "
1288 "sequence_separation in the new and "
1289 "awesome lDDT implementation. A crate of "
1290 "beer for Gabriel and he'll implement "
1293 def _GetNExp(self, atom_idx, ref_indices):
1294 """Returns number of close atoms around one or several atoms
1296 if isinstance(atom_idx, int):
1297 return len(ref_indices[atom_idx])
1298 elif isinstance(atom_idx, list):
1299 return sum([len(ref_indices[idx])
for idx
in atom_idx])
1301 raise RuntimeError(
"invalid input type")
1303 def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices,
1305 """Swaps symmetric positions in-place in order to maximize lDDT scores
1306 towards non-symmetric atoms.
1308 for sym
in symmetries:
1310 atom_indices = list()
1311 for sym_tuple
in sym:
1312 atom_indices += [sym_tuple[0], sym_tuple[1]]
1313 tot = self.
_GetNExp_GetNExp(atom_indices, sym_ref_indices)
1319 sym_one_conserved = self.
_EvalAtoms_EvalAtoms(
1329 pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1331 sym_two_conserved = self.
_EvalAtoms_EvalAtoms(
1339 sym_one_score = np.sum(sym_one_conserved) / (len(thresholds) * tot)
1340 sym_two_score = np.sum(sym_two_conserved) / (len(thresholds) * tot)
1342 if sym_one_score >= sym_two_score:
1347 pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
def __init__(self, atom_names)
def AddSymmetricCompound(self, name, symmetric_atoms)
def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices, ref_distances)
def _SetupCompound(self, r, compound_lib, custom_compounds, symmetry_settings, bb_only)
def _ProcessModel(self, model, chain_mapping, residue_mapping=None, thresholds=[0.5, 1.0, 2.0, 4.0], check_resnames=True)
def sym_ref_indices(self)
def lDDT(self, model, thresholds=[0.5, 1.0, 2.0, 4.0], local_lddt_prop=None, local_contact_prop=None, chain_mapping=None, no_interchain=False, no_intrachain=False, penalize_extra_chains=False, residue_mapping=None, return_dist_test=False, check_resnames=True, add_mdl_contacts=False, interaction_data=None)
def _GetChainRNums(self, ch, residue_mapping, model_ch_name, target_ch_name)
def _ProcessSequenceSeparation(self)
def sym_ref_distances(self)
def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances)
def ref_distances_ic(self)
def _GetTargetResidueNumbers(self, target, seqres_mapping)
def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances)
def sym_ref_distances_ic(self)
def sym_ref_distances_sc(self)
def sym_ref_indices_ic(self)
def GetNChainContacts(self, target_chain, no_interchain=False)
def sym_ref_indices_sc(self)
def ref_distances_sc(self)
def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes, ref_indices, ref_distances, no_interchain, no_intrachain)
def _GetNExp(self, atom_idx, ref_indices)
def __init__(self, target, compound_lib=None, custom_compounds=None, inclusion_radius=15, sequence_separation=0, symmetry_settings=None, seqres_mapping=dict(), bb_only=False)
def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings, seqres_mapping, bb_only)
def _GetExtraModelChainPenalty(self, model, chain_mapping)
def _EvalAtoms(self, pos, atom_indices, thresholds, ref_indices, ref_distances)
def GetDefaultSymmetrySettings()