comparison change_o/DefineClones.py @ 0:8a5a2abbb870 draft default tip

Uploaded
author davidvanzessen
date Mon, 29 Aug 2016 05:36:10 -0400
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:8a5a2abbb870
1 #!/usr/bin/env python3
2 """
3 Assign Ig sequences into clones
4 """
5 # Info
6 __author__ = 'Namita Gupta, Jason Anthony Vander Heiden, Gur Yaari, Mohamed Uduman'
7 from changeo import __version__, __date__
8
9 # Imports
10 import os
11 import re
12 import sys
13 import numpy as np
14 from argparse import ArgumentParser
15 from collections import OrderedDict
16 from itertools import chain
17 from textwrap import dedent
18 from time import time
19 from Bio import pairwise2
20 from Bio.Seq import translate
21
22 # Presto and changeo imports
23 from presto.Defaults import default_out_args
24 from presto.IO import getFileType, getOutputHandle, printLog, printProgress
25 from presto.Multiprocessing import manageProcesses
26 from presto.Sequence import getDNAScoreDict
27 from changeo.Commandline import CommonHelpFormatter, getCommonArgParser, parseCommonArgs
28 from changeo.Distance import getDNADistMatrix, getAADistMatrix, \
29 hs1f_model, m1n_model, hs5f_model, \
30 calcDistances, formClusters
31 from changeo.IO import getDbWriter, readDbFile, countDbFile
32 from changeo.Multiprocessing import DbData, DbResult
33
34 # Defaults
35 default_translate = False
36 default_distance = 0.0
37 default_bygroup_model = 'hs1f'
38 default_hclust_model = 'chen2010'
39 default_seq_field = 'JUNCTION'
40 default_norm = 'len'
41 default_sym = 'avg'
42 default_linkage = 'single'
43
44 # TODO: should be in Distance, but need to be after function definitions
45 # Amino acid Hamming distance
46 aa_model = getAADistMatrix(mask_dist=1, gap_dist=0)
47
48 # DNA Hamming distance
49 ham_model = getDNADistMatrix(mask_dist=0, gap_dist=0)
50
51
52 # TODO: this function is an abstraction to facilitate later cleanup
53 def getModelMatrix(model):
54 """
55 Simple wrapper to get distance matrix from model name
56
57 Arguments:
58 model = model name
59
60 Return:
61 a pandas.DataFrame containing the character distance matrix
62 """
63 if model == 'aa':
64 return(aa_model)
65 elif model == 'ham':
66 return(ham_model)
67 elif model == 'm1n':
68 return(m1n_model)
69 elif model == 'hs1f':
70 return(hs1f_model)
71 elif model == 'hs5f':
72 return(hs5f_model)
73 else:
74 sys.stderr.write('Unrecognized distance model: %s.\n' % model)
75
76
77 def indexJunctions(db_iter, fields=None, mode='gene', action='first'):
78 """
79 Identifies preclonal groups by V, J and junction length
80
81 Arguments:
82 db_iter = an iterator of IgRecords defined by readDbFile
83 fields = additional annotation fields to use to group preclones;
84 if None use only V, J and junction length
85 mode = specificity of alignment call to use for assigning preclones;
86 one of ('allele', 'gene')
87 action = how to handle multiple value fields when assigning preclones;
88 one of ('first', 'set')
89
90 Returns:
91 a dictionary of {(V, J, junction length):[IgRecords]}
92 """
93 # Define functions for grouping keys
94 if mode == 'allele' and fields is None:
95 def _get_key(rec, act):
96 return (rec.getVAllele(act), rec.getJAllele(act),
97 None if rec.junction is None else len(rec.junction))
98 elif mode == 'gene' and fields is None:
99 def _get_key(rec, act):
100 return (rec.getVGene(act), rec.getJGene(act),
101 None if rec.junction is None else len(rec.junction))
102 elif mode == 'allele' and fields is not None:
103 def _get_key(rec, act):
104 vdj = [rec.getVAllele(act), rec.getJAllele(act),
105 None if rec.junction is None else len(rec.junction)]
106 ann = [rec.toDict().get(k, None) for k in fields]
107 return tuple(chain(vdj, ann))
108 elif mode == 'gene' and fields is not None:
109 def _get_key(rec, act):
110 vdj = [rec.getVGene(act), rec.getJGene(act),
111 None if rec.junction is None else len(rec.junction)]
112 ann = [rec.toDict().get(k, None) for k in fields]
113 return tuple(chain(vdj, ann))
114
115 start_time = time()
116 clone_index = {}
117 rec_count = 0
118 for rec in db_iter:
119 key = _get_key(rec, action)
120
121 # Print progress
122 if rec_count == 0:
123 print('PROGRESS> Grouping sequences')
124
125 printProgress(rec_count, step=1000, start_time=start_time)
126 rec_count += 1
127
128 # Assigned passed preclone records to key and failed to index None
129 if all([k is not None and k != '' for k in key]):
130 #print key
131 # TODO: Has much slow. Should have less slow.
132 if action == 'set':
133
134 f_range = list(range(2, 3 + (len(fields) if fields else 0)))
135 vdj_range = list(range(2))
136
137 # Check for any keys that have matching columns and junction length and overlapping genes/alleles
138 to_remove = []
139 if len(clone_index) > (1 if None in clone_index else 0) and key not in clone_index:
140 key = list(key)
141 for k in clone_index:
142 if k is not None and all([key[i] == k[i] for i in f_range]):
143 if all([not set(key[i]).isdisjoint(set(k[i])) for i in vdj_range]):
144 for i in vdj_range: key[i] = tuple(set(key[i]).union(set(k[i])))
145 to_remove.append(k)
146
147 # Remove original keys, replace with union of all genes/alleles and append values to new key
148 val = [rec]
149 val += list(chain(*(clone_index.pop(k) for k in to_remove)))
150 clone_index[tuple(key)] = clone_index.get(tuple(key),[]) + val
151
152 elif action == 'first':
153 clone_index.setdefault(key, []).append(rec)
154 else:
155 clone_index.setdefault(None, []).append(rec)
156
157 printProgress(rec_count, step=1000, start_time=start_time, end=True)
158
159 return clone_index
160
161
162 def distanceClones(records, model=default_bygroup_model, distance=default_distance,
163 dist_mat=None, norm=default_norm, sym=default_sym,
164 linkage=default_linkage, seq_field=default_seq_field):
165 """
166 Separates a set of IgRecords into clones
167
168 Arguments:
169 records = an iterator of IgRecords
170 model = substitution model used to calculate distance
171 distance = the distance threshold to assign clonal groups
172 dist_mat = pandas DataFrame of pairwise nucleotide or amino acid distances
173 norm = normalization method
174 sym = symmetry method
175 linkage = type of linkage
176 seq_field = sequence field used to calculate distance between records
177
178 Returns:
179 a dictionary of lists defining {clone number: [IgRecords clonal group]}
180 """
181 # Get distance matrix if not provided
182 if dist_mat is None: dist_mat = getModelMatrix(model)
183
184 # Determine length of n-mers
185 if model in ['hs1f', 'm1n', 'aa', 'ham']:
186 nmer_len = 1
187 elif model in ['hs5f']:
188 nmer_len = 5
189 else:
190 sys.stderr.write('Unrecognized distance model: %s.\n' % model)
191
192 # Define unique junction mapping
193 seq_map = {}
194 for ig in records:
195 seq = ig.getSeqField(seq_field)
196 # Check if sequence length is 0
197 if len(seq) == 0:
198 return None
199
200 seq = re.sub('[\.-]','N', str(seq))
201 if model == 'aa': seq = translate(seq)
202
203 seq_map.setdefault(seq, []).append(ig)
204
205 # Process records
206 if len(seq_map) == 1:
207 return {1:records}
208
209 # Define sequences
210 seqs = list(seq_map.keys())
211
212 # Calculate pairwise distance matrix
213 dists = calcDistances(seqs, nmer_len, dist_mat, norm, sym)
214
215 # Perform hierarchical clustering
216 clusters = formClusters(dists, linkage, distance)
217
218 # Turn clusters into clone dictionary
219 clone_dict = {}
220 for i, c in enumerate(clusters):
221 clone_dict.setdefault(c, []).extend(seq_map[seqs[i]])
222
223 return clone_dict
224
225
226 def distChen2010(records):
227 """
228 Calculate pairwise distances as defined in Chen 2010
229
230 Arguments:
231 records = list of IgRecords where first is query to be compared to others in list
232
233 Returns:
234 list of distances
235 """
236 # Pull out query sequence and V/J information
237 query = records.popitem(last=False)
238 query_cdr3 = query.junction[3:-3]
239 query_v_allele = query.getVAllele()
240 query_v_gene = query.getVGene()
241 query_v_family = query.getVFamily()
242 query_j_allele = query.getJAllele()
243 query_j_gene = query.getJGene()
244 # Create alignment scoring dictionary
245 score_dict = getDNAScoreDict()
246
247 scores = [0]*len(records)
248 for i in range(len(records)):
249 ld = pairwise2.align.globalds(query_cdr3, records[i].junction[3:-3],
250 score_dict, -1, -1, one_alignment_only=True)
251 # Check V similarity
252 if records[i].getVAllele() == query_v_allele: ld += 0
253 elif records[i].getVGene() == query_v_gene: ld += 1
254 elif records[i].getVFamily() == query_v_family: ld += 3
255 else: ld += 5
256 # Check J similarity
257 if records[i].getJAllele() == query_j_allele: ld += 0
258 elif records[i].getJGene() == query_j_gene: ld += 1
259 else: ld += 3
260 # Divide by length
261 scores[i] = ld/max(len(records[i].junction[3:-3]), query_cdr3)
262
263 return scores
264
265
266 def distAdemokun2011(records):
267 """
268 Calculate pairwise distances as defined in Ademokun 2011
269
270 Arguments:
271 records = list of IgRecords where first is query to be compared to others in list
272
273 Returns:
274 list of distances
275 """
276 # Pull out query sequence and V family information
277 query = records.popitem(last=False)
278 query_cdr3 = query.junction[3:-3]
279 query_v_family = query.getVFamily()
280 # Create alignment scoring dictionary
281 score_dict = getDNAScoreDict()
282
283 scores = [0]*len(records)
284 for i in range(len(records)):
285
286 if abs(len(query_cdr3) - len(records[i].junction[3:-3])) > 10:
287 scores[i] = 1
288 elif query_v_family != records[i].getVFamily():
289 scores[i] = 1
290 else:
291 ld = pairwise2.align.globalds(query_cdr3, records[i].junction[3:-3],
292 score_dict, -1, -1, one_alignment_only=True)
293 scores[i] = ld/min(len(records[i].junction[3:-3]), query_cdr3)
294
295 return scores
296
297
298 def hierClust(dist_mat, method='chen2010'):
299 """
300 Calculate hierarchical clustering
301
302 Arguments:
303 dist_mat = square-formed distance matrix of pairwise CDR3 comparisons
304
305 Returns:
306 list of cluster ids
307 """
308 if method == 'chen2010':
309 clusters = formClusters(dist_mat, 'average', 0.32)
310 elif method == 'ademokun2011':
311 clusters = formClusters(dist_mat, 'complete', 0.25)
312 else: clusters = np.ones(dist_mat.shape[0])
313
314 return clusters
315
316 # TODO: Merge duplicate feed, process and collect functions.
317 def feedQueue(alive, data_queue, db_file, group_func, group_args={}):
318 """
319 Feeds the data queue with Ig records
320
321 Arguments:
322 alive = a multiprocessing.Value boolean controlling whether processing continues
323 if False exit process
324 data_queue = a multiprocessing.Queue to hold data for processing
325 db_file = the Ig record database file
326 group_func = the function to use for assigning preclones
327 group_args = a dictionary of arguments to pass to group_func
328
329 Returns:
330 None
331 """
332 # Open input file and perform grouping
333 try:
334 # Iterate over Ig records and assign groups
335 db_iter = readDbFile(db_file)
336 clone_dict = group_func(db_iter, **group_args)
337 except:
338 #sys.stderr.write('Exception in feeder grouping step\n')
339 alive.value = False
340 raise
341
342 # Add groups to data queue
343 try:
344 #print 'START FEED', alive.value
345 # Iterate over groups and feed data queue
346 clone_iter = iter(clone_dict.items())
347 while alive.value:
348 # Get data from queue
349 if data_queue.full(): continue
350 else: data = next(clone_iter, None)
351 # Exit upon reaching end of iterator
352 if data is None: break
353 #print "FEED", alive.value, k
354
355 # Feed queue
356 data_queue.put(DbData(*data))
357 else:
358 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
359 % os.getpid())
360 return None
361 except:
362 #sys.stderr.write('Exception in feeder queue feeding step\n')
363 alive.value = False
364 raise
365
366 return None
367
368
369 def feedQueueClust(alive, data_queue, db_file, group_func=None, group_args={}):
370 """
371 Feeds the data queue with Ig records
372
373 Arguments:
374 alive = a multiprocessing.Value boolean controlling whether processing continues
375 if False exit process
376 data_queue = a multiprocessing.Queue to hold data for processing
377 db_file = the Ig record database file
378
379 Returns:
380 None
381 """
382 # Open input file and perform grouping
383 try:
384 # Iterate over Ig records and order by junction length
385 records = {}
386 db_iter = readDbFile(db_file)
387 for rec in db_iter:
388 records[rec.id] = rec
389 records = OrderedDict(sorted(list(records.items()), key=lambda i: i[1].junction_length))
390 dist_dict = {}
391 for __ in range(len(records)):
392 k,v = records.popitem(last=False)
393 dist_dict[k] = [v].append(list(records.values()))
394 except:
395 #sys.stderr.write('Exception in feeder grouping step\n')
396 alive.value = False
397 raise
398
399 # Add groups to data queue
400 try:
401 # print 'START FEED', alive.value
402 # Iterate over groups and feed data queue
403 dist_iter = iter(dist_dict.items())
404 while alive.value:
405 # Get data from queue
406 if data_queue.full(): continue
407 else: data = next(dist_iter, None)
408 # Exit upon reaching end of iterator
409 if data is None: break
410 #print "FEED", alive.value, k
411
412 # Feed queue
413 data_queue.put(DbData(*data))
414 else:
415 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
416 % os.getpid())
417 return None
418 except:
419 #sys.stderr.write('Exception in feeder queue feeding step\n')
420 alive.value = False
421 raise
422
423 return None
424
425
426 def processQueue(alive, data_queue, result_queue, clone_func, clone_args):
427 """
428 Pulls from data queue, performs calculations, and feeds results queue
429
430 Arguments:
431 alive = a multiprocessing.Value boolean controlling whether processing continues
432 if False exit process
433 data_queue = a multiprocessing.Queue holding data to process
434 result_queue = a multiprocessing.Queue to hold processed results
435 clone_func = the function to call for clonal assignment
436 clone_args = a dictionary of arguments to pass to clone_func
437
438 Returns:
439 None
440 """
441 try:
442 # Iterator over data queue until sentinel object reached
443 while alive.value:
444 # Get data from queue
445 if data_queue.empty(): continue
446 else: data = data_queue.get()
447 # Exit upon reaching sentinel
448 if data is None: break
449
450 # Define result object for iteration and get data records
451 records = data.data
452 result = DbResult(data.id, records)
453
454 # Check for invalid data (due to failed indexing) and add failed result
455 if not data:
456 result_queue.put(result)
457 continue
458
459 # Add V(D)J to log
460 result.log['ID'] = ','.join([str(x) for x in data.id])
461 result.log['VALLELE'] = ','.join(set([(r.getVAllele() or '') for r in records]))
462 result.log['DALLELE'] = ','.join(set([(r.getDAllele() or '') for r in records]))
463 result.log['JALLELE'] = ','.join(set([(r.getJAllele() or '') for r in records]))
464 result.log['JUNCLEN'] = ','.join(set([(str(len(r.junction)) or '0') for r in records]))
465 result.log['SEQUENCES'] = len(records)
466
467 # Checking for preclone failure and assign clones
468 clones = clone_func(records, **clone_args) if data else None
469
470 # import cProfile
471 # prof = cProfile.Profile()
472 # clones = prof.runcall(clone_func, records, **clone_args)
473 # prof.dump_stats('worker-%d.prof' % os.getpid())
474
475 if clones is not None:
476 result.results = clones
477 result.valid = True
478 result.log['CLONES'] = len(clones)
479 else:
480 result.log['CLONES'] = 0
481
482 # Feed results to result queue
483 result_queue.put(result)
484 else:
485 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
486 % os.getpid())
487 return None
488 except:
489 #sys.stderr.write('Exception in worker\n')
490 alive.value = False
491 raise
492
493 return None
494
495
496 def processQueueClust(alive, data_queue, result_queue, clone_func, clone_args):
497 """
498 Pulls from data queue, performs calculations, and feeds results queue
499
500 Arguments:
501 alive = a multiprocessing.Value boolean controlling whether processing continues
502 if False exit process
503 data_queue = a multiprocessing.Queue holding data to process
504 result_queue = a multiprocessing.Queue to hold processed results
505 clone_func = the function to call for calculating pairwise distances between sequences
506 clone_args = a dictionary of arguments to pass to clone_func
507
508 Returns:
509 None
510 """
511
512 try:
513 # print 'START WORK', alive.value
514 # Iterator over data queue until sentinel object reached
515 while alive.value:
516 # Get data from queue
517 if data_queue.empty(): continue
518 else: data = data_queue.get()
519 # Exit upon reaching sentinel
520 if data is None: break
521 # print "WORK", alive.value, data['id']
522
523 # Define result object for iteration and get data records
524 records = data.data
525 result = DbResult(data.id, records)
526
527 # Create row of distance matrix and check for error
528 dist_row = clone_func(records, **clone_args) if data else None
529 if dist_row is not None:
530 result.results = dist_row
531 result.valid = True
532
533 # Feed results to result queue
534 result_queue.put(result)
535 else:
536 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
537 % os.getpid())
538 return None
539 except:
540 #sys.stderr.write('Exception in worker\n')
541 alive.value = False
542 raise
543
544 return None
545
546
547 def collectQueue(alive, result_queue, collect_queue, db_file, out_args, cluster_func=None, cluster_args={}):
548 """
549 Assembles results from a queue of individual sequence results and manages log/file I/O
550
551 Arguments:
552 alive = a multiprocessing.Value boolean controlling whether processing continues
553 if False exit process
554 result_queue = a multiprocessing.Queue holding processQueue results
555 collect_queue = a multiprocessing.Queue to store collector return values
556 db_file = the input database file name
557 out_args = common output argument dictionary from parseCommonArgs
558 cluster_func = the function to call for carrying out clustering on distance matrix
559 cluster_args = a dictionary of arguments to pass to cluster_func
560
561 Returns:
562 None
563 (adds 'log' and 'out_files' to collect_dict)
564 """
565 # Open output files
566 try:
567 # Count records and define output format
568 out_type = getFileType(db_file) if out_args['out_type'] is None \
569 else out_args['out_type']
570 result_count = countDbFile(db_file)
571
572 # Defined successful output handle
573 pass_handle = getOutputHandle(db_file,
574 out_label='clone-pass',
575 out_dir=out_args['out_dir'],
576 out_name=out_args['out_name'],
577 out_type=out_type)
578 pass_writer = getDbWriter(pass_handle, db_file, add_fields='CLONE')
579
580 # Defined failed alignment output handle
581 if out_args['failed']:
582 fail_handle = getOutputHandle(db_file,
583 out_label='clone-fail',
584 out_dir=out_args['out_dir'],
585 out_name=out_args['out_name'],
586 out_type=out_type)
587 fail_writer = getDbWriter(fail_handle, db_file)
588 else:
589 fail_handle = None
590 fail_writer = None
591
592 # Define log handle
593 if out_args['log_file'] is None:
594 log_handle = None
595 else:
596 log_handle = open(out_args['log_file'], 'w')
597 except:
598 #sys.stderr.write('Exception in collector file opening step\n')
599 alive.value = False
600 raise
601
602 # Get results from queue and write to files
603 try:
604 #print 'START COLLECT', alive.value
605 # Iterator over results queue until sentinel object reached
606 start_time = time()
607 rec_count = clone_count = pass_count = fail_count = 0
608 while alive.value:
609 # Get result from queue
610 if result_queue.empty(): continue
611 else: result = result_queue.get()
612 # Exit upon reaching sentinel
613 if result is None: break
614 #print "COLLECT", alive.value, result['id']
615
616 # Print progress for previous iteration and update record count
617 if rec_count == 0:
618 print('PROGRESS> Assigning clones')
619 printProgress(rec_count, result_count, 0.05, start_time)
620 rec_count += len(result.data)
621
622 # Write passed and failed records
623 if result:
624 for clone in result.results.values():
625 clone_count += 1
626 for i, rec in enumerate(clone):
627 rec.annotations['CLONE'] = clone_count
628 pass_writer.writerow(rec.toDict())
629 pass_count += 1
630 result.log['CLONE%i-%i' % (clone_count, i + 1)] = str(rec.junction)
631
632 else:
633 for i, rec in enumerate(result.data):
634 if fail_writer is not None: fail_writer.writerow(rec.toDict())
635 fail_count += 1
636 result.log['CLONE0-%i' % (i + 1)] = str(rec.junction)
637
638 # Write log
639 printLog(result.log, handle=log_handle)
640 else:
641 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
642 % os.getpid())
643 return None
644
645 # Print total counts
646 printProgress(rec_count, result_count, 0.05, start_time)
647
648 # Close file handles
649 pass_handle.close()
650 if fail_handle is not None: fail_handle.close()
651 if log_handle is not None: log_handle.close()
652
653 # Update return list
654 log = OrderedDict()
655 log['OUTPUT'] = os.path.basename(pass_handle.name)
656 log['CLONES'] = clone_count
657 log['RECORDS'] = rec_count
658 log['PASS'] = pass_count
659 log['FAIL'] = fail_count
660 collect_dict = {'log':log, 'out_files': [pass_handle.name]}
661 collect_queue.put(collect_dict)
662 except:
663 #sys.stderr.write('Exception in collector result processing step\n')
664 alive.value = False
665 raise
666
667 return None
668
669
670 def collectQueueClust(alive, result_queue, collect_queue, db_file, out_args, cluster_func, cluster_args):
671 """
672 Assembles results from a queue of individual sequence results and manages log/file I/O
673
674 Arguments:
675 alive = a multiprocessing.Value boolean controlling whether processing continues
676 if False exit process
677 result_queue = a multiprocessing.Queue holding processQueue results
678 collect_queue = a multiprocessing.Queue to store collector return values
679 db_file = the input database file name
680 out_args = common output argument dictionary from parseCommonArgs
681 cluster_func = the function to call for carrying out clustering on distance matrix
682 cluster_args = a dictionary of arguments to pass to cluster_func
683
684 Returns:
685 None
686 (adds 'log' and 'out_files' to collect_dict)
687 """
688 # Open output files
689 try:
690
691 # Iterate over Ig records to count and order by junction length
692 result_count = 0
693 records = {}
694 # print 'Reading file...'
695 db_iter = readDbFile(db_file)
696 for rec in db_iter:
697 records[rec.id] = rec
698 result_count += 1
699 records = OrderedDict(sorted(list(records.items()), key=lambda i: i[1].junction_length))
700
701 # Define empty matrix to store assembled results
702 dist_mat = np.zeros((result_count,result_count))
703
704 # Count records and define output format
705 out_type = getFileType(db_file) if out_args['out_type'] is None \
706 else out_args['out_type']
707
708 # Defined successful output handle
709 pass_handle = getOutputHandle(db_file,
710 out_label='clone-pass',
711 out_dir=out_args['out_dir'],
712 out_name=out_args['out_name'],
713 out_type=out_type)
714 pass_writer = getDbWriter(pass_handle, db_file, add_fields='CLONE')
715
716 # Defined failed cloning output handle
717 if out_args['failed']:
718 fail_handle = getOutputHandle(db_file,
719 out_label='clone-fail',
720 out_dir=out_args['out_dir'],
721 out_name=out_args['out_name'],
722 out_type=out_type)
723 fail_writer = getDbWriter(fail_handle, db_file)
724 else:
725 fail_handle = None
726 fail_writer = None
727
728 # Open log file
729 if out_args['log_file'] is None:
730 log_handle = None
731 else:
732 log_handle = open(out_args['log_file'], 'w')
733 except:
734 alive.value = False
735 raise
736
737 try:
738 # Iterator over results queue until sentinel object reached
739 start_time = time()
740 row_count = rec_count = 0
741 while alive.value:
742 # Get result from queue
743 if result_queue.empty(): continue
744 else: result = result_queue.get()
745 # Exit upon reaching sentinel
746 if result is None: break
747
748 # Print progress for previous iteration
749 if row_count == 0:
750 print('PROGRESS> Assigning clones')
751 printProgress(row_count, result_count, 0.05, start_time)
752
753 # Update counts for iteration
754 row_count += 1
755 rec_count += len(result)
756
757 # Add result row to distance matrix
758 if result:
759 dist_mat[list(range(result_count-len(result),result_count)),result_count-len(result)] = result.results
760
761 else:
762 sys.stderr.write('PID %s: Error in sibling process detected. Cleaning up.\n' \
763 % os.getpid())
764 return None
765
766 # Calculate linkage and carry out clustering
767 # print dist_mat
768 clusters = cluster_func(dist_mat, **cluster_args) if dist_mat is not None else None
769 clones = {}
770 # print clusters
771 for i, c in enumerate(clusters):
772 clones.setdefault(c, []).append(records[list(records.keys())[i]])
773
774 # Write passed and failed records
775 clone_count = pass_count = fail_count = 0
776 if clones:
777 for clone in clones.values():
778 clone_count += 1
779 for i, rec in enumerate(clone):
780 rec.annotations['CLONE'] = clone_count
781 pass_writer.writerow(rec.toDict())
782 pass_count += 1
783 #result.log['CLONE%i-%i' % (clone_count, i + 1)] = str(rec.junction)
784
785 else:
786 for i, rec in enumerate(result.data):
787 fail_writer.writerow(rec.toDict())
788 fail_count += 1
789 #result.log['CLONE0-%i' % (i + 1)] = str(rec.junction)
790
791 # Print final progress
792 printProgress(row_count, result_count, 0.05, start_time)
793
794 # Close file handles
795 pass_handle.close()
796 if fail_handle is not None: fail_handle.close()
797 if log_handle is not None: log_handle.close()
798
799 # Update return list
800 log = OrderedDict()
801 log['OUTPUT'] = os.path.basename(pass_handle.name)
802 log['CLONES'] = clone_count
803 log['RECORDS'] = rec_count
804 log['PASS'] = pass_count
805 log['FAIL'] = fail_count
806 collect_dict = {'log':log, 'out_files': [pass_handle.name]}
807 collect_queue.put(collect_dict)
808 except:
809 alive.value = False
810 raise
811
812 return None
813
814
815 def defineClones(db_file, feed_func, work_func, collect_func, clone_func, cluster_func=None,
816 group_func=None, group_args={}, clone_args={}, cluster_args={},
817 out_args=default_out_args, nproc=None, queue_size=None):
818 """
819 Define clonally related sequences
820
821 Arguments:
822 db_file = filename of input database
823 feed_func = the function that feeds the queue
824 work_func = the worker function that will run on each CPU
825 collect_func = the function that collects results from the workers
826 group_func = the function to use for assigning preclones
827 clone_func = the function to use for determining clones within preclonal groups
828 group_args = a dictionary of arguments to pass to group_func
829 clone_args = a dictionary of arguments to pass to clone_func
830 out_args = common output argument dictionary from parseCommonArgs
831 nproc = the number of processQueue processes;
832 if None defaults to the number of CPUs
833 queue_size = maximum size of the argument queue;
834 if None defaults to 2*nproc
835
836 Returns:
837 a list of successful output file names
838 """
839 # Print parameter info
840 log = OrderedDict()
841 log['START'] = 'DefineClones'
842 log['DB_FILE'] = os.path.basename(db_file)
843 if group_func is not None:
844 log['GROUP_FUNC'] = group_func.__name__
845 log['GROUP_ARGS'] = group_args
846 log['CLONE_FUNC'] = clone_func.__name__
847
848 # TODO: this is yucky, but can be fixed by using a model class
849 clone_log = clone_args.copy()
850 if 'dist_mat' in clone_log: del clone_log['dist_mat']
851 log['CLONE_ARGS'] = clone_log
852
853 if cluster_func is not None:
854 log['CLUSTER_FUNC'] = cluster_func.__name__
855 log['CLUSTER_ARGS'] = cluster_args
856 log['NPROC'] = nproc
857 printLog(log)
858
859 # Define feeder function and arguments
860 feed_args = {'db_file': db_file,
861 'group_func': group_func,
862 'group_args': group_args}
863 # Define worker function and arguments
864 work_args = {'clone_func': clone_func,
865 'clone_args': clone_args}
866 # Define collector function and arguments
867 collect_args = {'db_file': db_file,
868 'out_args': out_args,
869 'cluster_func': cluster_func,
870 'cluster_args': cluster_args}
871
872 # Call process manager
873 result = manageProcesses(feed_func, work_func, collect_func,
874 feed_args, work_args, collect_args,
875 nproc, queue_size)
876
877 # Print log
878 result['log']['END'] = 'DefineClones'
879 printLog(result['log'])
880
881 return result['out_files']
882
883
884 def getArgParser():
885 """
886 Defines the ArgumentParser
887
888 Arguments:
889 None
890
891 Returns:
892 an ArgumentParser object
893 """
894 # Define input and output fields
895 fields = dedent(
896 '''
897 output files:
898 clone-pass
899 database with assigned clonal group numbers.
900 clone-fail
901 database with records failing clonal grouping.
902
903 required fields:
904 SEQUENCE_ID, V_CALL or V_CALL_GENOTYPED, D_CALL, J_CALL, JUNCTION_LENGTH
905
906 <field>
907 sequence field specified by the --sf parameter
908
909 output fields:
910 CLONE
911 ''')
912
913 # Define ArgumentParser
914 parser = ArgumentParser(description=__doc__, epilog=fields,
915 formatter_class=CommonHelpFormatter)
916 parser.add_argument('--version', action='version',
917 version='%(prog)s:' + ' %s-%s' %(__version__, __date__))
918 subparsers = parser.add_subparsers(title='subcommands', dest='command', metavar='',
919 help='Cloning method')
920 # TODO: This is a temporary fix for Python issue 9253
921 subparsers.required = True
922
923 # Parent parser
924 parser_parent = getCommonArgParser(seq_in=False, seq_out=False, db_in=True,
925 multiproc=True)
926
927 # Distance cloning method
928 parser_bygroup = subparsers.add_parser('bygroup', parents=[parser_parent],
929 formatter_class=CommonHelpFormatter,
930 help='''Defines clones as having same V assignment,
931 J assignment, and junction length with
932 specified substitution distance model.''')
933 parser_bygroup.add_argument('-f', nargs='+', action='store', dest='fields', default=None,
934 help='Additional fields to use for grouping clones (non VDJ)')
935 parser_bygroup.add_argument('--mode', action='store', dest='mode',
936 choices=('allele', 'gene'), default='gene',
937 help='''Specifies whether to use the V(D)J allele or gene for
938 initial grouping.''')
939 parser_bygroup.add_argument('--act', action='store', dest='action', default='set',
940 choices=('first', 'set'),
941 help='''Specifies how to handle multiple V(D)J assignments
942 for initial grouping.''')
943 parser_bygroup.add_argument('--model', action='store', dest='model',
944 choices=('aa', 'ham', 'm1n', 'hs1f', 'hs5f'),
945 default=default_bygroup_model,
946 help='''Specifies which substitution model to use for
947 calculating distance between sequences. Where m1n is the
948 mouse single nucleotide transition/trasversion model
949 of Smith et al, 1996; hs1f is the human single
950 nucleotide model derived from Yaari et al, 2013; hs5f
951 is the human S5F model of Yaari et al, 2013; ham is
952 nucleotide Hamming distance; and aa is amino acid
953 Hamming distance. The hs5f data should be
954 considered experimental.''')
955 parser_bygroup.add_argument('--dist', action='store', dest='distance', type=float,
956 default=default_distance,
957 help='The distance threshold for clonal grouping')
958 parser_bygroup.add_argument('--norm', action='store', dest='norm',
959 choices=('len', 'mut', 'none'), default=default_norm,
960 help='''Specifies how to normalize distances. One of none
961 (do not normalize), len (normalize by length),
962 or mut (normalize by number of mutations between sequences).''')
963 parser_bygroup.add_argument('--sym', action='store', dest='sym',
964 choices=('avg', 'min'), default=default_sym,
965 help='''Specifies how to combine asymmetric distances. One of avg
966 (average of A->B and B->A) or min (minimum of A->B and B->A).''')
967 parser_bygroup.add_argument('--link', action='store', dest='linkage',
968 choices=('single', 'average', 'complete'), default=default_linkage,
969 help='''Type of linkage to use for hierarchical clustering.''')
970 parser_bygroup.add_argument('--sf', action='store', dest='seq_field',
971 default=default_seq_field,
972 help='''The name of the field to be used to calculate
973 distance between records''')
974 parser_bygroup.set_defaults(feed_func=feedQueue)
975 parser_bygroup.set_defaults(work_func=processQueue)
976 parser_bygroup.set_defaults(collect_func=collectQueue)
977 parser_bygroup.set_defaults(group_func=indexJunctions)
978 parser_bygroup.set_defaults(clone_func=distanceClones)
979
980
981 # Hierarchical clustering cloning method
982 parser_hclust = subparsers.add_parser('hclust', parents=[parser_parent],
983 formatter_class=CommonHelpFormatter,
984 help='Defines clones by specified distance metric on CDR3s and \
985 cutting of hierarchical clustering tree')
986 # parser_hclust.add_argument('-f', nargs='+', action='store', dest='fields', default=None,
987 # help='Fields to use for grouping clones (non VDJ)')
988 parser_hclust.add_argument('--method', action='store', dest='method',
989 choices=('chen2010', 'ademokun2011'), default=default_hclust_model,
990 help='Specifies which cloning method to use for calculating distance \
991 between CDR3s, computing linkage, and cutting clusters')
992 parser_hclust.set_defaults(feed_func=feedQueueClust)
993 parser_hclust.set_defaults(work_func=processQueueClust)
994 parser_hclust.set_defaults(collect_func=collectQueueClust)
995 parser_hclust.set_defaults(cluster_func=hierClust)
996
997 return parser
998
999
1000 if __name__ == '__main__':
1001 """
1002 Parses command line arguments and calls main function
1003 """
1004 # Parse arguments
1005 parser = getArgParser()
1006 args = parser.parse_args()
1007 args_dict = parseCommonArgs(args)
1008 # Convert case of fields
1009 if 'seq_field' in args_dict:
1010 args_dict['seq_field'] = args_dict['seq_field'].upper()
1011 if 'fields' in args_dict and args_dict['fields'] is not None:
1012 args_dict['fields'] = [f.upper() for f in args_dict['fields']]
1013
1014 # Define clone_args
1015 if args.command == 'bygroup':
1016 args_dict['group_args'] = {'fields': args_dict['fields'],
1017 'action': args_dict['action'],
1018 'mode':args_dict['mode']}
1019 args_dict['clone_args'] = {'model': args_dict['model'],
1020 'distance': args_dict['distance'],
1021 'norm': args_dict['norm'],
1022 'sym': args_dict['sym'],
1023 'linkage': args_dict['linkage'],
1024 'seq_field': args_dict['seq_field']}
1025
1026 # TODO: can be cleaned up with abstract model class
1027 args_dict['clone_args']['dist_mat'] = getModelMatrix(args_dict['model'])
1028
1029 del args_dict['fields']
1030 del args_dict['action']
1031 del args_dict['mode']
1032 del args_dict['model']
1033 del args_dict['distance']
1034 del args_dict['norm']
1035 del args_dict['sym']
1036 del args_dict['linkage']
1037 del args_dict['seq_field']
1038
1039 # Define clone_args
1040 if args.command == 'hclust':
1041 dist_funcs = {'chen2010':distChen2010, 'ademokun2011':distAdemokun2011}
1042 args_dict['clone_func'] = dist_funcs[args_dict['method']]
1043 args_dict['cluster_args'] = {'method': args_dict['method']}
1044 #del args_dict['fields']
1045 del args_dict['method']
1046
1047 # Call defineClones
1048 del args_dict['command']
1049 del args_dict['db_files']
1050 for f in args.__dict__['db_files']:
1051 args_dict['db_file'] = f
1052 defineClones(**args_dict)