10 from scipy.spatial.distance
import cdist
13 x2 = np.sum(p1**2, axis=1)
14 y2 = np.sum(p2**2, axis=1)
15 xy = np.matmul(p1, p2.T)
16 x2 = x2.reshape(-1, 1)
17 return np.sqrt(x2 - 2*xy + y2)
20 """ Defines atoms for custom compounds
22 lDDT requires the reference atoms of a compound which are typically
23 extracted from a :class:`ost.conop.CompoundLib`. This lightweight
24 container allows to handle arbitrary compounds which are not
25 necessarily in the compound library.
27 :param atom_names: Names of atoms of custom compound
28 :type atom_names: :class:`list` of :class:`str`
35 """ Construct custom compound from residue
37 :param res: Residue from which reference atom names are extracted,
38 hydrogen/deuterium atoms are filtered out
39 :type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle`
40 :returns: :class:`CustomCompound`
42 at_names = [a.name
for a
in res.atoms
if a.element
not in [
"H",
"D"]]
43 if len(at_names) != len(set(at_names)):
44 raise RuntimeError(
"Duplicate atoms detected in CustomCompound")
49 """Container for symmetric compounds
51 lDDT considers symmetries and selects the one resulting in the highest
54 A symmetry is defined as a renaming operation on one or more atoms that
55 leads to a chemically equivalent residue. Example would be OD1 and OD2 in
56 ASP => renaming OD1 to OD2 and vice versa gives a chemically equivalent
59 Use :func:`AddSymmetricCompound` to define a symmetry which can then
60 directly be accessed through the *symmetric_compounds* member.
66 """Adds symmetry for compound with *name*
68 :param name: Name of compound with symmetry
69 :type name: :class:`str`
70 :param symmetric_atoms: Pairs of atom names that define renaming
71 operation, i.e. after applying all switches
72 defined in the tuples, the resulting residue
73 should be chemically equivalent. Atom names
74 must refer to the PDB component dictionary.
75 :type symmetric_atoms: :class:`list` of :class:`tuple`
77 for pair
in symmetric_atoms:
79 raise RuntimeError(
"Expect pairs when defining symmetries")
84 """Constructs and returns :class:`SymmetrySettings` object for natural amino
90 symmetry_settings.AddSymmetricCompound(
"ASP", [(
"OD1",
"OD2")])
93 symmetry_settings.AddSymmetricCompound(
"GLU", [(
"OE1",
"OE2")])
96 symmetry_settings.AddSymmetricCompound(
"LEU", [(
"CD1",
"CD2")])
99 symmetry_settings.AddSymmetricCompound(
"VAL", [(
"CG1",
"CG2")])
102 symmetry_settings.AddSymmetricCompound(
"ARG", [(
"NH1",
"NH2")])
105 symmetry_settings.AddSymmetricCompound(
106 "PHE", [(
"CD1",
"CD2"), (
"CE1",
"CE2")]
110 symmetry_settings.AddSymmetricCompound(
111 "TYR", [(
"CD1",
"CD2"), (
"CE1",
"CE2")]
115 nuc_names = [
"A",
"C",
"G",
"U",
"DA",
"DC",
"DG",
"DT"]
116 for nuc_name
in nuc_names:
117 symmetry_settings.AddSymmetricCompound(
118 nuc_name, [(
"OP1",
"OP2")]
121 return symmetry_settings
125 """lDDT scorer object for a specific target
127 Sets up everything to score models of that target. lDDT (local distance
128 difference test) is defined as fraction of pairwise distances which exhibit
129 a difference < threshold when considering target and model. In case of
130 multiple thresholds, the average is returned. See
132 V. Mariani, M. Biasini, A. Barbato, T. Schwede, lDDT : A local
133 superposition-free score for comparing protein structures and models using
134 distance difference tests, Bioinformatics, 2013
136 :param target: The target
137 :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
138 :param compound_lib: Compound library from which a compound for each residue
139 is extracted based on its name. Uses
140 :func:`ost.conop.GetDefaultLib` if not given, raises
141 if this returns no valid compound library. Atoms
142 defined in the compound are searched in the residue and
143 build the reference for scoring. If the residue has
144 atoms with names ["A", "B", "C"] but the corresponding
145 compound only has ["A", "B"], "A" and "B" are
146 considered for scoring. If the residue has atoms
147 ["A", "B"] but the compound has ["A", "B", "C"], "C" is
148 considered missing and does not influence scoring, even
149 if present in the model.
150 :param custom_compounds: Custom compounds defining reference atoms. If
151 given, *custom_compounds* take precedent over
153 :type custom_compounds: :class:`dict` with residue names (:class:`str`) as
154 key and :class:`CustomCompound` as value.
155 :type compound_lib: :class:`ost.conop.CompoundLib`
156 :param inclusion_radius: All pairwise distances < *inclusion_radius* are
157 considered for scoring
158 :type inclusion_radius: :class:`float`
159 :param sequence_separation: Only pairwise distances between atoms of
160 residues which are further apart than this
161 threshold are considered. Residue distance is
162 based on resnum. The default (0) considers all
163 pairwise distances except intra-residue
165 :type sequence_separation: :class:`int`
166 :param symmetry_settings: Define residues exhibiting internal symmetry, uses
167 :func:`GetDefaultSymmetrySettings` if not given.
168 :type symmetry_settings: :class:`SymmetrySettings`
169 :param seqres_mapping: Mapping of model residues at the scoring stage
170 happens with residue numbers defining their location
171 in a reference sequence (SEQRES) using one based
172 indexing. If the residue numbers in *target* don't
173 correspond to that SEQRES, you can specify the
174 mapping manually. You can provide a dictionary to
175 specify a reference sequence (SEQRES) for one or more
176 chain(s). Key: chain name, value: alignment
177 (seq1: SEQRES, seq2: sequence of residues in chain).
178 Example: The residues in a chain with name "A" have
179 sequence "YEAH" and residue numbers [42,43,44,45].
180 You can provide an alignment with seq1 "``HELLYEAH``"
181 and seq2 "``----YEAH``". "Y" gets assigned residue
182 number 5, "E" gets assigned 6 and so on no matter
183 what the original residue numbers were.
184 :type seqres_mapping: :class:`dict` (key: :class:`str`, value:
185 :class:`ost.seq.AlignmentHandle`)
186 :param bb_only: Only consider atoms with name "CA" in case of amino acids and
187 "C3'" for Nucleotides. this invalidates *compound_lib*.
188 Raises if any residue in *target* is not
189 `r.chem_class.IsPeptideLinking()` or
190 `r.chem_class.IsNucleotideLinking()`
191 :type bb_only: :class:`bool`
192 :raises: :class:`RuntimeError` if *target* contains compound which is not in
193 *compound_lib*, :class:`RuntimeError` if *symmetry_settings*
194 specifies symmetric atoms that are not present in the according
195 compound in *compound_lib*, :class:`RuntimeError` if
196 *seqres_mapping* is not provided and *target* contains residue
197 numbers with insertion codes or the residue numbers for each chain
198 are not monotonically increasing, :class:`RuntimeError` if
199 *seqres_mapping* is provided but an alignment is invalid
200 (seq1 contains gaps, mismatch in seq1/seq2, seq2 does not match
201 residues in corresponding chains).
207 custom_compounds=None,
209 sequence_separation=0,
210 symmetry_settings=None,
211 seqres_mapping=dict(),
218 if compound_lib
is None:
219 compound_lib = conop.GetDefaultLib()
220 if compound_lib
is None:
221 raise RuntimeError(
"No compound_lib given and conop.GetDefaultLib "
222 "returns no valid compound library")
225 if symmetry_settings
is None:
326 lDDTScorer._SetupDistances(self.
targettarget, self.
n_atomsn_atoms,
335 lDDTScorer._SetupDistances(self.
targettarget, self.
n_atomsn_atoms,
360 lDDTScorer._SetupDistancesSC(self.
n_atomsn_atoms,
370 lDDTScorer._SetupDistancesSC(self.
n_atomsn_atoms,
380 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
390 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
400 lDDTScorer._SetupDistancesIC(self.
n_atomsn_atoms,
410 lDDTScorer._SetupDistancesIC(self.
n_atomsn_atoms,
420 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
430 lDDTScorer._NonSymDistances(self.
n_atomsn_atoms,
436 def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
437 local_lddt_prop=None, local_contact_prop=None,
438 chain_mapping=None, no_interchain=False,
439 no_intrachain=False, penalize_extra_chains=False,
440 residue_mapping=None, return_dist_test=False,
441 check_resnames=True, add_mdl_contacts=False,
442 interaction_data=None, set_atom_props=False):
443 """Computes lDDT of *model* - globally and per-residue
445 :param model: Model to be scored - models are preferably scored upon
446 performing stereo-chemistry checks in order to punish for
447 non-sensical irregularities. This must be done separately
448 as a pre-processing step. Target contacts that are not
449 covered by *model* are considered not conserved, thus
450 decreasing lDDT score. This also includes missing model
451 chains or model chains for which no mapping is provided in
453 :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
454 :param thresholds: Thresholds of distance differences to be considered
455 as correct - see docs in constructor for more info.
456 default: [0.5, 1.0, 2.0, 4.0]
457 :type thresholds: :class:`list` of :class:`floats`
458 :param local_lddt_prop: If set, per-residue scores will be assigned as
459 generic float property of that name
460 :type local_lddt_prop: :class:`str`
461 :param local_contact_prop: If set, number of expected contacts as well
462 as number of conserved contacts will be
463 assigned as generic int property.
464 Excected contacts will be set as
465 <local_contact_prop>_exp, conserved contacts
466 as <local_contact_prop>_cons. Values
467 are summed over all thresholds.
468 :type local_contact_prop: :class:`str`
469 :param chain_mapping: Mapping of model chains (key) onto target chains
470 (value). This is required if target or model have
472 :type chain_mapping: :class:`dict` with :class:`str` as keys/values
473 :param no_interchain: Whether to exclude interchain contacts
474 :type no_interchain: :class:`bool`
475 :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
476 consider interface related contacts)
477 :type no_intrachain: :class:`bool`
478 :param penalize_extra_chains: Whether to include a fixed penalty for
479 additional chains in the model that are
480 not mapped to the target. ONLY AFFECTS
481 RETURNED GLOBAL SCORE. In detail: adds the
482 number of intra-chain contacts of each
483 extra chain to the expected contacts, thus
485 :type penalize_extra_chains: :class:`bool`
486 :param residue_mapping: By default, residue mapping is based on residue
487 numbers. That means, a model chain and the
488 respective target chain map to the same
489 underlying reference sequence (SEQRES).
490 Alternatively, you can specify one or
491 several alignment(s) between model and target
492 chains by providing a dictionary. key: Name
493 of chain in model (respective target chain is
494 extracted from *chain_mapping*),
495 value: Alignment with first sequence
496 corresponding to target chain and second
497 sequence to model chain. There is NO reference
498 sequence involved, so the two sequences MUST
499 exactly match the actual residues observed in
500 the respective target/model chains (ATOMSEQ).
501 :type residue_mapping: :class:`dict` with key: :class:`str`,
502 value: :class:`ost.seq.AlignmentHandle`
503 :param return_dist_test: Whether to additionally return the underlying
504 per-residue data for the distance difference
505 test. Adds five objects to the return tuple.
506 First: Number of total contacts summed over all
508 Second: Number of conserved contacts summed
510 Third: list with length of scored residues.
511 Contains indices referring to model.residues.
512 Fourth: numpy array of size
513 len(scored_residues) containing the number of
515 Fifth: numpy matrix of shape
516 (len(scored_residues), len(thresholds))
517 specifying how many for each threshold are
519 :param check_resnames: On by default. Enforces residue name matches
520 between mapped model and target residues.
521 :type check_resnames: :class:`bool`
522 :param add_mdl_contacts: Adds model contacts - Only using contacts that
523 are within a certain distance threshold in the
524 target does not penalize for added model
525 contacts. If set to True, this flag will also
526 consider target contacts that are within the
527 specified distance threshold in the model but
528 not necessarily in the target. No contact will
529 be added if the respective atom pair is not
530 resolved in the target.
531 :type add_mdl_contacts: :class:`bool`
532 :param interaction_data: Pro param - don't use
533 :type interaction_data: :class:`tuple`
534 :param set_atom_props: If True, sets generic properties on a per atom
535 level if *local_lddt_prop*/*local_contact_prop*
537 In other words: this is the only way you can
538 get per-atom lDDT values.
539 :type set_atom_props: :class:`bool`
541 :returns: global and per-residue lDDT scores as a tuple -
542 first element is global lDDT score (None if *target* has no
543 contacts) and second element a list of per-residue scores with
544 length len(*model*.residues). None is assigned to residues that
545 are not covered by target. If a residue is covered but has no
546 contacts in *target*, 0.0 is assigned.
548 if chain_mapping
is None:
549 if len(self.
chain_nameschain_names) > 1
or len(model.chains) > 1:
550 raise NotImplementedError(
"Must provide chain mapping if "
551 "target or model have > 1 chains.")
552 chain_mapping = {model.chains[0].GetName(): self.
chain_nameschain_names[0]}
555 for model_chain, target_chain
in chain_mapping.items():
556 if target_chain
not in self.
chain_nameschain_names:
557 raise RuntimeError(f
"Target chain specified in "
558 f
"chain_mapping ({target_chain}) does "
559 f
"not exist. Target has chains: "
560 f
"{self.chain_names}")
561 ch = model.FindChain(model_chain)
563 raise RuntimeError(f
"Model chain specified in "
564 f
"chain_mapping ({model_chain}) does "
565 f
"not exist. Model has chains: "
566 f
"{[c.GetName() for c in model.chains]}")
570 pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
571 res_indices, ref_res_indices, symmetries = \
573 residue_mapping = residue_mapping,
574 thresholds = thresholds,
575 check_resnames = check_resnames)
577 if no_interchain
and no_intrachain:
578 raise RuntimeError(
"no_interchain and no_intrachain flags are "
579 "mutually exclusive")
582 sym_ref_indices =
None
583 sym_ref_distances =
None
587 if interaction_data
is None:
605 ref_indices, ref_distances = \
606 self.
_AddMdlContacts_AddMdlContacts(model, res_atom_indices, res_atom_hashes,
607 ref_indices, ref_distances,
608 no_interchain, no_intrachain)
610 sym_ref_indices, sym_ref_distances = \
612 ref_indices, ref_distances)
614 sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \
617 self.
_ResolveSymmetries_ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
620 atom_indices = list(itertools.chain.from_iterable(res_atom_indices))
622 per_atom_exp = np.asarray([self.
_GetNExp_GetNExp(i, ref_indices)
623 for i
in atom_indices], dtype=np.int32)
624 per_res_exp = np.asarray([self.
_GetNExp_GetNExp(res_ref_atom_indices[idx],
625 ref_indices)
for idx
in range(len(res_indices))], dtype=np.int32)
627 per_atom_conserved = self.
_EvalAtoms_EvalAtoms(pos, atom_indices, thresholds,
628 ref_indices, ref_distances)
629 per_res_conserved = np.zeros((len(res_atom_indices), len(thresholds)),
632 for r_idx
in range(len(res_atom_indices)):
633 end_idx = start_idx + len(res_atom_indices[r_idx])
634 per_res_conserved[r_idx] = np.sum(per_atom_conserved[start_idx:end_idx,:],
638 n_thresh = len(thresholds)
641 per_res_lDDT = [
None] * model.GetResidueCount()
642 for idx
in range(len(res_indices)):
643 n_exp = n_thresh * per_res_exp[idx]
645 score = np.sum(per_res_conserved[idx,:]) / n_exp
646 per_res_lDDT[res_indices[idx]] = score
648 per_res_lDDT[res_indices[idx]] = 0.0
651 n_distances = sum([len(x)
for x
in ref_indices])
652 if penalize_extra_chains:
655 lDDT_tot = int(n_thresh * n_distances)
656 lDDT_cons = int(np.sum(per_res_conserved))
659 lDDT = float(lDDT_cons) / lDDT_tot
663 residues = model.residues
664 for idx
in res_indices:
665 residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx])
667 if local_contact_prop:
668 residues = model.residues
669 exp_prop = local_contact_prop +
"_exp"
670 conserved_prop = local_contact_prop +
"_cons"
672 for i, r_idx
in enumerate(res_indices):
673 residues[r_idx].SetIntProp(exp_prop,
674 n_thresh * int(per_res_exp[i]))
675 residues[r_idx].SetIntProp(conserved_prop,
676 int(np.sum(per_res_conserved[i,:])))
678 if set_atom_props
and (local_lddt_prop
or local_contact_prop):
680 residues = model.residues
681 for i, indices
in enumerate(res_atom_indices):
682 r = residues[res_indices[i]]
683 r_idx = ref_res_indices[i]
687 a = r.FindAtom(anames[a_i - res_start_idx])
691 summed_per_atom_conserved = per_atom_conserved.sum(axis=1)
695 for a_idx
in range(len(atom_list)):
696 if per_atom_exp[a_idx] != 0:
697 tmp = summed_per_atom_conserved[a_idx] / per_atom_exp[a_idx]
699 atom_list[a_idx].SetFloatProp(local_lddt_prop, tmp)
701 if local_contact_prop:
702 conserved_prop = local_contact_prop +
"_cons"
703 exp_prop = local_contact_prop +
"_exp"
704 for a_idx
in range(len(atom_list)):
706 tmp = summed_per_atom_conserved[a_idx]
707 atom_list[a_idx].SetIntProp(conserved_prop, tmp)
709 tmp = per_atom_exp[a_idx] * n_thresh
710 atom_list[a_idx].SetIntProp(exp_prop, tmp)
713 return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
714 per_res_exp, per_res_conserved
716 return lDDT, per_res_lDDT
719 """Returns number of contacts expected for a certain chain in *target*
721 :param target_chain: Chain in *target* for which you want the number
723 :type target_chain: :class:`str`
724 :param no_interchain: Whether to exclude interchain contacts
725 :type no_interchain: :class:`bool`
726 :raises: :class:`RuntimeError` if specified chain doesnt exist
728 if target_chain
not in self.
chain_nameschain_names:
729 raise RuntimeError(f
"Specified chain name ({target_chain}) not in "
731 ch_idx = self.
chain_nameschain_names.index(target_chain)
741 def _ProcessModel(self, model, chain_mapping, residue_mapping = None,
742 thresholds = [0.5, 1.0, 2.0, 4.0],
743 check_resnames = True):
744 """ Helper that generates data structures from model
749 max_pos = model.bounds.GetMax()
750 max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
751 max_coordinate += 42 * max(thresholds)
752 pos = np.ones((self.
n_atomsn_atoms, 3), dtype=np.float32) * max_coordinate
756 res_ref_atom_indices = list()
760 res_atom_indices = list()
764 res_atom_hashes = list()
770 ref_res_indices = list()
775 current_model_res_idx = -1
776 for ch
in model.chains:
777 model_ch_name = ch.GetName()
778 if model_ch_name
not in chain_mapping:
779 current_model_res_idx += len(ch.residues)
781 target_ch_name = chain_mapping[model_ch_name]
783 rnums = self.
_GetChainRNums_GetChainRNums(ch, residue_mapping, model_ch_name,
786 for r, rnum
in zip(ch.residues, rnums):
787 current_model_res_idx += 1
788 res_mapper_key = (target_ch_name, rnum)
789 if res_mapper_key
not in self.
res_mapperres_mapper:
791 r_idx = self.
res_mapperres_mapper[res_mapper_key]
792 if check_resnames
and r.name != self.
compound_namescompound_names[r_idx]:
794 f
"Residue name mismatch for {r}, "
795 f
" expect {self.compound_names[r_idx]}"
800 atoms = [r.FindAtom(aname)
for aname
in anames]
801 res_ref_atom_indices.append(
802 list(range(res_start_idx, res_start_idx + len(anames)))
804 res_atom_indices.append(list())
805 res_atom_hashes.append(list())
806 res_indices.append(current_model_res_idx)
807 ref_res_indices.append(r_idx)
808 for a_idx, a
in enumerate(atoms):
811 pos[res_start_idx + a_idx][0] = p[0]
812 pos[res_start_idx + a_idx][1] = p[1]
813 pos[res_start_idx + a_idx][2] = p[2]
814 res_atom_indices[-1].append(res_start_idx + a_idx)
815 res_atom_hashes[-1].append(a.handle.GetHashCode())
819 a_one = atoms[sym_tuple[0]]
820 a_two = atoms[sym_tuple[1]]
821 if a_one.IsValid()
and a_two.IsValid():
824 res_start_idx + sym_tuple[0],
825 res_start_idx + sym_tuple[1],
828 if len(sym_indices) > 0:
829 symmetries.append(sym_indices)
831 return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes,
832 res_indices, ref_res_indices, symmetries)
835 def _GetExtraModelChainPenalty(self, model, chain_mapping):
836 """Counts n distances in extra model chains to be added as penalty
839 for chain
in model.chains:
840 ch_name = chain.GetName()
841 if ch_name
not in chain_mapping:
843 mdl_sel = model.Select(f
"cname={mol.QueryQuoteName(ch_name)}")
845 symmetry_settings = sm,
848 penalty += sum([len(x)
for x
in dummy_scorer.ref_indices])
851 def _GetChainRNums(self, ch, residue_mapping, model_ch_name,
853 """Map residues in model chain to target residues
855 There are two options: one is simply using residue numbers,
856 the other is a custom mapping as given in *residue_mapping*
858 if residue_mapping
and model_ch_name
in residue_mapping:
860 ch_idx = self.
chain_nameschain_names.index(target_ch_name)
866 target_rnums = self.
res_resnumsres_resnums[start_idx:end_idx]
868 target_seq = residue_mapping[model_ch_name].GetSequence(0)
869 model_seq = residue_mapping[model_ch_name].GetSequence(1)
870 if len(target_seq.GetGaplessString()) != len(target_rnums):
871 raise RuntimeError(f
"Try to perform residue mapping for "
872 f
"model chain {model_ch_name} which "
873 f
"maps to {target_ch_name} in target. "
874 f
"Target sequence in alignment suggests "
875 f
"{len(target_seq.GetGaplessString())} "
876 f
"residues but {len(target_rnums)} are "
878 if len(model_seq.GetGaplessString()) != len(ch.residues):
879 raise RuntimeError(f
"Try to perform residue mapping for "
880 f
"model chain {model_ch_name} which "
881 f
"maps to {target_ch_name} in target. "
882 f
"Model sequence in alignment suggests "
883 f
"{len(model_seq.GetGaplessString())} "
884 f
"residues but {len(ch.residues)} are "
888 for col
in residue_mapping[model_ch_name]:
892 if col[0] !=
'-' and col[1] !=
'-':
893 rnums.append(target_rnums[target_idx])
895 if col[0] ==
'-' and col[1] !=
'-':
898 rnums = [r.GetNumber()
for r
in ch.residues]
903 def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings,
904 seqres_mapping, bb_only):
905 """Sets target related lDDTScorer members defined in constructor
907 No distance related members - see _SetupDistances
913 for chain
in self.
targettarget.chains:
914 ch_name = chain.GetName()
918 for r, rnum
in zip(chain.residues, residue_numbers[ch_name]):
922 self.
_SetupCompound_SetupCompound(r, compound_lib, custom_compounds,
923 symmetry_settings, bb_only)
930 atoms = [r.FindAtom(an)
for an
in self.
compound_anamescompound_anames[r.name]]
933 self.
atom_indicesatom_indices[a.handle.GetHashCode()] = current_idx
935 positions.append(np.asarray([p[0], p[1], p[2]],
938 positions.append(np.zeros(3, dtype=np.float32))
943 for a_idx
in sym_tuple:
946 hashcode = a.handle.GetHashCode()
950 self.
positionspositions = np.vstack(positions)
951 self.
n_atomsn_atoms = current_idx
953 def _GetTargetResidueNumbers(self, target, seqres_mapping):
954 """Returns residue numbers for each chain in target as dict
956 They're either directly extracted from the raw residue number
957 from the structure or from user provided alignments
959 residue_numbers = dict()
960 for ch
in target.chains:
961 ch_name = ch.GetName()
963 if ch_name
in seqres_mapping:
964 seqres = seqres_mapping[ch_name].GetSequence(0).GetString()
965 atomseq = seqres_mapping[ch_name].GetSequence(1).GetString()
969 "SEQRES in seqres_mapping must not " "contain gaps"
971 atomseq_from_chain = [r.one_letter_code
for r
in ch.residues]
972 if atomseq.replace(
"-",
"") != atomseq_from_chain:
974 "ATOMSEQ in seqres_mapping must match "
975 "raw sequence extracted from chain "
979 for seqres_olc, atomseq_olc
in zip(seqres, atomseq):
980 if seqres_olc !=
"-":
982 if atomseq_olc !=
"-":
983 if seqres_olc != atomseq_olc:
985 f
"Residue with number {rnum} in "
986 f
"chain {ch_name} has SEQRES "
991 rnums = [r.GetNumber()
for r
in ch.residues]
992 assert len(rnums) == len(ch.residues)
993 residue_numbers[ch_name] = rnums
994 return residue_numbers
996 def _SetupCompound(self, r, compound_lib, custom_compounds,
997 symmetry_settings, bb_only):
998 """fill self.compound_anames/self.compound_symmetric_atoms
1002 if r.chem_class.IsPeptideLinking():
1004 elif r.chem_class.IsNucleotideLinking():
1007 raise RuntimeError(f
"Only support amino acids and nucleotides "
1008 f
"if bb_only is True, failed with {str(r)}")
1012 symmetric_atoms = list()
1013 if custom_compounds
is not None and r.GetName()
in custom_compounds:
1014 atom_names = list(custom_compounds[r.GetName()].atom_names)
1016 compound = compound_lib.FindCompound(r.name)
1017 if compound
is None:
1018 raise RuntimeError(f
"no entry for {r} in compound_lib")
1019 for atom_spec
in compound.GetAtomSpecs():
1020 if atom_spec.element
not in [
"H",
"D"]:
1021 atom_names.append(atom_spec.name)
1022 if r.name
in symmetry_settings.symmetric_compounds:
1023 for pair
in symmetry_settings.symmetric_compounds[r.name]:
1025 a = atom_names.index(pair[0])
1026 b = atom_names.index(pair[1])
1028 msg = f
"Could not find symmetric atoms "
1029 msg += f
"({pair[0]}, {pair[1]}) for {r.name} "
1030 msg += f
"as specified in SymmetrySettings in "
1031 msg += f
"compound from component dictionary. "
1032 msg += f
"Atoms in compound: {atom_names}"
1033 raise RuntimeError(msg)
1034 symmetric_atoms.append((a, b))
1036 if len(symmetric_atoms) > 0:
1039 def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes,
1040 ref_indices, ref_distances, no_interchain,
1044 in_target = np.zeros(self.
n_atomsn_atoms, dtype=bool)
1047 mdl_atom_indices = dict()
1048 for at_indices, at_hashes
in zip(res_atom_indices, res_atom_hashes):
1049 for i, h
in zip(at_indices, at_hashes):
1051 mdl_atom_indices[h] = i
1056 mdl_ref_indices, mdl_ref_distances = \
1057 lDDTScorer._SetupDistances(model, self.
n_atomsn_atoms, mdl_atom_indices,
1060 mdl_ref_indices, mdl_ref_distances = \
1061 lDDTScorer._SetupDistancesSC(self.
n_atomsn_atoms,
1067 mdl_ref_indices, mdl_ref_distances = \
1068 lDDTScorer._SetupDistancesIC(self.
n_atomsn_atoms,
1074 for i
in range(self.
n_atomsn_atoms):
1075 mask = np.isin(mdl_ref_indices[i], ref_indices[i],
1076 assume_unique=
True, invert=
True)
1077 if np.sum(mask) > 0:
1078 added_mdl_indices = mdl_ref_indices[i][mask]
1079 ref_indices[i] = np.append(ref_indices[i],
1083 tmp = self.
positionspositions.take(added_mdl_indices, axis=0)
1084 np.subtract(tmp, self.
positionspositions[i][
None, :], out=tmp)
1085 np.square(tmp, out=tmp)
1086 tmp = tmp.sum(axis=1)
1087 np.sqrt(tmp, out=tmp)
1088 ref_distances[i] = np.append(ref_distances[i], tmp)
1090 return (ref_indices, ref_distances)
1095 def _SetupDistances(structure, n_atoms, atom_index_mapping,
1098 """Compute distance related members of lDDTScorer
1100 Brute force all vs all distance computation kills lDDT for large
1101 complexes. Instead of building some KD tree data structure, we make use
1102 of expected spatial proximity of atoms in the same chain. Distances are
1103 computed as follows:
1105 - process each chain individually
1106 - perform crude collision detection
1107 - process potentially interacting chain pairs
1108 - concatenate distances from all processing steps
1110 ref_indices = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1111 ref_distances = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1113 indices = [list()
for _
in range(n_atoms)]
1114 distances = [list()
for _
in range(n_atoms)]
1115 per_chain_pos = list()
1116 per_chain_indices = list()
1119 for ch
in structure.chains:
1121 atom_indices = list()
1125 for r_idx, r
in enumerate(ch.residues):
1128 hash_code = a.handle.GetHashCode()
1129 if hash_code
in atom_index_mapping:
1131 pos_list.append(np.asarray([p[0], p[1], p[2]]))
1132 atom_indices.append(atom_index_mapping[hash_code])
1134 mask_start.extend([r_start_idx] * n_valid_atoms)
1135 mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms)
1136 r_start_idx += n_valid_atoms
1138 if len(pos_list) == 0:
1142 pos = np.vstack(pos_list)
1143 atom_indices = np.asarray(atom_indices)
1144 dists =
cdist(pos, pos)
1147 far_away = 2 * inclusion_radius
1148 for idx
in range(atom_indices.shape[0]):
1149 dists[idx, range(mask_start[idx], mask_end[idx])] = far_away
1152 within_mask = dists < inclusion_radius
1153 for idx
in range(atom_indices.shape[0]):
1154 indices_to_append = atom_indices[within_mask[idx,:]]
1155 if indices_to_append.shape[0] > 0:
1156 full_at_idx = atom_indices[idx]
1157 indices[full_at_idx].append(indices_to_append)
1158 distances[full_at_idx].append(dists[idx, within_mask[idx,:]])
1160 per_chain_pos.append(pos)
1161 per_chain_indices.append(atom_indices)
1164 min_pos = [p.min(0)
for p
in per_chain_pos]
1165 max_pos = [p.max(0)
for p
in per_chain_pos]
1166 chain_pairs = list()
1167 for idx_one
in range(len(per_chain_pos)):
1168 for idx_two
in range(idx_one + 1, len(per_chain_pos)):
1169 if np.max(min_pos[idx_one] - max_pos[idx_two]) > inclusion_radius:
1171 if np.max(min_pos[idx_two] - max_pos[idx_one]) > inclusion_radius:
1173 chain_pairs.append((idx_one, idx_two))
1176 for pair
in chain_pairs:
1177 dists =
cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1178 within = dists <= inclusion_radius
1181 tmp = within.sum(axis=1)
1182 for idx
in range(tmp.shape[0]):
1187 at_idx = per_chain_indices[pair[0]][idx]
1188 indices_to_insert = per_chain_indices[pair[1]][within[idx,:]]
1189 distances_to_insert = dists[idx, within[idx, :]]
1190 insertion_idx = len(indices[at_idx])
1191 for i
in range(insertion_idx):
1192 if indices_to_insert[0] > indices[at_idx][i][0]:
1195 indices[at_idx].insert(insertion_idx, indices_to_insert)
1196 distances[at_idx].insert(insertion_idx, distances_to_insert)
1199 tmp = within.sum(axis=0)
1200 for idx
in range(tmp.shape[0]):
1205 at_idx = per_chain_indices[pair[1]][idx]
1206 indices_to_insert = per_chain_indices[pair[0]][within[:, idx]]
1207 distances_to_insert = dists[within[:, idx], idx]
1208 insertion_idx = len(indices[at_idx])
1209 for i
in range(insertion_idx):
1210 if indices_to_insert[0] > indices[at_idx][i][0]:
1213 indices[at_idx].insert(insertion_idx, indices_to_insert)
1214 distances[at_idx].insert(insertion_idx, distances_to_insert)
1217 for at_idx
in range(n_atoms):
1218 if len(indices[at_idx]) > 0:
1219 ref_indices[at_idx] = np.hstack(indices[at_idx])
1220 ref_distances[at_idx] = np.hstack(distances[at_idx])
1222 return (ref_indices, ref_distances)
1225 def _SetupDistancesSC(n_atoms, chain_start_indices,
1226 ref_indices, ref_distances):
1227 """Select subset of contacts only covering intra-chain contacts
1230 ref_indices_sc = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1231 ref_distances_sc = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1233 n_chains = len(chain_start_indices)
1234 for ch_idx
in range(n_chains):
1235 chain_s = chain_start_indices[ch_idx]
1237 if ch_idx + 1 < n_chains:
1238 chain_e = chain_start_indices[ch_idx+1]
1239 for i
in range(chain_s, chain_e):
1240 if len(ref_indices[i]) > 0:
1241 intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
1242 ref_indices[i]<chain_e))[0]
1243 ref_indices_sc[i] = ref_indices[i][intra_idx]
1244 ref_distances_sc[i] = ref_distances[i][intra_idx]
1246 return (ref_indices_sc, ref_distances_sc)
1249 def _SetupDistancesIC(n_atoms, chain_start_indices,
1250 ref_indices, ref_distances):
1251 """Select subset of contacts only covering inter-chain contacts
1254 ref_indices_ic = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1255 ref_distances_ic = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1257 n_chains = len(chain_start_indices)
1258 for ch_idx
in range(n_chains):
1259 chain_s = chain_start_indices[ch_idx]
1261 if ch_idx + 1 < n_chains:
1262 chain_e = chain_start_indices[ch_idx+1]
1263 for i
in range(chain_s, chain_e):
1264 if len(ref_indices[i]) > 0:
1265 inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s,
1266 ref_indices[i]>=chain_e))[0]
1267 ref_indices_ic[i] = ref_indices[i][inter_idx]
1268 ref_distances_ic[i] = ref_distances[i][inter_idx]
1270 return (ref_indices_ic, ref_distances_ic)
1273 def _NonSymDistances(n_atoms, symmetric_atoms, ref_indices, ref_distances):
1274 """Transfer indices/distances of non-symmetric atoms and return
1277 sym_ref_indices = [np.asarray([], dtype=np.int64)
for idx
in range(n_atoms)]
1278 sym_ref_distances = [np.asarray([], dtype=np.float64)
for idx
in range(n_atoms)]
1280 for idx
in symmetric_atoms:
1283 for i, d
in zip(ref_indices[idx], ref_distances[idx]):
1284 if i
not in symmetric_atoms:
1287 sym_ref_indices[idx] = indices
1288 sym_ref_distances[idx] = np.asarray(distances)
1290 return (sym_ref_indices, sym_ref_distances)
1292 def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances):
1293 """Computes number of distance differences within given thresholds
1295 returns np.array with len(thresholds) elements
1297 a_p = pos[atom_idx, :]
1298 tmp = pos.take(ref_indices[atom_idx], axis=0)
1299 np.subtract(tmp, a_p[
None, :], out=tmp)
1300 np.square(tmp, out=tmp)
1301 tmp = tmp.sum(axis=1)
1302 np.sqrt(tmp, out=tmp)
1303 np.subtract(ref_distances[atom_idx], tmp, out=tmp)
1304 np.absolute(tmp, out=tmp)
1305 return np.asarray([(tmp <= thresh).sum()
for thresh
in thresholds],
1309 self, pos, atom_indices, thresholds, ref_indices, ref_distances
1311 """Calls _EvalAtom for several atoms and sums up the computed number
1312 of distance differences within given thresholds
1314 returns numpy matrix of shape (n_atoms, len(threshold))
1316 conserved = np.zeros((len(atom_indices), len(thresholds)),
1318 for a_idx, a
in enumerate(atom_indices):
1319 conserved[a_idx, :] = self.
_EvalAtom_EvalAtom(pos, a, thresholds,
1320 ref_indices, ref_distances)
1323 def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices,
1325 """Calls _EvalAtoms for a bunch of residues
1327 residues are defined in *res_atom_indices* as lists of atom indices
1328 returns numpy matrix of shape (n_residues, len(thresholds)).
1330 conserved = np.zeros((len(res_atom_indices), len(thresholds)),
1332 for rai_idx, rai
in enumerate(res_atom_indices):
1333 conserved[rai_idx,:] = np.sum(self.
_EvalAtoms_EvalAtoms(pos, rai, thresholds,
1334 ref_indices, ref_distances), axis=0)
1337 def _ProcessSequenceSeparation(self):
1339 raise NotImplementedError(
"Congratulations! You're the first one "
1340 "requesting a non-default "
1341 "sequence_separation in the new and "
1342 "awesome lDDT implementation. A crate of "
1343 "beer for Gabriel and he'll implement "
1346 def _GetNExp(self, atom_idx, ref_indices):
1347 """Returns number of close atoms around one or several atoms
1349 if isinstance(atom_idx, int):
1350 return len(ref_indices[atom_idx])
1351 elif isinstance(atom_idx, list):
1352 return sum([len(ref_indices[idx])
for idx
in atom_idx])
1354 raise RuntimeError(
"invalid input type")
1356 def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices,
1358 """Swaps symmetric positions in-place in order to maximize lDDT scores
1359 towards non-symmetric atoms.
1361 for sym
in symmetries:
1363 atom_indices = list()
1364 for sym_tuple
in sym:
1365 atom_indices += [sym_tuple[0], sym_tuple[1]]
1366 tot = self.
_GetNExp_GetNExp(atom_indices, sym_ref_indices)
1372 sym_one_conserved = self.
_EvalAtoms_EvalAtoms(
1382 pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
1384 sym_two_conserved = self.
_EvalAtoms_EvalAtoms(
1392 sym_one_score = np.sum(sym_one_conserved) / (len(thresholds) * tot)
1393 sym_two_score = np.sum(sym_two_conserved) / (len(thresholds) * tot)
1395 if sym_one_score >= sym_two_score:
1400 pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
def __init__(self, atom_names)
def AddSymmetricCompound(self, name, symmetric_atoms)
def _SetupCompound(self, r, compound_lib, custom_compounds, symmetry_settings, bb_only)
def lDDT(self, model, thresholds=[0.5, 1.0, 2.0, 4.0], local_lddt_prop=None, local_contact_prop=None, chain_mapping=None, no_interchain=False, no_intrachain=False, penalize_extra_chains=False, residue_mapping=None, return_dist_test=False, check_resnames=True, add_mdl_contacts=False, interaction_data=None, set_atom_props=False)
def _ProcessModel(self, model, chain_mapping, residue_mapping=None, thresholds=[0.5, 1.0, 2.0, 4.0], check_resnames=True)
def sym_ref_indices(self)
def _GetChainRNums(self, ch, residue_mapping, model_ch_name, target_ch_name)
def _ProcessSequenceSeparation(self)
def sym_ref_distances(self)
def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances)
def ref_distances_ic(self)
def _GetTargetResidueNumbers(self, target, seqres_mapping)
def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances)
def sym_ref_distances_ic(self)
def sym_ref_distances_sc(self)
def sym_ref_indices_ic(self)
def GetNChainContacts(self, target_chain, no_interchain=False)
def sym_ref_indices_sc(self)
def ref_distances_sc(self)
def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes, ref_indices, ref_distances, no_interchain, no_intrachain)
def _GetNExp(self, atom_idx, ref_indices)
def __init__(self, target, compound_lib=None, custom_compounds=None, inclusion_radius=15, sequence_separation=0, symmetry_settings=None, seqres_mapping=dict(), bb_only=False)
def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings, seqres_mapping, bb_only)
def _GetExtraModelChainPenalty(self, model, chain_mapping)
def _EvalAtoms(self, pos, atom_indices, thresholds, ref_indices, ref_distances)
def GetDefaultSymmetrySettings()