456 def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
457 local_lddt_prop=None, local_contact_prop=None,
458 chain_mapping=None, no_interchain=False,
459 no_intrachain=False, penalize_extra_chains=False,
460 residue_mapping=None, return_dist_test=False,
461 check_resnames=True, add_mdl_contacts=False,
462 interaction_data=None, set_atom_props=False):
463 """Computes LDDT of *model* - globally and per-residue
465 :param model: Model to be scored - models are preferably scored upon
466 performing stereo-chemistry checks in order to punish for
467 non-sensical irregularities. This must be done separately
468 as a pre-processing step. Target contacts that are not
469 covered by *model* are considered not conserved, thus
470 decreasing LDDT score. This also includes missing model
471 chains or model chains for which no mapping is provided in
473 :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
474 :param thresholds: Thresholds of distance differences to be considered
475 as correct - see docs in constructor for more info.
476 default: [0.5, 1.0, 2.0, 4.0]
477 :type thresholds: :class:`list` of :class:`floats`
478 :param local_lddt_prop: If set, per-residue scores will be assigned as
479 generic float property of that name
480 :type local_lddt_prop: :class:`str`
481 :param local_contact_prop: If set, number of expected contacts as well
482 as number of conserved contacts will be
483 assigned as generic int property.
484 Excected contacts will be set as
485 <local_contact_prop>_exp, conserved contacts
486 as <local_contact_prop>_cons. Values
487 are summed over all thresholds.
488 :type local_contact_prop: :class:`str`
489 :param chain_mapping: Mapping of model chains (key) onto target chains
490 (value). This is required if target or model have
492 :type chain_mapping: :class:`dict` with :class:`str` as keys/values
493 :param no_interchain: Whether to exclude interchain contacts
494 :type no_interchain: :class:`bool`
495 :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
496 consider interface related contacts)
497 :type no_intrachain: :class:`bool`
498 :param penalize_extra_chains: Whether to include a fixed penalty for
499 additional chains in the model that are
500 not mapped to the target. ONLY AFFECTS
501 RETURNED GLOBAL SCORE. In detail: adds the
502 number of intra-chain contacts of each
503 extra chain to the expected contacts, thus
505 :type penalize_extra_chains: :class:`bool`
506 :param residue_mapping: By default, residue mapping is based on residue
507 numbers. That means, a model chain and the
508 respective target chain map to the same
509 underlying reference sequence (SEQRES).
510 Alternatively, you can specify one or
511 several alignment(s) between model and target
512 chains by providing a dictionary. key: Name
513 of chain in model (respective target chain is
514 extracted from *chain_mapping*),
515 value: Alignment with first sequence
516 corresponding to target chain and second
517 sequence to model chain. There is NO reference
518 sequence involved, so the two sequences MUST
519 exactly match the actual residues observed in
520 the respective target/model chains (ATOMSEQ).
521 :type residue_mapping: :class:`dict` with key: :class:`str`,
522 value: :class:`ost.seq.AlignmentHandle`
523 :param return_dist_test: Whether to additionally return the underlying
524 per-residue data for the distance difference
525 test. Adds five objects to the return tuple.
526 First: Number of total contacts summed over all
528 Second: Number of conserved contacts summed
530 Third: list with length of scored residues.
531 Contains indices referring to model.residues.
532 Fourth: numpy array of size
533 len(scored_residues) containing the number of
535 Fifth: numpy matrix of shape
536 (len(scored_residues), len(thresholds))
537 specifying how many for each threshold are
539 :param check_resnames: On by default. Enforces residue name matches
540 between mapped model and target residues.
541 :type check_resnames: :class:`bool`
542 :param add_mdl_contacts: Adds model contacts - Only using contacts that
543 are within a certain distance threshold in the
544 target does not penalize for added model
545 contacts. If set to True, this flag will also
546 consider target contacts that are within the
547 specified distance threshold in the model but
548 not necessarily in the target. No contact will
549 be added if the respective atom pair is not
550 resolved in the target.
551 :type add_mdl_contacts: :class:`bool`
552 :param interaction_data: Pro param - don't use
553 :type interaction_data: :class:`tuple`
554 :param set_atom_props: If True, sets generic properties on a per atom
555 level if *local_lddt_prop*/*local_contact_prop*
557 In other words: this is the only way you can
558 get per-atom LDDT values.
559 :type set_atom_props: :class:`bool`
561 :returns: global and per-residue LDDT scores as a tuple -
562 first element is global LDDT score (None if *target* has no
563 contacts) and second element a list of per-residue scores with
564 length len(*model*.residues). None is assigned to residues that
565 are not covered by target. If a residue is covered but has no
566 contacts in *target*, 0.0 is assigned.
568 if chain_mapping
is None:
569 if len(self.
chain_names) > 1
or len(model.chains) > 1:
570 raise NotImplementedError(
"Must provide chain mapping if "
571 "target or model have > 1 chains.")
572 chain_mapping = {model.chains[0].GetName(): self.
chain_names[0]}
575 for model_chain, target_chain
in chain_mapping.items():
577 raise RuntimeError(f
"Target chain specified in "
578 f
"chain_mapping ({target_chain}) does "
579 f
"not exist. Target has chains: "
580 f
"{self.chain_names}")
581 ch = model.FindChain(model_chain)
583 raise RuntimeError(f
"Model chain specified in "
584 f
"chain_mapping ({model_chain}) does "
585 f
"not exist. Model has chains: "
586 f
"{[c.GetName() for c in model.chains]}")
590 pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
591 res_indices, ref_res_indices, symmetries = \
593 residue_mapping = residue_mapping,
595 check_resnames = check_resnames)
597 if no_interchain
and no_intrachain:
598 raise RuntimeError(
"no_interchain and no_intrachain flags are "
599 "mutually exclusive")
601 sym_ref_indices =
None
602 sym_ref_distances =
None
606 if interaction_data
is None:
624 ref_indices, ref_distances = \
626 ref_indices, ref_distances,
627 no_interchain, no_intrachain)
629 sym_ref_indices, sym_ref_distances = \
631 ref_indices, ref_distances)
633 sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \
639 atom_indices = list(itertools.chain.from_iterable(res_atom_indices))
641 per_atom_exp = np.asarray([self.
_GetNExp(i, ref_indices)
642 for i
in atom_indices], dtype=np.int32)
643 per_res_exp = np.asarray([self.
_GetNExp(res_ref_atom_indices[idx],
644 ref_indices)
for idx
in range(len(res_indices))], dtype=np.int32)
646 per_atom_conserved = self.
_EvalAtoms(pos, atom_indices, thresholds,
647 ref_indices, ref_distances)
648 per_res_conserved = np.zeros((len(res_atom_indices), len(thresholds)),
651 for r_idx
in range(len(res_atom_indices)):
652 end_idx = start_idx + len(res_atom_indices[r_idx])
653 per_res_conserved[r_idx] = np.sum(per_atom_conserved[start_idx:end_idx,:],
657 n_thresh = len(thresholds)
660 per_res_lDDT = [
None] * model.GetResidueCount()
661 for idx
in range(len(res_indices)):
662 n_exp = n_thresh * per_res_exp[idx]
664 score = np.sum(per_res_conserved[idx,:]) / n_exp
665 per_res_lDDT[res_indices[idx]] = score
667 per_res_lDDT[res_indices[idx]] = 0.0
670 n_distances = sum([len(x)
for x
in ref_indices])
671 if penalize_extra_chains:
674 lDDT_tot = int(n_thresh * n_distances)
675 lDDT_cons = int(np.sum(per_res_conserved))
678 lDDT = float(lDDT_cons) / lDDT_tot
682 residues = model.residues
683 for idx
in res_indices:
684 residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx])
686 if local_contact_prop:
687 residues = model.residues
688 exp_prop = local_contact_prop +
"_exp"
689 conserved_prop = local_contact_prop +
"_cons"
691 for i, r_idx
in enumerate(res_indices):
692 residues[r_idx].SetIntProp(exp_prop,
693 n_thresh * int(per_res_exp[i]))
694 residues[r_idx].SetIntProp(conserved_prop,
695 int(np.sum(per_res_conserved[i,:])))
697 if set_atom_props
and (local_lddt_prop
or local_contact_prop):
699 residues = model.residues
700 for i, indices
in enumerate(res_atom_indices):
701 r = residues[res_indices[i]]
702 r_idx = ref_res_indices[i]
706 a = r.FindAtom(anames[a_i - res_start_idx])
710 summed_per_atom_conserved = per_atom_conserved.sum(axis=1)
714 for a_idx
in range(len(atom_list)):
715 if per_atom_exp[a_idx] != 0:
716 tmp = summed_per_atom_conserved[a_idx] / per_atom_exp[a_idx]
718 atom_list[a_idx].SetFloatProp(local_lddt_prop, tmp)
720 if local_contact_prop:
721 conserved_prop = local_contact_prop +
"_cons"
722 exp_prop = local_contact_prop +
"_exp"
723 for a_idx
in range(len(atom_list)):
725 tmp = summed_per_atom_conserved[a_idx]
726 atom_list[a_idx].SetIntProp(conserved_prop, tmp)
728 tmp = per_atom_exp[a_idx] * n_thresh
729 atom_list[a_idx].SetIntProp(exp_prop, tmp)
732 return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
733 per_res_exp, per_res_conserved
735 return lDDT, per_res_lDDT
737 def DRMSD(self, model, dist_cap = 5,
738 chain_mapping=None, no_interchain=False,
739 no_intrachain=False, residue_mapping=None,
740 check_resnames=True, add_mdl_contacts=False,
741 interaction_data=None):
742 """ DRMSD of *model* - globally and per-residue
744 Very similar to LDDT as we operate on distance differences for all
745 interatomic distances within the same inclusion radius as in LDDT.
746 DRMSD is the distance rmsd, i.e. the RMSD of distance differences.
747 Distance differences are capped at *dist_cap* which is also the default
748 value for missing distances.
750 :param model: Model to be scored - models are preferably scored upon
751 performing stereo-chemistry checks in order to punish for
752 non-sensical irregularities. This must be done separately
753 as a pre-processing step. Target contacts that are not
754 covered by *model* are considered not conserved, thus
755 increasing DRMSD score. This also includes missing model
756 chains or model chains for which no mapping is provided in
758 :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
759 :param dist_cap: Cap for distance differences.
760 :type dist_cap: :class:`float`
761 :param chain_mapping: Mapping of model chains (key) onto target chains
762 (value). This is required if target or model have
764 :type chain_mapping: :class:`dict` with :class:`str` as keys/values
765 :param no_interchain: Whether to exclude interchain contacts
766 :type no_interchain: :class:`bool`
767 :param no_intrachain: Whether to exclude intrachain contacts (i.e. only
768 consider interface related contacts)
769 :type no_intrachain: :class:`bool`
770 :param residue_mapping: By default, residue mapping is based on residue
771 numbers. That means, a model chain and the
772 respective target chain map to the same
773 underlying reference sequence (SEQRES).
774 Alternatively, you can specify one or
775 several alignment(s) between model and target
776 chains by providing a dictionary. key: Name
777 of chain in model (respective target chain is
778 extracted from *chain_mapping*),
779 value: Alignment with first sequence
780 corresponding to target chain and second
781 sequence to model chain. There is NO reference
782 sequence involved, so the two sequences MUST
783 exactly match the actual residues observed in
784 the respective target/model chains (ATOMSEQ).
785 :type residue_mapping: :class:`dict` with key: :class:`str`,
786 value: :class:`ost.seq.AlignmentHandle`
787 :param check_resnames: On by default. Enforces residue name matches
788 between mapped model and target residues.
789 :type check_resnames: :class:`bool`
790 :param add_mdl_contacts: Adds model contacts - Only using contacts that
791 are within a certain distance threshold in the
792 target does not penalize for added model
793 contacts. If set to True, this flag will also
794 consider target contacts that are within the
795 specified distance threshold in the model but
796 not necessarily in the target. No contact will
797 be added if the respective atom pair is not
798 resolved in the target.
799 :type add_mdl_contacts: :class:`bool`
800 :param interaction_data: Pro param - don't use
801 :type interaction_data: :class:`tuple`
803 :returns: global and per-residue DRMSD scores as a tuple -
804 first element is global DRMSD score (None if *target* has no
805 contacts) and second element a list of per-residue scores with
806 length len(*model*.residues). None is assigned to residues that
807 are not covered by target. If a residue is covered but has no
808 contacts in *target*, None is assigned.
810 if chain_mapping
is None:
811 if len(self.
chain_names) > 1
or len(model.chains) > 1:
812 raise NotImplementedError(
"Must provide chain mapping if "
813 "target or model have > 1 chains.")
814 chain_mapping = {model.chains[0].GetName(): self.
chain_names[0]}
817 for model_chain, target_chain
in chain_mapping.items():
819 raise RuntimeError(f
"Target chain specified in "
820 f
"chain_mapping ({target_chain}) does "
821 f
"not exist. Target has chains: "
822 f
"{self.chain_names}")
823 ch = model.FindChain(model_chain)
825 raise RuntimeError(f
"Model chain specified in "
826 f
"chain_mapping ({model_chain}) does "
827 f
"not exist. Model has chains: "
828 f
"{[c.GetName() for c in model.chains]}")
832 pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
833 res_indices, ref_res_indices, symmetries = \
835 residue_mapping = residue_mapping,
837 check_resnames = check_resnames)
839 if no_interchain
and no_intrachain:
840 raise RuntimeError(
"no_interchain and no_intrachain flags are "
841 "mutually exclusive")
843 sym_ref_indices =
None
844 sym_ref_distances =
None
848 if interaction_data
is None:
866 ref_indices, ref_distances = \
868 ref_indices, ref_distances,
869 no_interchain, no_intrachain)
871 sym_ref_indices, sym_ref_distances = \
873 ref_indices, ref_distances)
875 sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \
881 atom_indices = list(itertools.chain.from_iterable(res_atom_indices))
883 per_atom_exp = np.asarray([self.
_GetNExp(i, ref_indices)
884 for i
in atom_indices], dtype=np.int32)
885 per_res_exp = np.asarray([self.
_GetNExp(res_ref_atom_indices[idx],
886 ref_indices)
for idx
in range(len(res_indices))], dtype=np.int32)
887 per_atom_ssd = self.
_EvalAtomsSSD(pos, atom_indices, dist_cap,
888 ref_indices, ref_distances)
892 per_res_drmsd = [
None] * model.GetResidueCount()
893 for r_idx
in range(len(res_atom_indices)):
894 end_idx = start_idx + len(res_atom_indices[r_idx])
895 n_tot = per_res_exp[r_idx]
897 ssd = np.sum(per_atom_ssd[start_idx:end_idx])
900 n_missing = n_tot - np.sum(per_atom_exp[start_idx:end_idx])
901 ssd += n_missing*dist_cap*dist_cap
902 per_res_drmsd[res_indices[r_idx]] = np.sqrt(ssd/n_tot)
907 n_tot = sum([len(x)
for x
in ref_indices])
909 ssd = np.sum(per_atom_ssd)
912 n_missing = n_tot - np.sum(per_atom_exp)
913 ssd += (dist_cap*dist_cap*n_missing)
914 drmsd = np.sqrt(ssd/n_tot)
916 return drmsd, per_res_drmsd
943 check_resnames = True):
944 """ Helper that generates data structures from model
949 max_pos = model.bounds.GetMax()
950 max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
951 max_coordinate += 42 * nirvana_dist
952 pos = np.ones((self.
n_atoms, 3), dtype=np.float32) * max_coordinate
956 res_ref_atom_indices = list()
960 res_atom_indices = list()
964 res_atom_hashes = list()
970 ref_res_indices = list()
975 current_model_res_idx = -1
976 for ch
in model.chains:
977 model_ch_name = ch.GetName()
978 if model_ch_name
not in chain_mapping:
979 current_model_res_idx += len(ch.residues)
981 target_ch_name = chain_mapping[model_ch_name]
986 for r, rnum
in zip(ch.residues, rnums):
987 current_model_res_idx += 1
988 res_mapper_key = (target_ch_name, rnum)
994 f
"Residue name mismatch for {r}, "
995 f
" expect {self.compound_names[r_idx]}"
1000 atoms = [r.FindAtom(aname)
for aname
in anames]
1001 res_ref_atom_indices.append(
1002 list(range(res_start_idx, res_start_idx + len(anames)))
1004 res_atom_indices.append(list())
1005 res_atom_hashes.append(list())
1006 res_indices.append(current_model_res_idx)
1007 ref_res_indices.append(r_idx)
1008 for a_idx, a
in enumerate(atoms):
1011 pos[res_start_idx + a_idx][0] = p[0]
1012 pos[res_start_idx + a_idx][1] = p[1]
1013 pos[res_start_idx + a_idx][2] = p[2]
1014 res_atom_indices[-1].append(res_start_idx + a_idx)
1015 res_atom_hashes[-1].append(a.handle.GetHashCode())
1017 sym_indices = list()
1019 a_one = atoms[sym_tuple[0]]
1020 a_two = atoms[sym_tuple[1]]
1021 if a_one.IsValid()
and a_two.IsValid():
1024 res_start_idx + sym_tuple[0],
1025 res_start_idx + sym_tuple[1],
1028 if len(sym_indices) > 0:
1029 symmetries.append(sym_indices)
1031 return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes,
1032 res_indices, ref_res_indices, symmetries)
1240 ref_indices, ref_distances, no_interchain,
1244 in_target = np.zeros(self.
n_atoms, dtype=bool)
1247 mdl_atom_indices = dict()
1248 for at_indices, at_hashes
in zip(res_atom_indices, res_atom_hashes):
1249 for i, h
in zip(at_indices, at_hashes):
1251 mdl_atom_indices[h] = i
1256 mdl_ref_indices, mdl_ref_distances = \
1257 lDDTScorer._SetupDistances(model, self.
n_atoms, mdl_atom_indices,
1260 mdl_ref_indices, mdl_ref_distances = \
1261 lDDTScorer._SetupDistancesSC(self.
n_atoms,
1267 mdl_ref_indices, mdl_ref_distances = \
1268 lDDTScorer._SetupDistancesIC(self.
n_atoms,
1275 mask = np.isin(mdl_ref_indices[i], ref_indices[i],
1276 assume_unique=
True, invert=
True)
1277 if np.sum(mask) > 0:
1278 added_mdl_indices = mdl_ref_indices[i][mask]
1279 ref_indices[i] = np.append(ref_indices[i],
1283 tmp = self.
positions.take(added_mdl_indices, axis=0)
1284 np.subtract(tmp, self.
positions[i][
None, :], out=tmp)
1285 np.square(tmp, out=tmp)
1286 tmp = tmp.sum(axis=1)
1287 np.sqrt(tmp, out=tmp)
1288 ref_distances[i] = np.append(ref_distances[i], tmp)
1290 return (ref_indices, ref_distances)
1298 """Compute distance related members of lDDTScorer
1300 Brute force all vs all distance computation kills LDDT for large
1301 complexes. Instead of building some KD tree data structure, we make use
1302 of expected spatial proximity of atoms in the same chain. Distances are
1303 computed as follows:
1305 - process each chain individually
1306 - perform crude collision detection
1307 - process potentially interacting chain pairs
1308 - concatenate distances from all processing steps
1310 ref_indices = [np.asarray([], dtype=np.int32)
for idx
in range(n_atoms)]
1311 ref_distances = [np.asarray([], dtype=np.float32)
for idx
in range(n_atoms)]
1313 indices = [list()
for _
in range(n_atoms)]
1314 distances = [list()
for _
in range(n_atoms)]
1315 per_chain_pos = list()
1316 per_chain_indices = list()
1319 for ch
in structure.chains:
1321 atom_indices = list()
1325 for r_idx, r
in enumerate(ch.residues):
1328 hash_code = a.handle.GetHashCode()
1329 if hash_code
in atom_index_mapping:
1331 pos_list.append(np.asarray([p[0], p[1], p[2]], dtype=np.float32))
1332 atom_indices.append(atom_index_mapping[hash_code])
1334 mask_start.extend([r_start_idx] * n_valid_atoms)
1335 mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms)
1336 r_start_idx += n_valid_atoms
1338 if len(pos_list) == 0:
1342 pos = np.vstack(pos_list)
1343 atom_indices = np.asarray(atom_indices, dtype=np.int32)
1345 if atom_indices.shape[0] > 20000:
1348 dists = cdist(pos, pos)
1351 far_away = 2 * inclusion_radius
1352 for idx
in range(atom_indices.shape[0]):
1353 dists[idx, range(mask_start[idx], mask_end[idx])] = far_away
1356 within_mask = dists < inclusion_radius
1357 for idx
in range(atom_indices.shape[0]):
1358 indices_to_append = atom_indices[within_mask[idx,:]]
1359 if indices_to_append.shape[0] > 0:
1360 full_at_idx = atom_indices[idx]
1361 indices[full_at_idx].append(indices_to_append)
1362 distances[full_at_idx].append(dists[idx, within_mask[idx,:]])
1366 per_chain_pos.append(pos)
1367 per_chain_indices.append(atom_indices)
1370 min_pos = [p.min(0)
for p
in per_chain_pos]
1371 max_pos = [p.max(0)
for p
in per_chain_pos]
1372 chain_pairs = list()
1373 for idx_one
in range(len(per_chain_pos)):
1374 for idx_two
in range(idx_one + 1, len(per_chain_pos)):
1375 if np.max(min_pos[idx_one] - max_pos[idx_two]) > inclusion_radius:
1377 if np.max(min_pos[idx_two] - max_pos[idx_one]) > inclusion_radius:
1379 chain_pairs.append((idx_one, idx_two))
1382 for pair
in chain_pairs:
1383 if per_chain_pos[pair[0]].shape[0] > 20000
or per_chain_pos[pair[1]].shape[0] > 20000:
1384 dists =
blockwise_cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1386 dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
1387 within = dists <= inclusion_radius
1390 tmp = within.sum(axis=1)
1391 for idx
in range(tmp.shape[0]):
1396 at_idx = per_chain_indices[pair[0]][idx]
1397 indices_to_insert = per_chain_indices[pair[1]][within[idx,:]]
1398 distances_to_insert = dists[idx, within[idx, :]]
1399 insertion_idx = len(indices[at_idx])
1400 for i
in range(insertion_idx):
1401 if indices_to_insert[0] > indices[at_idx][i][0]:
1404 indices[at_idx].insert(insertion_idx, indices_to_insert)
1405 distances[at_idx].insert(insertion_idx, distances_to_insert)
1408 tmp = within.sum(axis=0)
1409 for idx
in range(tmp.shape[0]):
1414 at_idx = per_chain_indices[pair[1]][idx]
1415 indices_to_insert = per_chain_indices[pair[0]][within[:, idx]]
1416 distances_to_insert = dists[within[:, idx], idx]
1417 insertion_idx = len(indices[at_idx])
1418 for i
in range(insertion_idx):
1419 if indices_to_insert[0] > indices[at_idx][i][0]:
1422 indices[at_idx].insert(insertion_idx, indices_to_insert)
1423 distances[at_idx].insert(insertion_idx, distances_to_insert)
1428 for at_idx
in range(n_atoms):
1429 if len(indices[at_idx]) > 0:
1430 ref_indices[at_idx] = np.hstack(indices[at_idx])
1431 ref_distances[at_idx] = np.hstack(distances[at_idx])
1433 return (ref_indices, ref_distances)