OpenStructure
Loading...
Searching...
No Matches
bb_lddt.py
Go to the documentation of this file.
1import itertools
2import numpy as np
3from scipy.spatial import distance
4
5import time
6from ost import mol
7
9 """ Helper object for BBlDDT computation
10
11 Holds structural information and getters for interacting chains, i.e.
12 interfaces. Peptide residues are represented by their CA position
13 and nucleotides by C3'.
14
15 :param ent: Structure for BBlDDT score computation
16 :type ent: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
17 :param contact_d: Pairwise distance of residues to be considered as contacts
18 :type contact_d: :class:`float`
19 """
20 def __init__(self, ent, dist_thresh = 15.0,
21 dist_diff_thresholds = [0.5, 1.0, 2.0, 4.0]):
22 pep_query = "(peptide=true and aname=\"CA\")"
23 nuc_query = "(nucleotide=True and aname=\"C3'\")"
24 self._view = ent.Select(" or ".join([pep_query, nuc_query]))
25 self._dist_thresh = dist_thresh
26 self._dist_diff_thresholds = dist_diff_thresholds
27
28 # the following attributes will be lazily evaluated
29 self._chain_names = None
32 self._sequence = dict()
33 self._pos = dict()
34 self._pair_dist = dict()
35 self._sc_dist = dict()
36 self._n_pair_contacts = None
37 self._n_sc_contacts = None
38 self._n_contacts = None
39 # min and max xyz for elements in pos used for fast collision
40 # detection
41 self._min_pos = dict()
42 self._max_pos = dict()
43
44 @property
45 def view(self):
46 """ Processed structure
47
48 View that only contains representative atoms. That's CA for peptide
49 residues and C3' for nucleotides.
50
51 :type: :class:`ost.mol.EntityView`
52 """
53 return self._view
54
55 @property
56 def dist_thresh(self):
57 """ Pairwise distance of residues to be considered as contacts
58
59 Given at :class:`BBlDDTEntity` construction
60
61 :type: :class:`float`
62 """
63 return self._dist_thresh
64
65 @property
67 """ Distance difference thresholds for lDDT computation
68
69 Given at :class:`BBlDDTEntity` construction
70
71 :type: :class:`list` of :class:`float`
72 """
73 return self._dist_diff_thresholds
74
75 @property
76 def chain_names(self):
77 """ Chain names in :attr:`~view`
78
79 Names are sorted
80
81 :type: :class:`list` of :class:`str`
82 """
83 if self._chain_names is None:
84 self._chain_names = sorted([ch.name for ch in self.view.chains])
85 return self._chain_names
86
87 @property
89 """ Pairs of chains in :attr:`~view` with at least one contact
90
91 :type: :class:`list` of :class:`tuples`
92 """
93 if self._interacting_chains is None:
94 # ugly hack: also computes self._n_pair_contacts
95 # this assumption is made when computing the n_pair_contacts
96 # attribute
97 self._interacting_chains = list()
98 self._n_pair_contacts = list()
99 for x in itertools.combinations(self.chain_nameschain_names, 2):
100 if self.PotentialInteraction(x[0], x[1]):
101 n = np.count_nonzero(self.PairDist(x[0], x[1]) < self.dist_thresh)
102 if n > 0:
103 self._interacting_chains.append(x)
104 self._n_pair_contacts.append(n)
105 return self._interacting_chains
106
107 @property
109 """ Pairs of chains in :attr:`view` with potential contribution to lDDT
110
111 That are pairs of chains that have at least one interaction within
112 :attr:`~dist_thresh` + max(:attr:`~dist_diff_thresholds`)
113 """
114 if self._potentially_contributing_chains is None:
116 max_dist_diff_thresh = max(self.dist_diff_thresholds)
117 thresh = self.dist_thresh + max_dist_diff_thresh
118 for x in itertools.combinations(self.chain_nameschain_names, 2):
119 if self.PotentialInteraction(x[0], x[1],
120 slack = max_dist_diff_thresh):
121 n = np.count_nonzero(self.PairDist(x[0], x[1]) < thresh)
122 if n > 0:
124
126
127 @property
129 """ Number of contacts in :attr:`~interacting_chains`
130
131 :type: :class:`list` of :class:`int`
132 """
133 if self._n_pair_contacts:
134 # ugly hack: assumption that computing self.interacting_chains
135 # also triggers computation of n_pair_contacts
136 int_chains = self.interacting_chains
137 return self._n_pair_contacts
138
139 @property
140 def n_sc_contacts(self):
141 """ Number of contacts for single chains in :attr:`~chain_names`
142
143 :type: :class:`list` of :class:`int`
144 """
145 if self._n_sc_contacts is None:
146 self._n_sc_contacts = list()
147 for cname in self.chain_nameschain_names:
148 dist_mat = self.Dist(cname)
149 n = np.count_nonzero(dist_mat < self.dist_thresh)
150 # dist_mat is symmetric => first remove the diagonal from n
151 # as these are distances with itself, i.e. zeroes.
152 # Division by two then removes the symmetric component.
153 self._n_sc_contacts.append(int((n-dist_mat.shape[0])/2))
154 return self._n_sc_contacts
155
156 @property
157 def n_contacts(self):
158 """ Total number of contacts
159
160 That's the sum of all :attr:`~n_pair_contacts` and
161 :attr:`~n_sc_contacts`.
162
163 :type: :class:`int`
164 """
165 if self._n_contacts is None:
166 self._n_contacts = sum(self.n_pair_contacts) + sum(self.n_sc_contacts)
167 return self._n_contacts
168
169 def GetChain(self, chain_name):
170 """ Get chain by name
171
172 :param chain_name: Chain in :attr:`~view`
173 :type chain_name: :class:`str`
174 """
175 chain = self.view.FindChain(chain_name)
176 if not chain.IsValid():
177 raise RuntimeError(f"view has no chain named \"{chain_name}\"")
178 return chain
179
180 def GetSequence(self, chain_name):
181 """ Get sequence of chain
182
183 Returns sequence of specified chain as raw :class:`str`
184
185 :param chain_name: Chain in :attr:`~view`
186 :type chain_name: :class:`str`
187 """
188 if chain_name not in self._sequence:
189 ch = self.GetChain(chain_name)
190 s = ''.join([r.one_letter_code for r in ch.residues])
191 self._sequence[chain_name] = s
192 return self._sequence[chain_name]
193
194 def GetPos(self, chain_name):
195 """ Get representative positions of chain
196
197 That's CA positions for peptide residues and C3' for
198 nucleotides. Returns positions as a numpy array of shape
199 (n_residues, 3).
200
201 :param chain_name: Chain in :attr:`~view`
202 :type chain_name: :class:`str`
203 """
204 if chain_name not in self._pos:
205 ch = self.GetChain(chain_name)
206 pos = np.zeros((ch.GetResidueCount(), 3))
207 for i, r in enumerate(ch.residues):
208 pos[i,:] = r.atoms[0].GetPos().data
209 self._pos[chain_name] = pos
210 return self._pos[chain_name]
211
212 def Dist(self, chain_name):
213 """ Get pairwise distance of residues within same chain
214
215 Returns distances as square numpy array of shape (a,a)
216 where a is the number of residues in specified chain.
217 """
218 if chain_name not in self._sc_dist:
219 self._sc_dist[chain_name] = distance.cdist(self.GetPos(chain_name),
220 self.GetPos(chain_name),
221 'euclidean')
222 return self._sc_dist[chain_name]
223
224 def PairDist(self, chain_name_one, chain_name_two):
225 """ Get pairwise distances between two chains
226
227 Returns distances as numpy array of shape (a, b).
228 Where a is the number of residues of the chain that comes BEFORE the
229 other in :attr:`~chain_names`
230 """
231 key = (min(chain_name_one, chain_name_two),
232 max(chain_name_one, chain_name_two))
233 if key not in self._pair_dist:
234 self._pair_dist[key] = distance.cdist(self.GetPos(key[0]),
235 self.GetPos(key[1]),
236 'euclidean')
237 return self._pair_dist[key]
238
239 def GetMinPos(self, chain_name):
240 """ Get min x,y,z cooridnates for given chain
241
242 Based on positions that are extracted with GetPos
243
244 :param chain_name: Chain in :attr:`~view`
245 :type chain_name: :class:`str`
246 """
247 if chain_name not in self._min_pos:
248 self._min_pos[chain_name] = self.GetPos(chain_name).min(0)
249 return self._min_pos[chain_name]
250
251 def GetMaxPos(self, chain_name):
252 """ Get max x,y,z cooridnates for given chain
253
254 Based on positions that are extracted with GetPos
255
256 :param chain_name: Chain in :attr:`~view`
257 :type chain_name: :class:`str`
258 """
259 if chain_name not in self._max_pos:
260 self._max_pos[chain_name] = self.GetPos(chain_name).max(0)
261 return self._max_pos[chain_name]
262
263 def PotentialInteraction(self, chain_name_one, chain_name_two,
264 slack=0.0):
265 """ Returns True if chains potentially interact
266
267 Based on crude collision detection. There is no guarantee
268 that they actually interact if True is returned. However,
269 if False is returned, they don't interact for sure.
270
271 :param chain_name_one: Chain in :attr:`~view`
272 :type chain_name_one: class:`str`
273 :param chain_name_two: Chain in :attr:`~view`
274 :type chain_name_two: class:`str`
275 :param slack: Add slack to interaction distance threshold
276 :type slack: :class:`float`
277 """
278 min_one = self.GetMinPos(chain_name_one)
279 max_one = self.GetMaxPos(chain_name_one)
280 min_two = self.GetMinPos(chain_name_two)
281 max_two = self.GetMaxPos(chain_name_two)
282 if np.max(min_one - max_two) > (self.dist_thresh + slack):
283 return False
284 if np.max(min_two - max_one) > (self.dist_thresh + slack):
285 return False
286 return True
287
288
290 """ Helper object to compute Backbone only lDDT score
291
292 Tightly integrated into the mechanisms from the chain_mapping module.
293 The prefered way to derive an object of type :class:`BBlDDTScorer` is
294 through the static constructor: :func:`~FromMappingResult`.
295
296 lDDT computation in :func:`BBlDDTScorer.Score` implements caching.
297 Repeated computations with alternative chain mappings thus become faster.
298
299 :param target: Structure designated as "target". Can be fetched from
300 :class:`ost.mol.alg.chain_mapping.MappingResult`
301 :type target: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
302 :param chem_groups: Groups of chemically equivalent chains in *target*.
303 Can be fetched from
304 :class:`ost.mol.alg.chain_mapping.MappingResult`
305 :type chem_groups: :class:`list` of :class:`list` of :class:`str`
306 :param model: Structure designated as "model". Can be fetched from
307 :class:`ost.mol.alg.chain_mapping.MappingResult`
308 :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
309 :param alns: Each alignment is accessible with ``alns[(t_chain,m_chain)]``.
310 First sequence is the sequence of the respective chain in
311 :attr:`~qsent1`, second sequence the one from :attr:`~qsent2`.
312 Can be fetched from
313 :class:`ost.mol.alg.chain_mapping.MappingResult`
314 :type alns: :class:`dict` with key: :class:`tuple` of :class:`str`, value:
315 :class:`ost.seq.AlignmentHandle`
316 :param dist_thresh: Max distance of a pairwise interaction in target
317 to be considered as contact in lDDT
318 :type dist_thresh: :class:`float`
319 :param dist_diff_thresholds: Distance difference thresholds for
320 lDDT computations
321 :type dist_diff_thresholds: :class:`list` of :class:`float`
322 """
323 def __init__(self, target, chem_groups, model, alns, dist_thresh = 15.0,
324 dist_diff_thresholds = [0.5, 1.0, 2.0, 4.0]):
325
326 self._trg = BBlDDTEntity(target, dist_thresh = dist_thresh,
327 dist_diff_thresholds=dist_diff_thresholds)
328
329 # ensure that target chain names match the ones in chem_groups
330 chem_group_ch_names = list(itertools.chain.from_iterable(chem_groups))
331 if self._trg.chain_names != sorted(chem_group_ch_names):
332 raise RuntimeError(f"Expect exact same chain names in chem_groups "
333 f"and in target (which is processed to only "
334 f"contain peptides/nucleotides). target: "
335 f"{self._trg.chain_names}, chem_groups: "
336 f"{chem_group_ch_names}")
337
338 self._chem_groups = chem_groups
339 self._mdl = BBlDDTEntity(model, dist_thresh = dist_thresh,
340 dist_diff_thresholds=dist_diff_thresholds)
341 self._alns = alns
342 self._dist_diff_thresholds = dist_diff_thresholds
343 self._dist_thresh = dist_thresh
344
345 # cache for mapped interface scores
346 # key: tuple of tuple ((trg_ch1, trg_ch2),
347 # ((mdl_ch1, mdl_ch2))
348 # where the first tuple is lexicographically sorted
349 # the second tuple is mapping dependent
350 # value: numpy array of len(dist_thresholds) representing the
351 # respective numbers of fulfilled contacts
352 self._pairwise_cache = dict()
353
354 # cache for mapped single chain scores
355 # key: tuple (trg_ch, mdl_ch)
356 # value: numpy array of len(dist_thresholds) representing the
357 # respective numbers of fulfilled contacts
358 self._sc_cache = dict()
359
360 @staticmethod
361 def FromMappingResult(mapping_result, dist_thresh = 15.0,
362 dist_diff_thresholds = [0.5, 1.0, 2.0, 4.0]):
363 """ The preferred way to get a :clas:`BBlDDTScorer`
364
365 Static constructor that derives an object of type :class:`QSScorer`
366 using a :class:`ost.mol.alg.chain_mapping.MappingResult`
367
368 :param mapping_result: Data source
369 :type mapping_result: :class:`ost.mol.alg.chain_mapping.MappingResult`
370 :param dist_thresh: The lDDT distance threshold
371 :type dist_thresh: :class:`float`
372 :param dist_diff_thresholds: The lDDT distance difference thresholds
373 :type dist_diff_thresholds: :class:`list` of :class:`float`
374 """
375 scorer = BBlDDTScorer(mapping_result.target, mapping_result.chem_groups,
376 mapping_result.model, alns = mapping_result.alns,
377 dist_thresh = dist_thresh,
378 dist_diff_thresholds = dist_diff_thresholds)
379 return scorer
380
381 @property
382 def trg(self):
383 """ The :class:`BBlDDTEntity` representing target
384
385 :type: :class:`BBlDDTEntity`
386 """
387 return self._trg
388
389 @property
390 def mdl(self):
391 """ The :class:`BBlDDTEntity` representing model
392
393 :type: :class:`BBlDDTEntity`
394 """
395 return self._mdl
396
397 @property
398 def alns(self):
399 """ Alignments between chains in :attr:`~trg` and :attr:`~mdl`
400
401 Provided at object construction. Each alignment is accessible with
402 ``alns[(t_chain,m_chain)]``. First sequence is the sequence of the
403 respective chain in :attr:`~trg`, second sequence the one from
404 :attr:`~mdl`.
405
406 :type: :class:`dict` with key: :class:`tuple` of :class:`str`, value:
407 :class:`ost.seq.AlignmentHandle`
408 """
409 return self._alns
410
411 @property
412 def chem_groups(self):
413 """ Groups of chemically equivalent chains in :attr:`~trg`
414
415 Provided at object construction
416
417 :type: :class:`list` of :class:`list` of :class:`str`
418 """
419 return self._chem_groups
420
421 def Score(self, mapping, check=True):
422 """ Computes Backbone lDDT given chain mapping
423
424 Again, the preferred way is to get *mapping* is from an object
425 of type :class:`ost.mol.alg.chain_mapping.MappingResult`.
426
427 :param mapping: see
428 :attr:`ost.mol.alg.chain_mapping.MappingResult.mapping`
429 :type mapping: :class:`list` of :class:`list` of :class:`str`
430 :param check: Perform input checks, can be disabled for speed purposes
431 if you know what you're doing.
432 :type check: :class:`bool`
433 :returns: The score
434 """
435 if check:
436 # ensure that dimensionality of mapping matches self.chem_groups
437 if len(self.chem_groupschem_groups) != len(mapping):
438 raise RuntimeError("Dimensions of self.chem_groups and mapping "
439 "must match")
440 for a,b in zip(self.chem_groupschem_groups, mapping):
441 if len(a) != len(b):
442 raise RuntimeError("Dimensions of self.chem_groups and "
443 "mapping must match")
444 # ensure that chain names in mapping are all present in qsent2
445 for name in itertools.chain.from_iterable(mapping):
446 if name is not None and name not in self.mdl.chain_names:
447 raise RuntimeError(f"Each chain in mapping must be present "
448 f"in self.mdl. No match for "
449 f"\"{name}\"")
450
451 flat_mapping = dict()
452 for a, b in zip(self.chem_groupschem_groups, mapping):
453 flat_mapping.update({x: y for x, y in zip(a, b) if y is not None})
454
455 return self.FromFlatMapping(flat_mapping)
456
457 def FromFlatMapping(self, flat_mapping):
458 """ Same as :func:`Score` but with flat mapping
459
460 :param flat_mapping: Dictionary with target chain names as keys and
461 the mapped model chain names as value
462 :type flat_mapping: :class:`dict` with :class:`str` as key and value
463 :returns: :class:`float` representing lDDT
464 """
465 n_conserved = np.zeros(len(self._dist_diff_thresholds), dtype=np.int32)
466
467 # process single chains
468 for cname in self.trg.chain_names:
469 if cname in flat_mapping:
470 n_conserved += self._NSCConserved(cname, flat_mapping[cname])
471
472 # process interfaces
473 for interface in self.trg.interacting_chains:
474 if interface[0] in flat_mapping and interface[1] in flat_mapping:
475 mdl_interface = (flat_mapping[interface[0]],
476 flat_mapping[interface[1]])
477 n_conserved += self._NPairConserved(interface, mdl_interface)
478
479 return np.mean(n_conserved / self.trg.n_contacts)
480
481 def _NSCConserved(self, trg_ch, mdl_ch):
482 if (trg_ch, mdl_ch) in self._sc_cache:
483 return self._sc_cache[(trg_ch, mdl_ch)]
484 trg_dist = self.trg.Dist(trg_ch)
485 mdl_dist = self.mdl.Dist(mdl_ch)
486 trg_indices, mdl_indices = self._IndexMapping(trg_ch, mdl_ch)
487 trg_dist = trg_dist[np.ix_(trg_indices, trg_indices)]
488 mdl_dist = mdl_dist[np.ix_(mdl_indices, mdl_indices)]
489 # mask to select relevant distances (dist in trg < dist_thresh)
490 # np.triu zeroes the values below the diagonal
491 mask = np.triu(trg_dist < self._dist_thresh)
492 n_diag = trg_dist.shape[0]
493 trg_dist = trg_dist[mask]
494 mdl_dist = mdl_dist[mask]
495 dist_diffs = np.absolute(trg_dist - mdl_dist)
496 n_conserved = np.zeros(len(self._dist_diff_thresholds), dtype=np.int32)
497 for thresh_idx, thresh in enumerate(self._dist_diff_thresholds):
498 N = (dist_diffs < thresh).sum()
499 # still need to consider the 0.0 dist diffs on the diagonal
500 n_conserved[thresh_idx] = int((N - n_diag))
501 self._sc_cache[(trg_ch, mdl_ch)] = n_conserved
502 return n_conserved
503
504 def _NPairConserved(self, trg_int, mdl_int):
505 key_one = (trg_int, mdl_int)
506 if key_one in self._pairwise_cache:
507 return self._pairwise_cache[key_one]
508 key_two = ((trg_int[1], trg_int[0]), (mdl_int[1], mdl_int[0]))
509 if key_two in self._pairwise_cache:
510 return self._pairwise_cache[key_two]
511 trg_dist = self.trg.PairDist(trg_int[0], trg_int[1])
512 mdl_dist = self.mdl.PairDist(mdl_int[0], mdl_int[1])
513 if trg_int[0] > trg_int[1]:
514 trg_dist = trg_dist.transpose()
515 if mdl_int[0] > mdl_int[1]:
516 mdl_dist = mdl_dist.transpose()
517 trg_indices_1, mdl_indices_1 = self._IndexMapping(trg_int[0], mdl_int[0])
518 trg_indices_2, mdl_indices_2 = self._IndexMapping(trg_int[1], mdl_int[1])
519 trg_dist = trg_dist[np.ix_(trg_indices_1, trg_indices_2)]
520 mdl_dist = mdl_dist[np.ix_(mdl_indices_1, mdl_indices_2)]
521 # reduce to relevant distances (dist in trg < dist_thresh)
522 mask = trg_dist < self._dist_thresh
523 trg_dist = trg_dist[mask]
524 mdl_dist = mdl_dist[mask]
525 dist_diffs = np.absolute(trg_dist - mdl_dist)
526 n_conserved = np.zeros(len(self._dist_diff_thresholds), dtype=np.int32)
527 for thresh_idx, thresh in enumerate(self._dist_diff_thresholds):
528 n_conserved[thresh_idx] = (dist_diffs < thresh).sum()
529 self._pairwise_cache[key_one] = n_conserved
530 return n_conserved
531
532 def _IndexMapping(self, ch1, ch2):
533 """ Fetches aln and returns indices of aligned residues
534
535 returns 2 numpy arrays containing the indices of residues in
536 ch1 and ch2 which are aligned
537 """
538 mapped_indices_1 = list()
539 mapped_indices_2 = list()
540 idx_1 = 0
541 idx_2 = 0
542 for col in self.alns[(ch1, ch2)]:
543 if col[0] != '-' and col[1] != '-':
544 mapped_indices_1.append(idx_1)
545 mapped_indices_2.append(idx_2)
546 if col[0] != '-':
547 idx_1 +=1
548 if col[1] != '-':
549 idx_2 +=1
550 return (np.array(mapped_indices_1), np.array(mapped_indices_2))
__init__(self, ent, dist_thresh=15.0, dist_diff_thresholds=[0.5, 1.0, 2.0, 4.0])
Definition bb_lddt.py:21
GetSequence(self, chain_name)
Definition bb_lddt.py:180
PairDist(self, chain_name_one, chain_name_two)
Definition bb_lddt.py:224
PotentialInteraction(self, chain_name_one, chain_name_two, slack=0.0)
Definition bb_lddt.py:264
FromFlatMapping(self, flat_mapping)
Definition bb_lddt.py:457
FromMappingResult(mapping_result, dist_thresh=15.0, dist_diff_thresholds=[0.5, 1.0, 2.0, 4.0])
Definition bb_lddt.py:362
_NPairConserved(self, trg_int, mdl_int)
Definition bb_lddt.py:504
__init__(self, target, chem_groups, model, alns, dist_thresh=15.0, dist_diff_thresholds=[0.5, 1.0, 2.0, 4.0])
Definition bb_lddt.py:324
_NSCConserved(self, trg_ch, mdl_ch)
Definition bb_lddt.py:481
Score(self, mapping, check=True)
Definition bb_lddt.py:421