OpenStructure
Loading...
Searching...
No Matches
contact_score.py
Go to the documentation of this file.
1import itertools
2import numpy as np
3
4import time
5from ost import mol
6from ost import geom
7from ost import io
8
10 """ Helper object for Contact-score computation
11 """
12 def __init__(self, ent, contact_d = 5.0, contact_mode="aa"):
13
14 if contact_mode not in ["aa", "repr"]:
15 raise RuntimeError("contact_mode must be in [\"aa\", \"repr\"]")
16
17 if contact_mode == "repr":
18 for r in ent.residues:
19 repr_at = None
20 if r.IsPeptideLinking():
21 cb = r.FindAtom("CB")
22 if cb.IsValid():
23 repr_at = cb
24 elif r.GetName() == "GLY":
25 ca = r.FindAtom("CA")
26 if ca.IsValid():
27 repr_at = ca
28 elif r.IsNucleotideLinking():
29 c3 = r.FindAtom("C3'")
30 if c3.IsValid():
31 repr_at = c3
32 else:
33 raise RuntimeError(f"Only support peptide and nucleotide "
34 f"residues in \"repr\" contact mode. "
35 f"Problematic residue: {r}")
36 if repr_at is None:
37 raise RuntimeError(f"Residue {r} has no required "
38 f"representative atom (CB for peptide "
39 f"residues (CA for GLY) C3' for "
40 f"nucleotide residues.")
41
42 self._contact_mode = contact_mode
43
44 if self.contact_modecontact_mode == "aa":
45 self._view = ent.CreateFullView()
46 elif self.contact_modecontact_mode == "repr":
47 pep_query = "(peptide=true and (aname=\"CB\" or (rname=\"GLY\" and aname=\"CA\")))"
48 nuc_query = "(nucleotide=True and aname=\"C3'\")"
49 self._view = ent.Select(" or ".join([pep_query, nuc_query]))
50 self._contact_d = contact_d
51
52 # the following attributes will be lazily evaluated
53 self._chain_names = None
55 self._sequence = dict()
56 self._contacts = None
57 self._hr_contacts = None
60
61 @property
62 def view(self):
63 """ The structure depending on *contact_mode*
64
65 Full view in case of "aa", view that only contains representative
66 atoms in case of "repr".
67
68 :type: :class:`ost.mol.EntityView`
69 """
70 return self._view
71
72 @property
73 def contact_mode(self):
74 """ The contact mode
75
76 Can either be "aa", meaning that all atoms are considered to identify
77 contacts, or "repr" which only considers distances between
78 representative atoms. For peptides thats CB (CA for GLY), for
79 nucleotides thats C3'.
80
81 :type: :class:`str`
82 """
83 return self._contact_mode
84
85 @property
86 def contact_d(self):
87 """ Pairwise distance of residues to be considered as contacts
88
89 Given at :class:`ContactScorer` construction
90
91 :type: :class:`float`
92 """
93 return self._contact_d
94
95 @property
96 def chain_names(self):
97 """ Chain names in :attr:`~view`
98
99 Names are sorted
100
101 :type: :class:`list` of :class:`str`
102 """
103 if self._chain_names is None:
104 self._chain_names = sorted([ch.name for ch in self.view.chains])
105 return self._chain_names
106
107 @property
109 """ Pairs of chains in :attr:`~view` with at least one contact
110
111 :type: :class:`list` of :class:`tuples`
112 """
113 if self._interacting_chains is None:
114 self._interacting_chains = list(self.contacts.keys())
115 return self._interacting_chains
116
117 @property
118 def contacts(self):
119 """ Interchain contacts
120
121 Organized as :class:`dict` with key (cname1, cname2) and values being
122 a set of tuples with the respective residue indices.
123 cname1 < cname2 evaluates to True.
124 """
125 if self._contacts is None:
126 self._SetupContacts()
127 return self._contacts
128
129 @property
130 def hr_contacts(self):
131 """ Human readable interchain contacts
132
133 Human readable version of :attr:`~contacts`. Simple list with tuples
134 containing two strings specifying the residues in contact. Format:
135 <cname>.<rnum>.<ins_code>
136 """
137 if self._hr_contacts is None:
138 self._SetupContacts()
139 return self._hr_contacts
140
141 @property
143 """ Interface residues
144
145 Residues in each chain that are in contact with any other chain.
146 Organized as :class:`dict` with key cname and values the respective
147 residue indices in a :class:`set`.
148 """
149 if self._interface_residues is None:
151 return self._interface_residues
152
153 @property
155 """ Human readable interface residues
156
157 Human readable version of :attr:`interface_residues`. :class:`list` of
158 strings specifying the interface residues in format:
159 <cname>.<rnum>.<ins_code>
160 """
161 if self._interface_residues is None:
163 return self._hr_interface_residues
164
165 def GetChain(self, chain_name):
166 """ Get chain by name
167
168 :param chain_name: Chain in :attr:`~view`
169 :type chain_name: :class:`str`
170 """
171 chain = self.view.FindChain(chain_name)
172 if not chain.IsValid():
173 raise RuntimeError(f"view has no chain named \"{chain_name}\"")
174 return chain
175
176 def GetSequence(self, chain_name):
177 """ Get sequence of chain
178
179 Returns sequence of specified chain as raw :class:`str`
180
181 :param chain_name: Chain in :attr:`~view`
182 :type chain_name: :class:`str`
183 """
184 if chain_name not in self._sequence:
185 ch = self.GetChain(chain_name)
186 s = ''.join([r.one_letter_code for r in ch.residues])
187 self._sequence[chain_name] = s
188 return self._sequence[chain_name]
189
190 def _SetupContacts(self):
191 self._contacts = dict()
192 self._hr_contacts = list()
193
194 # set indices relative to full view
195 for ch in self.view.chains:
196 for r_idx, r in enumerate(ch.residues):
197 r.SetIntProp("contact_idx", r_idx)
198
199 residue_lists = list()
200 min_res_x = list()
201 min_res_y = list()
202 min_res_z = list()
203 max_res_x = list()
204 max_res_y = list()
205 max_res_z = list()
206 per_res_pos = list()
207 min_chain_pos = list()
208 max_chain_pos = list()
209
210 for cname in self.chain_nameschain_names:
211 ch = self.view.FindChain(cname)
212 if ch.GetAtomCount() == 0:
213 raise RuntimeError(f"Chain without atoms observed: \"{cname}\"")
214 residue_lists.append([r for r in ch.residues])
215 res_pos = list()
216 for r in residue_lists[-1]:
217 pos = np.zeros((r.GetAtomCount(), 3))
218 for at_idx, at in enumerate(r.atoms):
219 p = at.GetPos()
220 pos[(at_idx, 0)] = p[0]
221 pos[(at_idx, 1)] = p[1]
222 pos[(at_idx, 2)] = p[2]
223 res_pos.append(pos)
224 min_res_pos = np.vstack([p.min(0) for p in res_pos])
225 max_res_pos = np.vstack([p.max(0) for p in res_pos])
226 min_res_x.append(min_res_pos[:, 0])
227 min_res_y.append(min_res_pos[:, 1])
228 min_res_z.append(min_res_pos[:, 2])
229 max_res_x.append(max_res_pos[:, 0])
230 max_res_y.append(max_res_pos[:, 1])
231 max_res_z.append(max_res_pos[:, 2])
232 min_chain_pos.append(min_res_pos.min(0))
233 max_chain_pos.append(max_res_pos.max(0))
234 per_res_pos.append(res_pos)
235
236 # operate on squared contact_d (scd) to save some square roots
237 scd = self.contact_d * self.contact_d
238
239 for ch1_idx in range(len(self.chain_nameschain_names)):
240 for ch2_idx in range(ch1_idx + 1, len(self.chain_nameschain_names)):
241 # chains which fulfill the following expressions have no contact
242 # within self.contact_d
243 if np.max(min_chain_pos[ch1_idx] - max_chain_pos[ch2_idx]) > self.contact_d:
244 continue
245 if np.max(min_chain_pos[ch2_idx] - max_chain_pos[ch1_idx]) > self.contact_d:
246 continue
247
248 # same thing for residue positions but all at once
249 skip_one = np.subtract.outer(min_res_x[ch1_idx], max_res_x[ch2_idx]) > self.contact_d
250 skip_one = np.logical_or(skip_one, np.subtract.outer(min_res_y[ch1_idx], max_res_y[ch2_idx]) > self.contact_d)
251 skip_one = np.logical_or(skip_one, np.subtract.outer(min_res_z[ch1_idx], max_res_z[ch2_idx]) > self.contact_d)
252 skip_two = np.subtract.outer(min_res_x[ch2_idx], max_res_x[ch1_idx]) > self.contact_d
253 skip_two = np.logical_or(skip_two, np.subtract.outer(min_res_y[ch2_idx], max_res_y[ch1_idx]) > self.contact_d)
254 skip_two = np.logical_or(skip_two, np.subtract.outer(min_res_z[ch2_idx], max_res_z[ch1_idx]) > self.contact_d)
255 skip = np.logical_or(skip_one, skip_two.T)
256
257 # identify residue pairs for which we cannot exclude a contact
258 r1_indices, r2_indices = np.nonzero(np.logical_not(skip))
259 ch1_per_res_pos = per_res_pos[ch1_idx]
260 ch2_per_res_pos = per_res_pos[ch2_idx]
261 for r1_idx, r2_idx in zip(r1_indices, r2_indices):
262 # compute pairwise distances
263 p1 = ch1_per_res_pos[r1_idx]
264 p2 = ch2_per_res_pos[r2_idx]
265 x2 = np.sum(p1**2, axis=1) # (m)
266 y2 = np.sum(p2**2, axis=1) # (n)
267 xy = np.matmul(p1, p2.T) # (m, n)
268 x2 = x2.reshape(-1, 1)
269 squared_distances = x2 - 2*xy + y2 # (m, n)
270 if np.min(squared_distances) <= scd:
271 # its a contact!
272 r1 = residue_lists[ch1_idx][r1_idx]
273 r2 = residue_lists[ch2_idx][r2_idx]
274 cname_key = (self.chain_nameschain_names[ch1_idx], self.chain_nameschain_names[ch2_idx])
275 if cname_key not in self._contacts:
276 self._contacts[cname_key] = set()
277 self._contacts[cname_key].add((r1.GetIntProp("contact_idx"),
278 r2.GetIntProp("contact_idx")))
279 rnum1 = r1.GetNumber()
280 hr1 = f"{self.chain_names[ch1_idx]}.{rnum1.num}.{rnum1.ins_code}"
281 rnum2 = r2.GetNumber()
282 hr2 = f"{self.chain_names[ch2_idx]}.{rnum2.num}.{rnum2.ins_code}"
283 self._hr_contacts.append((hr1.strip("\u0000"),
284 hr2.strip("\u0000")))
285
286
288 self._interface_residues = {cname: set() for cname in self.chain_nameschain_names}
289 for k,v in self.contacts.items():
290 for item in v:
291 self._interface_residues[k[0]].add(item[0])
292 self._interface_residues[k[1]].add(item[1])
293
295 interface_residues = set()
296 for item in self.hr_contacts:
297 interface_residues.add(item[0])
298 interface_residues.add(item[1])
299 self._hr_interface_residues = list(interface_residues)
300
301
303 """
304 Holds data relevant to compute ics
305 """
306 def __init__(self, n_trg_contacts, n_mdl_contacts, n_union, n_intersection):
307 self._n_trg_contacts = n_trg_contacts
308 self._n_mdl_contacts = n_mdl_contacts
309 self._n_union = n_union
310 self._n_intersection = n_intersection
311
312 @property
313 def n_trg_contacts(self):
314 """ Number of contacts in target
315
316 :type: :class:`int`
317 """
318 return self._n_trg_contacts
319
320 @property
321 def n_mdl_contacts(self):
322 """ Number of contacts in model
323
324 :type: :class:`int`
325 """
326 return self._n_mdl_contacts
327
328 @property
329 def precision(self):
330 """ Precision of model contacts
331
332 The fraction of model contacts that are also present in target
333
334 :type: :class:`int`
335 """
336 if self._n_mdl_contacts != 0:
337 return self._n_intersection / self._n_mdl_contacts
338 else:
339 return 0.0
340
341 @property
342 def recall(self):
343 """ Recall of model contacts
344
345 The fraction of target contacts that are also present in model
346
347 :type: :class:`int`
348 """
349 if self._n_trg_contacts != 0:
350 return self._n_intersection / self._n_trg_contacts
351 else:
352 return 0.0
353
354 @property
355 def ics(self):
356 """ The Interface Contact Similarity score (ICS)
357
358 Combination of :attr:`precision` and :attr:`recall` using the F1-measure
359
360 :type: :class:`float`
361 """
362 p = self.precision
363 r = self.recall
364 nominator = p*r
365 denominator = p + r
366 if denominator != 0.0:
367 return 2*nominator/denominator
368 else:
369 return 0.0
370
372 """
373 Holds data relevant to compute ips
374 """
375 def __init__(self, n_trg_int_res, n_mdl_int_res, n_union, n_intersection):
376 self._n_trg_int_res = n_trg_int_res
377 self._n_mdl_int_res = n_mdl_int_res
378 self._n_union = n_union
379 self._n_intersection = n_intersection
380
381 @property
382 def n_trg_int_res(self):
383 """ Number of interface residues in target
384
385 :type: :class:`int`
386 """
387 return self._n_trg_int_res
388
389 @property
390 def n_mdl_int_res(self):
391 """ Number of interface residues in model
392
393 :type: :class:`int`
394 """
395 return self._n_mdl_int_res
396
397 @property
398 def precision(self):
399 """ Precision of model interface residues
400
401 The fraction of model interface residues that are also interface
402 residues in target
403
404 :type: :class:`int`
405 """
406 if self._n_mdl_int_res != 0:
407 return self._n_intersection / self._n_mdl_int_res
408 else:
409 return 0.0
410
411 @property
412 def recall(self):
413 """ Recall of model interface residues
414
415 The fraction of target interface residues that are also interface
416 residues in model
417
418 :type: :class:`int`
419 """
420 if self._n_trg_int_res != 0:
421 return self._n_intersection / self._n_trg_int_res
422 else:
423 return 0.0
424
425 @property
426 def ips(self):
427 """ The Interface Patch Similarity score (IPS)
428
429 Jaccard coefficient of interface residues in model/target.
430 Technically thats :attr:`intersection`/:attr:`union`
431
432 :type: :class:`float`
433 """
434 if(self._n_union > 0):
435 return self._n_intersection/self._n_union
436 return 0.0
437
439 """ Helper object to compute Contact scores
440
441 Tightly integrated into the mechanisms from the chain_mapping module.
442 The prefered way to derive an object of type :class:`ContactScorer` is
443 through the static constructor: :func:`~FromMappingResult`.
444
445 Usage is the same as for :class:`ost.mol.alg.QSScorer`
446 """
447
448 def __init__(self, target, chem_groups, model, alns,
449 contact_mode="aa", contact_d=5.0):
450 self._cent1 = ContactEntity(target, contact_mode = contact_mode,
451 contact_d = contact_d)
452 # ensure that target chain names match the ones in chem_groups
453 chem_group_ch_names = list(itertools.chain.from_iterable(chem_groups))
454 if self._cent1.chain_names != sorted(chem_group_ch_names):
455 raise RuntimeError(f"Expect exact same chain names in chem_groups "
456 f"and in target (which is processed to only "
457 f"contain peptides/nucleotides). target: "
458 f"{self._cent1.chain_names}, chem_groups: "
459 f"{chem_group_ch_names}")
460
461 self._chem_groups = chem_groups
462 self._cent2 = ContactEntity(model, contact_mode = contact_mode,
463 contact_d = contact_d)
464 self._alns = alns
465
466 # cache for mapped interface scores
467 # key: tuple of tuple ((qsent1_ch1, qsent1_ch2),
468 # ((qsent2_ch1, qsent2_ch2))
469 # value: tuple with four numbers required for computation of
470 # per-interface scores.
471 # The first two are relevant for ICS, the others for per
472 # interface IPS.
473 # 1: n_union_contacts
474 # 2: n_intersection_contacts
475 # 3: n_union_interface_residues
476 # 4: n_intersection_interface_residues
478
479 # cache for mapped single chain scores
480 # for interface residues of single chains
481 # key: tuple: (qsent1_ch, qsent2_ch)
482 # value: tuple with two numbers required for computation of IPS
483 # 1: n_union
484 # 2: n_intersection
485 self._mapped_cache_sc = dict()
486
487 @staticmethod
488 def FromMappingResult(mapping_result, contact_mode="aa", contact_d = 5.0):
489 """ The preferred way to get a :class:`ContactScorer`
490
491 Static constructor that derives an object of type :class:`ContactScorer`
492 using a :class:`ost.mol.alg.chain_mapping.MappingResult`
493
494 :param mapping_result: Data source
495 :type mapping_result: :class:`ost.mol.alg.chain_mapping.MappingResult`
496 """
497 contact_scorer = ContactScorer(mapping_result.target,
498 mapping_result.chem_groups,
499 mapping_result.model,
500 mapping_result.alns,
501 contact_mode = contact_mode,
502 contact_d = contact_d)
503 return contact_scorer
504
505 @property
506 def cent1(self):
507 """ Represents *target*
508
509 :type: :class:`ContactEntity`
510 """
511 return self._cent1
512
513 @property
514 def chem_groups(self):
515 """ Groups of chemically equivalent chains in *target*
516
517 Provided at object construction
518
519 :type: :class:`list` of :class:`list` of :class:`str`
520 """
521 return self._chem_groups
522
523 @property
524 def cent2(self):
525 """ Represents *model*
526
527 :type: :class:`ContactEntity`
528 """
529 return self._cent2
530
531 @property
532 def alns(self):
533 """ Alignments between chains in :attr:`~cent1` and :attr:`~cent2`
534
535 Provided at object construction. Each alignment is accessible with
536 ``alns[(t_chain,m_chain)]``. First sequence is the sequence of the
537 respective chain in :attr:`~cent1`, second sequence the one from
538 :attr:`~cent2`.
539
540 :type: :class:`dict` with key: :class:`tuple` of :class:`str`, value:
541 :class:`ost.seq.AlignmentHandle`
542 """
543 return self._alns
544
545 def ScoreICS(self, mapping, check=True):
546 """ Computes ICS given chain mapping
547
548 Again, the preferred way is to get *mapping* is from an object
549 of type :class:`ost.mol.alg.chain_mapping.MappingResult`.
550
551 :param mapping: see
552 :attr:`ost.mol.alg.chain_mapping.MappingResult.mapping`
553 :type mapping: :class:`list` of :class:`list` of :class:`str`
554 :param check: Perform input checks, can be disabled for speed purposes
555 if you know what you're doing.
556 :type check: :class:`bool`
557 :returns: Result object of type :class:`ContactScorerResultICS`
558 """
559
560 if check:
561 # ensure that dimensionality of mapping matches self.chem_groups
562 if len(self.chem_groupschem_groups) != len(mapping):
563 raise RuntimeError("Dimensions of self.chem_groups and mapping "
564 "must match")
565 for a,b in zip(self.chem_groupschem_groups, mapping):
566 if len(a) != len(b):
567 raise RuntimeError("Dimensions of self.chem_groups and "
568 "mapping must match")
569 # ensure that chain names in mapping are all present in cent2
570 for name in itertools.chain.from_iterable(mapping):
571 if name is not None and name not in self.cent2.chain_names:
572 raise RuntimeError(f"Each chain in mapping must be present "
573 f"in self.cent2. No match for "
574 f"\"{name}\"")
575
576 flat_mapping = dict()
577 for a, b in zip(self.chem_groupschem_groups, mapping):
578 flat_mapping.update({x: y for x, y in zip(a, b) if y is not None})
579
580 return self.ICSFromFlatMapping(flat_mapping)
581
582 def ScoreICSInterface(self, trg_ch1, trg_ch2, mdl_ch1, mdl_ch2):
583 """ Computes ICS scores only considering one interface
584
585 This only works for interfaces that are computed in :func:`Score`, i.e.
586 interfaces for which the alignments are set up correctly.
587
588 :param trg_ch1: Name of first interface chain in target
589 :type trg_ch1: :class:`str`
590 :param trg_ch2: Name of second interface chain in target
591 :type trg_ch2: :class:`str`
592 :param mdl_ch1: Name of first interface chain in model
593 :type mdl_ch1: :class:`str`
594 :param mdl_ch2: Name of second interface chain in model
595 :type mdl_ch2: :class:`str`
596 :returns: Result object of type :class:`ContactScorerResultICS`
597 :raises: :class:`RuntimeError` if no aln for trg_ch1/mdl_ch1 or
598 trg_ch2/mdl_ch2 is available.
599 """
600 if (trg_ch1, mdl_ch1) not in self.alns:
601 raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and "
602 f"mdl_ch1 ({mdl_ch1}) available. Did you "
603 f"construct the QSScorer object from a "
604 f"MappingResult and are trg_ch1 and mdl_ch1 "
605 f"mapped to each other?")
606 if (trg_ch2, mdl_ch2) not in self.alns:
607 raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and "
608 f"mdl_ch1 ({mdl_ch1}) available. Did you "
609 f"construct the QSScorer object from a "
610 f"MappingResult and are trg_ch1 and mdl_ch1 "
611 f"mapped to each other?")
612 trg_int = (trg_ch1, trg_ch2)
613 mdl_int = (mdl_ch1, mdl_ch2)
614 trg_int_r = (trg_ch2, trg_ch1)
615 mdl_int_r = (mdl_ch2, mdl_ch1)
616
617 if trg_int in self.cent1.contacts:
618 n_trg = len(self.cent1.contacts[trg_int])
619 elif trg_int_r in self.cent1.contacts:
620 n_trg = len(self.cent1.contacts[trg_int_r])
621 else:
622 n_trg = 0
623
624 if mdl_int in self.cent2.contacts:
625 n_mdl = len(self.cent2.contacts[mdl_int])
626 elif mdl_int_r in self.cent2.contacts:
627 n_mdl = len(self.cent2.contacts[mdl_int_r])
628 else:
629 n_mdl = 0
630
631 n_union, n_intersection, _, _ = self._MappedInterfaceScores(trg_int, mdl_int)
632 return ContactScorerResultICS(n_trg, n_mdl, n_union, n_intersection)
633
634 def ICSFromFlatMapping(self, flat_mapping):
635 """ Same as :func:`ScoreICS` but with flat mapping
636
637 :param flat_mapping: Dictionary with target chain names as keys and
638 the mapped model chain names as value
639 :type flat_mapping: :class:`dict` with :class:`str` as key and value
640 :returns: Result object of type :class:`ContactScorerResultICS`
641 """
642 n_trg = sum([len(x) for x in self.cent1.contacts.values()])
643 n_mdl = sum([len(x) for x in self.cent2.contacts.values()])
644 n_union = 0
645 n_intersection = 0
646
647 processed_cent2_interfaces = set()
648 for int1 in self.cent1.interacting_chains:
649 if int1[0] in flat_mapping and int1[1] in flat_mapping:
650 int2 = (flat_mapping[int1[0]], flat_mapping[int1[1]])
651 a, b, _, _ = self._MappedInterfaceScores(int1, int2)
652 n_union += a
653 n_intersection += b
654 processed_cent2_interfaces.add((min(int2), max(int2)))
655
656 # process interfaces that only exist in qsent2
657 r_flat_mapping = {v:k for k,v in flat_mapping.items()} # reverse mapping
658 for int2 in self.cent2.interacting_chains:
659 if int2 not in processed_cent2_interfaces:
660 if int2[0] in r_flat_mapping and int2[1] in r_flat_mapping:
661 int1 = (r_flat_mapping[int2[0]], r_flat_mapping[int2[1]])
662 a, b, _, _ = self._MappedInterfaceScores(int1, int2)
663 n_union += a
664 n_intersection += b
665
666 return ContactScorerResultICS(n_trg, n_mdl,
667 n_union, n_intersection)
668
669 def ScoreIPS(self, mapping, check=True):
670 """ Computes IPS given chain mapping
671
672 Again, the preferred way is to get *mapping* is from an object
673 of type :class:`ost.mol.alg.chain_mapping.MappingResult`.
674
675 :param mapping: see
676 :attr:`ost.mol.alg.chain_mapping.MappingResult.mapping`
677 :type mapping: :class:`list` of :class:`list` of :class:`str`
678 :param check: Perform input checks, can be disabled for speed purposes
679 if you know what you're doing.
680 :type check: :class:`bool`
681 :returns: Result object of type :class:`ContactScorerResultIPS`
682 """
683
684 if check:
685 # ensure that dimensionality of mapping matches self.chem_groups
686 if len(self.chem_groupschem_groups) != len(mapping):
687 raise RuntimeError("Dimensions of self.chem_groups and mapping "
688 "must match")
689 for a,b in zip(self.chem_groupschem_groups, mapping):
690 if len(a) != len(b):
691 raise RuntimeError("Dimensions of self.chem_groups and "
692 "mapping must match")
693 # ensure that chain names in mapping are all present in cent2
694 for name in itertools.chain.from_iterable(mapping):
695 if name is not None and name not in self.cent2.chain_names:
696 raise RuntimeError(f"Each chain in mapping must be present "
697 f"in self.cent2. No match for "
698 f"\"{name}\"")
699
700 flat_mapping = dict()
701 for a, b in zip(self.chem_groupschem_groups, mapping):
702 flat_mapping.update({x: y for x, y in zip(a, b) if y is not None})
703
704 return self.IPSFromFlatMapping(flat_mapping)
705
706 def ScoreIPSInterface(self, trg_ch1, trg_ch2, mdl_ch1, mdl_ch2):
707 """ Computes IPS scores only considering one interface
708
709 This only works for interfaces that are computed in :func:`Score`, i.e.
710 interfaces for which the alignments are set up correctly.
711
712 :param trg_ch1: Name of first interface chain in target
713 :type trg_ch1: :class:`str`
714 :param trg_ch2: Name of second interface chain in target
715 :type trg_ch2: :class:`str`
716 :param mdl_ch1: Name of first interface chain in model
717 :type mdl_ch1: :class:`str`
718 :param mdl_ch2: Name of second interface chain in model
719 :type mdl_ch2: :class:`str`
720 :returns: Result object of type :class:`ContactScorerResultIPS`
721 :raises: :class:`RuntimeError` if no aln for trg_ch1/mdl_ch1 or
722 trg_ch2/mdl_ch2 is available.
723 """
724 if (trg_ch1, mdl_ch1) not in self.alns:
725 raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and "
726 f"mdl_ch1 ({mdl_ch1}) available. Did you "
727 f"construct the QSScorer object from a "
728 f"MappingResult and are trg_ch1 and mdl_ch1 "
729 f"mapped to each other?")
730 if (trg_ch2, mdl_ch2) not in self.alns:
731 raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and "
732 f"mdl_ch1 ({mdl_ch1}) available. Did you "
733 f"construct the QSScorer object from a "
734 f"MappingResult and are trg_ch1 and mdl_ch1 "
735 f"mapped to each other?")
736 trg_int = (trg_ch1, trg_ch2)
737 mdl_int = (mdl_ch1, mdl_ch2)
738 trg_int_r = (trg_ch2, trg_ch1)
739 mdl_int_r = (mdl_ch2, mdl_ch1)
740
741 trg_contacts = None
742 if trg_int in self.cent1.contacts:
743 trg_contacts = self.cent1.contacts[trg_int]
744 elif trg_int_r in self.cent1.contacts:
745 trg_contacts = self.cent1.contacts[trg_int_r]
746
747 if trg_contacts is None:
748 n_trg = 0
749 else:
750 n_trg = len(set([x[0] for x in trg_contacts]))
751 n_trg += len(set([x[1] for x in trg_contacts]))
752
753 mdl_contacts = None
754 if mdl_int in self.cent2.contacts:
755 mdl_contacts = self.cent2.contacts[mdl_int]
756 elif mdl_int_r in self.cent2.contacts:
757 mdl_contacts = self.cent2.contacts[mdl_int_r]
758
759 if mdl_contacts is None:
760 n_mdl = 0
761 else:
762 n_mdl = len(set([x[0] for x in mdl_contacts]))
763 n_mdl += len(set([x[1] for x in mdl_contacts]))
764
765 _, _, n_union, n_intersection = self._MappedInterfaceScores(trg_int, mdl_int)
766 return ContactScorerResultIPS(n_trg, n_mdl, n_union, n_intersection)
767
768
769 def IPSFromFlatMapping(self, flat_mapping):
770 """ Same as :func:`ScoreIPS` but with flat mapping
771
772 :param flat_mapping: Dictionary with target chain names as keys and
773 the mapped model chain names as value
774 :type flat_mapping: :class:`dict` with :class:`str` as key and value
775 :returns: Result object of type :class:`ContactScorerResultIPS`
776 """
777 n_trg = sum([len(x) for x in self.cent1.interface_residues.values()])
778 n_mdl = sum([len(x) for x in self.cent2.interface_residues.values()])
779 n_union = 0
780 n_intersection = 0
781
782 processed_cent2_chains = set()
783 for trg_ch in self.cent1.chain_names:
784 if trg_ch in flat_mapping:
785 a, b = self._MappedSCScores(trg_ch, flat_mapping[trg_ch])
786 n_union += a
787 n_intersection += b
788 processed_cent2_chains.add(flat_mapping[trg_ch])
789 else:
790 n_union += len(self.cent1.interface_residues[trg_ch])
791
792 for mdl_ch in self._cent2.chain_names:
793 if mdl_ch not in processed_cent2_chains:
794 n_union += len(self.cent2.interface_residues[mdl_ch])
795
796 return ContactScorerResultIPS(n_trg, n_mdl,
797 n_union, n_intersection)
798
799
800 def _MappedInterfaceScores(self, int1, int2):
801 key_one = (int1, int2)
802 if key_one in self._mapped_cache_interface:
803 return self._mapped_cache_interface[key_one]
804 key_two = ((int1[1], int1[0]), (int2[1], int2[0]))
805 if key_two in self._mapped_cache_interface:
806 return self._mapped_cache_interface[key_two]
807
808 a, b, c, d = self._InterfaceScores(int1, int2)
809 self._mapped_cache_interface[key_one] = (a, b, c, d)
810 return (a, b, c, d)
811
812 def _InterfaceScores(self, int1, int2):
813 if int1 in self.cent1.contacts:
814 ref_contacts = self.cent1.contacts[int1]
815 elif (int1[1], int1[0]) in self.cent1.contacts:
816 ref_contacts = self.cent1.contacts[(int1[1], int1[0])]
817 # need to reverse contacts
818 ref_contacts = set([(x[1], x[0]) for x in ref_contacts])
819 else:
820 ref_contacts = set() # no contacts at all
821
822 if int2 in self.cent2.contacts:
823 mdl_contacts = self.cent2.contacts[int2]
824 elif (int2[1], int2[0]) in self.cent2.contacts:
825 mdl_contacts = self.cent2.contacts[(int2[1], int2[0])]
826 # need to reverse contacts
827 mdl_contacts = set([(x[1], x[0]) for x in mdl_contacts])
828 else:
829 mdl_contacts = set() # no contacts at all
830
831 # indices in contacts lists are specific to the respective
832 # structures, need manual mapping from alignments
833 ch1_aln = self.alns[(int1[0], int2[0])]
834 ch2_aln = self.alns[(int1[1], int2[1])]
835 mapped_ref_contacts = set()
836 mapped_mdl_contacts = set()
837 for c in ref_contacts:
838 mapped_c = (ch1_aln.GetPos(0, c[0]), ch2_aln.GetPos(0, c[1]))
839 mapped_ref_contacts.add(mapped_c)
840 for c in mdl_contacts:
841 mapped_c = (ch1_aln.GetPos(1, c[0]), ch2_aln.GetPos(1, c[1]))
842 mapped_mdl_contacts.add(mapped_c)
843
844 contact_union = len(mapped_ref_contacts.union(mapped_mdl_contacts))
845 contact_intersection = len(mapped_ref_contacts.intersection(mapped_mdl_contacts))
846
847 # above, we computed the union and intersection on actual
848 # contacts. Here, we do the same on interface residues
849
850 # process interface residues of chain one in interface
851 tmp_ref = set([x[0] for x in mapped_ref_contacts])
852 tmp_mdl = set([x[0] for x in mapped_mdl_contacts])
853 intres_union = len(tmp_ref.union(tmp_mdl))
854 intres_intersection = len(tmp_ref.intersection(tmp_mdl))
855
856 # process interface residues of chain two in interface
857 tmp_ref = set([x[1] for x in mapped_ref_contacts])
858 tmp_mdl = set([x[1] for x in mapped_mdl_contacts])
859 intres_union += len(tmp_ref.union(tmp_mdl))
860 intres_intersection += len(tmp_ref.intersection(tmp_mdl))
861
862 return (contact_union, contact_intersection,
863 intres_union, intres_intersection)
864
865 def _MappedSCScores(self, ref_ch, mdl_ch):
866 if (ref_ch, mdl_ch) in self._mapped_cache_sc:
867 return self._mapped_cache_sc[(ref_ch, mdl_ch)]
868 n_union, n_intersection = self._SCScores(ref_ch, mdl_ch)
869 self._mapped_cache_sc[(ref_ch, mdl_ch)] = (n_union, n_intersection)
870 return (n_union, n_intersection)
871
872 def _SCScores(self, ch1, ch2):
873 ref_int_res = self.cent1.interface_residues[ch1]
874 mdl_int_res = self.cent2.interface_residues[ch2]
875 aln = self.alns[(ch1, ch2)]
876 mapped_ref_int_res = set()
877 mapped_mdl_int_res = set()
878 for r_idx in ref_int_res:
879 mapped_ref_int_res.add(aln.GetPos(0, r_idx))
880 for r_idx in mdl_int_res:
881 mapped_mdl_int_res.add(aln.GetPos(1, r_idx))
882 return(len(mapped_ref_int_res.union(mapped_mdl_int_res)),
883 len(mapped_ref_int_res.intersection(mapped_mdl_int_res)))
884
885# specify public interface
886__all__ = ('ContactEntity', 'ContactScorerResultICS', 'ContactScorerResultIPS', 'ContactScorer')
__init__(self, ent, contact_d=5.0, contact_mode="aa")
ScoreIPSInterface(self, trg_ch1, trg_ch2, mdl_ch1, mdl_ch2)
ScoreICSInterface(self, trg_ch1, trg_ch2, mdl_ch1, mdl_ch2)
__init__(self, target, chem_groups, model, alns, contact_mode="aa", contact_d=5.0)
FromMappingResult(mapping_result, contact_mode="aa", contact_d=5.0)
__init__(self, n_trg_contacts, n_mdl_contacts, n_union, n_intersection)
__init__(self, n_trg_int_res, n_mdl_int_res, n_union, n_intersection)