OpenStructure
 All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
chain_mapping.py
Go to the documentation of this file.
1 """
2 Chain mapping aims to identify a one-to-one relationship between chains in a
3 reference structure and a model.
4 """
5 
6 import itertools
7 import copy
8 
9 import numpy as np
10 
11 from scipy.special import factorial
12 from scipy.special import binom # as of Python 3.8, the math module implements
13  # comb, i.e. n choose k
14 
15 import ost
16 from ost import seq
17 from ost import mol
18 from ost import geom
19 
20 from ost.mol.alg import lddt
21 from ost.mol.alg import qsscore
22 
23 def _CSel(ent, cnames):
24  """ Returns view with specified chains
25 
26  Ensures that quotation marks are around chain names to not confuse
27  OST query language with weird special characters.
28  """
29  query = "cname=" + ','.join([mol.QueryQuoteName(cname) for cname in cnames])
30  return ent.Select(query)
31 
33  """ Result object for the chain mapping functions in :class:`ChainMapper`
34 
35  Constructor is directly called within the functions, no need to construct
36  such objects yourself.
37  """
38  def __init__(self, target, model, chem_groups, chem_mapping, mapping, alns,
39  opt_score=None):
40  self._target = target
41  self._model = model
42  self._chem_groups = chem_groups
43  self._chem_mapping = chem_mapping
44  self._mapping = mapping
45  self._alns = alns
46  self._opt_score = opt_score
47 
48  @property
49  def target(self):
50  """ Target/reference structure, i.e. :attr:`ChainMapper.target`
51 
52  :type: :class:`ost.mol.EntityView`
53  """
54  return self._target
55 
56  @property
57  def model(self):
58  """ Model structure that gets mapped onto :attr:`~target`
59 
60  Underwent same processing as :attr:`ChainMapper.target`, i.e.
61  only contains peptide/nucleotide chains of sufficient size.
62 
63  :type: :class:`ost.mol.EntityView`
64  """
65  return self._model
66 
67  @property
68  def chem_groups(self):
69  """ Groups of chemically equivalent chains in :attr:`~target`
70 
71  Same as :attr:`ChainMapper.chem_group`
72 
73  :class:`list` of :class:`list` of :class:`str` (chain names)
74  """
75  return self._chem_groups
76 
77  @property
78  def chem_mapping(self):
79  """ Assigns chains in :attr:`~model` to :attr:`~chem_groups`.
80 
81  :class:`list` of :class:`list` of :class:`str` (chain names)
82  """
83  return self._chem_mapping
84 
85  @property
86  def mapping(self):
87  """ Mapping of :attr:`~model` chains onto :attr:`~target`
88 
89  Exact same shape as :attr:`~chem_groups` but containing the names of the
90  mapped chains in :attr:`~model`. May contain None for :attr:`~target`
91  chains that are not covered. No guarantee that all chains in
92  :attr:`~model` are mapped.
93 
94  :class:`list` of :class:`list` of :class:`str` (chain names)
95  """
96  return self._mapping
97 
98  @property
99  def alns(self):
100  """ Alignments of mapped chains in :attr:`~target` and :attr:`~model`
101 
102  Each alignment is accessible with ``alns[(t_chain,m_chain)]``. First
103  sequence is the sequence of :attr:`target` chain, second sequence the
104  one from :attr:`~model`. The respective :class:`ost.mol.EntityView` are
105  attached with :func:`ost.seq.ConstSequenceHandle.AttachView`.
106 
107  :type: :class:`dict` with key: :class:`tuple` of :class:`str`, value:
108  :class:`ost.seq.AlignmentHandle`
109  """
110  return self._alns
111 
112  @property
113  def opt_score(self):
114  """ Placeholder property without any guarantee of being set
115 
116  Different scores get optimized in the various chain mapping algorithms.
117  Some of them may set their final optimal score in that property.
118  Consult the documentation of the respective chain mapping algorithm
119  for more information. Won't be in the return dict of
120  :func:`JSONSummary`.
121  """
122  return self._opt_score
123 
124  def GetFlatMapping(self, mdl_as_key=False):
125  """ Returns flat mapping as :class:`dict` for all mapable chains
126 
127  :param mdl_as_key: Default is target chain name as key and model chain
128  name as value. This can be reversed with this flag.
129  :returns: :class:`dict` with :class:`str` as key/value that describe
130  one-to-one mapping
131  """
132  flat_mapping = dict()
133  for trg_chem_group, mdl_chem_group in zip(self.chem_groups,
134  self.mapping):
135  for a,b in zip(trg_chem_group, mdl_chem_group):
136  if a is not None and b is not None:
137  if mdl_as_key:
138  flat_mapping[b] = a
139  else:
140  flat_mapping[a] = b
141  return flat_mapping
142 
143  def JSONSummary(self):
144  """ Returns JSON serializable summary of results
145  """
146  json_dict = dict()
147  json_dict["chem_groups"] = self.chem_groups
148  json_dict["mapping"] = self.mapping
149  json_dict["flat_mapping"] = self.GetFlatMapping()
150  json_dict["alns"] = list()
151  for aln in self.alns.values():
152  trg_seq = aln.GetSequence(0)
153  mdl_seq = aln.GetSequence(1)
154  aln_dict = {"trg_ch": trg_seq.GetName(), "trg_seq": str(trg_seq),
155  "mdl_ch": mdl_seq.GetName(), "mdl_seq": str(mdl_seq)}
156  json_dict["alns"].append(aln_dict)
157  return json_dict
158 
159 
161 
162  """ Result object for :func:`ChainMapper.GetRepr`
163 
164  Constructor is directly called within the function, no need to construct
165  such objects yourself.
166 
167  :param lDDT: lDDT for this mapping. Depends on how you call
168  :func:`ChainMapper.GetRepr` whether this is backbone only or
169  full atom lDDT.
170  :type lDDT: :class:`float`
171  :param substructure: The full substructure for which we searched for a
172  representation
173  :type substructure: :class:`ost.mol.EntityView`
174  :param ref_view: View pointing to the same underlying entity as
175  *substructure* but only contains the stuff that is mapped
176  :type ref_view: :class:`mol.EntityView`
177  :param mdl_view: The matching counterpart in model
178  :type mdl_view: :class:`mol.EntityView`
179  """
180  def __init__(self, lDDT, substructure, ref_view, mdl_view):
181  self._lDDT = lDDT
182  self._substructure = substructure
183  assert(len(ref_view.residues) == len(mdl_view.residues))
184  self._ref_view = ref_view
185  self._mdl_view = mdl_view
186 
187  # lazily evaluated attributes
188  self._ref_bb_pos = None
189  self._mdl_bb_pos = None
190  self._ref_full_bb_pos = None
191  self._mdl_full_bb_pos = None
192  self._transform = None
193  self._superposed_mdl_bb_pos = None
194  self._bb_rmsd = None
195  self._gdt_8 = None
196  self._gdt_4 = None
197  self._gdt_2 = None
198  self._gdt_1 = None
199  self._ost_query = None
200  self._flat_mapping = None
201  self._inconsistent_residues = None
202 
203  @property
204  def lDDT(self):
205  """ lDDT of representation result
206 
207  Depends on how you call :func:`ChainMapper.GetRepr` whether this is
208  backbone only or full atom lDDT.
209 
210  :type: :class:`float`
211  """
212  return self._lDDT
213 
214  @property
215  def substructure(self):
216  """ The full substructure for which we searched for a
217  representation
218 
219  :type: :class:`ost.mol.EntityView`
220  """
221  return self._substructure
222 
223  @property
224  def ref_view(self):
225  """ View which contains the mapped subset of :attr:`substructure`
226 
227  :type: :class:`ost.mol.EntityView`
228  """
229  return self._ref_view
230 
231  @property
232  def mdl_view(self):
233  """ The :attr:`ref_view` representation in the model
234 
235  :type: :class:`ost.mol.EntityView`
236  """
237  return self._mdl_view
238 
239  @property
240  def ref_residues(self):
241  """ The reference residues
242 
243  :type: class:`mol.ResidueViewList`
244  """
245  return self.ref_view.residues
246 
247  @property
248  def mdl_residues(self):
249  """ The model residues
250 
251  :type: :class:`mol.ResidueViewList`
252  """
253  return self.mdl_view.residues
254 
255  @property
257  """ A list of mapped residue whose names do not match (eg. ALA in the
258  reference and LEU in the model).
259 
260  The mismatches are reported as a tuple of :class:`~ost.mol.ResidueView`
261  (reference, model), or as an empty list if all the residue names match.
262 
263  :type: :class:`list`
264  """
265  if self._inconsistent_residues is None:
267  self.ref_residues, self.mdl_residues)
268  return self._inconsistent_residues
269 
270  @property
271  def ref_bb_pos(self):
272  """ Representative backbone positions for reference residues.
273 
274  Thats CA positions for peptides and C3' positions for Nucleotides.
275 
276  :type: :class:`geom.Vec3List`
277  """
278  if self._ref_bb_pos is None:
279  self._ref_bb_pos = self._GetBBPos(self.ref_residues)
280  return self._ref_bb_pos
281 
282  @property
283  def mdl_bb_pos(self):
284  """ Representative backbone positions for model residues.
285 
286  Thats CA positions for peptides and C3' positions for Nucleotides.
287 
288  :type: :class:`geom.Vec3List`
289  """
290  if self._mdl_bb_pos is None:
291  self._mdl_bb_pos = self._GetBBPos(self.mdl_residues)
292  return self._mdl_bb_pos
293 
294  @property
295  def ref_full_bb_pos(self):
296  """ Representative backbone positions for reference residues.
297 
298  Thats N, CA and C positions for peptides and O5', C5', C4', C3', O3'
299  positions for Nucleotides.
300 
301  :type: :class:`geom.Vec3List`
302  """
303  if self._ref_full_bb_pos is None:
305  return self._ref_full_bb_pos
306 
307  @property
308  def mdl_full_bb_pos(self):
309  """ Representative backbone positions for reference residues.
310 
311  Thats N, CA and C positions for peptides and O5', C5', C4', C3', O3'
312  positions for Nucleotides.
313 
314  :type: :class:`geom.Vec3List`
315  """
316  if self._mdl_full_bb_pos is None:
318  return self._mdl_full_bb_pos
319 
320  @property
321  def transform(self):
322  """ Transformation to superpose mdl residues onto ref residues
323 
324  Superposition computed as minimal RMSD superposition on
325  :attr:`ref_bb_pos` and :attr:`mdl_bb_pos`. If number of positions is
326  smaller 3, the full_bb_pos equivalents are used instead.
327 
328  :type: :class:`ost.geom.Mat4`
329  """
330  if self._transform is None:
331  if len(self.mdl_bb_pos) < 3:
332  self._transform = _GetTransform(self.mdl_full_bb_pos,
333  self.ref_full_bb_pos, False)
334  else:
335  self._transform = _GetTransform(self.mdl_bb_pos,
336  self.ref_bb_pos, False)
337  return self._transform
338 
339  @property
341  """ :attr:`mdl_bb_pos` with :attr:`transform applied`
342 
343  :type: :class:`geom.Vec3List`
344  """
345  if self._superposed_mdl_bb_pos is None:
347  self._superposed_mdl_bb_pos.ApplyTransform(self.transform)
348  return self._superposed_mdl_bb_pos
349 
350  @property
351  def bb_rmsd(self):
352  """ RMSD between :attr:`ref_bb_pos` and :attr:`superposed_mdl_bb_pos`
353 
354  :type: :class:`float`
355  """
356  if self._bb_rmsd is None:
357  self._bb_rmsd = self.ref_bb_pos.GetRMSD(self.superposed_mdl_bb_pos)
358  return self._bb_rmsd
359 
360  @property
361  def gdt_8(self):
362  """ GDT with one single threshold: 8.0
363 
364  :type: :class:`float`
365  """
366  if self._gdt_8 is None:
367  self._gdt_8 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 8.0)
368  return self._gdt_8
369 
370  @property
371  def gdt_4(self):
372  """ GDT with one single threshold: 4.0
373 
374  :type: :class:`float`
375  """
376  if self._gdt_4 is None:
377  self._gdt_4 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 4.0)
378  return self._gdt_4
379 
380  @property
381  def gdt_2(self):
382  """ GDT with one single threshold: 2.0
383 
384  :type: :class:`float`
385  """
386  if self._gdt_2 is None:
387  self._gdt_2 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 2.0)
388  return self._gdt_2
389 
390  @property
391  def gdt_1(self):
392  """ GDT with one single threshold: 1.0
393 
394  :type: :class:`float`
395  """
396  if self._gdt_1 is None:
397  self._gdt_1 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 1.0)
398  return self._gdt_1
399 
400  @property
401  def ost_query(self):
402  """ query for mdl residues in OpenStructure query language
403 
404  Repr can be selected as ``full_mdl.Select(ost_query)``
405 
406  Returns invalid query if residue numbers have insertion codes.
407 
408  :type: :class:`str`
409  """
410  if self._ost_query is None:
411  chain_rnums = dict()
412  for r in self.mdl_residues:
413  chname = r.GetChain().GetName()
414  rnum = r.GetNumber().GetNum()
415  if chname not in chain_rnums:
416  chain_rnums[chname] = list()
417  chain_rnums[chname].append(str(rnum))
418  chain_queries = list()
419  for k,v in chain_rnums.items():
420  q = f"(cname={mol.QueryQuoteName(k)} and "
421  q += f"rnum={','.join(v)})"
422  chain_queries.append(q)
423  self._ost_query = " or ".join(chain_queries)
424  return self._ost_query
425 
426  def JSONSummary(self):
427  """ Returns JSON serializable summary of results
428  """
429  json_dict = dict()
430  json_dict["lDDT"] = self.lDDT
431  json_dict["ref_residues"] = [r.GetQualifiedName() for r in \
432  self.ref_residues]
433  json_dict["mdl_residues"] = [r.GetQualifiedName() for r in \
434  self.mdl_residues]
435  json_dict["transform"] = list(self.transform.data)
436  json_dict["bb_rmsd"] = self.bb_rmsd
437  json_dict["gdt_8"] = self.gdt_8
438  json_dict["gdt_4"] = self.gdt_4
439  json_dict["gdt_2"] = self.gdt_2
440  json_dict["gdt_1"] = self.gdt_1
441  json_dict["ost_query"] = self.ost_query
442  json_dict["flat_mapping"] = self.GetFlatChainMapping()
443  return json_dict
444 
445  def GetFlatChainMapping(self, mdl_as_key=False):
446  """ Returns flat mapping of all chains in the representation
447 
448  :param mdl_as_key: Default is target chain name as key and model chain
449  name as value. This can be reversed with this flag.
450  :returns: :class:`dict` with :class:`str` as key/value that describe
451  one-to-one mapping
452  """
453  flat_mapping = dict()
454  for trg_res, mdl_res in zip(self.ref_residues, self.mdl_residues):
455  if mdl_as_key:
456  flat_mapping[mdl_res.chain.name] = trg_res.chain.name
457  else:
458  flat_mapping[trg_res.chain.name] = mdl_res.chain.name
459  return flat_mapping
460 
461  def _GetFullBBPos(self, residues):
462  """ Helper to extract full backbone positions
463  """
464  exp_pep_atoms = ["N", "CA", "C"]
465  exp_nuc_atoms = ["\"O5'\"", "\"C5'\"", "\"C4'\"", "\"C3'\"", "\"O3'\""]
466  bb_pos = geom.Vec3List()
467  for r in residues:
468  if r.GetChemType() == mol.ChemType.NUCLEOTIDES:
469  exp_atoms = exp_nuc_atoms
470  elif r.GetChemType() == mol.ChemType.AMINOACIDS:
471  exp_atoms = exp_pep_atoms
472  else:
473  raise RuntimeError("Something terrible happened... RUN...")
474  for aname in exp_atoms:
475  a = r.FindAtom(aname)
476  if not a.IsValid():
477  raise RuntimeError("Something terrible happened... "
478  "RUN...")
479  bb_pos.append(a.GetPos())
480  return bb_pos
481 
482  def _GetBBPos(self, residues):
483  """ Helper to extract single representative position for each residue
484  """
485  bb_pos = geom.Vec3List()
486  for r in residues:
487  at = r.FindAtom("CA")
488  if not at.IsValid():
489  at = r.FindAtom("C3'")
490  if not at.IsValid():
491  raise RuntimeError("Something terrible happened... RUN...")
492  bb_pos.append(at.GetPos())
493  return bb_pos
494 
495  def _GetInconsistentResidues(self, ref_residues, mdl_residues):
496  """ Helper to extract a list of inconsistent residues.
497  """
498  if len(ref_residues) != len(mdl_residues):
499  raise ValueError("Something terrible happened... Reference and "
500  "model lengths differ... RUN...")
501  inconsistent_residues = list()
502  for ref_residue, mdl_residue in zip(ref_residues, mdl_residues):
503  if ref_residue.name != mdl_residue.name:
504  inconsistent_residues.append((ref_residue, mdl_residue))
505  return inconsistent_residues
506 
507 
509  """ Class to compute chain mappings
510 
511  All algorithms are performed on processed structures which fulfill
512  criteria as given in constructor arguments (*min_pep_length*,
513  "min_nuc_length") and only contain residues which have all required backbone
514  atoms. for peptide residues thats N, CA, C and CB (no CB for GLY), for
515  nucleotide residues thats O5', C5', C4', C3' and O3'.
516 
517  Chain mapping is a three step process:
518 
519  * Group chemically identical chains in *target* using pairwise
520  alignments that are either computed with Needleman-Wunsch (NW) or
521  simply derived from residue numbers (*resnum_alignments* flag).
522  In case of NW, *pep_subst_mat*, *pep_gap_open* and *pep_gap_ext*
523  and their nucleotide equivalents are relevant. Two chains are
524  considered identical if they fulfill the thresholds given by
525  *pep_seqid_thr*, *pep_gap_thr*, their nucleotide equivalents
526  respectively. The grouping information is available as
527  attributes of this class.
528 
529  * Map chains in an input model to these groups. Generating alignments
530  and the similarity criteria are the same as above. You can either
531  get the group mapping with :func:`GetChemMapping` or directly call
532  one of the full fletched one-to-one chain mapping functions which
533  execute that step internally.
534 
535  * Obtain one-to-one mapping for chains in an input model and
536  *target* with one of the available mapping functions. Just to get an
537  idea of complexity. If *target* and *model* are octamers, there are
538  ``8! = 40320`` possible chain mappings.
539 
540  :param target: Target structure onto which models are mapped.
541  Computations happen on a selection only containing
542  polypeptides and polynucleotides.
543  :type target: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
544  :param resnum_alignments: Use residue numbers instead of
545  Needleman-Wunsch to compute pairwise
546  alignments. Relevant for :attr:`~chem_groups`
547  and related attributes.
548  :type resnum_alignments: :class:`bool`
549  :param pep_seqid_thr: Threshold used to decide when two chains are
550  identical. 95 percent tolerates the few mutations
551  crystallographers like to do.
552  :type pep_seqid_thr: :class:`float`
553  :param pep_gap_thr: Additional threshold to avoid gappy alignments with
554  high seqid. By default this is disabled (set to 1.0).
555  This threshold checks for a maximum allowed fraction
556  of gaps in any of the two sequences after stripping
557  terminal gaps. The reason for not just normalizing
558  seqid by the longer sequence is that one sequence
559  might be a perfect subsequence of the other but only
560  cover half of it.
561  :type pep_gap_thr: :class:`float`
562  :param nuc_seqid_thr: Nucleotide equivalent for *pep_seqid_thr*
563  :type nuc_seqid_thr: :class:`float`
564  :param nuc_gap_thr: Nucleotide equivalent for *nuc_gap_thr*
565  :type nuc_gap_thr: :class:`float`
566  :param pep_subst_mat: Substitution matrix to align peptide sequences,
567  irrelevant if *resnum_alignments* is True,
568  defaults to seq.alg.BLOSUM62
569  :type pep_subst_mat: :class:`ost.seq.alg.SubstWeightMatrix`
570  :param pep_gap_open: Gap open penalty to align peptide sequences,
571  irrelevant if *resnum_alignments* is True
572  :type pep_gap_open: :class:`int`
573  :param pep_gap_ext: Gap extension penalty to align peptide sequences,
574  irrelevant if *resnum_alignments* is True
575  :type pep_gap_ext: :class:`int`
576  :param nuc_subst_mat: Nucleotide equivalent for *pep_subst_mat*,
577  defaults to seq.alg.NUC44
578  :type nuc_subst_mat: :class:`ost.seq.alg.SubstWeightMatrix`
579  :param nuc_gap_open: Nucleotide equivalent for *pep_gap_open*
580  :type nuc_gap_open: :class:`int`
581  :param nuc_gap_ext: Nucleotide equivalent for *pep_gap_ext*
582  :type nuc_gap_ext: :class:`int`
583  :param min_pep_length: Minimal number of residues for a peptide chain to be
584  considered in target and in models.
585  :type min_pep_length: :class:`int`
586  :param min_nuc_length: Minimal number of residues for a nucleotide chain to be
587  considered in target and in models.
588  :type min_nuc_length: :class:`int`
589  :param n_max_naive: Max possible chain mappings that are enumerated in
590  :func:`~GetNaivelDDTMapping` /
591  :func:`~GetDecomposerlDDTMapping`. A
592  :class:`RuntimeError` is raised in case of bigger
593  complexity.
594  :type n_max_naive: :class:`int`
595  """
596  def __init__(self, target, resnum_alignments=False,
597  pep_seqid_thr = 95., pep_gap_thr = 1.0,
598  nuc_seqid_thr = 95., nuc_gap_thr = 1.0,
599  pep_subst_mat = seq.alg.BLOSUM62, pep_gap_open = -11,
600  pep_gap_ext = -1, nuc_subst_mat = seq.alg.NUC44,
601  nuc_gap_open = -4, nuc_gap_ext = -4,
602  min_pep_length = 10, min_nuc_length = 4,
603  n_max_naive = 1e8):
604 
605  # attributes
606  self.resnum_alignments = resnum_alignments
607  self.pep_seqid_thr = pep_seqid_thr
608  self.pep_gap_thr = pep_gap_thr
609  self.nuc_seqid_thr = nuc_seqid_thr
610  self.nuc_gap_thr = nuc_gap_thr
611  self.min_pep_length = min_pep_length
612  self.min_nuc_length = min_nuc_length
613  self.n_max_naive = n_max_naive
614 
615  # lazy computed attributes
616  self._chem_groups = None
617  self._chem_group_alignments = None
618  self._chem_group_ref_seqs = None
619  self._chem_group_types = None
620 
621  # helper class to generate pairwise alignments
622  self.aligner = _Aligner(resnum_aln = resnum_alignments,
623  pep_subst_mat = pep_subst_mat,
624  pep_gap_open = pep_gap_open,
625  pep_gap_ext = pep_gap_ext,
626  nuc_subst_mat = nuc_subst_mat,
627  nuc_gap_open = nuc_gap_open,
628  nuc_gap_ext = nuc_gap_ext)
629 
630  # target structure preprocessing
631  self._target, self._polypep_seqs, self._polynuc_seqs = \
632  self.ProcessStructure(target)
633 
634  @property
635  def target(self):
636  """Target structure that only contains peptides/nucleotides
637 
638  Contains only residues that have the backbone representatives
639  (CA for peptide and C3' for nucleotides) to avoid ATOMSEQ alignment
640  inconsistencies when switching between all atom and backbone only
641  representations.
642 
643  :type: :class:`ost.mol.EntityView`
644  """
645  return self._target
646 
647  @property
648  def polypep_seqs(self):
649  """Sequences of peptide chains in :attr:`~target`
650 
651  Respective :class:`EntityView` from *target* for each sequence s are
652  available as ``s.GetAttachedView()``
653 
654  :type: :class:`ost.seq.SequenceList`
655  """
656  return self._polypep_seqs
657 
658  @property
659  def polynuc_seqs(self):
660  """Sequences of nucleotide chains in :attr:`~target`
661 
662  Respective :class:`EntityView` from *target* for each sequence s are
663  available as ``s.GetAttachedView()``
664 
665  :type: :class:`ost.seq.SequenceList`
666  """
667  return self._polynuc_seqs
668 
669  @property
670  def chem_groups(self):
671  """Groups of chemically equivalent chains in :attr:`~target`
672 
673  First chain in group is the one with longest sequence.
674 
675  :getter: Computed on first use (cached)
676  :type: :class:`list` of :class:`list` of :class:`str` (chain names)
677  """
678  if self._chem_groups is None:
679  self._chem_groups = list()
680  for a in self.chem_group_alignments:
681  self._chem_groups.append([s.GetName() for s in a.sequences])
682  return self._chem_groups
683 
684  @property
686  """MSA for each group in :attr:`~chem_groups`
687 
688  Sequences in MSAs exhibit same order as in :attr:`~chem_groups` and
689  have the respective :class:`ost.mol.EntityView` from *target* attached.
690 
691  :getter: Computed on first use (cached)
692  :type: :class:`ost.seq.AlignmentList`
693  """
694  if self._chem_group_alignments is None:
696  _GetChemGroupAlignments(self.polypep_seqs, self.polynuc_seqs,
697  self.aligner,
698  pep_seqid_thr=self.pep_seqid_thr,
699  pep_gap_thr=self.pep_gap_thr,
700  nuc_seqid_thr=self.nuc_seqid_thr,
701  nuc_gap_thr=self.nuc_gap_thr)
702 
703  return self._chem_group_alignments
704 
705  @property
707  """Reference (longest) sequence for each group in :attr:`~chem_groups`
708 
709  Respective :class:`EntityView` from *target* for each sequence s are
710  available as ``s.GetAttachedView()``
711 
712  :getter: Computed on first use (cached)
713  :type: :class:`ost.seq.SequenceList`
714  """
715  if self._chem_group_ref_seqs is None:
716  self._chem_group_ref_seqs = seq.CreateSequenceList()
717  for a in self.chem_group_alignments:
718  s = seq.CreateSequence(a.GetSequence(0).GetName(),
719  a.GetSequence(0).GetGaplessString())
720  s.AttachView(a.GetSequence(0).GetAttachedView())
721  self._chem_group_ref_seqs.AddSequence(s)
722  return self._chem_group_ref_seqs
723 
724  @property
725  def chem_group_types(self):
726  """ChemType of each group in :attr:`~chem_groups`
727 
728  Specifying if groups are poly-peptides/nucleotides, i.e.
729  :class:`ost.mol.ChemType.AMINOACIDS` or
730  :class:`ost.mol.ChemType.NUCLEOTIDES`
731 
732  :getter: Computed on first use (cached)
733  :type: :class:`list` of :class:`ost.mol.ChemType`
734  """
735  if self._chem_group_types is None:
737  _GetChemGroupAlignments(self.polypep_seqs, self.polynuc_seqs,
738  self.aligner,
739  pep_seqid_thr=self.pep_seqid_thr,
740  pep_gap_thr=self.pep_gap_thr,
741  nuc_seqid_thr=self.nuc_seqid_thr,
742  nuc_gap_thr=self.nuc_gap_thr)
743 
744  return self._chem_group_types
745 
746  def GetChemMapping(self, model):
747  """Maps sequences in *model* to chem_groups of target
748 
749  :param model: Model from which to extract sequences, a
750  selection that only includes peptides and nucleotides
751  is performed and returned along other results.
752  :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
753  :returns: Tuple with two lists of length `len(self.chem_groups)` and
754  an :class:`ost.mol.EntityView` representing *model*:
755  1) Each element is a :class:`list` with mdl chain names that
756  map to the chem group at that position.
757  2) Each element is a :class:`ost.seq.AlignmentList` aligning
758  these mdl chain sequences to the chem group ref sequences.
759  3) A selection of *model* that only contains polypeptides and
760  polynucleotides whose ATOMSEQ exactly matches the sequence
761  info in the returned alignments.
762  """
763  mdl, mdl_pep_seqs, mdl_nuc_seqs = self.ProcessStructure(model)
764  mapping = [list() for x in self.chem_groups]
765  alns = [seq.AlignmentList() for x in self.chem_groups]
766 
767  for s in mdl_pep_seqs:
768  idx, aln = _MapSequence(self.chem_group_ref_seqs,
769  self.chem_group_types,
770  s, mol.ChemType.AMINOACIDS,
771  self.aligner)
772  if idx is not None:
773  mapping[idx].append(s.GetName())
774  alns[idx].append(aln)
775 
776  for s in mdl_nuc_seqs:
777  idx, aln = _MapSequence(self.chem_group_ref_seqs,
778  self.chem_group_types,
779  s, mol.ChemType.NUCLEOTIDES,
780  self.aligner)
781  if idx is not None:
782  mapping[idx].append(s.GetName())
783  alns[idx].append(aln)
784 
785  return (mapping, alns, mdl)
786 
787 
788  def GetlDDTMapping(self, model, inclusion_radius=15.0,
789  thresholds=[0.5, 1.0, 2.0, 4.0], strategy="naive",
790  steep_opt_rate = None, block_seed_size = 5,
791  block_blocks_per_chem_group = 5,
792  chem_mapping_result = None):
793  """ Identify chain mapping by optimizing lDDT score
794 
795  Maps *model* chain sequences to :attr:`~chem_groups` and find mapping
796  based on backbone only lDDT score (CA for amino acids C3' for
797  Nucleotides).
798 
799  Either performs a naive search, i.e. enumerate all possible mappings or
800  executes a greedy strategy that tries to identify a (close to) optimal
801  mapping in an iterative way by starting from a start mapping (seed). In
802  each iteration, the one-to-one mapping that leads to highest increase
803  in number of conserved contacts is added with the additional requirement
804  that this added mapping must have non-zero interface counts towards the
805  already mapped chains. So basically we're "growing" the mapped structure
806  by only adding connected stuff.
807 
808  The available strategies:
809 
810  * **naive**: Enumerates all possible mappings and returns best
811 
812  * **greedy_fast**: perform all vs. all single chain lDDTs within the
813  respective ref/mdl chem groups. The mapping with highest number of
814  conserved contacts is selected as seed for greedy extension
815 
816  * **greedy_full**: try multiple seeds for greedy extension, i.e. try
817  all ref/mdl chain combinations within the respective chem groups and
818  retain the mapping leading to the best lDDT.
819 
820  * **greedy_block**: try multiple seeds for greedy extension, i.e. try
821  all ref/mdl chain combinations within the respective chem groups and
822  extend them to *block_seed_size*. *block_blocks_per_chem_group*
823  for each chem group are selected for exhaustive extension.
824 
825  Sets :attr:`MappingResult.opt_score` in case of no trivial one-to-one
826  mapping.
827 
828  :param model: Model to map
829  :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
830  :param inclusion_radius: Inclusion radius for lDDT
831  :type inclusion_radius: :class:`float`
832  :param thresholds: Thresholds for lDDT
833  :type thresholds: :class:`list` of :class:`float`
834  :param strategy: Strategy to find mapping. Must be in ["naive",
835  "greedy_fast", "greedy_full", "greedy_block"]
836  :type strategy: :class:`str`
837  :param steep_opt_rate: Only relevant for greedy strategies.
838  If set, every *steep_opt_rate* mappings, a simple
839  optimization is executed with the goal of
840  avoiding local minima. The optimization
841  iteratively checks all possible swaps of mappings
842  within their respective chem groups and accepts
843  swaps that improve lDDT score. Iteration stops as
844  soon as no improvement can be achieved anymore.
845  :type steep_opt_rate: :class:`int`
846  :param block_seed_size: Param for *greedy_block* strategy - Initial seeds
847  are extended by that number of chains.
848  :type block_seed_size: :class:`int`
849  :param block_blocks_per_chem_group: Param for *greedy_block* strategy -
850  Number of blocks per chem group that
851  are extended in an initial search
852  for high scoring local solutions.
853  :type block_blocks_per_chem_group: :class:`int`
854  :param chem_mapping_result: Pro param. The result of
855  :func:`~GetChemMapping` where you provided
856  *model*. If set, *model* parameter is not
857  used.
858  :type chem_mapping_result: :class:`tuple`
859  :returns: A :class:`MappingResult`
860  """
861 
862  strategies = ["naive", "greedy_fast", "greedy_full", "greedy_block"]
863  if strategy not in strategies:
864  raise RuntimeError(f"Strategy must be in {strategies}")
865 
866  if chem_mapping_result is None:
867  chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
868  else:
869  chem_mapping, chem_group_alns, mdl = chem_mapping_result
870 
871  ref_mdl_alns = _GetRefMdlAlns(self.chem_groups,
873  chem_mapping,
874  chem_group_alns)
875 
876  # check for the simplest case
877  one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping)
878  if one_to_one is not None:
879  alns = dict()
880  for ref_group, mdl_group in zip(self.chem_groups, one_to_one):
881  for ref_ch, mdl_ch in zip(ref_group, mdl_group):
882  if ref_ch is not None and mdl_ch is not None:
883  aln = ref_mdl_alns[(ref_ch, mdl_ch)]
884  aln.AttachView(0, _CSel(self.target, [ref_ch]))
885  aln.AttachView(1, _CSel(mdl, [mdl_ch]))
886  alns[(ref_ch, mdl_ch)] = aln
887  return MappingResult(self.target, mdl, self.chem_groups, chem_mapping,
888  one_to_one, alns)
889 
890  mapping = None
891  opt_lddt = None
892 
893  if strategy == "naive":
894  mapping, opt_lddt = _lDDTNaive(self.target, mdl, inclusion_radius,
895  thresholds, self.chem_groups,
896  chem_mapping, ref_mdl_alns,
897  self.n_max_naive)
898  else:
899  # its one of the greedy strategies - setup greedy searcher
900  the_greed = _lDDTGreedySearcher(self.target, mdl, self.chem_groups,
901  chem_mapping, ref_mdl_alns,
902  inclusion_radius=inclusion_radius,
903  thresholds=thresholds,
904  steep_opt_rate=steep_opt_rate)
905  if strategy == "greedy_fast":
906  mapping = _lDDTGreedyFast(the_greed)
907  elif strategy == "greedy_full":
908  mapping = _lDDTGreedyFull(the_greed)
909  elif strategy == "greedy_block":
910  mapping = _lDDTGreedyBlock(the_greed, block_seed_size,
911  block_blocks_per_chem_group)
912  # cached => lDDT computation is fast here
913  opt_lddt = the_greed.lDDT(self.chem_groups, mapping)
914 
915  alns = dict()
916  for ref_group, mdl_group in zip(self.chem_groups, mapping):
917  for ref_ch, mdl_ch in zip(ref_group, mdl_group):
918  if ref_ch is not None and mdl_ch is not None:
919  aln = ref_mdl_alns[(ref_ch, mdl_ch)]
920  aln.AttachView(0, _CSel(self.target, [ref_ch]))
921  aln.AttachView(1, _CSel(mdl, [mdl_ch]))
922  alns[(ref_ch, mdl_ch)] = aln
923 
924  return MappingResult(self.target, mdl, self.chem_groups, chem_mapping,
925  mapping, alns, opt_score = opt_lddt)
926 
927 
928  def GetQSScoreMapping(self, model, contact_d = 12.0, strategy = "naive",
929  block_seed_size = 5, block_blocks_per_chem_group = 5,
930  steep_opt_rate = None, chem_mapping_result = None,
931  greedy_prune_contact_map = False):
932  """ Identify chain mapping based on QSScore
933 
934  Scoring is based on CA/C3' positions which are present in all chains of
935  a :attr:`chem_groups` as well as the *model* chains which are mapped to
936  that respective chem group.
937 
938  The following strategies are available:
939 
940  * **naive**: Naively iterate all possible mappings and return best based
941  on QS score.
942 
943  * **greedy_fast**: perform all vs. all single chain lDDTs within the
944  respective ref/mdl chem groups. The mapping with highest number of
945  conserved contacts is selected as seed for greedy extension.
946  Extension is based on QS-score.
947 
948  * **greedy_full**: try multiple seeds for greedy extension, i.e. try
949  all ref/mdl chain combinations within the respective chem groups and
950  retain the mapping leading to the best QS-score.
951 
952  * **greedy_block**: try multiple seeds for greedy extension, i.e. try
953  all ref/mdl chain combinations within the respective chem groups and
954  extend them to *block_seed_size*. *block_blocks_per_chem_group*
955  for each chem group are selected for exhaustive extension.
956 
957  Sets :attr:`MappingResult.opt_score` in case of no trivial one-to-one
958  mapping.
959 
960  :param model: Model to map
961  :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
962  :param contact_d: Max distance between two residues to be considered as
963  contact in qs scoring
964  :type contact_d: :class:`float`
965  :param strategy: Strategy for sampling, must be in ["naive"]
966  :type strategy: :class:`str`
967  :param chem_mapping_result: Pro param. The result of
968  :func:`~GetChemMapping` where you provided
969  *model*. If set, *model* parameter is not
970  used.
971  :type chem_mapping_result: :class:`tuple`
972  :param greedy_prune_contact_map: Relevant for all strategies that use
973  greedy extensions. If True, only chains
974  with at least 3 contacts (8A CB
975  distance) towards already mapped chains
976  in trg/mdl are considered for
977  extension. All chains that give a
978  potential non-zero QS-score increase
979  are used otherwise (at least one
980  contact within 12A). The consequence
981  is reduced runtime and usually no
982  real reduction in accuracy.
983  :returns: A :class:`MappingResult`
984  """
985 
986  strategies = ["naive", "greedy_fast", "greedy_full", "greedy_block"]
987  if strategy not in strategies:
988  raise RuntimeError(f"strategy must be {strategies}")
989 
990  if chem_mapping_result is None:
991  chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
992  else:
993  chem_mapping, chem_group_alns, mdl = chem_mapping_result
994  ref_mdl_alns = _GetRefMdlAlns(self.chem_groups,
996  chem_mapping,
997  chem_group_alns)
998 
999  # check for the simplest case
1000  one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping)
1001  if one_to_one is not None:
1002  alns = dict()
1003  for ref_group, mdl_group in zip(self.chem_groups, one_to_one):
1004  for ref_ch, mdl_ch in zip(ref_group, mdl_group):
1005  if ref_ch is not None and mdl_ch is not None:
1006  aln = ref_mdl_alns[(ref_ch, mdl_ch)]
1007  aln.AttachView(0, _CSel(self.target, [ref_ch]))
1008  aln.AttachView(1, _CSel(mdl, [mdl_ch]))
1009  alns[(ref_ch, mdl_ch)] = aln
1010  return MappingResult(self.target, mdl, self.chem_groups, chem_mapping,
1011  one_to_one, alns)
1012  mapping = None
1013  opt_qsscore = None
1014 
1015  if strategy == "naive":
1016  mapping, opt_qsscore = _QSScoreNaive(self.target, mdl,
1017  self.chem_groups,
1018  chem_mapping, ref_mdl_alns,
1019  contact_d, self.n_max_naive)
1020  else:
1021  # its one of the greedy strategies - setup greedy searcher
1022  the_greed = _QSScoreGreedySearcher(self.target, mdl,
1023  self.chem_groups,
1024  chem_mapping, ref_mdl_alns,
1025  contact_d = contact_d,
1026  steep_opt_rate=steep_opt_rate,
1027  greedy_prune_contact_map = greedy_prune_contact_map)
1028  if strategy == "greedy_fast":
1029  mapping = _QSScoreGreedyFast(the_greed)
1030  elif strategy == "greedy_full":
1031  mapping = _QSScoreGreedyFull(the_greed)
1032  elif strategy == "greedy_block":
1033  mapping = _QSScoreGreedyBlock(the_greed, block_seed_size,
1034  block_blocks_per_chem_group)
1035  # cached => QSScore computation is fast here
1036  opt_qsscore = the_greed.Score(mapping, check=False)
1037 
1038  alns = dict()
1039  for ref_group, mdl_group in zip(self.chem_groups, mapping):
1040  for ref_ch, mdl_ch in zip(ref_group, mdl_group):
1041  if ref_ch is not None and mdl_ch is not None:
1042  aln = ref_mdl_alns[(ref_ch, mdl_ch)]
1043  aln.AttachView(0, _CSel(self.target, [ref_ch]))
1044  aln.AttachView(1, _CSel(mdl, [mdl_ch]))
1045  alns[(ref_ch, mdl_ch)] = aln
1046 
1047  return MappingResult(self.target, mdl, self.chem_groups, chem_mapping,
1048  mapping, alns, opt_score = opt_qsscore)
1049 
1050  def GetRigidMapping(self, model, strategy = "greedy_single_gdtts",
1051  single_chain_gdtts_thresh=0.4, subsampling=None,
1052  first_complete=False, iterative_superposition=False,
1053  chem_mapping_result = None):
1054  """Identify chain mapping based on rigid superposition
1055 
1056  Superposition and scoring is based on CA/C3' positions which are present
1057  in all chains of a :attr:`chem_groups` as well as the *model*
1058  chains which are mapped to that respective chem group.
1059 
1060  Transformations to superpose *model* onto :attr:`ChainMapper.target`
1061  are estimated using all possible combinations of target and model chains
1062  within the same chem groups and build the basis for further extension.
1063 
1064  There are four extension strategies:
1065 
1066  * **greedy_single_gdtts**: Iteratively add the model/target chain pair
1067  that adds the most conserved contacts based on the GDT-TS metric
1068  (Number of CA/C3' atoms within [8, 4, 2, 1] Angstrom). The mapping
1069  with highest GDT-TS score is returned. However, that mapping is not
1070  guaranteed to be complete (see *single_chain_gdtts_thresh*).
1071 
1072  * **greedy_iterative_gdtts**: Same as greedy_single_gdtts except that
1073  the transformation gets updated with each added chain pair.
1074 
1075  * **greedy_single_rmsd**: Conceptually similar to greedy_single_gdtts
1076  but the added chain pairs are the ones with lowest RMSD.
1077  The mapping with lowest overall RMSD gets returned.
1078  *single_chain_gdtts_thresh* is only applied to derive the initial
1079  transformations. After that, the minimal RMSD chain pair gets
1080  iteratively added without applying any threshold.
1081 
1082  * **greedy_iterative_rmsd**: Same as greedy_single_rmsd exept that
1083  the transformation gets updated with each added chain pair.
1084  *single_chain_gdtts_thresh* is only applied to derive the initial
1085  transformations. After that, the minimal RMSD chain pair gets
1086  iteratively added without applying any threshold.
1087 
1088  :param model: Model to map
1089  :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
1090  :param strategy: Strategy to extend mappings from initial transforms,
1091  see description above. Must be in ["greedy_single",
1092  "greedy_iterative", "greedy_iterative_rmsd"]
1093  :type strategy: :class:`str`
1094  :param single_chain_gdtts_thresh: Minimal GDT-TS score for model/target
1095  chain pair to be added to mapping.
1096  Mapping extension for a given
1097  transform stops when no pair fulfills
1098  this threshold, potentially leading to
1099  an incomplete mapping.
1100  :type single_chain_gdtts_thresh: :class:`float`
1101  :param subsampling: If given, only use an equally distributed subset
1102  of all CA/C3' positions for superposition/scoring.
1103  :type subsampling: :class:`int`
1104  :param first_complete: Avoid full enumeration and return first found
1105  mapping that covers all model chains or all
1106  target chains. Has no effect on
1107  greedy_iterative_rmsd strategy.
1108  :type first_complete: :class:`bool`
1109  :param iterative_superposition: Whether to compute inital
1110  transformations with
1111  :func:`ost.mol.alg.IterativeSuperposeSVD`
1112  as oposed to
1113  :func:`ost.mol.alg.SuperposeSVD`
1114  :type iterative_superposition: :class:`bool`
1115  :param chem_mapping_result: Pro param. The result of
1116  :func:`~GetChemMapping` where you provided
1117  *model*. If set, *model* parameter is not
1118  used.
1119  :type chem_mapping_result: :class:`tuple`
1120  :returns: A :class:`MappingResult`
1121  """
1122 
1123  strategies = ["greedy_single_gdtts", "greedy_iterative_gdtts",
1124  "greedy_single_rmsd", "greedy_iterative_rmsd"]
1125  if strategy not in strategies:
1126  raise RuntimeError(f"strategy must be {strategies}")
1127 
1128  if chem_mapping_result is None:
1129  chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
1130  else:
1131  chem_mapping, chem_group_alns, mdl = chem_mapping_result
1132  ref_mdl_alns = _GetRefMdlAlns(self.chem_groups,
1133  self.chem_group_alignments,
1134  chem_mapping,
1135  chem_group_alns)
1136 
1137  # check for the simplest case
1138  one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping)
1139  if one_to_one is not None:
1140  alns = dict()
1141  for ref_group, mdl_group in zip(self.chem_groups, one_to_one):
1142  for ref_ch, mdl_ch in zip(ref_group, mdl_group):
1143  if ref_ch is not None and mdl_ch is not None:
1144  aln = ref_mdl_alns[(ref_ch, mdl_ch)]
1145  aln.AttachView(0, _CSel(self.target, [ref_ch]))
1146  aln.AttachView(1, _CSel(mdl, [mdl_ch]))
1147  alns[(ref_ch, mdl_ch)] = aln
1148  return MappingResult(self.target, mdl, self.chem_groups, chem_mapping,
1149  one_to_one, alns)
1150 
1151  trg_group_pos, mdl_group_pos = _GetRefPos(self.target, mdl,
1152  self.chem_group_alignments,
1153  chem_group_alns,
1154  max_pos = subsampling)
1155 
1156  # get transforms of any mdl chain onto any trg chain in same chem group
1157  # that fulfills gdtts threshold
1158  initial_transforms = list()
1159  initial_mappings = list()
1160  for trg_pos, trg_chains, mdl_pos, mdl_chains in zip(trg_group_pos,
1161  self.chem_groups,
1162  mdl_group_pos,
1163  chem_mapping):
1164  for t_pos, t in zip(trg_pos, trg_chains):
1165  for m_pos, m in zip(mdl_pos, mdl_chains):
1166  if len(t_pos) >= 3 and len(m_pos) >= 3:
1167  transform = _GetTransform(m_pos, t_pos,
1168  iterative_superposition)
1169  t_m_pos = geom.Vec3List(m_pos)
1170  t_m_pos.ApplyTransform(transform)
1171  gdt = t_pos.GetGDTTS(t_m_pos)
1172  if gdt >= single_chain_gdtts_thresh:
1173  initial_transforms.append(transform)
1174  initial_mappings.append((t,m))
1175 
1176  if strategy == "greedy_single_gdtts":
1177  mapping = _SingleRigidGDTTS(initial_transforms, initial_mappings,
1178  self.chem_groups, chem_mapping,
1179  trg_group_pos, mdl_group_pos,
1180  single_chain_gdtts_thresh,
1181  iterative_superposition, first_complete,
1182  len(self.target.chains),
1183  len(mdl.chains))
1184 
1185  elif strategy == "greedy_iterative_gdtts":
1186  mapping = _IterativeRigidGDTTS(initial_transforms, initial_mappings,
1187  self.chem_groups, chem_mapping,
1188  trg_group_pos, mdl_group_pos,
1189  single_chain_gdtts_thresh,
1190  iterative_superposition,
1191  first_complete,
1192  len(self.target.chains),
1193  len(mdl.chains))
1194 
1195  elif strategy == "greedy_single_rmsd":
1196  mapping = _SingleRigidRMSD(initial_transforms, initial_mappings,
1197  self.chem_groups, chem_mapping,
1198  trg_group_pos, mdl_group_pos,
1199  iterative_superposition)
1200 
1201 
1202  elif strategy == "greedy_iterative_rmsd":
1203  mapping = _IterativeRigidRMSD(initial_transforms, initial_mappings,
1204  self.chem_groups, chem_mapping,
1205  trg_group_pos, mdl_group_pos,
1206  iterative_superposition)
1207 
1208  # translate mapping format and return
1209  final_mapping = list()
1210  for ref_chains in self.chem_groups:
1211  mapped_mdl_chains = list()
1212  for ref_ch in ref_chains:
1213  if ref_ch in mapping:
1214  mapped_mdl_chains.append(mapping[ref_ch])
1215  else:
1216  mapped_mdl_chains.append(None)
1217  final_mapping.append(mapped_mdl_chains)
1218 
1219  alns = dict()
1220  for ref_group, mdl_group in zip(self.chem_groups, final_mapping):
1221  for ref_ch, mdl_ch in zip(ref_group, mdl_group):
1222  if ref_ch is not None and mdl_ch is not None:
1223  aln = ref_mdl_alns[(ref_ch, mdl_ch)]
1224  aln.AttachView(0, _CSel(self.target, [ref_ch]))
1225  aln.AttachView(1, _CSel(mdl, [mdl_ch]))
1226  alns[(ref_ch, mdl_ch)] = aln
1227 
1228  return MappingResult(self.target, mdl, self.chem_groups, chem_mapping,
1229  final_mapping, alns)
1230 
1231  def GetMapping(self, model, n_max_naive = 40320):
1232  """ Convenience function to get mapping with currently preferred method
1233 
1234  If number of possible chain mappings is <= *n_max_naive*, a naive
1235  QS-score mapping is performed and optimal QS-score is guaranteed.
1236  For anything else, a QS-score mapping with the greedy_full strategy is
1237  performed (greedy_prune_contact_map = True). The default for
1238  *n_max_naive* of 40320 corresponds to an octamer (8!=40320). A
1239  structure with stoichiometry A6B2 would be 6!*2!=1440 etc.
1240  """
1241  chem_mapping_res = self.GetChemMapping(model)
1242  if _NMappingsWithin(self.chem_groups, chem_mapping_res[0], n_max_naive):
1243  return self.GetQSScoreMapping(model, strategy="naive",
1244  chem_mapping_result=chem_mapping_res)
1245  else:
1246  return self.GetQSScoreMapping(model, strategy="greedy_full",
1247  greedy_prune_contact_map=True,
1248  chem_mapping_result=chem_mapping_res)
1249 
1250  def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0,
1251  thresholds=[0.5, 1.0, 2.0, 4.0], bb_only=False,
1252  only_interchain=False, chem_mapping_result = None,
1253  global_mapping = None):
1254  """ Identify *topn* representations of *substructure* in *model*
1255 
1256  *substructure* defines a subset of :attr:`~target` for which one
1257  wants the *topn* representations in *model*. Representations are scored
1258  and sorted by lDDT.
1259 
1260  :param substructure: A :class:`ost.mol.EntityView` which is a subset of
1261  :attr:`~target`. Should be selected with the
1262  OpenStructure query language. Example: if you're
1263  interested in residues with number 42,43 and 85 in
1264  chain A:
1265  ``substructure=mapper.target.Select("cname=A and rnum=42,43,85")``
1266  A :class:`RuntimeError` is raised if *substructure*
1267  does not refer to the same underlying
1268  :class:`ost.mol.EntityHandle` as :attr:`~target`.
1269  :type substructure: :class:`ost.mol.EntityView`
1270  :param model: Structure in which one wants to find representations for
1271  *substructure*
1272  :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
1273  :param topn: Max number of representations that are returned
1274  :type topn: :class:`int`
1275  :param inclusion_radius: Inclusion radius for lDDT
1276  :type inclusion_radius: :class:`float`
1277  :param thresholds: Thresholds for lDDT
1278  :type thresholds: :class:`list` of :class:`float`
1279  :param bb_only: Only consider backbone atoms in lDDT computation
1280  :type bb_only: :class:`bool`
1281  :param only_interchain: Only score interchain contacts in lDDT. Useful
1282  if you want to identify interface patches.
1283  :type only_interchain: :class:`bool`
1284  :param chem_mapping_result: Pro param. The result of
1285  :func:`~GetChemMapping` where you provided
1286  *model*. If set, *model* parameter is not
1287  used.
1288  :type chem_mapping_result: :class:`tuple`
1289  :param global_mapping: Pro param. Specify a global mapping result. This
1290  fully defines the desired representation in the
1291  model but extracts it and enriches it with all
1292  the nice attributes of :class:`ReprResult`.
1293  The target attribute in *global_mapping* must be
1294  of the same entity as self.target and the model
1295  attribute of *global_mapping* must be of the same
1296  entity as *model*.
1297  :type global_mapping: :class:`MappingResult`
1298  :returns: :class:`list` of :class:`ReprResult`
1299  """
1300 
1301  if topn < 1:
1302  raise RuntimeError("topn must be >= 1")
1303 
1304  if global_mapping is not None:
1305  # ensure that this mapping is derived from the same structures
1306  if global_mapping.target.handle.GetHashCode() != \
1307  self.target.handle.GetHashCode():
1308  raise RuntimeError("global_mapping.target must be the same "
1309  "entity as self.target")
1310  if global_mapping.model.handle.GetHashCode() != \
1311  model.handle.GetHashCode():
1312  raise RuntimeError("global_mapping.model must be the same "
1313  "entity as model param")
1314 
1315  # check whether substructure really is a subset of self.target
1316  for r in substructure.residues:
1317  ch_name = r.GetChain().GetName()
1318  rnum = r.GetNumber()
1319  target_r = self.target.FindResidue(ch_name, rnum)
1320  if not target_r.IsValid():
1321  raise RuntimeError(f"substructure has residue "
1322  f"{r.GetQualifiedName()} which is not in "
1323  f"self.target")
1324  if target_r.handle.GetHashCode() != r.handle.GetHashCode():
1325  raise RuntimeError(f"substructure has residue "
1326  f"{r.GetQualifiedName()} which has an "
1327  f"equivalent in self.target but it does "
1328  f"not refer to the same underlying "
1329  f"EntityHandle")
1330  for a in r.atoms:
1331  target_a = target_r.FindAtom(a.GetName())
1332  if not target_a.IsValid():
1333  raise RuntimeError(f"substructure has atom "
1334  f"{a.GetQualifiedName()} which is not "
1335  f"in self.target")
1336  if a.handle.GetHashCode() != target_a.handle.GetHashCode():
1337  raise RuntimeError(f"substructure has atom "
1338  f"{a.GetQualifiedName()} which has an "
1339  f"equivalent in self.target but it does "
1340  f"not refer to the same underlying "
1341  f"EntityHandle")
1342 
1343  # check whether it contains either CA or C3'
1344  ca = r.FindAtom("CA")
1345  c3 = r.FindAtom("C3'") # FindAtom with prime in string is tested
1346  # and works
1347  if not ca.IsValid() and not c3.IsValid():
1348  raise RuntimeError("All residues in substructure must contain "
1349  "a backbone atom named CA or C3\'")
1350 
1351  # perform mapping and alignments on full structures
1352  if chem_mapping_result is None:
1353  chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
1354  else:
1355  chem_mapping, chem_group_alns, mdl = chem_mapping_result
1356  ref_mdl_alns = _GetRefMdlAlns(self.chem_groups,
1357  self.chem_group_alignments,
1358  chem_mapping,
1359  chem_group_alns)
1360 
1361  # Get residue indices relative to full target chain
1362  substructure_res_indices = dict()
1363  for ch in substructure.chains:
1364  full_ch = self.target.FindChain(ch.GetName())
1365  idx = [full_ch.GetResidueIndex(r.GetNumber()) for r in ch.residues]
1366  substructure_res_indices[ch.GetName()] = idx
1367 
1368  # strip down variables to make them specific to substructure
1369  # keep only chem_groups which are present in substructure
1370  substructure_chem_groups = list()
1371  substructure_chem_mapping = list()
1372 
1373  chnames = set([ch.GetName() for ch in substructure.chains])
1374  for chem_group, mapping in zip(self.chem_groups, chem_mapping):
1375  substructure_chem_group = [ch for ch in chem_group if ch in chnames]
1376  if len(substructure_chem_group) > 0:
1377  substructure_chem_groups.append(substructure_chem_group)
1378  substructure_chem_mapping.append(mapping)
1379 
1380  # early stopping if no mdl chain can be mapped to substructure
1381  n_mapped_mdl_chains = sum([len(m) for m in substructure_chem_mapping])
1382  if n_mapped_mdl_chains == 0:
1383  return list()
1384 
1385  # strip the reference sequence in alignments to only contain
1386  # sequence from substructure
1387  substructure_ref_mdl_alns = dict()
1388  mdl_views = dict()
1389  for ch in mdl.chains:
1390  mdl_views[ch.GetName()] = _CSel(mdl, [ch.GetName()])
1391  for chem_group, mapping in zip(substructure_chem_groups,
1392  substructure_chem_mapping):
1393  for ref_ch in chem_group:
1394  for mdl_ch in mapping:
1395  full_aln = ref_mdl_alns[(ref_ch, mdl_ch)]
1396  ref_seq = full_aln.GetSequence(0)
1397  # the ref sequence is tricky... we start with a gap only
1398  # sequence and only add olcs as defined by the residue
1399  # indices that we extracted before...
1400  tmp = ['-'] * len(full_aln)
1401  for idx in substructure_res_indices[ref_ch]:
1402  idx_in_seq = ref_seq.GetPos(idx)
1403  tmp[idx_in_seq] = ref_seq[idx_in_seq]
1404  ref_seq = seq.CreateSequence(ref_ch, ''.join(tmp))
1405  ref_seq.AttachView(_CSel(substructure, [ref_ch]))
1406  mdl_seq = full_aln.GetSequence(1)
1407  mdl_seq = seq.CreateSequence(mdl_seq.GetName(),
1408  mdl_seq.GetString())
1409  mdl_seq.AttachView(mdl_views[mdl_ch])
1410  aln = seq.CreateAlignment()
1411  aln.AddSequence(ref_seq)
1412  aln.AddSequence(mdl_seq)
1413  substructure_ref_mdl_alns[(ref_ch, mdl_ch)] = aln
1414 
1415  lddt_scorer = lddt.lDDTScorer(substructure,
1416  inclusion_radius = inclusion_radius,
1417  bb_only = bb_only)
1418  scored_mappings = list()
1419 
1420  if global_mapping:
1421  # construct mapping of substructure from global mapping
1422  flat_mapping = global_mapping.GetFlatMapping()
1423  mapping = list()
1424  for chem_group, chem_mapping in zip(substructure_chem_groups,
1425  substructure_chem_mapping):
1426  chem_group_mapping = list()
1427  for ch in chem_group:
1428  if ch in flat_mapping:
1429  mdl_ch = flat_mapping[ch]
1430  if mdl_ch in chem_mapping:
1431  chem_group_mapping.append(mdl_ch)
1432  else:
1433  chem_group_mapping.append(None)
1434  else:
1435  chem_group_mapping.append(None)
1436  mapping.append(chem_group_mapping)
1437  mappings = [mapping]
1438  else:
1439  mappings = list(_ChainMappings(substructure_chem_groups,
1440  substructure_chem_mapping,
1441  self.n_max_naive))
1442 
1443  for mapping in mappings:
1444  # chain_mapping and alns as input for lDDT computation
1445  lddt_chain_mapping = dict()
1446  lddt_alns = dict()
1447  n_res_aln = 0
1448  for ref_chem_group, mdl_chem_group in zip(substructure_chem_groups,
1449  mapping):
1450  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
1451  # some mdl chains can be None
1452  if mdl_ch is not None:
1453  lddt_chain_mapping[mdl_ch] = ref_ch
1454  aln = substructure_ref_mdl_alns[(ref_ch, mdl_ch)]
1455  lddt_alns[mdl_ch] = aln
1456  tmp = [int(c[0] != '-' and c[1] != '-') for c in aln]
1457  n_res_aln += sum(tmp)
1458  # don't compute lDDT if no single residue in mdl and ref is aligned
1459  if n_res_aln == 0:
1460  continue
1461 
1462  lDDT, _ = lddt_scorer.lDDT(mdl, thresholds=thresholds,
1463  chain_mapping=lddt_chain_mapping,
1464  residue_mapping = lddt_alns,
1465  check_resnames = False,
1466  no_intrachain = only_interchain)
1467 
1468  if lDDT is None:
1469  ost.LogVerbose("No valid contacts in the reference")
1470  lDDT = 0.0 # that means, that we have not a single valid contact
1471  # in lDDT. For the code below to work, we just set it
1472  # to a terrible score => 0.0
1473 
1474  if len(scored_mappings) == 0:
1475  scored_mappings.append((lDDT, mapping))
1476  elif len(scored_mappings) < topn:
1477  scored_mappings.append((lDDT, mapping))
1478  scored_mappings.sort(reverse=True, key=lambda x: x[0])
1479  elif lDDT > scored_mappings[-1][0]:
1480  scored_mappings.append((lDDT, mapping))
1481  scored_mappings.sort(reverse=True, key=lambda x: x[0])
1482  scored_mappings = scored_mappings[:topn]
1483 
1484  # finalize and return
1485  results = list()
1486  for scored_mapping in scored_mappings:
1487  ref_view = substructure.handle.CreateEmptyView()
1488  mdl_view = mdl.handle.CreateEmptyView()
1489  for ref_ch_group, mdl_ch_group in zip(substructure_chem_groups,
1490  scored_mapping[1]):
1491  for ref_ch, mdl_ch in zip(ref_ch_group, mdl_ch_group):
1492  if ref_ch is not None and mdl_ch is not None:
1493  aln = substructure_ref_mdl_alns[(ref_ch, mdl_ch)]
1494  for col in aln:
1495  if col[0] != '-' and col[1] != '-':
1496  ref_view.AddResidue(col.GetResidue(0),
1497  mol.ViewAddFlag.INCLUDE_ALL)
1498  mdl_view.AddResidue(col.GetResidue(1),
1499  mol.ViewAddFlag.INCLUDE_ALL)
1500  results.append(ReprResult(scored_mapping[0], substructure,
1501  ref_view, mdl_view))
1502  return results
1503 
1504  def GetNMappings(self, model):
1505  """ Returns number of possible mappings
1506 
1507  :param model: Model with chains that are mapped onto
1508  :attr:`chem_groups`
1509  :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
1510  """
1511  chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
1512  return _NMappings(self.chem_groups, chem_mapping)
1513 
1514  def ProcessStructure(self, ent):
1515  """ Entity processing for chain mapping
1516 
1517  * Selects view containing peptide and nucleotide residues which have
1518  required backbone atoms present - for peptide residues thats
1519  N, CA, C and CB (no CB for GLY), for nucleotide residues thats
1520  O5', C5', C4', C3' and O3'.
1521  * filters view by chain lengths, see *min_pep_length* and
1522  *min_nuc_length* in constructor
1523  * Extracts atom sequences for each chain in that view
1524  * Attaches corresponding :class:`ost.mol.EntityView` to each sequence
1525  * If residue number alignments are used, strictly increasing residue
1526  numbers without insertion codes are ensured in each chain
1527 
1528  :param ent: Entity to process
1529  :type ent: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
1530  :returns: Tuple with 3 elements: 1) :class:`ost.mol.EntityView`
1531  containing peptide and nucleotide residues 2)
1532  :class:`ost.seq.SequenceList` containing ATOMSEQ sequences
1533  for each polypeptide chain in returned view, sequences have
1534  :class:`ost.mol.EntityView` of according chains attached
1535  3) same for polynucleotide chains
1536  """
1537  view = ent.CreateEmptyView()
1538  exp_pep_atoms = ["N", "CA", "C", "CB"]
1539  exp_nuc_atoms = ["\"O5'\"", "\"C5'\"", "\"C4'\"", "\"C3'\"", "\"O3'\""]
1540  pep_query = "peptide=true and aname=" + ','.join(exp_pep_atoms)
1541  nuc_query = "nucleotide=true and aname=" + ','.join(exp_nuc_atoms)
1542 
1543  pep_sel = ent.Select(pep_query)
1544  for r in pep_sel.residues:
1545  if len(r.atoms) == 4:
1546  view.AddResidue(r.handle, mol.INCLUDE_ALL)
1547  elif r.name == "GLY" and len(r.atoms) == 3:
1548  atom_names = [a.GetName() for a in r.atoms]
1549  if sorted(atom_names) == ["C", "CA", "N"]:
1550  view.AddResidue(r.handle, mol.INCLUDE_ALL)
1551 
1552  nuc_sel = ent.Select(nuc_query)
1553  for r in nuc_sel.residues:
1554  if len(r.atoms) == 5:
1555  view.AddResidue(r.handle, mol.INCLUDE_ALL)
1556 
1557  polypep_seqs = seq.CreateSequenceList()
1558  polynuc_seqs = seq.CreateSequenceList()
1559 
1560  if len(view.residues) == 0:
1561  # no residues survived => return
1562  return (view, polypep_seqs, polynuc_seqs)
1563 
1564  for ch in view.chains:
1565  n_res = len(ch.residues)
1566  n_pep = sum([r.IsPeptideLinking() for r in ch.residues])
1567  n_nuc = sum([r.IsNucleotideLinking() for r in ch.residues])
1568 
1569  # guarantee that we have either pep or nuc (no mix of the two)
1570  if n_pep > 0 and n_nuc > 0:
1571  raise RuntimeError(f"Must not mix peptide and nucleotide linking "
1572  f"residues in same chain ({ch.GetName()})")
1573 
1574  if (n_pep + n_nuc) != n_res:
1575  raise RuntimeError("All residues must either be peptide_linking "
1576  "or nucleotide_linking")
1577 
1578  # filter out short chains
1579  if n_pep > 0 and n_pep < self.min_pep_length:
1580  continue
1581 
1582  if n_nuc > 0 and n_nuc < self.min_nuc_length:
1583  continue
1584 
1585  # the superfast residue number based alignment adds some
1586  # restrictions on the numbers themselves:
1587  # 1) no insertion codes 2) strictly increasing
1588  if self.resnum_alignments:
1589  # check if no insertion codes are present in residue numbers
1590  ins_codes = [r.GetNumber().GetInsCode() for r in ch.residues]
1591  if len(set(ins_codes)) != 1 or ins_codes[0] != '\0':
1592  raise RuntimeError("Residue numbers in input structures must not "
1593  "contain insertion codes")
1594 
1595  # check if residue numbers are strictly increasing
1596  nums = [r.GetNumber().GetNum() for r in ch.residues]
1597  if not all(i < j for i, j in zip(nums, nums[1:])):
1598  raise RuntimeError("Residue numbers in input structures must be "
1599  "strictly increasing for each chain")
1600 
1601  s = ''.join([r.one_letter_code for r in ch.residues])
1602  s = seq.CreateSequence(ch.GetName(), s)
1603  s.AttachView(_CSel(view, [ch.GetName()]))
1604  if n_pep == n_res:
1605  polypep_seqs.AddSequence(s)
1606  elif n_nuc == n_res:
1607  polynuc_seqs.AddSequence(s)
1608  else:
1609  raise RuntimeError("This shouldnt happen")
1610 
1611  if len(polypep_seqs) == 0 and len(polynuc_seqs) == 0:
1612  raise RuntimeError(f"No chain fulfilled minimum length requirement "
1613  f"to be considered in chain mapping "
1614  f"({self.min_pep_length} for peptide chains, "
1615  f"{self.min_nuc_length} for nucleotide chains) "
1616  f"- mapping failed")
1617 
1618  # select for chains for which we actually extracted the sequence
1619  chain_names = [s.GetAttachedView().chains[0].name for s in polypep_seqs]
1620  chain_names += [s.GetAttachedView().chains[0].name for s in polynuc_seqs]
1621  view = _CSel(view, chain_names)
1622 
1623  return (view, polypep_seqs, polynuc_seqs)
1624 
1625  def Align(self, s1, s2, stype):
1626  """ Access to internal sequence alignment functionality
1627 
1628  Alignment parameterization is setup at ChainMapper construction
1629 
1630  :param s1: First sequence to align - must have view attached in case
1631  of resnum_alignments
1632  :type s1: :class:`ost.seq.SequenceHandle`
1633  :param s2: Second sequence to align - must have view attached in case
1634  of resnum_alignments
1635  :type s2: :class:`ost.seq.SequenceHandle`
1636  :param stype: Type of sequences to align, must be in
1637  [:class:`ost.mol.ChemType.AMINOACIDS`,
1638  :class:`ost.mol.ChemType.NUCLEOTIDES`]
1639  :returns: Pairwise alignment of s1 and s2
1640  """
1641  if stype not in [mol.ChemType.AMINOACIDS, mol.ChemType.NUCLEOTIDES]:
1642  raise RuntimeError("stype must be ost.mol.ChemType.AMINOACIDS or "
1643  "ost.mol.ChemType.NUCLEOTIDES")
1644  return self.aligner.Align(s1, s2, chem_type = stype)
1645 
1646 
1647 # INTERNAL HELPERS
1648 ##################
1649 class _Aligner:
1650  def __init__(self, pep_subst_mat = seq.alg.BLOSUM62, pep_gap_open = -5,
1651  pep_gap_ext = -2, nuc_subst_mat = seq.alg.NUC44,
1652  nuc_gap_open = -4, nuc_gap_ext = -4, resnum_aln = False):
1653  """ Helper class to compute alignments
1654 
1655  Sets default values for substitution matrix, gap open and gap extension
1656  penalties. They are only used in default mode (Needleman-Wunsch aln).
1657  If *resnum_aln* is True, only residue numbers of views that are attached
1658  to input sequences are considered.
1659  """
1660  self.pep_subst_mat = pep_subst_mat
1661  self.pep_gap_open = pep_gap_open
1662  self.pep_gap_ext = pep_gap_ext
1663  self.nuc_subst_mat = nuc_subst_mat
1664  self.nuc_gap_open = nuc_gap_open
1665  self.nuc_gap_ext = nuc_gap_ext
1666  self.resnum_aln = resnum_aln
1667 
1668  def Align(self, s1, s2, chem_type=None):
1669  if self.resnum_aln:
1670  return self.ResNumAlign(s1, s2)
1671  else:
1672  if chem_type is None:
1673  raise RuntimeError("Must specify chem_type for NW alignment")
1674  return self.NWAlign(s1, s2, chem_type)
1675 
1676  def NWAlign(self, s1, s2, chem_type):
1677  """ Returns pairwise alignment using Needleman-Wunsch algorithm
1678 
1679  :param s1: First sequence to align
1680  :type s1: :class:`ost.seq.SequenceHandle`
1681  :param s2: Second sequence to align
1682  :type s2: :class:`ost.seq.SequenceHandle`
1683  :param chem_type: Must be in [:class:`ost.mol.ChemType.AMINOACIDS`,
1684  :class:`ost.mol.ChemType.NUCLEOTIDES`], determines
1685  substitution matrix and gap open/extension penalties
1686  :type chem_type: :class:`ost.mol.ChemType`
1687  :returns: Alignment with s1 as first and s2 as second sequence
1688  """
1689  if chem_type == mol.ChemType.AMINOACIDS:
1690  return seq.alg.GlobalAlign(s1, s2, self.pep_subst_mat,
1691  gap_open=self.pep_gap_open,
1692  gap_ext=self.pep_gap_ext)[0]
1693  elif chem_type == mol.ChemType.NUCLEOTIDES:
1694  return seq.alg.GlobalAlign(s1, s2, self.nuc_subst_mat,
1695  gap_open=self.nuc_gap_open,
1696  gap_ext=self.nuc_gap_ext)[0]
1697  else:
1698  raise RuntimeError("Invalid ChemType")
1699  return aln
1700 
1701  def ResNumAlign(self, s1, s2):
1702  """ Returns pairwise alignment using residue numbers of attached views
1703 
1704  Assumes that there are no insertion codes (alignment only on numerical
1705  component) and that resnums are strictly increasing (fast min/max
1706  identification). These requirements are assured if a structure has been
1707  processed by :class:`ChainMapper`.
1708 
1709  :param s1: First sequence to align, must have :class:`ost.mol.EntityView`
1710  attached
1711  :type s1: :class:`ost.seq.SequenceHandle`
1712  :param s2: Second sequence to align, must have :class:`ost.mol.EntityView`
1713  attached
1714  :type s2: :class:`ost.seq.SequenceHandle`
1715  """
1716  assert(s1.HasAttachedView())
1717  assert(s2.HasAttachedView())
1718  v1 = s1.GetAttachedView()
1719  rnums1 = [r.GetNumber().GetNum() for r in v1.residues]
1720  v2 = s2.GetAttachedView()
1721  rnums2 = [r.GetNumber().GetNum() for r in v2.residues]
1722 
1723  min_num = min(rnums1[0], rnums2[0])
1724  max_num = max(rnums1[-1], rnums2[-1])
1725  aln_length = max_num - min_num + 1
1726 
1727  aln_s1 = ['-'] * aln_length
1728  for r, rnum in zip(v1.residues, rnums1):
1729  aln_s1[rnum-min_num] = r.one_letter_code
1730 
1731  aln_s2 = ['-'] * aln_length
1732  for r, rnum in zip(v2.residues, rnums2):
1733  aln_s2[rnum-min_num] = r.one_letter_code
1734 
1735  aln = seq.CreateAlignment()
1736  aln.AddSequence(seq.CreateSequence(s1.GetName(), ''.join(aln_s1)))
1737  aln.AddSequence(seq.CreateSequence(s2.GetName(), ''.join(aln_s2)))
1738  return aln
1739 
1740 def _GetAlnPropsTwo(aln):
1741  """Returns basic properties of *aln* version two...
1742 
1743  :param aln: Alignment to compute properties
1744  :type aln: :class:`seq.AlignmentHandle`
1745  :returns: Tuple with 2 elements. 1) sequence identify in range [0, 100]
1746  considering aligned columns 2) Fraction of non-gap characters
1747  in first sequence that are covered by non-gap characters in
1748  second sequence.
1749  """
1750  assert(aln.GetCount() == 2)
1751  n_tot = sum([1 for col in aln if col[0] != '-'])
1752  n_aligned = sum([1 for col in aln if (col[0] != '-' and col[1] != '-')])
1753  return (seq.alg.SequenceIdentity(aln), float(n_aligned)/n_tot)
1754 
1755 def _GetAlnPropsOne(aln):
1756 
1757  """Returns basic properties of *aln* version one...
1758 
1759  :param aln: Alignment to compute properties
1760  :type aln: :class:`seq.AlignmentHandle`
1761  :returns: Tuple with 3 elements. 1) sequence identify in range [0, 100]
1762  considering aligned columns 2) Fraction of gaps between
1763  first and last aligned column in s1 3) same for s2.
1764  """
1765  assert(aln.GetCount() == 2)
1766  n_gaps_1 = str(aln.GetSequence(0)).strip('-').count('-')
1767  n_gaps_2 = str(aln.GetSequence(1)).strip('-').count('-')
1768  gap_frac_1 = float(n_gaps_1)/len(aln.GetSequence(0).GetGaplessString())
1769  gap_frac_2 = float(n_gaps_2)/len(aln.GetSequence(1).GetGaplessString())
1770  return (seq.alg.SequenceIdentity(aln), gap_frac_1, gap_frac_2)
1771 
1772 def _GetChemGroupAlignments(pep_seqs, nuc_seqs, aligner, pep_seqid_thr=95.,
1773  pep_gap_thr=0.1, nuc_seqid_thr=95.,
1774  nuc_gap_thr=0.1):
1775  """Returns alignments with groups of chemically equivalent chains
1776 
1777  :param pep_seqs: List of polypeptide sequences
1778  :type pep_seqs: :class:`seq.SequenceList`
1779  :param nuc_seqs: List of polynucleotide sequences
1780  :type nuc_seqs: :class:`seq.SequenceList`
1781  :param aligner: Helper class to generate pairwise alignments
1782  :type aligner: :class:`_Aligner`
1783  :param pep_seqid_thr: Threshold used to decide when two peptide chains are
1784  identical. 95 percent tolerates the few mutations
1785  crystallographers like to do.
1786  :type pep_seqid_thr: :class:`float`
1787  :param pep_gap_thr: Additional threshold to avoid gappy alignments with high
1788  seqid. The reason for not just normalizing seqid by the
1789  longer sequence is that one sequence might be a perfect
1790  subsequence of the other but only cover half of it. This
1791  threshold checks for a maximum allowed fraction of gaps
1792  in any of the two sequences after stripping terminal gaps.
1793  :type pep_gap_thr: :class:`float`
1794  :param nuc_seqid_thr: Nucleotide equivalent of *pep_seqid_thr*
1795  :type nuc_seqid_thr: :class:`float`
1796  :param nuc_gap_thr: Nucleotide equivalent of *nuc_gap_thr*
1797  :type nuc_gap_thr: :class:`float`
1798  :returns: Tuple with first element being an AlignmentList. Each alignment
1799  represents a group of chemically equivalent chains and the first
1800  sequence is the longest. Second element is a list of equivalent
1801  length specifying the types of the groups. List elements are in
1802  [:class:`ost.ChemType.AMINOACIDS`,
1803  :class:`ost.ChemType.NUCLEOTIDES`]
1804  """
1805  pep_groups = _GroupSequences(pep_seqs, pep_seqid_thr, pep_gap_thr, aligner,
1806  mol.ChemType.AMINOACIDS)
1807  nuc_groups = _GroupSequences(nuc_seqs, nuc_seqid_thr, nuc_gap_thr, aligner,
1808  mol.ChemType.NUCLEOTIDES)
1809  group_types = [mol.ChemType.AMINOACIDS] * len(pep_groups)
1810  group_types += [mol.ChemType.NUCLEOTIDES] * len(nuc_groups)
1811  groups = pep_groups
1812  groups.extend(nuc_groups)
1813  return (groups, group_types)
1814 
1815 def _GroupSequences(seqs, seqid_thr, gap_thr, aligner, chem_type):
1816  """Get list of alignments representing groups of equivalent sequences
1817 
1818  :param seqid_thr: Threshold used to decide when two chains are identical.
1819  :type seqid_thr: :class:`float`
1820  :param gap_thr: Additional threshold to avoid gappy alignments with high
1821  seqid. The reason for not just normalizing seqid by the
1822  longer sequence is that one sequence might be a perfect
1823  subsequence of the other but only cover half of it. This
1824  threshold checks for a maximum allowed fraction of gaps
1825  in any of the two sequences after stripping terminal gaps.
1826  :type gap_thr: :class:`float`
1827  :param aligner: Helper class to generate pairwise alignments
1828  :type aligner: :class:`_Aligner`
1829  :param chem_type: ChemType of seqs which is passed to *aligner*, must be in
1830  [:class:`ost.mol.ChemType.AMINOACIDS`,
1831  :class:`ost.mol.ChemType.NUCLEOTIDES`]
1832  :type chem_type: :class:`ost.mol.ChemType`
1833  :returns: A list of alignments, one alignment for each group
1834  with longest sequence (reference) as first sequence.
1835  :rtype: :class:`ost.seq.AlignmentList`
1836  """
1837  groups = list()
1838  for s_idx in range(len(seqs)):
1839  matching_group = None
1840  for g_idx in range(len(groups)):
1841  for g_s_idx in range(len(groups[g_idx])):
1842  aln = aligner.Align(seqs[s_idx], seqs[groups[g_idx][g_s_idx]],
1843  chem_type)
1844  sid, frac_i, frac_j = _GetAlnPropsOne(aln)
1845  if sid >= seqid_thr and frac_i < gap_thr and frac_j < gap_thr:
1846  matching_group = g_idx
1847  break
1848  if matching_group is not None:
1849  break
1850 
1851  if matching_group is None:
1852  groups.append([s_idx])
1853  else:
1854  groups[matching_group].append(s_idx)
1855 
1856  # sort based on sequence length
1857  sorted_groups = list()
1858  for g in groups:
1859  if len(g) > 1:
1860  tmp = sorted([[len(seqs[i]), i] for i in g], reverse=True)
1861  sorted_groups.append([x[1] for x in tmp])
1862  else:
1863  sorted_groups.append(g)
1864 
1865  # translate from indices back to sequences and directly generate alignments
1866  # of the groups with the longest (first) sequence as reference
1867  aln_list = seq.AlignmentList()
1868  for g in sorted_groups:
1869  if len(g) == 1:
1870  # aln with one single sequence
1871  aln_list.append(seq.CreateAlignment(seqs[g[0]]))
1872  else:
1873  # obtain pairwise aln of first sequence (reference) to all others
1874  alns = seq.AlignmentList()
1875  i = g[0]
1876  for j in g[1:]:
1877  alns.append(aligner.Align(seqs[i], seqs[j], chem_type))
1878  # and merge
1879  aln_list.append(seq.alg.MergePairwiseAlignments(alns, seqs[i]))
1880 
1881  # transfer attached views
1882  seq_dict = {s.GetName(): s for s in seqs}
1883  for aln_idx in range(len(aln_list)):
1884  for aln_s_idx in range(aln_list[aln_idx].GetCount()):
1885  s_name = aln_list[aln_idx].GetSequence(aln_s_idx).GetName()
1886  s = seq_dict[s_name]
1887  aln_list[aln_idx].AttachView(aln_s_idx, s.GetAttachedView())
1888 
1889  return aln_list
1890 
1891 def _MapSequence(ref_seqs, ref_types, s, s_type, aligner):
1892  """Tries top map *s* onto any of the sequences in *ref_seqs*
1893 
1894  Computes alignments of *s* to each of the reference sequences of equal type
1895  and sorts them by seqid*fraction_covered (seqid: sequence identity of
1896  aligned columns in alignment, fraction_covered: Fraction of non-gap
1897  characters in reference sequence that are covered by non-gap characters in
1898  *s*). Best scoring mapping is returned.
1899 
1900  :param ref_seqs: Reference sequences
1901  :type ref_seqs: :class:`ost.seq.SequenceList`
1902  :param ref_types: Types of reference sequences, e.g.
1903  ost.mol.ChemType.AminoAcids
1904  :type ref_types: :class:`list` of :class:`ost.mol.ChemType`
1905  :param s: Sequence to map
1906  :type s: :class:`ost.seq.SequenceHandle`
1907  :param s_type: Type of *s*, only try mapping to sequences in *ref_seqs*
1908  with equal type as defined in *ref_types*
1909  :param aligner: Helper class to generate pairwise alignments
1910  :type aligner: :class:`_Aligner`
1911  :returns: Tuple with two elements. 1) index of sequence in *ref_seqs* to
1912  which *s* can be mapped 2) Pairwise sequence alignment with
1913  sequence from *ref_seqs* as first sequence. Both elements are
1914  None if no mapping can be found.
1915  :raises: :class:`RuntimeError` if mapping is ambiguous, i.e. *s*
1916  successfully maps to more than one sequence in *ref_seqs*
1917  """
1918  scored_alns = list()
1919  for ref_idx, ref_seq in enumerate(ref_seqs):
1920  if ref_types[ref_idx] == s_type:
1921  aln = aligner.Align(ref_seq, s, s_type)
1922  seqid, fraction_covered = _GetAlnPropsTwo(aln)
1923  score = seqid * fraction_covered
1924  scored_alns.append((score, ref_idx, aln))
1925 
1926  if len(scored_alns) == 0:
1927  return (None, None) # no mapping possible...
1928 
1929  scored_alns = sorted(scored_alns, key=lambda x: x[0], reverse=True)
1930  return (scored_alns[0][1], scored_alns[0][2])
1931 
1932 def _GetRefMdlAlns(ref_chem_groups, ref_chem_group_msas, mdl_chem_groups,
1933  mdl_chem_group_alns, pairs=None):
1934  """ Get all possible ref/mdl chain alignments given chem group mapping
1935 
1936  :param ref_chem_groups: :attr:`ChainMapper.chem_groups`
1937  :type ref_chem_groups: :class:`list` of :class:`list` of :class:`str`
1938  :param ref_chem_group_msas: :attr:`ChainMapper.chem_group_alignments`
1939  :type ref_chem_group_msas: :class:`ost.seq.AlignmentList`
1940  :param mdl_chem_groups: Groups of model chains that are mapped to
1941  *ref_chem_groups*. Return value of
1942  :func:`ChainMapper.GetChemMapping`.
1943  :type mdl_chem_groups: :class:`list` of :class:`list` of :class:`str`
1944  :param mdl_chem_group_alns: A pairwise sequence alignment for every chain
1945  in *mdl_chem_groups* that aligns these sequences
1946  to the respective reference sequence.
1947  Return values of
1948  :func:`ChainMapper.GetChemMapping`.
1949  :type mdl_chem_group_alns: :class:`list` of :class:`ost.seq.AlignmentList`
1950  :param pairs: Pro param - restrict return dict to specified pairs. A set of
1951  tuples in form (<trg_ch>, <mdl_ch>)
1952  :type pairs: :class:`set`
1953  :returns: A dictionary holding all possible ref/mdl chain alignments. Keys
1954  in that dictionary are tuples of the form (ref_ch, mdl_ch) and
1955  values are the respective pairwise alignments with first sequence
1956  being from ref, the second from mdl.
1957  """
1958  # alignment of each model chain to chem_group reference sequence
1959  mdl_alns = dict()
1960  for alns in mdl_chem_group_alns:
1961  for aln in alns:
1962  mdl_chain_name = aln.GetSequence(1).GetName()
1963  mdl_alns[mdl_chain_name] = aln
1964 
1965  # generate all alignments between ref/mdl chain atomseqs that we will
1966  # ever observe
1967  ref_mdl_alns = dict()
1968  for ref_chains, mdl_chains, ref_aln in zip(ref_chem_groups, mdl_chem_groups,
1969  ref_chem_group_msas):
1970  for ref_ch in ref_chains:
1971  for mdl_ch in mdl_chains:
1972  if pairs is not None and (ref_ch, mdl_ch) not in pairs:
1973  continue
1974  # obtain alignments of mdl and ref chains towards chem
1975  # group ref sequence and merge them
1976  aln_list = seq.AlignmentList()
1977  # do ref aln
1978  s1 = ref_aln.GetSequence(0)
1979  s2 = ref_aln.GetSequence(ref_chains.index(ref_ch))
1980  aln_list.append(seq.CreateAlignment(s1, s2))
1981  # do mdl aln
1982  aln_list.append(mdl_alns[mdl_ch])
1983  # merge
1984  ref_seq = seq.CreateSequence(s1.GetName(),
1985  s1.GetGaplessString())
1986  merged_aln = seq.alg.MergePairwiseAlignments(aln_list,
1987  ref_seq)
1988  # merged_aln:
1989  # seq1: ref seq of chem group
1990  # seq2: seq of ref chain
1991  # seq3: seq of mdl chain
1992  # => we need the alignment between seq2 and seq3
1993  s2 = merged_aln.GetSequence(1)
1994  s3 = merged_aln.GetSequence(2)
1995  # cut leading and trailing gap columns
1996  a = 0 # number of leading gap columns
1997  for idx in range(len(s2)):
1998  if s2[idx] != '-' or s3[idx] != '-':
1999  break
2000  a += 1
2001  b = 0 # number of trailing gap columns
2002  for idx in reversed(range(len(s2))):
2003  if s2[idx] != '-' or s3[idx] != '-':
2004  break
2005  b += 1
2006  s2 = seq.CreateSequence(s2.GetName(), s2[a: len(s2)-b])
2007  s3 = seq.CreateSequence(s3.GetName(), s3[a: len(s3)-b])
2008  ref_mdl_alns[(ref_ch, mdl_ch)] = seq.CreateAlignment(s2, s3)
2009 
2010  return ref_mdl_alns
2011 
2012 def _CheckOneToOneMapping(ref_chains, mdl_chains):
2013  """ Checks whether we already have a perfect one to one mapping
2014 
2015  That means each list in *ref_chains* has exactly one element and each
2016  list in *mdl_chains* has either one element (it's mapped) or is empty
2017  (ref chain has no mapped mdl chain). Returns None if no such mapping
2018  can be found.
2019 
2020  :param ref_chains: corresponds to :attr:`ChainMapper.chem_groups`
2021  :type ref_chains: :class:`list` of :class:`list` of :class:`str`
2022  :param mdl_chains: mdl chains mapped to chem groups in *ref_chains*, i.e.
2023  the return value of :func:`ChainMapper.GetChemMapping`
2024  :type mdl_chains: class:`list` of :class:`list` of :class:`str`
2025  :returns: A :class:`list` of :class:`list` if a one to one mapping is found,
2026  None otherwise
2027  """
2028  only_one_to_one = True
2029  one_to_one = list()
2030  for ref, mdl in zip(ref_chains, mdl_chains):
2031  if len(ref) == 1 and len(mdl) == 1:
2032  one_to_one.append(mdl)
2033  elif len(ref) == 1 and len(mdl) == 0:
2034  one_to_one.append([None])
2035  else:
2036  only_one_to_one = False
2037  break
2038  if only_one_to_one:
2039  return one_to_one
2040  else:
2041  return None
2042 
2044 
2045  def __init__(self, ref, mdl, ref_mdl_alns, inclusion_radius = 15.0,
2046  thresholds = [0.5, 1.0, 2.0, 4.0]):
2047  """ Compute backbone only lDDT scores for ref/mdl
2048 
2049  Uses the pairwise decomposable property of backbone only lDDT and
2050  implements a caching mechanism to efficiently enumerate different
2051  chain mappings.
2052  """
2053 
2054  self.ref = ref
2055  self.mdl = mdl
2056  self.ref_mdl_alns = ref_mdl_alns
2057  self.inclusion_radius = inclusion_radius
2058  self.thresholds = thresholds
2059 
2060  # keep track of single chains and interfaces in ref
2061  self.ref_chains = list() # e.g. ['A', 'B', 'C']
2062  self.ref_interfaces = list() # e.g. [('A', 'B'), ('A', 'C')]
2063 
2064  # holds lDDT scorer for each chain in ref
2065  # key: chain name, value: scorer
2066  self.single_chain_scorer = dict()
2067 
2068  # cache for single chain conserved contacts
2069  # key: tuple (ref_ch, mdl_ch) value: number of conserved contacts
2070  self.single_chain_cache = dict()
2071 
2072  # holds lDDT scorer for each pairwise interface in target
2073  # key: tuple (ref_ch1, ref_ch2), value: scorer
2074  self.interface_scorer = dict()
2075 
2076  # cache for interface conserved contacts
2077  # key: tuple of tuple ((ref_ch1, ref_ch2),((mdl_ch1, mdl_ch2))
2078  # value: number of conserved contacts
2079  self.interface_cache = dict()
2080 
2081  self.n = 0
2082 
2083  self._SetupScorer()
2084 
2085  def _SetupScorer(self):
2086  for ch in self.ref.chains:
2087  # Select everything close to that chain
2088  query = f"{self.inclusion_radius} <> "
2089  query += f"[cname={mol.QueryQuoteName(ch.GetName())}] "
2090  query += f"and cname!={mol.QueryQuoteName(ch.GetName())}"
2091  for close_ch in self.ref.Select(query).chains:
2092  k1 = (ch.GetName(), close_ch.GetName())
2093  k2 = (close_ch.GetName(), ch.GetName())
2094  if k1 not in self.interface_scorer and \
2095  k2 not in self.interface_scorer:
2096  dimer_ref = _CSel(self.ref, [k1[0], k1[1]])
2097  s = lddt.lDDTScorer(dimer_ref, bb_only=True)
2098  self.interface_scorer[k1] = s
2099  self.interface_scorer[k2] = s
2100  self.n += self.interface_scorer[k1].n_distances_ic
2101  self.ref_interfaces.append(k1)
2102  # single chain scorer are actually interface scorers to save
2103  # some distance calculations
2104  if ch.GetName() not in self.single_chain_scorer:
2105  self.single_chain_scorer[ch.GetName()] = s
2106  self.n += s.GetNChainContacts(ch.GetName(),
2107  no_interchain=True)
2108  self.ref_chains.append(ch.GetName())
2109  if close_ch.GetName() not in self.single_chain_scorer:
2110  self.single_chain_scorer[close_ch.GetName()] = s
2111  self.n += s.GetNChainContacts(close_ch.GetName(),
2112  no_interchain=True)
2113  self.ref_chains.append(close_ch.GetName())
2114 
2115  # add any missing single chain scorer
2116  for ch in self.ref.chains:
2117  if ch.GetName() not in self.single_chain_scorer:
2118  single_chain_ref = _CSel(self.ref, [ch.GetName()])
2119  self.single_chain_scorer[ch.GetName()] = \
2120  lddt.lDDTScorer(single_chain_ref, bb_only = True)
2121  self.n += self.single_chain_scorer[ch.GetName()].n_distances
2122  self.ref_chains.append(ch.GetName())
2123 
2124  def lDDT(self, ref_chain_groups, mdl_chain_groups):
2125 
2126  flat_map = dict()
2127  for ref_chains, mdl_chains in zip(ref_chain_groups, mdl_chain_groups):
2128  for ref_ch, mdl_ch in zip(ref_chains, mdl_chains):
2129  flat_map[ref_ch] = mdl_ch
2130 
2131  return self.lDDTFromFlatMap(flat_map)
2132 
2133 
2134  def lDDTFromFlatMap(self, flat_map):
2135  conserved = 0
2136 
2137  # do single chain scores
2138  for ref_ch in self.ref_chains:
2139  if ref_ch in flat_map and flat_map[ref_ch] is not None:
2140  conserved += self.SCCounts(ref_ch, flat_map[ref_ch])
2141 
2142  # do interfaces
2143  for ref_ch1, ref_ch2 in self.ref_interfaces:
2144  if ref_ch1 in flat_map and ref_ch2 in flat_map:
2145  mdl_ch1 = flat_map[ref_ch1]
2146  mdl_ch2 = flat_map[ref_ch2]
2147  if mdl_ch1 is not None and mdl_ch2 is not None:
2148  conserved += self.IntCounts(ref_ch1, ref_ch2, mdl_ch1,
2149  mdl_ch2)
2150 
2151  return conserved / (len(self.thresholds) * self.n)
2152 
2153  def SCCounts(self, ref_ch, mdl_ch):
2154  if not (ref_ch, mdl_ch) in self.single_chain_cache:
2155  alns = dict()
2156  alns[mdl_ch] = self.ref_mdl_alns[(ref_ch, mdl_ch)]
2157  mdl_sel = _CSel(self.mdl, [mdl_ch])
2158  s = self.single_chain_scorer[ref_ch]
2159  _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel,
2160  residue_mapping=alns,
2161  return_dist_test=True,
2162  no_interchain=True,
2163  chain_mapping={mdl_ch: ref_ch},
2164  check_resnames=False)
2165  self.single_chain_cache[(ref_ch, mdl_ch)] = conserved
2166  return self.single_chain_cache[(ref_ch, mdl_ch)]
2167 
2168  def IntCounts(self, ref_ch1, ref_ch2, mdl_ch1, mdl_ch2):
2169  k1 = ((ref_ch1, ref_ch2),(mdl_ch1, mdl_ch2))
2170  k2 = ((ref_ch2, ref_ch1),(mdl_ch2, mdl_ch1))
2171  if k1 not in self.interface_cache and k2 not in self.interface_cache:
2172  alns = dict()
2173  alns[mdl_ch1] = self.ref_mdl_alns[(ref_ch1, mdl_ch1)]
2174  alns[mdl_ch2] = self.ref_mdl_alns[(ref_ch2, mdl_ch2)]
2175  mdl_sel = _CSel(self.mdl, [mdl_ch1, mdl_ch2])
2176  s = self.interface_scorer[(ref_ch1, ref_ch2)]
2177  _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel,
2178  residue_mapping=alns,
2179  return_dist_test=True,
2180  no_intrachain=True,
2181  chain_mapping={mdl_ch1: ref_ch1,
2182  mdl_ch2: ref_ch2},
2183  check_resnames=False)
2184  self.interface_cache[k1] = conserved
2185  self.interface_cache[k2] = conserved
2186  return self.interface_cache[k1]
2187 
2189  def __init__(self, ref, mdl, ref_chem_groups, mdl_chem_groups,
2190  ref_mdl_alns, inclusion_radius = 15.0,
2191  thresholds = [0.5, 1.0, 2.0, 4.0],
2192  steep_opt_rate = None):
2193  """ Greedy extension of already existing but incomplete chain mappings
2194  """
2195  super().__init__(ref, mdl, ref_mdl_alns,
2196  inclusion_radius = inclusion_radius,
2197  thresholds = thresholds)
2198  self.steep_opt_rate = steep_opt_rate
2199  self.neighbors = {k: set() for k in self.ref_chains}
2200  for k in self.interface_scorer.keys():
2201  self.neighbors[k[0]].add(k[1])
2202  self.neighbors[k[1]].add(k[0])
2203 
2204  assert(len(ref_chem_groups) == len(mdl_chem_groups))
2205  self.ref_chem_groups = ref_chem_groups
2206  self.mdl_chem_groups = mdl_chem_groups
2207  self.ref_ch_group_mapper = dict()
2208  self.mdl_ch_group_mapper = dict()
2209  for g_idx, (ref_g, mdl_g) in enumerate(zip(ref_chem_groups,
2210  mdl_chem_groups)):
2211  for ch in ref_g:
2212  self.ref_ch_group_mapper[ch] = g_idx
2213  for ch in mdl_g:
2214  self.mdl_ch_group_mapper[ch] = g_idx
2215 
2216  # keep track of mdl chains that potentially give lDDT contributions,
2217  # i.e. they have locations within inclusion_radius + max(thresholds)
2218  self.mdl_neighbors = dict()
2219  d = self.inclusion_radius + max(self.thresholds)
2220  for ch in self.mdl.chains:
2221  ch_name = ch.GetName()
2222  self.mdl_neighbors[ch_name] = set()
2223  query = f"{d} <> [cname={mol.QueryQuoteName(ch_name)}]"
2224  query += f" and cname !={mol.QueryQuoteName(ch_name)}"
2225  for close_ch in self.mdl.Select(query).chains:
2226  self.mdl_neighbors[ch_name].add(close_ch.GetName())
2227 
2228 
2229  def ExtendMapping(self, mapping, max_ext = None):
2230 
2231  if len(mapping) == 0:
2232  raise RuntimError("Mapping must contain a starting point")
2233 
2234  # Ref chains onto which we can map. The algorithm starts with a mapping
2235  # on ref_ch. From there we can start to expand to connected neighbors.
2236  # All neighbors that we can reach from the already mapped chains are
2237  # stored in this set which will be updated during runtime
2238  map_targets = set()
2239  for ref_ch in mapping.keys():
2240  map_targets.update(self.neighbors[ref_ch])
2241 
2242  # remove the already mapped chains
2243  for ref_ch in mapping.keys():
2244  map_targets.discard(ref_ch)
2245 
2246  if len(map_targets) == 0:
2247  return mapping # nothing to extend
2248 
2249  # keep track of what model chains are not yet mapped for each chem group
2250  free_mdl_chains = list()
2251  for chem_group in self.mdl_chem_groups:
2252  tmp = [x for x in chem_group if x not in mapping.values()]
2253  free_mdl_chains.append(set(tmp))
2254 
2255  # keep track of what ref chains got a mapping
2256  newly_mapped_ref_chains = list()
2257 
2258  something_happened = True
2259  while something_happened:
2260  something_happened=False
2261 
2262  if self.steep_opt_rate is not None:
2263  n_chains = len(newly_mapped_ref_chains)
2264  if n_chains > 0 and n_chains % self.steep_opt_rate == 0:
2265  mapping = self._SteepOpt(mapping, newly_mapped_ref_chains)
2266 
2267  if max_ext is not None and len(newly_mapped_ref_chains) >= max_ext:
2268  break
2269 
2270  max_n = 0
2271  max_mapping = None
2272  for ref_ch in map_targets:
2273  chem_group_idx = self.ref_ch_group_mapper[ref_ch]
2274  for mdl_ch in free_mdl_chains[chem_group_idx]:
2275  # single chain score
2276  n_single = self.SCCounts(ref_ch, mdl_ch)
2277  # scores towards neighbors that are already mapped
2278  n_inter = 0
2279  for neighbor in self.neighbors[ref_ch]:
2280  if neighbor in mapping and mapping[neighbor] in \
2281  self.mdl_neighbors[mdl_ch]:
2282  n_inter += self.IntCounts(ref_ch, neighbor, mdl_ch,
2283  mapping[neighbor])
2284  n = n_single + n_inter
2285 
2286  if n_inter > 0 and n > max_n:
2287  # Only accept a new solution if its actually connected
2288  # i.e. n_inter > 0. Otherwise we could just map a big
2289  # fat mdl chain sitting somewhere in Nirvana
2290  max_n = n
2291  max_mapping = (ref_ch, mdl_ch)
2292 
2293  if max_n > 0:
2294  something_happened = True
2295  # assign new found mapping
2296  mapping[max_mapping[0]] = max_mapping[1]
2297 
2298  # add all neighboring chains to map targets as they are now
2299  # reachable
2300  for neighbor in self.neighbors[max_mapping[0]]:
2301  if neighbor not in mapping:
2302  map_targets.add(neighbor)
2303 
2304  # remove the ref chain from map targets
2305  map_targets.remove(max_mapping[0])
2306 
2307  # remove the mdl chain from free_mdl_chains - its taken...
2308  chem_group_idx = self.ref_ch_group_mapper[max_mapping[0]]
2309  free_mdl_chains[chem_group_idx].remove(max_mapping[1])
2310 
2311  # keep track of what ref chains got a mapping
2312  newly_mapped_ref_chains.append(max_mapping[0])
2313 
2314  return mapping
2315 
2316  def _SteepOpt(self, mapping, chains_to_optimize=None):
2317 
2318  # just optimize ALL ref chains if nothing specified
2319  if chains_to_optimize is None:
2320  chains_to_optimize = mapping.keys()
2321 
2322  # make sure that we only have ref chains which are actually mapped
2323  ref_chains = [x for x in chains_to_optimize if mapping[x] is not None]
2324 
2325  # group ref chains to be optimized into chem groups
2326  tmp = dict()
2327  for ch in ref_chains:
2328  chem_group_idx = self.ref_ch_group_mapper[ch]
2329  if chem_group_idx in tmp:
2330  tmp[chem_group_idx].append(ch)
2331  else:
2332  tmp[chem_group_idx] = [ch]
2333  chem_groups = list(tmp.values())
2334 
2335  # try all possible mapping swaps. Swaps that improve the score are
2336  # immediately accepted and we start all over again
2337  current_lddt = self.lDDTFromFlatMap(mapping)
2338  something_happened = True
2339  while something_happened:
2340  something_happened = False
2341  for chem_group in chem_groups:
2342  if something_happened:
2343  break
2344  for ch1, ch2 in itertools.combinations(chem_group, 2):
2345  swapped_mapping = dict(mapping)
2346  swapped_mapping[ch1] = mapping[ch2]
2347  swapped_mapping[ch2] = mapping[ch1]
2348  score = self.lDDTFromFlatMap(swapped_mapping)
2349  if score > current_lddt:
2350  something_happened = True
2351  mapping = swapped_mapping
2352  current_lddt = score
2353  break
2354 
2355  return mapping
2356 
2357 
2358 def _lDDTNaive(trg, mdl, inclusion_radius, thresholds, chem_groups,
2359  chem_mapping, ref_mdl_alns, n_max_naive):
2360  """ Naively iterates all possible chain mappings and returns the best
2361  """
2362  best_mapping = None
2363  best_lddt = -1.0
2364 
2365  # Benchmarks on homo-oligomers indicate that full blown lDDT
2366  # computation is faster up to tetramers => 4!=24 possible mappings.
2367  # For stuff bigger than that, the decomposer approach should be used
2368  if _NMappingsWithin(chem_groups, chem_mapping, 24):
2369  # Setup scoring
2370  lddt_scorer = lddt.lDDTScorer(trg, bb_only = True)
2371  for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive):
2372  # chain_mapping and alns as input for lDDT computation
2373  lddt_chain_mapping = dict()
2374  lddt_alns = dict()
2375  for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
2376  for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
2377  # some mdl chains can be None
2378  if mdl_ch is not None:
2379  lddt_chain_mapping[mdl_ch] = ref_ch
2380  lddt_alns[mdl_ch] = ref_mdl_alns[(ref_ch, mdl_ch)]
2381  lDDT, _ = lddt_scorer.lDDT(mdl, thresholds=thresholds,
2382  chain_mapping=lddt_chain_mapping,
2383  residue_mapping = lddt_alns,
2384  check_resnames = False)
2385  if lDDT > best_lddt:
2386  best_mapping = mapping
2387  best_lddt = lDDT
2388 
2389  else:
2390  # Setup scoring
2391  lddt_scorer = _lDDTDecomposer(trg, mdl, ref_mdl_alns,
2392  inclusion_radius=inclusion_radius,
2393  thresholds = thresholds)
2394  for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive):
2395  lDDT = lddt_scorer.lDDT(chem_groups, mapping)
2396  if lDDT > best_lddt:
2397  best_mapping = mapping
2398  best_lddt = lDDT
2399 
2400  return (best_mapping, best_lddt)
2401 
2402 
2403 def _GetSeeds(ref_chem_groups, mdl_chem_groups, mapped_ref_chains = set(),
2404  mapped_mdl_chains = set()):
2405  seeds = list()
2406  for ref_chains, mdl_chains in zip(ref_chem_groups,
2407  mdl_chem_groups):
2408  for ref_ch in ref_chains:
2409  if ref_ch not in mapped_ref_chains:
2410  for mdl_ch in mdl_chains:
2411  if mdl_ch not in mapped_mdl_chains:
2412  seeds.append((ref_ch, mdl_ch))
2413  return seeds
2414 
2415 
2416 def _lDDTGreedyFast(the_greed):
2417 
2418  something_happened = True
2419  mapping = dict()
2420 
2421  while something_happened:
2422  something_happened = False
2423  seeds = _GetSeeds(the_greed.ref_chem_groups,
2424  the_greed.mdl_chem_groups,
2425  mapped_ref_chains = set(mapping.keys()),
2426  mapped_mdl_chains = set(mapping.values()))
2427  # search for best scoring starting point
2428  n_best = 0
2429  best_seed = None
2430  for seed in seeds:
2431  n = the_greed.SCCounts(seed[0], seed[1])
2432  if n > n_best:
2433  n_best = n
2434  best_seed = seed
2435  if n_best == 0:
2436  break # no proper seed found anymore...
2437  # add seed to mapping and start the greed
2438  mapping[best_seed[0]] = best_seed[1]
2439  mapping = the_greed.ExtendMapping(mapping)
2440  something_happened = True
2441 
2442  # translate mapping format and return
2443  final_mapping = list()
2444  for ref_chains in the_greed.ref_chem_groups:
2445  mapped_mdl_chains = list()
2446  for ref_ch in ref_chains:
2447  if ref_ch in mapping:
2448  mapped_mdl_chains.append(mapping[ref_ch])
2449  else:
2450  mapped_mdl_chains.append(None)
2451  final_mapping.append(mapped_mdl_chains)
2452 
2453  return final_mapping
2454 
2455 
2456 def _lDDTGreedyFull(the_greed):
2457  """ Uses each reference chain as starting point for expansion
2458  """
2459 
2460  seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups)
2461  best_overall_score = -1.0
2462  best_overall_mapping = dict()
2463 
2464  for seed in seeds:
2465 
2466  # do initial extension
2467  mapping = the_greed.ExtendMapping({seed[0]: seed[1]})
2468 
2469  # repeat the process until we have a full mapping
2470  something_happened = True
2471  while something_happened:
2472  something_happened = False
2473  remnant_seeds = _GetSeeds(the_greed.ref_chem_groups,
2474  the_greed.mdl_chem_groups,
2475  mapped_ref_chains = set(mapping.keys()),
2476  mapped_mdl_chains = set(mapping.values()))
2477  if len(remnant_seeds) > 0:
2478  # still more mapping to be done
2479  best_score = -1.0
2480  best_mapping = None
2481  for remnant_seed in remnant_seeds:
2482  tmp_mapping = dict(mapping)
2483  tmp_mapping[remnant_seed[0]] = remnant_seed[1]
2484  tmp_mapping = the_greed.ExtendMapping(tmp_mapping)
2485  score = the_greed.lDDTFromFlatMap(tmp_mapping)
2486  if score > best_score:
2487  best_score = score
2488  best_mapping = tmp_mapping
2489  if best_mapping is not None:
2490  something_happened = True
2491  mapping = best_mapping
2492 
2493  score = the_greed.lDDTFromFlatMap(mapping)
2494  if score > best_overall_score:
2495  best_overall_score = score
2496  best_overall_mapping = mapping
2497 
2498  mapping = best_overall_mapping
2499 
2500  # translate mapping format and return
2501  final_mapping = list()
2502  for ref_chains in the_greed.ref_chem_groups:
2503  mapped_mdl_chains = list()
2504  for ref_ch in ref_chains:
2505  if ref_ch in mapping:
2506  mapped_mdl_chains.append(mapping[ref_ch])
2507  else:
2508  mapped_mdl_chains.append(None)
2509  final_mapping.append(mapped_mdl_chains)
2510 
2511  return final_mapping
2512 
2513 
2514 def _lDDTGreedyBlock(the_greed, seed_size, blocks_per_chem_group):
2515  """ try multiple seeds, i.e. try all ref/mdl chain combinations within the
2516  respective chem groups and compute single chain lDDTs. The
2517  *blocks_per_chem_group* best scoring ones are extend by *seed_size* chains
2518  and the best scoring one is exhaustively extended.
2519  """
2520 
2521  if seed_size is None or seed_size < 1:
2522  raise RuntimeError(f"seed_size must be an int >= 1 (got {seed_size})")
2523 
2524  if blocks_per_chem_group is None or blocks_per_chem_group < 1:
2525  raise RuntimeError(f"blocks_per_chem_group must be an int >= 1 "
2526  f"(got {blocks_per_chem_group})")
2527 
2528  max_ext = seed_size - 1 # -1 => start seed already has size 1
2529 
2530  ref_chem_groups = copy.deepcopy(the_greed.ref_chem_groups)
2531  mdl_chem_groups = copy.deepcopy(the_greed.mdl_chem_groups)
2532 
2533  mapping = dict()
2534  something_happened = True
2535  while something_happened:
2536  something_happened = False
2537  starting_blocks = list()
2538  for ref_chains, mdl_chains in zip(ref_chem_groups, mdl_chem_groups):
2539  if len(mdl_chains) == 0:
2540  continue # nothing to map
2541  ref_chains_copy = list(ref_chains)
2542  for i in range(blocks_per_chem_group):
2543  if len(ref_chains_copy) == 0:
2544  break
2545  seeds = list()
2546  for ref_ch in ref_chains_copy:
2547  seeds += [(ref_ch, mdl_ch) for mdl_ch in mdl_chains]
2548  # extend starting seeds to *seed_size* and retain best scoring
2549  # block for further extension
2550  best_score = -1.0
2551  best_mapping = None
2552  best_seed = None
2553  for s in seeds:
2554  seed = dict(mapping)
2555  seed.update({s[0]: s[1]})
2556  seed = the_greed.ExtendMapping(seed, max_ext = max_ext)
2557  seed_lddt = the_greed.lDDTFromFlatMap(seed)
2558  if seed_lddt > best_score:
2559  best_score = seed_lddt
2560  best_mapping = seed
2561  best_seed = s
2562  if best_mapping != None:
2563  starting_blocks.append(best_mapping)
2564  if best_seed[0] in ref_chains_copy:
2565  # remove that ref chain to enforce diversity
2566  ref_chains_copy.remove(best_seed[0])
2567 
2568  # fully expand initial starting blocks
2569  best_lddt = 0.0
2570  best_mapping = None
2571  for seed in starting_blocks:
2572  seed = the_greed.ExtendMapping(seed)
2573  seed_lddt = the_greed.lDDTFromFlatMap(seed)
2574  if seed_lddt > best_lddt:
2575  best_lddt = seed_lddt
2576  best_mapping = seed
2577 
2578  if best_lddt == 0.0:
2579  break # no proper mapping found anymore
2580 
2581  something_happened = True
2582  mapping.update(best_mapping)
2583  for ref_ch, mdl_ch in best_mapping.items():
2584  for group_idx in range(len(ref_chem_groups)):
2585  if ref_ch in ref_chem_groups[group_idx]:
2586  ref_chem_groups[group_idx].remove(ref_ch)
2587  if mdl_ch in mdl_chem_groups[group_idx]:
2588  mdl_chem_groups[group_idx].remove(mdl_ch)
2589 
2590  # translate mapping format and return
2591  final_mapping = list()
2592  for ref_chains in the_greed.ref_chem_groups:
2593  mapped_mdl_chains = list()
2594  for ref_ch in ref_chains:
2595  if ref_ch in mapping:
2596  mapped_mdl_chains.append(mapping[ref_ch])
2597  else:
2598  mapped_mdl_chains.append(None)
2599  final_mapping.append(mapped_mdl_chains)
2600 
2601  return final_mapping
2602 
2603 
2605  def __init__(self, ref, mdl, ref_chem_groups, mdl_chem_groups,
2606  ref_mdl_alns, contact_d = 12.0,
2607  steep_opt_rate = None, greedy_prune_contact_map=False):
2608  """ Greedy extension of already existing but incomplete chain mappings
2609  """
2610  super().__init__(ref, ref_chem_groups, mdl, ref_mdl_alns,
2611  contact_d = contact_d)
2612  self.ref = ref
2613  self.mdl = mdl
2614  self.ref_mdl_alns = ref_mdl_alns
2615  self.steep_opt_rate = steep_opt_rate
2616 
2617  if greedy_prune_contact_map:
2618  self.neighbors = {k: set() for k in self.qsent1.chain_names}
2619  for p in self.qsent1.interacting_chains:
2620  if np.count_nonzero(self.qsent1.PairDist(p[0], p[1])<=8) >= 3:
2621  self.neighbors[p[0]].add(p[1])
2622  self.neighbors[p[1]].add(p[0])
2623 
2624  self.mdl_neighbors = {k: set() for k in self.qsent2.chain_names}
2625  for p in self.qsent2.interacting_chains:
2626  if np.count_nonzero(self.qsent2.PairDist(p[0], p[1])<=8) >= 3:
2627  self.mdl_neighbors[p[0]].add(p[1])
2628  self.mdl_neighbors[p[1]].add(p[0])
2629 
2630 
2631  else:
2632  self.neighbors = {k: set() for k in self.qsent1.chain_names}
2633  for p in self.qsent1.interacting_chains:
2634  self.neighbors[p[0]].add(p[1])
2635  self.neighbors[p[1]].add(p[0])
2636 
2637  self.mdl_neighbors = {k: set() for k in self.qsent2.chain_names}
2638  for p in self.qsent2.interacting_chains:
2639  self.mdl_neighbors[p[0]].add(p[1])
2640  self.mdl_neighbors[p[1]].add(p[0])
2641 
2642  assert(len(ref_chem_groups) == len(mdl_chem_groups))
2643  self.ref_chem_groups = ref_chem_groups
2644  self.mdl_chem_groups = mdl_chem_groups
2645  self.ref_ch_group_mapper = dict()
2646  self.mdl_ch_group_mapper = dict()
2647  for g_idx, (ref_g, mdl_g) in enumerate(zip(ref_chem_groups,
2648  mdl_chem_groups)):
2649  for ch in ref_g:
2650  self.ref_ch_group_mapper[ch] = g_idx
2651  for ch in mdl_g:
2652  self.mdl_ch_group_mapper[ch] = g_idx
2653 
2654  # cache for lDDT based single chain conserved contacts
2655  # used to identify starting points for further extension by QS score
2656  # key: tuple (ref_ch, mdl_ch) value: number of conserved contacts
2657  self.single_chain_scorer = dict()
2658  self.single_chain_cache = dict()
2659  for ch in self.ref.chains:
2660  single_chain_ref = _CSel(self.ref, [ch.GetName()])
2661  self.single_chain_scorer[ch.GetName()] = \
2662  lddt.lDDTScorer(single_chain_ref, bb_only = True)
2663 
2664  def SCCounts(self, ref_ch, mdl_ch):
2665  if not (ref_ch, mdl_ch) in self.single_chain_cache:
2666  alns = dict()
2667  alns[mdl_ch] = self.ref_mdl_alns[(ref_ch, mdl_ch)]
2668  mdl_sel = _CSel(self.mdl, [mdl_ch])
2669  s = self.single_chain_scorer[ref_ch]
2670  _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel,
2671  residue_mapping=alns,
2672  return_dist_test=True,
2673  no_interchain=True,
2674  chain_mapping={mdl_ch: ref_ch},
2675  check_resnames=False)
2676  self.single_chain_cache[(ref_ch, mdl_ch)] = conserved
2677  return self.single_chain_cache[(ref_ch, mdl_ch)]
2678 
2679  def ExtendMapping(self, mapping, max_ext = None):
2680 
2681  if len(mapping) == 0:
2682  raise RuntimError("Mapping must contain a starting point")
2683 
2684  # Ref chains onto which we can map. The algorithm starts with a mapping
2685  # on ref_ch. From there we can start to expand to connected neighbors.
2686  # All neighbors that we can reach from the already mapped chains are
2687  # stored in this set which will be updated during runtime
2688  map_targets = set()
2689  for ref_ch in mapping.keys():
2690  map_targets.update(self.neighbors[ref_ch])
2691 
2692  # remove the already mapped chains
2693  for ref_ch in mapping.keys():
2694  map_targets.discard(ref_ch)
2695 
2696  if len(map_targets) == 0:
2697  return mapping # nothing to extend
2698 
2699  # keep track of what model chains are not yet mapped for each chem group
2700  free_mdl_chains = list()
2701  for chem_group in self.mdl_chem_groups:
2702  tmp = [x for x in chem_group if x not in mapping.values()]
2703  free_mdl_chains.append(set(tmp))
2704 
2705  # keep track of what ref chains got a mapping
2706  newly_mapped_ref_chains = list()
2707 
2708  something_happened = True
2709  while something_happened:
2710  something_happened=False
2711 
2712  if self.steep_opt_rate is not None:
2713  n_chains = len(newly_mapped_ref_chains)
2714  if n_chains > 0 and n_chains % self.steep_opt_rate == 0:
2715  mapping = self._SteepOpt(mapping, newly_mapped_ref_chains)
2716 
2717  if max_ext is not None and len(newly_mapped_ref_chains) >= max_ext:
2718  break
2719 
2720  score_result = self.FromFlatMapping(mapping)
2721  old_score = score_result.QS_global
2722  nominator = score_result.weighted_scores
2723  denominator = score_result.weight_sum + score_result.weight_extra_all
2724 
2725  max_diff = 0.0
2726  max_mapping = None
2727  for ref_ch in map_targets:
2728  chem_group_idx = self.ref_ch_group_mapper[ref_ch]
2729  for mdl_ch in free_mdl_chains[chem_group_idx]:
2730  # we're not computing full QS-score here, we directly hack
2731  # into the QS-score formula to compute a diff
2732  nominator_diff = 0.0
2733  denominator_diff = 0.0
2734  for neighbor in self.neighbors[ref_ch]:
2735  if neighbor in mapping and mapping[neighbor] in \
2736  self.mdl_neighbors[mdl_ch]:
2737  # it's a newly added interface if (ref_ch, mdl_ch)
2738  # are added to mapping
2739  int1 = (ref_ch, neighbor)
2740  int2 = (mdl_ch, mapping[neighbor])
2741  a, b, c, d = self._MappedInterfaceScores(int1, int2)
2742  nominator_diff += a # weighted_scores
2743  denominator_diff += b # weight_sum
2744  denominator_diff += d # weight_extra_all
2745  # the respective interface penalties are subtracted
2746  # from denominator
2747  denominator_diff -= self._InterfacePenalty1(int1)
2748  denominator_diff -= self._InterfacePenalty2(int2)
2749 
2750  if nominator_diff > 0:
2751  # Only accept a new solution if its actually connected
2752  # i.e. nominator_diff > 0.
2753  new_nominator = nominator + nominator_diff
2754  new_denominator = denominator + denominator_diff
2755  new_score = 0.0
2756  if new_denominator != 0.0:
2757  new_score = new_nominator/new_denominator
2758  diff = new_score - old_score
2759  if diff > max_diff:
2760  max_diff = diff
2761  max_mapping = (ref_ch, mdl_ch)
2762 
2763  if max_mapping is not None:
2764  something_happened = True
2765  # assign new found mapping
2766  mapping[max_mapping[0]] = max_mapping[1]
2767 
2768  # add all neighboring chains to map targets as they are now
2769  # reachable
2770  for neighbor in self.neighbors[max_mapping[0]]:
2771  if neighbor not in mapping:
2772  map_targets.add(neighbor)
2773 
2774  # remove the ref chain from map targets
2775  map_targets.remove(max_mapping[0])
2776 
2777  # remove the mdl chain from free_mdl_chains - its taken...
2778  chem_group_idx = self.ref_ch_group_mapper[max_mapping[0]]
2779  free_mdl_chains[chem_group_idx].remove(max_mapping[1])
2780 
2781  # keep track of what ref chains got a mapping
2782  newly_mapped_ref_chains.append(max_mapping[0])
2783 
2784  return mapping
2785 
2786  def _SteepOpt(self, mapping, chains_to_optimize=None):
2787 
2788  # just optimize ALL ref chains if nothing specified
2789  if chains_to_optimize is None:
2790  chains_to_optimize = mapping.keys()
2791 
2792  # make sure that we only have ref chains which are actually mapped
2793  ref_chains = [x for x in chains_to_optimize if mapping[x] is not None]
2794 
2795  # group ref chains to be optimized into chem groups
2796  tmp = dict()
2797  for ch in ref_chains:
2798  chem_group_idx = self.ref_ch_group_mapper[ch]
2799  if chem_group_idx in tmp:
2800  tmp[chem_group_idx].append(ch)
2801  else:
2802  tmp[chem_group_idx] = [ch]
2803  chem_groups = list(tmp.values())
2804 
2805  # try all possible mapping swaps. Swaps that improve the score are
2806  # immediately accepted and we start all over again
2807  score_result = self.FromFlatMapping(mapping)
2808  current_score = score_result.QS_global
2809  something_happened = True
2810  while something_happened:
2811  something_happened = False
2812  for chem_group in chem_groups:
2813  if something_happened:
2814  break
2815  for ch1, ch2 in itertools.combinations(chem_group, 2):
2816  swapped_mapping = dict(mapping)
2817  swapped_mapping[ch1] = mapping[ch2]
2818  swapped_mapping[ch2] = mapping[ch1]
2819  score_result = self.FromFlatMapping(swapped_mapping)
2820  if score_result.QS_global > current_score:
2821  something_happened = True
2822  mapping = swapped_mapping
2823  current_score = score_result.QS_global
2824  break
2825  return mapping
2826 
2827 
2828 def _QSScoreNaive(trg, mdl, chem_groups, chem_mapping, ref_mdl_alns, contact_d,
2829  n_max_naive):
2830  best_mapping = None
2831  best_score = -1.0
2832  # qs_scorer implements caching, score calculation is thus as fast as it gets
2833  # you'll just hit a wall when the number of possible mappings becomes large
2834  qs_scorer = qsscore.QSScorer(trg, chem_groups, mdl, ref_mdl_alns)
2835  for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive):
2836  score_result = qs_scorer.Score(mapping, check=False)
2837  if score_result.QS_global > best_score:
2838  best_mapping = mapping
2839  best_score = score_result.QS_global
2840  return (best_mapping, best_score)
2841 
2842 
2843 def _QSScoreGreedyFast(the_greed):
2844 
2845  something_happened = True
2846  mapping = dict()
2847  while something_happened:
2848  something_happened = False
2849  # search for best scoring starting point, we're using lDDT here
2850  n_best = 0
2851  best_seed = None
2852  seeds = _GetSeeds(the_greed.ref_chem_groups,
2853  the_greed.mdl_chem_groups,
2854  mapped_ref_chains = set(mapping.keys()),
2855  mapped_mdl_chains = set(mapping.values()))
2856  for seed in seeds:
2857  n = the_greed.SCCounts(seed[0], seed[1])
2858  if n > n_best:
2859  n_best = n
2860  best_seed = seed
2861  if n_best == 0:
2862  break # no proper seed found anymore...
2863  # add seed to mapping and start the greed
2864  mapping[best_seed[0]] = best_seed[1]
2865  mapping = the_greed.ExtendMapping(mapping)
2866  something_happened = True
2867 
2868  # translate mapping format and return
2869  final_mapping = list()
2870  for ref_chains in the_greed.ref_chem_groups:
2871  mapped_mdl_chains = list()
2872  for ref_ch in ref_chains:
2873  if ref_ch in mapping:
2874  mapped_mdl_chains.append(mapping[ref_ch])
2875  else:
2876  mapped_mdl_chains.append(None)
2877  final_mapping.append(mapped_mdl_chains)
2878 
2879  return final_mapping
2880 
2881 
2882 def _QSScoreGreedyFull(the_greed):
2883  """ Uses each reference chain as starting point for expansion
2884  """
2885 
2886  seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups)
2887  best_overall_score = -1.0
2888  best_overall_mapping = dict()
2889 
2890  for seed in seeds:
2891 
2892  # do initial extension
2893  mapping = the_greed.ExtendMapping({seed[0]: seed[1]})
2894 
2895  # repeat the process until we have a full mapping
2896  something_happened = True
2897  while something_happened:
2898  something_happened = False
2899  remnant_seeds = _GetSeeds(the_greed.ref_chem_groups,
2900  the_greed.mdl_chem_groups,
2901  mapped_ref_chains = set(mapping.keys()),
2902  mapped_mdl_chains = set(mapping.values()))
2903  if len(remnant_seeds) > 0:
2904  # still more mapping to be done
2905  best_score = -1.0
2906  best_mapping = None
2907  for remnant_seed in remnant_seeds:
2908  tmp_mapping = dict(mapping)
2909  tmp_mapping[remnant_seed[0]] = remnant_seed[1]
2910  tmp_mapping = the_greed.ExtendMapping(tmp_mapping)
2911  score_result = the_greed.FromFlatMapping(tmp_mapping)
2912  if score_result.QS_global > best_score:
2913  best_score = score_result.QS_global
2914  best_mapping = tmp_mapping
2915  if best_mapping is not None:
2916  something_happened = True
2917  mapping = best_mapping
2918 
2919  score_result = the_greed.FromFlatMapping(mapping)
2920  if score_result.QS_global > best_overall_score:
2921  best_overall_score = score_result.QS_global
2922  best_overall_mapping = mapping
2923 
2924  mapping = best_overall_mapping
2925 
2926  # translate mapping format and return
2927  final_mapping = list()
2928  for ref_chains in the_greed.ref_chem_groups:
2929  mapped_mdl_chains = list()
2930  for ref_ch in ref_chains:
2931  if ref_ch in mapping:
2932  mapped_mdl_chains.append(mapping[ref_ch])
2933  else:
2934  mapped_mdl_chains.append(None)
2935  final_mapping.append(mapped_mdl_chains)
2936 
2937  return final_mapping
2938 
2939 
2940 def _QSScoreGreedyBlock(the_greed, seed_size, blocks_per_chem_group):
2941  """ try multiple seeds, i.e. try all ref/mdl chain combinations within the
2942  respective chem groups and compute single chain lDDTs. The
2943  *blocks_per_chem_group* best scoring ones are extend by *seed_size* chains
2944  and the best scoring one with respect to QS score is exhaustively extended.
2945  """
2946 
2947  if seed_size is None or seed_size < 1:
2948  raise RuntimeError(f"seed_size must be an int >= 1 (got {seed_size})")
2949 
2950  if blocks_per_chem_group is None or blocks_per_chem_group < 1:
2951  raise RuntimeError(f"blocks_per_chem_group must be an int >= 1 "
2952  f"(got {blocks_per_chem_group})")
2953 
2954  max_ext = seed_size - 1 # -1 => start seed already has size 1
2955 
2956  ref_chem_groups = copy.deepcopy(the_greed.ref_chem_groups)
2957  mdl_chem_groups = copy.deepcopy(the_greed.mdl_chem_groups)
2958 
2959  mapping = dict()
2960 
2961  something_happened = True
2962  while something_happened:
2963  something_happened = False
2964  starting_blocks = list()
2965  for ref_chains, mdl_chains in zip(ref_chem_groups, mdl_chem_groups):
2966  if len(mdl_chains) == 0:
2967  continue # nothing to map
2968  ref_chains_copy = list(ref_chains)
2969  for i in range(blocks_per_chem_group):
2970  if len(ref_chains_copy) == 0:
2971  break
2972  seeds = list()
2973  for ref_ch in ref_chains_copy:
2974  seeds += [(ref_ch, mdl_ch) for mdl_ch in mdl_chains]
2975  # extend starting seeds to *seed_size* and retain best scoring block
2976  # for further extension
2977  best_score = -1.0
2978  best_mapping = None
2979  best_seed = None
2980  for s in seeds:
2981  seed = dict(mapping)
2982  seed.update({s[0]: s[1]})
2983  seed = the_greed.ExtendMapping(seed, max_ext = max_ext)
2984  score_result = the_greed.FromFlatMapping(seed)
2985  if score_result.QS_global > best_score:
2986  best_score = score_result.QS_global
2987  best_mapping = seed
2988  best_seed = s
2989  if best_mapping != None:
2990  starting_blocks.append(best_mapping)
2991  if best_seed[0] in ref_chains_copy:
2992  # remove selected ref chain to enforce diversity
2993  ref_chains_copy.remove(best_seed[0])
2994 
2995  # fully expand initial starting blocks
2996  best_score = -1.0
2997  best_mapping = None
2998  for seed in starting_blocks:
2999  seed = the_greed.ExtendMapping(seed)
3000  score_result = the_greed.FromFlatMapping(seed)
3001  if score_result.QS_global > best_score:
3002  best_score = score_result.QS_global
3003  best_mapping = seed
3004 
3005  if best_mapping is not None and len(best_mapping) > len(mapping):
3006  # this even accepts extensions that lead to no increase in QS-score
3007  # at least they make sense from an lDDT perspective
3008  something_happened = True
3009  mapping.update(best_mapping)
3010  for ref_ch, mdl_ch in best_mapping.items():
3011  for group_idx in range(len(ref_chem_groups)):
3012  if ref_ch in ref_chem_groups[group_idx]:
3013  ref_chem_groups[group_idx].remove(ref_ch)
3014  if mdl_ch in mdl_chem_groups[group_idx]:
3015  mdl_chem_groups[group_idx].remove(mdl_ch)
3016 
3017  # translate mapping format and return
3018  final_mapping = list()
3019  for ref_chains in the_greed.ref_chem_groups:
3020  mapped_mdl_chains = list()
3021  for ref_ch in ref_chains:
3022  if ref_ch in mapping:
3023  mapped_mdl_chains.append(mapping[ref_ch])
3024  else:
3025  mapped_mdl_chains.append(None)
3026  final_mapping.append(mapped_mdl_chains)
3027 
3028  return final_mapping
3029 
3030 
3031 def _SingleRigidGDTTS(initial_transforms, initial_mappings, chem_groups,
3032  chem_mapping, trg_group_pos, mdl_group_pos,
3033  single_chain_gdtts_thresh, iterative_superposition,
3034  first_complete, n_trg_chains, n_mdl_chains):
3035  """ Takes initial transforms and sequentially adds chain pairs with
3036  best scoring gdtts that fulfill single_chain_gdtts_thresh. The mapping
3037  from the transform that leads to best overall gdtts score is returned.
3038  Optionally, the first complete mapping, i.e. a mapping that covers all
3039  target chains or all model chains, is returned.
3040  """
3041  best_mapping = dict()
3042  best_gdt = 0
3043  for transform in initial_transforms:
3044  mapping = dict()
3045  mapped_mdl_chains = set()
3046  gdt = 0.0
3047 
3048  for trg_chains, mdl_chains, trg_pos, mdl_pos, in zip(chem_groups,
3049  chem_mapping,
3050  trg_group_pos,
3051  mdl_group_pos):
3052 
3053  if len(trg_pos) == 0 or len(mdl_pos) == 0:
3054  continue # cannot compute valid gdt
3055 
3056  gdt_scores = list()
3057 
3058  t_mdl_pos = list()
3059  for m_pos in mdl_pos:
3060  t_m_pos = geom.Vec3List(m_pos)
3061  t_m_pos.ApplyTransform(transform)
3062  t_mdl_pos.append(t_m_pos)
3063 
3064  for t_pos, t in zip(trg_pos, trg_chains):
3065  for t_m_pos, m in zip(t_mdl_pos, mdl_chains):
3066  gdt = t_pos.GetGDTTS(t_m_pos)
3067  if gdt >= single_chain_gdtts_thresh:
3068  gdt_scores.append((gdt, (t,m)))
3069 
3070  n_gdt_contacts = 4 * len(trg_pos[0])
3071  gdt_scores.sort(reverse=True)
3072  for item in gdt_scores:
3073  p = item[1]
3074  if p[0] not in mapping and p[1] not in mapped_mdl_chains:
3075  mapping[p[0]] = p[1]
3076  mapped_mdl_chains.add(p[1])
3077  gdt += (item[0] * n_gdt_contacts)
3078 
3079  if gdt > best_gdt:
3080  best_gdt = gdt
3081  best_mapping = mapping
3082  if first_complete:
3083  n = len(mapping)
3084  if n == n_mdl_chains or n == n_trg_chains:
3085  break
3086 
3087  return best_mapping
3088 
3089 
3090 def _IterativeRigidGDTTS(initial_transforms, initial_mappings, chem_groups,
3091  chem_mapping, trg_group_pos, mdl_group_pos,
3092  single_chain_gdtts_thresh, iterative_superposition,
3093  first_complete, n_trg_chains, n_mdl_chains):
3094  """ Takes initial transforms and sequentially adds chain pairs with
3095  best scoring gdtts that fulfill single_chain_gdtts_thresh. With each
3096  added chain pair, the transform gets updated. Thus the naming iterative.
3097  The mapping from the initial transform that leads to best overall gdtts
3098  score is returned. Optionally, the first complete mapping, i.e. a mapping
3099  that covers all target chains or all model chains, is returned.
3100  """
3101 
3102  # to directly retrieve positions using chain names
3103  trg_pos_dict = dict()
3104  for trg_pos, trg_chains in zip(trg_group_pos, chem_groups):
3105  for t_pos, t in zip(trg_pos, trg_chains):
3106  trg_pos_dict[t] = t_pos
3107  mdl_pos_dict = dict()
3108  for mdl_pos, mdl_chains in zip(mdl_group_pos, chem_mapping):
3109  for m_pos, m in zip(mdl_pos, mdl_chains):
3110  mdl_pos_dict[m] = m_pos
3111 
3112  best_mapping = dict()
3113  best_gdt = 0
3114  for initial_transform, initial_mapping in zip(initial_transforms,
3115  initial_mappings):
3116  mapping = {initial_mapping[0]: initial_mapping[1]}
3117  transform = geom.Mat4(initial_transform)
3118  mapped_trg_pos = geom.Vec3List(trg_pos_dict[initial_mapping[0]])
3119  mapped_mdl_pos = geom.Vec3List(mdl_pos_dict[initial_mapping[1]])
3120 
3121  # the following variables contain the chains which are
3122  # available for mapping
3123  trg_chain_groups = [set(group) for group in chem_groups]
3124  mdl_chain_groups = [set(group) for group in chem_mapping]
3125 
3126  # search and kick out inital mapping
3127  for group in trg_chain_groups:
3128  if initial_mapping[0] in group:
3129  group.remove(initial_mapping[0])
3130  break
3131  for group in mdl_chain_groups:
3132  if initial_mapping[1] in group:
3133  group.remove(initial_mapping[1])
3134  break
3135 
3136  something_happened = True
3137  while something_happened:
3138  # search for best mapping given current transform
3139  something_happened=False
3140  best_sc_mapping = None
3141  best_sc_group_idx = None
3142  best_sc_gdt = 0.0
3143  group_idx = 0
3144  for trg_chains, mdl_chains in zip(trg_chain_groups, mdl_chain_groups):
3145  for t in trg_chains:
3146  t_pos = trg_pos_dict[t]
3147  for m in mdl_chains:
3148  m_pos = mdl_pos_dict[m]
3149  t_m_pos = geom.Vec3List(m_pos)
3150  t_m_pos.ApplyTransform(transform)
3151  gdt = t_pos.GetGDTTS(t_m_pos)
3152  if gdt > single_chain_gdtts_thresh and gdt > best_sc_gdt:
3153  best_sc_gdt = gdt
3154  best_sc_mapping = (t,m)
3155  best_sc_group_idx = group_idx
3156  group_idx += 1
3157 
3158  if best_sc_mapping is not None:
3159  something_happened = True
3160  mapping[best_sc_mapping[0]] = best_sc_mapping[1]
3161  mapped_trg_pos.extend(trg_pos_dict[best_sc_mapping[0]])
3162  mapped_mdl_pos.extend(mdl_pos_dict[best_sc_mapping[1]])
3163  trg_chain_groups[best_sc_group_idx].remove(best_sc_mapping[0])
3164  mdl_chain_groups[best_sc_group_idx].remove(best_sc_mapping[1])
3165 
3166  transform = _GetTransform(mapped_mdl_pos, mapped_trg_pos,
3167  iterative_superposition)
3168 
3169  # compute overall gdt for current transform (non-normalized gdt!!!)
3170  mapped_mdl_pos.ApplyTransform(transform)
3171  gdt = mapped_trg_pos.GetGDTTS(mapped_mdl_pos, norm=False)
3172 
3173  if gdt > best_gdt:
3174  best_gdt = gdt
3175  best_mapping = mapping
3176  if first_complete:
3177  n = len(mapping)
3178  if n == n_mdl_chains or n == n_trg_chains:
3179  break
3180 
3181  return best_mapping
3182 
3183 def _SingleRigidRMSD(initial_transforms, initial_mappings, chem_groups,
3184  chem_mapping, trg_group_pos, mdl_group_pos,
3185  iterative_superposition):
3186  """
3187  Takes initial transforms and sequentially adds chain pairs with lowest RMSD.
3188  The mapping from the transform that leads to lowest overall RMSD is
3189  returned.
3190  """
3191  best_mapping = dict()
3192  best_ssd = float("inf") # we're actually going for summed squared distances
3193  # Since all positions have same lengths and we do a
3194  # full mapping, lowest SSD has a guarantee of also
3195  # being lowest RMSD
3196  for transform in initial_transforms:
3197  mapping = dict()
3198  mapped_mdl_chains = set()
3199  ssd = 0.0
3200  for trg_chains, mdl_chains, trg_pos, mdl_pos, in zip(chem_groups,
3201  chem_mapping,
3202  trg_group_pos,
3203  mdl_group_pos):
3204  if len(trg_pos) == 0 or len(mdl_pos) == 0:
3205  continue # cannot compute valid rmsd
3206  ssds = list()
3207  t_mdl_pos = list()
3208  for m_pos in mdl_pos:
3209  t_m_pos = geom.Vec3List(m_pos)
3210  t_m_pos.ApplyTransform(transform)
3211  t_mdl_pos.append(t_m_pos)
3212  for t_pos, t in zip(trg_pos, trg_chains):
3213  for t_m_pos, m in zip(t_mdl_pos, mdl_chains):
3214  ssd = t_pos.GetSummedSquaredDistances(t_m_pos)
3215  ssds.append((ssd, (t,m)))
3216  ssds.sort()
3217  for item in ssds:
3218  p = item[1]
3219  if p[0] not in mapping and p[1] not in mapped_mdl_chains:
3220  mapping[p[0]] = p[1]
3221  mapped_mdl_chains.add(p[1])
3222  ssd += item[0]
3223 
3224  if ssd < best_ssd:
3225  best_ssd = ssd
3226  best_mapping = mapping
3227 
3228  return best_mapping
3229 
3230 def _IterativeRigidRMSD(initial_transforms, initial_mappings, chem_groups,
3231  chem_mapping, trg_group_pos, mdl_group_pos,
3232  iterative_superposition):
3233  """ Takes initial transforms and sequentially adds chain pairs with
3234  lowest RMSD. With each added chain pair, the transform gets updated.
3235  Thus the naming iterative. The mapping from the initial transform that
3236  leads to best overall RMSD score is returned.
3237  """
3238 
3239  # to directly retrieve positions using chain names
3240  trg_pos_dict = dict()
3241  for trg_pos, trg_chains in zip(trg_group_pos, chem_groups):
3242  for t_pos, t in zip(trg_pos, trg_chains):
3243  trg_pos_dict[t] = t_pos
3244  mdl_pos_dict = dict()
3245  for mdl_pos, mdl_chains in zip(mdl_group_pos, chem_mapping):
3246  for m_pos, m in zip(mdl_pos, mdl_chains):
3247  mdl_pos_dict[m] = m_pos
3248 
3249  best_mapping = dict()
3250  best_rmsd = float("inf")
3251  for initial_transform, initial_mapping in zip(initial_transforms,
3252  initial_mappings):
3253  mapping = {initial_mapping[0]: initial_mapping[1]}
3254  transform = geom.Mat4(initial_transform)
3255  mapped_trg_pos = geom.Vec3List(trg_pos_dict[initial_mapping[0]])
3256  mapped_mdl_pos = geom.Vec3List(mdl_pos_dict[initial_mapping[1]])
3257 
3258  # the following variables contain the chains which are
3259  # available for mapping
3260  trg_chain_groups = [set(group) for group in chem_groups]
3261  mdl_chain_groups = [set(group) for group in chem_mapping]
3262 
3263  # search and kick out inital mapping
3264  for group in trg_chain_groups:
3265  if initial_mapping[0] in group:
3266  group.remove(initial_mapping[0])
3267  break
3268  for group in mdl_chain_groups:
3269  if initial_mapping[1] in group:
3270  group.remove(initial_mapping[1])
3271  break
3272 
3273  something_happened = True
3274  while something_happened:
3275  # search for best mapping given current transform
3276  something_happened=False
3277  best_sc_mapping = None
3278  best_sc_group_idx = None
3279  best_sc_rmsd = float("inf")
3280  group_idx = 0
3281  for trg_chains, mdl_chains in zip(trg_chain_groups, mdl_chain_groups):
3282  for t in trg_chains:
3283  t_pos = trg_pos_dict[t]
3284  for m in mdl_chains:
3285  m_pos = mdl_pos_dict[m]
3286  t_m_pos = geom.Vec3List(m_pos)
3287  t_m_pos.ApplyTransform(transform)
3288  rmsd = t_pos.GetRMSD(t_m_pos)
3289  if rmsd < best_sc_rmsd:
3290  best_sc_rmsd = rmsd
3291  best_sc_mapping = (t,m)
3292  best_sc_group_idx = group_idx
3293  group_idx += 1
3294 
3295  if best_sc_mapping is not None:
3296  something_happened = True
3297  mapping[best_sc_mapping[0]] = best_sc_mapping[1]
3298  mapped_trg_pos.extend(trg_pos_dict[best_sc_mapping[0]])
3299  mapped_mdl_pos.extend(mdl_pos_dict[best_sc_mapping[1]])
3300  trg_chain_groups[best_sc_group_idx].remove(best_sc_mapping[0])
3301  mdl_chain_groups[best_sc_group_idx].remove(best_sc_mapping[1])
3302 
3303  transform = _GetTransform(mapped_mdl_pos, mapped_trg_pos,
3304  iterative_superposition)
3305 
3306  # compute overall RMSD for current transform
3307  mapped_mdl_pos.ApplyTransform(transform)
3308  rmsd = mapped_trg_pos.GetRMSD(mapped_mdl_pos)
3309 
3310  if rmsd < best_rmsd:
3311  best_rmsd = rmsd
3312  best_mapping = mapping
3313 
3314  return best_mapping
3315 
3316 
3317 def _GetRefPos(trg, mdl, trg_msas, mdl_alns, max_pos = None):
3318  """ Extracts reference positions which are present in trg and mdl
3319  """
3320 
3321  # select only backbone atoms, makes processing simpler later on
3322  # (just select res.atoms[0].GetPos() as ref pos)
3323  bb_trg = trg.Select("aname=\"CA\",\"C3'\"")
3324  bb_mdl = mdl.Select("aname=\"CA\",\"C3'\"")
3325 
3326  # mdl_alns are pairwise, let's construct MSAs
3327  mdl_msas = list()
3328  for aln_list in mdl_alns:
3329  if len(aln_list) > 0:
3330  tmp = aln_list[0].GetSequence(0)
3331  ref_seq = seq.CreateSequence(tmp.GetName(), tmp.GetGaplessString())
3332  mdl_msas.append(seq.alg.MergePairwiseAlignments(aln_list, ref_seq))
3333  else:
3334  mdl_msas.append(seq.CreateAlignment())
3335 
3336  trg_pos = list()
3337  mdl_pos = list()
3338 
3339  for trg_msa, mdl_msa in zip(trg_msas, mdl_msas):
3340 
3341  if mdl_msa.GetCount() > 0:
3342  # make sure they have the same ref sequence (should be a given...)
3343  assert(trg_msa.GetSequence(0).GetGaplessString() == \
3344  mdl_msa.GetSequence(0).GetGaplessString())
3345  else:
3346  # if mdl_msa is empty, i.e. no model chain maps to the chem group
3347  # represented by trg_msa, we just continue. The result will be
3348  # empty position lists added to trg_pos and mdl_pos.
3349  pass
3350 
3351  # check which columns in MSAs are fully covered (indices relative to
3352  # first sequence)
3353  trg_indices = _GetFullyCoveredIndices(trg_msa)
3354  mdl_indices = _GetFullyCoveredIndices(mdl_msa)
3355 
3356  # get indices where both, mdl and trg, are fully covered
3357  indices = sorted(list(trg_indices.intersection(mdl_indices)))
3358 
3359  # subsample if necessary
3360  if max_pos is not None and len(indices) > max_pos:
3361  step = int(len(indices)/max_pos)
3362  indices = [indices[i] for i in range(0, len(indices), step)]
3363 
3364  # translate to column indices in the respective MSAs
3365  trg_indices = _RefIndicesToColumnIndices(trg_msa, indices)
3366  mdl_indices = _RefIndicesToColumnIndices(mdl_msa, indices)
3367 
3368  # extract positions
3369  trg_pos.append(list())
3370  mdl_pos.append(list())
3371  for s_idx in range(trg_msa.GetCount()):
3372  trg_pos[-1].append(_ExtractMSAPos(trg_msa, s_idx, trg_indices,
3373  bb_trg))
3374  # first seq in mdl_msa is ref sequence in trg and does not belong to mdl
3375  for s_idx in range(1, mdl_msa.GetCount()):
3376  mdl_pos[-1].append(_ExtractMSAPos(mdl_msa, s_idx, mdl_indices,
3377  bb_mdl))
3378 
3379  return (trg_pos, mdl_pos)
3380 
3381 def _GetFullyCoveredIndices(msa):
3382  """ Helper for _GetRefPos
3383 
3384  Returns a set containing the indices relative to first sequence in msa which
3385  are fully covered in all other sequences
3386 
3387  --AA-A-A
3388  -BBBB-BB
3389  CCCC-C-C
3390 
3391  => (0,1,3)
3392  """
3393  indices = set()
3394  ref_idx = 0
3395  for col in msa:
3396  if sum([1 for olc in col if olc != '-']) == col.GetRowCount():
3397  indices.add(ref_idx)
3398  if col[0] != '-':
3399  ref_idx += 1
3400  return indices
3401 
3402 def _RefIndicesToColumnIndices(msa, indices):
3403  """ Helper for _GetRefPos
3404 
3405  Returns a list of mapped indices. indices refer to non-gap one letter
3406  codes in the first msa sequence. The returnes mapped indices are translated
3407  to the according msa column indices
3408  """
3409  ref_idx = 0
3410  mapping = dict()
3411  for col_idx, col in enumerate(msa):
3412  if col[0] != '-':
3413  mapping[ref_idx] = col_idx
3414  ref_idx += 1
3415  return [mapping[i] for i in indices]
3416 
3417 def _ExtractMSAPos(msa, s_idx, indices, view):
3418  """ Helper for _GetRefPos
3419 
3420  Returns a geom.Vec3List containing positions refering to given msa sequence.
3421  => Chain with corresponding name is mapped onto sequence and the position of
3422  the first atom of each residue specified in indices is extracted.
3423  Indices refers to column indices in msa!
3424  """
3425  s = msa.GetSequence(s_idx)
3426  s_v = _CSel(view, [s.GetName()])
3427 
3428  # sanity check
3429  assert(len(s.GetGaplessString()) == len(s_v.residues))
3430 
3431  residue_idx = [s.GetResidueIndex(i) for i in indices]
3432  return geom.Vec3List([s_v.residues[i].atoms[0].pos for i in residue_idx])
3433 
3434 def _NChemGroupMappings(ref_chains, mdl_chains):
3435  """ Number of mappings within one chem group
3436 
3437  :param ref_chains: Reference chains
3438  :type ref_chains: :class:`list` of :class:`str`
3439  :param mdl_chains: Model chains that are mapped onto *ref_chains*
3440  :type mdl_chains: :class:`list` of :class:`str`
3441  :returns: Number of possible mappings of *mdl_chains* onto *ref_chains*
3442  """
3443  n_ref = len(ref_chains)
3444  n_mdl = len(mdl_chains)
3445  if n_ref == n_mdl:
3446  return factorial(n_ref)
3447  elif n_ref > n_mdl:
3448  n_choose_k = binom(n_ref, n_mdl)
3449  return n_choose_k * factorial(n_mdl)
3450  else:
3451  n_choose_k = binom(n_mdl, n_ref)
3452  return n_choose_k * factorial(n_ref)
3453 
3454 def _NMappings(ref_chains, mdl_chains):
3455  """ Number of mappings for a full chem mapping
3456 
3457  :param ref_chains: Chem groups of reference
3458  :type ref_chains: :class:`list` of :class:`list` of :class:`str`
3459  :param mdl_chains: Model chains that map onto those chem groups
3460  :type mdl_chains: :class:`list` of :class:`list` of :class:`str`
3461  :returns: Number of possible mappings of *mdl_chains* onto *ref_chains*
3462  """
3463  assert(len(ref_chains) == len(mdl_chains))
3464  n = 1
3465  for a,b in zip(ref_chains, mdl_chains):
3466  n *= _NChemGroupMappings(a,b)
3467  return n
3468 
3469 def _NMappingsWithin(ref_chains, mdl_chains, max_mappings):
3470  """ Check whether total number of mappings is smaller than given maximum
3471 
3472  In principle the same as :func:`_NMappings` but it stops as soon as the
3473  maximum is hit.
3474 
3475  :param ref_chains: Chem groups of reference
3476  :type ref_chains: :class:`list` of :class:`list` of :class:`str`
3477  :param mdl_chains: Model chains that map onto those chem groups
3478  :type mdl_chains: :class:`list` of :class:`list` of :class:`str`
3479  :param max_mappings: Number of max allowed mappings
3480  :returns: Whether number of possible mappings of *mdl_chains* onto
3481  *ref_chains* is below or equal *max_mappings*.
3482  """
3483  assert(len(ref_chains) == len(mdl_chains))
3484  n = 1
3485  for a,b in zip(ref_chains, mdl_chains):
3486  n *= _NChemGroupMappings(a,b)
3487  if n > max_mappings:
3488  return False
3489  return True
3490 
3491 def _RefSmallerGenerator(ref_chains, mdl_chains):
3492  """ Returns all possible ways to map mdl_chains onto ref_chains
3493 
3494  Specific for the case where len(ref_chains) < len(mdl_chains)
3495  """
3496  for c in itertools.combinations(mdl_chains, len(ref_chains)):
3497  for p in itertools.permutations(c):
3498  yield list(p)
3499 
3500 def _RefLargerGenerator(ref_chains, mdl_chains):
3501  """ Returns all possible ways to map mdl_chains onto ref_chains
3502 
3503  Specific for the case where len(ref_chains) > len(mdl_chains)
3504  Ref chains without mapped mdl chain are assigned None
3505  """
3506  n_ref = len(ref_chains)
3507  n_mdl = len(mdl_chains)
3508  for c in itertools.combinations(range(n_ref), n_mdl):
3509  for p in itertools.permutations(mdl_chains):
3510  ret_list = [None] * n_ref
3511  for idx, ch in zip(c, p):
3512  ret_list[idx] = ch
3513  yield ret_list
3514 
3515 def _RefEqualGenerator(ref_chains, mdl_chains):
3516  """ Returns all possible ways to map mdl_chains onto ref_chains
3517 
3518  Specific for the case where len(ref_chains) == len(mdl_chains)
3519  """
3520  for p in itertools.permutations(mdl_chains):
3521  yield list(p)
3522 
3523 def _ConcatIterators(iterators):
3524  for item in itertools.product(*iterators):
3525  yield list(item)
3526 
3527 def _ChainMappings(ref_chains, mdl_chains, n_max=None):
3528  """Returns all possible ways to map *mdl_chains* onto fixed *ref_chains*
3529 
3530  :param ref_chains: List of list of chemically equivalent chains in reference
3531  :type ref_chains: :class:`list` of :class:`list`
3532  :param mdl_chains: Equally long list of list of chemically equivalent chains
3533  in model that map on those ref chains.
3534  :type mdl_chains: :class:`list` of :class:`list`
3535  :param n_max: Aborts and raises :class:`RuntimeError` if max number of
3536  mappings is above this threshold.
3537  :type n_max: :class:`int`
3538  :returns: Iterator over all possible mappings of *mdl_chains* onto fixed
3539  *ref_chains*. Potentially contains None as padding when number of
3540  model chains for a certain mapping is smaller than the according
3541  reference chains.
3542  Example: _ChainMappings([['A', 'B', 'C'], ['D', 'E']],
3543  [['x', 'y'], ['i', 'j']])
3544  gives an iterator over: [[['x', 'y', None], ['i', 'j']],
3545  [['x', 'y', None], ['j', 'i']],
3546  [['y', 'x', None], ['i', 'j']],
3547  [['y', 'x', None], ['j', 'i']],
3548  [['x', None, 'y'], ['i', 'j']],
3549  [['x', None, 'y'], ['j', 'i']],
3550  [['y', None, 'x'], ['i', 'j']],
3551  [['y', None, 'x'], ['j', 'i']],
3552  [[None, 'x', 'y'], ['i', 'j']],
3553  [[None, 'x', 'y'], ['j', 'i']],
3554  [[None, 'y', 'x'], ['i', 'j']],
3555  [[None, 'y', 'x'], ['j', 'i']]]
3556  """
3557  assert(len(ref_chains) == len(mdl_chains))
3558 
3559  if n_max is not None:
3560  if not _NMappingsWithin(ref_chains, mdl_chains, n_max):
3561  raise RuntimeError(f"Too many mappings. Max allowed: {n_max}")
3562 
3563  # one iterator per mapping representing all mdl combinations relative to
3564  # reference
3565  iterators = list()
3566  for ref, mdl in zip(ref_chains, mdl_chains):
3567  if len(ref) == 0:
3568  raise RuntimeError("Expext at least one chain in ref chem group")
3569  if len(ref) == len(mdl):
3570  iterators.append(_RefEqualGenerator(ref, mdl))
3571  elif len(ref) < len(mdl):
3572  iterators.append(_RefSmallerGenerator(ref, mdl))
3573  else:
3574  iterators.append(_RefLargerGenerator(ref, mdl))
3575 
3576  return _ConcatIterators(iterators)
3577 
3578 
3579 def _GetTransform(pos_one, pos_two, iterative):
3580  """ Computes minimal RMSD superposition for pos_one onto pos_two
3581 
3582  :param pos_one: Positions that should be superposed onto *pos_two*
3583  :type pos_one: :class:`geom.Vec3List`
3584  :param pos_two: Reference positions
3585  :type pos_two: :class:`geom.Vec3List`
3586  :iterative: Whether iterative superposition should be used. Iterative
3587  potentially raises, uses standard superposition as fallback.
3588  :type iterative: :class:`bool`
3589  :returns: Transformation matrix to superpose *pos_one* onto *pos_two*
3590  :rtype: :class:`geom.Mat4`
3591  """
3592  res = None
3593  if iterative:
3594  try:
3595  res = mol.alg.IterativeSuperposeSVD(pos_one, pos_two)
3596  except:
3597  pass # triggers fallback below
3598  if res is None:
3599  res = mol.alg.SuperposeSVD(pos_one, pos_two)
3600  return res.transformation
3601 
3602 # specify public interface
3603 __all__ = ('ChainMapper', 'ReprResult', 'MappingResult')