Mercurial > repos > cafletezbrant > kmersvm
comparison kmersvm/scripts/kmersvm_train_kfb_copy.py @ 7:fd740d515502 draft default tip
Uploaded revised kmer-SVM to include modules from kmer-visual.
author | cafletezbrant |
---|---|
date | Sun, 16 Jun 2013 18:06:14 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
6:1aea7c1a9ab1 | 7:fd740d515502 |
---|---|
1 #!/usr/bin/env python | |
2 """ | |
3 kmersvm_train.py; train a support vector machine using shogun toolbox | |
4 Copyright (C) 2011 Dongwon Lee | |
5 | |
6 This program is free software: you can redistribute it and/or modify | |
7 it under the terms of the GNU General Public License as published by | |
8 the Free Software Foundation, either version 3 of the License, or | |
9 (at your option) any later version. | |
10 | |
11 This program is distributed in the hope that it will be useful, | |
12 but WITHOUT ANY WARRANTY; without even the implied warranty of | |
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
14 GNU General Public License for more details. | |
15 | |
16 You should have received a copy of the GNU General Public License | |
17 along with this program. If not, see <http://www.gnu.org/licenses/>. | |
18 | |
19 | |
20 """ | |
21 | |
22 | |
23 | |
24 import sys | |
25 import optparse | |
26 import random | |
27 import numpy | |
28 from math import log, exp | |
29 | |
30 from libkmersvm import * | |
31 try: | |
32 from shogun.PreProc import SortWordString, SortUlongString | |
33 except ImportError: | |
34 from shogun.Preprocessor import SortWordString, SortUlongString | |
35 from shogun.Kernel import CommWordStringKernel, CommUlongStringKernel, \ | |
36 CombinedKernel | |
37 | |
38 from shogun.Features import StringWordFeatures, StringUlongFeatures, \ | |
39 StringCharFeatures, CombinedFeatures, DNA, Labels | |
40 from shogun.Classifier import MSG_INFO, MSG_ERROR | |
41 try: | |
42 from shogun.Classifier import SVMLight | |
43 except ImportError: | |
44 from shogun.Classifier import LibSVM | |
45 | |
46 """ | |
47 global variables | |
48 """ | |
49 g_kmers = [] | |
50 g_rcmap = [] | |
51 | |
52 | |
53 def kmerid2kmer(kmerid, kmerlen): | |
54 """convert integer kmerid to kmer string | |
55 | |
56 Arguments: | |
57 kmerid -- integer, id of k-mer | |
58 kmerlen -- integer, length of k-mer | |
59 | |
60 Return: | |
61 kmer string | |
62 """ | |
63 | |
64 nts = "ACGT" | |
65 kmernts = [] | |
66 kmerid2 = kmerid | |
67 | |
68 for i in xrange(kmerlen): | |
69 ntid = kmerid2 % 4 | |
70 kmernts.append(nts[ntid]) | |
71 kmerid2 = int((kmerid2-ntid)/4) | |
72 | |
73 return ''.join(reversed(kmernts)) | |
74 | |
75 | |
76 def kmer2kmerid(kmer, kmerlen): | |
77 """convert kmer string to integer kmerid | |
78 | |
79 Arguments: | |
80 kmerid -- integer, id of k-mer | |
81 kmerlen -- integer, length of k-mer | |
82 | |
83 Return: | |
84 id of k-mer | |
85 """ | |
86 | |
87 nt2id = {'A':0, 'C':1, 'G':2, 'T':3} | |
88 | |
89 return reduce(lambda x, y: (4*x+y), [nt2id[x] for x in kmer]) | |
90 | |
91 | |
92 def get_rcmap(kmerid, kmerlen): | |
93 """mapping kmerid to its reverse complement k-mer on-the-fly | |
94 | |
95 Arguments: | |
96 kmerid -- integer, id of k-mer | |
97 kmerlen -- integer, length of k-mer | |
98 | |
99 Return: | |
100 integer kmerid after mapping to its reverse complement | |
101 """ | |
102 | |
103 #1. get kmer from kmerid | |
104 #2. get reverse complement kmer | |
105 #3. get kmerid from revcomp kmer | |
106 rckmerid = kmer2kmerid(revcomp(kmerid2kmer(kmerid, kmerlen)), kmerlen) | |
107 | |
108 if rckmerid < kmerid: | |
109 return rckmerid | |
110 | |
111 return kmerid | |
112 | |
113 | |
114 def non_redundant_word_features(feats, kmerlen): | |
115 """convert the features from Shogun toolbox to non-redundant word features (handle reverse complements) | |
116 Arguments: | |
117 feats -- StringWordFeatures | |
118 kmerlen -- integer, length of k-mer | |
119 | |
120 Return: | |
121 StringWordFeatures after converting reverse complement k-mer ids | |
122 """ | |
123 | |
124 rcmap = g_rcmap | |
125 | |
126 for i in xrange(feats.get_num_vectors()): | |
127 nf = [rcmap[int(kmerid)] for kmerid in feats.get_feature_vector(i)] | |
128 | |
129 feats.set_feature_vector(numpy.array(nf, numpy.dtype('u2')), i) | |
130 | |
131 preproc = SortWordString() | |
132 preproc.init(feats) | |
133 try: | |
134 feats.add_preproc(preproc) | |
135 feats.apply_preproc() | |
136 except AttributeError: | |
137 feats.add_preprocessor(preproc) | |
138 feats.apply_preprocessor() | |
139 | |
140 return feats | |
141 | |
142 | |
143 def non_redundant_ulong_features(feats, kmerlen): | |
144 """convert the features from Shogun toolbox to non-redundant ulong features | |
145 Arguments: | |
146 feats -- StringUlongFeatures | |
147 kmerlen -- integer, length of k-mer | |
148 | |
149 Return: | |
150 StringUlongFeatures after converting reverse complement k-mer ids | |
151 """ | |
152 | |
153 for i in xrange(feats.get_num_vectors()): | |
154 nf = [get_rcmap(int(kmerid), kmerlen) \ | |
155 for kmerid in feats.get_feature_vector(i)] | |
156 | |
157 feats.set_feature_vector(numpy.array(nf, numpy.dtype('u8')), i) | |
158 | |
159 preproc = SortUlongString() | |
160 preproc.init(feats) | |
161 try: | |
162 feats.add_preproc(preproc) | |
163 feats.apply_preproc() | |
164 except AttributeError: | |
165 feats.add_preprocessor(preproc) | |
166 feats.apply_preprocessor() | |
167 | |
168 return feats | |
169 | |
170 | |
171 def svm_learn(kernel, labels, options): | |
172 """train SVM using SVMLight or LibSVM | |
173 | |
174 Arguments: | |
175 kernel -- kernel object from Shogun toolbox | |
176 lebels -- list of labels | |
177 options -- object containing option data | |
178 | |
179 Return: | |
180 trained svm object | |
181 """ | |
182 | |
183 try: | |
184 svm=SVMLight(options.svmC, kernel, Labels(numpy.array(labels, dtype=numpy.double))) | |
185 except NameError: | |
186 svm=LibSVM(options.svmC, kernel, Labels(numpy.array(labels, dtype=numpy.double))) | |
187 | |
188 if options.quiet == False: | |
189 svm.io.set_loglevel(MSG_INFO) | |
190 svm.io.set_target_to_stderr() | |
191 | |
192 svm.set_epsilon(options.epsilon) | |
193 svm.parallel.set_num_threads(1) | |
194 if options.weight != 1.0: | |
195 svm.set_C(options.svmC, options.svmC*options.weight) | |
196 svm.train() | |
197 | |
198 if options.quiet == False: | |
199 svm.io.set_loglevel(MSG_ERROR) | |
200 | |
201 return svm | |
202 | |
203 | |
204 def _get_spectrum_features(seqs, kmerlen): | |
205 """generate spectrum features (internal) | |
206 | |
207 Arguments: | |
208 seqs -- list of sequences | |
209 kmerlen -- integer, length of k-mer | |
210 | |
211 Return: | |
212 StringWord(Ulong)Features after treatment of redundant reverse complement k-mers | |
213 """ | |
214 | |
215 char_feats = StringCharFeatures(seqs, DNA) | |
216 | |
217 if kmerlen <= 8: | |
218 string_features = StringWordFeatures | |
219 non_redundant_features = non_redundant_word_features | |
220 else: | |
221 string_features = StringUlongFeatures | |
222 non_redundant_features = non_redundant_ulong_features | |
223 | |
224 feats = string_features(DNA) | |
225 feats.obtain_from_char(char_feats, kmerlen-1, kmerlen, 0, False) | |
226 return non_redundant_features(feats, kmerlen) | |
227 | |
228 | |
229 def get_spectrum_features(seqs, options): | |
230 """generate spectrum features (wrapper) | |
231 """ | |
232 return _get_spectrum_features(seqs, options.kmerlen) | |
233 | |
234 | |
235 def get_weighted_spectrum_features(seqs, options): | |
236 """generate weighted spectrum features | |
237 """ | |
238 global g_kmers | |
239 global g_rcmap | |
240 | |
241 subfeats_list = [] | |
242 | |
243 for k in xrange(options.kmerlen, options.kmerlen2+1): | |
244 char_feats = StringCharFeatures(seqs, DNA) | |
245 if k <= 8: | |
246 g_kmers = generate_kmers(k) | |
247 g_rcmap = generate_rcmap_table(k, g_kmers) | |
248 | |
249 subfeats = _get_spectrum_features(seqs, k) | |
250 subfeats_list.append(subfeats) | |
251 | |
252 return subfeats_list | |
253 | |
254 | |
255 def get_spectrum_kernel(feats, options): | |
256 """build spectrum kernel with non-redundant k-mer list (removing reverse complement) | |
257 | |
258 Arguments: | |
259 feats -- feature object | |
260 options -- object containing option data | |
261 | |
262 Return: | |
263 StringWord(Ulong)Features, CommWord(Ulong)StringKernel | |
264 """ | |
265 if options.kmerlen <= 8: | |
266 return CommWordStringKernel(feats, feats) | |
267 else: | |
268 return CommUlongStringKernel(feats, feats) | |
269 | |
270 | |
271 def get_weighted_spectrum_kernel(subfeats_list, options): | |
272 """build weighted spectrum kernel with non-redundant k-mer list (removing reverse complement) | |
273 | |
274 Arguments: | |
275 subfeats_list -- list of sub-feature objects | |
276 options -- object containing option data | |
277 | |
278 Return: | |
279 CombinedFeatures of StringWord(Ulong)Features, CombinedKernel of CommWord(Ulong)StringKernel | |
280 """ | |
281 kmerlen = options.kmerlen | |
282 kmerlen2 = options.kmerlen2 | |
283 | |
284 subkernels = 0 | |
285 kernel = CombinedKernel() | |
286 feats = CombinedFeatures() | |
287 | |
288 for subfeats in subfeats_list: | |
289 feats.append_feature_obj(subfeats) | |
290 | |
291 for k in xrange(kmerlen, kmerlen2+1): | |
292 if k <= 8: | |
293 subkernel = CommWordStringKernel(10, False) | |
294 else: | |
295 subkernel = CommUlongStringKernel(10, False) | |
296 | |
297 kernel.append_kernel(subkernel) | |
298 subkernels+=1 | |
299 | |
300 kernel.init(feats, feats) | |
301 | |
302 kernel.set_subkernel_weights(numpy.array([1/float(subkernels)]*subkernels, numpy.dtype('float64'))) | |
303 | |
304 return kernel | |
305 | |
306 | |
307 def init_spectrum_kernel(kern, feats_lhs, feats_rhs): | |
308 """initialize spectrum kernel (wrapper function) | |
309 """ | |
310 kern.init(feats_lhs, feats_rhs) | |
311 | |
312 | |
313 def init_weighted_spectrum_kernel(kern, subfeats_list_lhs, subfeats_list_rhs): | |
314 """initialize weighted spectrum kernel (wrapper function) | |
315 """ | |
316 feats_lhs = CombinedFeatures() | |
317 feats_rhs = CombinedFeatures() | |
318 | |
319 for subfeats in subfeats_list_lhs: | |
320 feats_lhs.append_feature_obj(subfeats) | |
321 | |
322 for subfeats in subfeats_list_rhs: | |
323 feats_rhs.append_feature_obj(subfeats) | |
324 | |
325 kern.init(feats_lhs, feats_rhs) | |
326 | |
327 | |
328 def get_sksvm_weights(svm, feats, options): | |
329 """calculate the SVM weight vector of spectrum kernel | |
330 """ | |
331 kmerlen = options.kmerlen | |
332 alphas = svm.get_alphas() | |
333 support_vector_ids = svm.get_support_vectors() | |
334 | |
335 w = numpy.array([0]*(2**(2*kmerlen)), numpy.double) | |
336 | |
337 for i in xrange(len(alphas)): | |
338 x = [0]*(2**(2*kmerlen)) | |
339 for kmerid in feats.get_feature_vector(int(support_vector_ids[i])): | |
340 x[int(kmerid)] += 1 | |
341 x = numpy.array(x, numpy.double) | |
342 w += (alphas[i]*x/numpy.sqrt(numpy.sum(x**2))) | |
343 | |
344 return w | |
345 | |
346 def get_feature_counts(svm, feats, options): | |
347 """calculate feature counts for SVs | |
348 """ | |
349 kmerlen = options.kmerlen | |
350 alphas = svm.get_alphas() | |
351 support_vector_ids = svm.get_support_vectors() | |
352 output = options.outputname + "_counts.out" | |
353 | |
354 global g_kmers | |
355 global g_rcmap | |
356 | |
357 w = numpy.array([0]*(2**(2*kmerlen)), numpy.double) | |
358 | |
359 for i in xrange(len(support_vector_ids)): | |
360 x = [0]*(2**(2*kmerlen)) | |
361 for kmerid in feats.get_feature_vector(int(support_vector_ids[i])): | |
362 x[int(kmerid)] += 1 | |
363 | |
364 x = numpy.array(x, numpy.double) | |
365 w += x | |
366 | |
367 if options.sort: | |
368 w_sorted = sorted(zip(range(len(w)), w), key=lambda x: x[1], reverse=True) | |
369 else: | |
370 w_sorted = zip(range(len(w)), w) | |
371 | |
372 for i in map(lambda x: x[0], w_sorted): | |
373 if i == g_rcmap[i]: | |
374 f.write('\t'.join( [g_kmers[i], revcomp(g_kmers[i]), str(w[i])] ) + '\n') | |
375 | |
376 f.close() | |
377 | |
378 | |
379 | |
380 def get_wsksvm_weights(svm, subfeats_list, options): | |
381 """calculate the SVM weight vector of weighted spectrum kernel | |
382 """ | |
383 kmerlen = options.kmerlen | |
384 kmerlen2 = options.kmerlen2 | |
385 alphas = svm.get_alphas() | |
386 support_vector_ids = svm.get_support_vectors() | |
387 kmerlens = range(kmerlen, kmerlen2+1) | |
388 | |
389 weights = [] | |
390 for idx in xrange(len(kmerlens)): | |
391 subfeats = subfeats_list[idx] | |
392 | |
393 k = kmerlens[idx] | |
394 w = numpy.array([0]*(2**(2*k)), numpy.double) | |
395 | |
396 for i in xrange(len(alphas)): | |
397 x = [0]*(2**(2*k)) | |
398 for kmerid in subfeats.get_feature_vector(int(support_vector_ids[i])): | |
399 x[int(kmerid)] += 1 | |
400 x = numpy.array(x, numpy.double) | |
401 w += (alphas[i]*x/numpy.sqrt(numpy.sum(x**2))) | |
402 | |
403 w /= len(kmerlens) | |
404 weights.append(w) | |
405 | |
406 return weights | |
407 | |
408 | |
409 def save_header(f, bias, A, B, options): | |
410 f.write("#parameters:\n") | |
411 f.write("#kernel=" + str(options.ktype) + "\n") | |
412 f.write("#kmerlen=" + str(options.kmerlen) + "\n") | |
413 if options.ktype == 2: | |
414 f.write("#kmerlen2=" + str(options.kmerlen2) + "\n") | |
415 f.write("#bias=" + str(bias) + "\n") | |
416 f.write("#A=" + str(A) + "\n") | |
417 f.write("#B=" + str(B) + "\n") | |
418 f.write("#NOTE: k-mers with large negative weights are also important. They can be found at the bottom of the list.\n") | |
419 f.write("#k-mer\trevcomp\tSVM-weight\n") | |
420 | |
421 | |
422 def save_sksvm_weights(w, bias, A, B, options): | |
423 """save the SVM weight vector from spectrum kernel | |
424 """ | |
425 output = options.outputname + "_weights.out" | |
426 kmerlen = options.kmerlen | |
427 | |
428 f = open(output, 'w') | |
429 save_header(f, bias, A, B, options) | |
430 | |
431 global g_kmers | |
432 global g_rcmap | |
433 | |
434 if options.sort: | |
435 w_sorted = sorted(zip(range(len(w)), w), key=lambda x: x[1], reverse=True) | |
436 else: | |
437 w_sorted = zip(range(len(w)), w) | |
438 | |
439 if kmerlen <= 8: | |
440 for i in map(lambda x: x[0], w_sorted): | |
441 if i == g_rcmap[i]: | |
442 f.write('\t'.join( [g_kmers[i], revcomp(g_kmers[i]), str(w[i])] ) + '\n') | |
443 else: | |
444 for i in map(lambda x: x[0], w_sorted): | |
445 if i == get_rcmap(i, kmerlen): | |
446 kmer = kmerid2kmer(i, kmerlen) | |
447 f.write('\t'.join( [kmer, revcomp(kmer), str(w[i])] ) + '\n') | |
448 | |
449 f.close() | |
450 | |
451 | |
452 def save_wsksvm_weights(w, bias, A, B, options): | |
453 """save the SVM weight vector from weighted spectrum kernel | |
454 """ | |
455 output = options.outputname + "_weights.out" | |
456 kmerlen = options.kmerlen | |
457 kmerlen2 = options.kmerlen2 | |
458 | |
459 f = open(output, 'w') | |
460 save_header(f, bias, A, B, options) | |
461 | |
462 global g_kmers | |
463 global g_rcmap | |
464 | |
465 kmerlens = range(kmerlen, kmerlen2+1) | |
466 for idx in xrange(len(kmerlens)): | |
467 k = kmerlens[idx] | |
468 subw = w[idx] | |
469 | |
470 if options.sort: | |
471 subw_sorted = sorted(zip(range(len(subw)), subw), key=lambda x: x[1], reverse=True) | |
472 else: | |
473 subw_sorted = zip(range(len(subw)), subw) | |
474 | |
475 if k <= 8: | |
476 g_kmers = generate_kmers(k) | |
477 g_rcmap = generate_rcmap_table(k, g_kmers) | |
478 for i in map(lambda x: x[0], subw_sorted): | |
479 if i == g_rcmap[i]: | |
480 f.write('\t'.join( [g_kmers[i], revcomp(g_kmers[i]), str(subw[i])] ) + "\n") | |
481 else: | |
482 for i in map(lambda x: x[0], subw_sorted): | |
483 if i == get_rcmap(i, k): | |
484 kmer = kmerid2kmer(i, k) | |
485 f.write('\t'.join( [kmers, revcomp(kmers), str(subw[i])] ) + "\n") | |
486 | |
487 f.close() | |
488 | |
489 | |
490 def save_predictions(output, preds, cvs): | |
491 """save prediction | |
492 """ | |
493 f = open(output, 'w') | |
494 f.write('\t'.join(["#seq_id", "SVM score", "label", "NCV"]) + "\n") | |
495 for i in xrange(len(preds)): | |
496 f.write('\t'.join([preds[i][1], str(preds[i][2]), str(preds[i][3]), str(cvs[i])]) + "\n") | |
497 f.close() | |
498 | |
499 | |
500 def generate_cv_list(ncv, n1, n2): | |
501 """generate the N-fold cross validation list | |
502 | |
503 Arguments: | |
504 ncv -- integer, number of cross-validation | |
505 n1 -- integer, number of positives | |
506 n2 -- integer, number of negatives | |
507 | |
508 Return: | |
509 a list of N-fold cross validation | |
510 """ | |
511 | |
512 shuffled_idx_list1 = range(n1) | |
513 shuffled_idx_list2 = range(n1,n1+n2) | |
514 | |
515 random.shuffle(shuffled_idx_list1) | |
516 random.shuffle(shuffled_idx_list2) | |
517 | |
518 shuffled_idx_list = shuffled_idx_list1 + shuffled_idx_list2 | |
519 | |
520 idx = 0 | |
521 icv = 0 | |
522 cv = [0] * (n1+n2) | |
523 while(idx < (n1+n2)): | |
524 cv[shuffled_idx_list[idx]] = icv | |
525 | |
526 idx += 1 | |
527 icv += 1 | |
528 if icv == ncv: | |
529 icv = 0 | |
530 | |
531 return cv | |
532 | |
533 | |
534 def split_cv_list(cvlist, icv, data): | |
535 """split data into training and test based on cross-validation list | |
536 | |
537 Arguments: | |
538 cvlist -- list, cross-validation list | |
539 icv -- integer, corss-validation set of interest | |
540 data -- list, data set to be splitted | |
541 | |
542 Return: | |
543 a list of training set and a list of test set | |
544 """ | |
545 | |
546 tr_data = [] | |
547 te_data = [] | |
548 | |
549 for i in xrange(len(data)): | |
550 if cvlist[i] == icv: | |
551 te_data.append(data[i]) | |
552 else: | |
553 tr_data.append(data[i]) | |
554 | |
555 return tr_data, te_data | |
556 | |
557 | |
558 def LMAI(svms, labels, prior0, prior1): | |
559 """fitting svms to sigmoid function (improved version introduced by Lin 2003) | |
560 | |
561 Arguments: | |
562 svms -- list of svm scores | |
563 labels -- list of labels | |
564 prior0 -- prior of negative set | |
565 prior1 -- prior of positive set | |
566 | |
567 Return: | |
568 A, B parameter of 1/(1+exp(A*SVM+B)) | |
569 """ | |
570 | |
571 #parameter settings | |
572 maxiter = 100 | |
573 minstep = 1e-10 | |
574 sigma = 1e-3 | |
575 | |
576 hiTarget = (prior1+1.0)/float(prior1+2.0) | |
577 loTarget = 1/float(prior0+2.0) | |
578 | |
579 t = [0]*len(labels) | |
580 for i in xrange(len(labels)): | |
581 if labels[i] == 1: | |
582 t[i] = hiTarget | |
583 else: | |
584 t[i] = loTarget | |
585 | |
586 A = 0.0 | |
587 B = log((prior0+1.0)/float(prior1+1.0)) | |
588 fval = 0.0 | |
589 | |
590 for i in xrange(len(labels)): | |
591 fApB = svms[i]*A+B | |
592 if fApB >= 0: | |
593 fval += (t[i]*fApB+log(1+exp(-fApB))) | |
594 else: | |
595 fval += ((t[i]-1)*fApB+log(1+exp(fApB))) | |
596 | |
597 | |
598 for it in xrange(maxiter): | |
599 #print "iteration:", it | |
600 #Update Graidient and Hessian (use H'= H + sigma I) | |
601 h11 = sigma | |
602 h22 = sigma | |
603 h21 = 0.0 | |
604 g1 = 0.0 | |
605 g2 = 0.0 | |
606 | |
607 for i in xrange(len(labels)): | |
608 fApB = svms[i]*A+B | |
609 if fApB >= 0: | |
610 p = exp(-fApB) / float(1.0+exp(-fApB)) | |
611 q = 1.0 / float(1.0 + exp(-fApB)) | |
612 else: | |
613 p = 1.0 / float(1.0 + exp(fApB)) | |
614 q = exp(fApB) / float(1.0+exp(fApB)) | |
615 d2 = p*q | |
616 h11 += (svms[i]*svms[i]*d2) | |
617 h22 += d2 | |
618 h21 += (svms[i]*d2) | |
619 d1 = t[i]-p | |
620 g1 += (svms[i]*d1) | |
621 g2 += d1 | |
622 | |
623 #Stopping criteria | |
624 if (abs(g1)<1e-5) and (abs(g2)<1e-5): | |
625 break | |
626 | |
627 det = h11*h22-h21*h21 | |
628 dA = -(h22*g1-h21*g2)/float(det) | |
629 dB = -(-h21*g1+h11*g2)/float(det) | |
630 gd = g1*dA+g2*dB | |
631 stepsize=1 | |
632 while stepsize >= minstep: | |
633 newA = A+stepsize*dA | |
634 newB = B+stepsize*dB | |
635 newf = 0.0 | |
636 | |
637 for i in xrange(len(labels)): | |
638 fApB = svms[i]*newA+newB | |
639 if fApB >= 0: | |
640 newf += (t[i]*fApB + log(1+exp(-fApB))) | |
641 else: | |
642 newf += ((t[i]-1)*fApB + log(1+exp(fApB))) | |
643 | |
644 if newf < (fval+0.0001*stepsize*gd): | |
645 A=newA | |
646 B=newB | |
647 fval=newf | |
648 break | |
649 else: | |
650 stepsize=stepsize/float(2.0) | |
651 | |
652 #Line search failes | |
653 if stepsize < minstep: | |
654 #print "Line search fails" | |
655 break | |
656 | |
657 #if it >= maxiter: | |
658 # print "Reaching maximum iterations" | |
659 | |
660 return A, B | |
661 | |
662 | |
663 def wsksvm_classify(seqs, svm, kern, feats, options): | |
664 feats_te = get_weighted_spectrum_features(seqs, options) | |
665 init_weighted_spectrum_kernel(kern, feats, feats_te) | |
666 | |
667 return svm.apply().get_labels().tolist() | |
668 | |
669 | |
670 def score_seq(s, svmw, kmerlen): | |
671 """calculate SVM score of given sequence using single set of svm weights | |
672 | |
673 Arguments: | |
674 s -- string, DNA sequence | |
675 svmw -- numpy array, SVM weights | |
676 kmerlen -- integer, length of k-mer of SVM weight | |
677 | |
678 Return: | |
679 SVM score | |
680 """ | |
681 | |
682 global g_rcmap | |
683 kmer2kmerid_func = kmer2kmerid | |
684 | |
685 x = [0]*(2**(2*kmerlen)) | |
686 for j in xrange(len(s)-kmerlen+1): | |
687 x[ g_rcmap[kmer2kmerid_func(s[j:j+kmerlen], kmerlen)] ] += 1 | |
688 | |
689 x = numpy.array(x, numpy.double) | |
690 score_norm = numpy.dot(svmw, x)/numpy.sqrt(numpy.sum(x**2)) | |
691 | |
692 return score_norm | |
693 | |
694 | |
695 def sksvm_classify(seqs, svm, kern, feats, options): | |
696 """classify the given sequences | |
697 """ | |
698 if options.kmerlen <= 8: | |
699 #this is much faster when the length of kmer is short, and SVs are many | |
700 svmw = get_sksvm_weights(svm, feats, options) | |
701 return [score_seq(s, svmw, options.kmerlen)+svm.get_bias() for s in seqs] | |
702 else: | |
703 feats_te = get_spectrum_features(seqs, options) | |
704 init_spectrum_kernel(kern, feats, feats_te) | |
705 | |
706 return svm.apply().get_labels().tolist() | |
707 | |
708 | |
709 def main(argv = sys.argv): | |
710 usage = "Usage: %prog [options] POSITIVE_SEQ NEGATIVE_SEQ" | |
711 desc = "1. take two files(FASTA format) as input, 2. train an SVM and store the trained SVM weights" | |
712 parser = optparse.OptionParser(usage=usage, description=desc) | |
713 parser.add_option("-t", dest="ktype", type="int", default=1, \ | |
714 help="set the type of kernel, 1:Spectrum, 2:Weighted Spectrums (default=1.Spectrum)") | |
715 | |
716 parser.add_option("-C", dest="svmC", type="float", default=1, \ | |
717 help="set the regularization parameter svmC (default=1)") | |
718 | |
719 parser.add_option("-e", dest="epsilon", type="float", default=0.00001, \ | |
720 help="set the precision parameter epsilon (default=0.00001)") | |
721 | |
722 parser.add_option("-w", dest="weight", type="float", default=0.0, \ | |
723 help="set the weight for positive set (default=auto, 1+log(N/P))") | |
724 | |
725 parser.add_option("-k", dest="kmerlen", type="int",default=6, \ | |
726 help="set the (min) length of k-mer for (weighted) spectrum kernel (default = 6)") | |
727 | |
728 parser.add_option("-K", dest="kmerlen2", type="int",default=8, \ | |
729 help="set the max length of k-mer for weighted spectrum kernel (default = 8)") | |
730 | |
731 parser.add_option("-n", dest="outputname", default="kmersvm_output", \ | |
732 help="set the name of output files (default=kmersvm_output)") | |
733 | |
734 parser.add_option("-v", dest="ncv", type="int", default=0, \ | |
735 help="if set, it will perform N-fold cross-validation and generate a prediction file (default = 0)") | |
736 | |
737 parser.add_option("-p", dest="posteriorp", default=False, action="store_true", \ | |
738 help="estimate parameters for posterior probability with N-CV. this option requires -v option to be set (default=false)") | |
739 | |
740 parser.add_option("-r", dest="rseed", type="int", default=1, \ | |
741 help="set the random number seed for cross-validation (-p option) (default=1)") | |
742 | |
743 parser.add_option("-q", dest="quiet", default=False, action="store_true", \ | |
744 help="supress messages (default=false)") | |
745 | |
746 parser.add_option("-s", dest="sort", default=False, action="store_true", \ | |
747 help="sort the kmers by absolute values of SVM weights (default=false)") | |
748 | |
749 ktype_str = ["", "Spectrum", "Weighted Spectrums"] | |
750 | |
751 (options, args) = parser.parse_args() | |
752 | |
753 if len(args) == 0: | |
754 parser.print_help() | |
755 sys.exit(0) | |
756 | |
757 if len(args) != 2: | |
758 parser.error("incorrect number of arguments") | |
759 parser.print_help() | |
760 sys.exit(0) | |
761 | |
762 if options.posteriorp and options.ncv == 0: | |
763 parser.error("posterior probability estimation requires N-fold CV process (-v option should be set)") | |
764 parser.print_help() | |
765 sys.exit(0) | |
766 | |
767 random.seed(options.rseed) | |
768 | |
769 """ | |
770 set global variable | |
771 """ | |
772 if (options.ktype == 1) and (options.kmerlen <= 8): | |
773 global g_kmers | |
774 global g_rcmap | |
775 | |
776 g_kmers = generate_kmers(options.kmerlen) | |
777 g_rcmap = generate_rcmap_table(options.kmerlen, g_kmers) | |
778 | |
779 posf = args[0] | |
780 negf = args[1] | |
781 | |
782 seqs_pos, sids_pos = read_fastafile(posf) | |
783 seqs_neg, sids_neg = read_fastafile(negf) | |
784 npos = len(seqs_pos) | |
785 nneg = len(seqs_neg) | |
786 seqs = seqs_pos + seqs_neg | |
787 sids = sids_pos + sids_neg | |
788 | |
789 if options.weight == 0: | |
790 #DEBUGGED by dlee 02/17/13 | |
791 options.weight = 1 + log(nneg/float(npos)) | |
792 | |
793 if options.quiet == False: | |
794 sys.stderr.write('SVM parameters:\n') | |
795 sys.stderr.write(' kernel-type: ' + str(options.ktype) + "." + ktype_str[options.ktype] + '\n') | |
796 sys.stderr.write(' svm-C: ' + str(options.svmC) + '\n') | |
797 sys.stderr.write(' epsilon: ' + str(options.epsilon) + '\n') | |
798 sys.stderr.write(' weight: ' + str(options.weight) + '\n') | |
799 sys.stderr.write('\n') | |
800 | |
801 sys.stderr.write('Other options:\n') | |
802 sys.stderr.write(' kmerlen: ' + str(options.kmerlen) + '\n') | |
803 if options.ktype == 2: | |
804 sys.stderr.write(' kmerlen2: ' + str(options.kmerlen2) + '\n') | |
805 sys.stderr.write(' outputname: ' + options.outputname + '\n') | |
806 sys.stderr.write(' posteriorp: ' + str(options.posteriorp) + '\n') | |
807 if options.ncv > 0: | |
808 sys.stderr.write(' ncv: ' + str(options.ncv) + '\n') | |
809 sys.stderr.write(' rseed: ' + str(options.rseed) + '\n') | |
810 sys.stderr.write(' sorted-weight: ' + str(options.sort) + '\n') | |
811 | |
812 sys.stderr.write('\n') | |
813 | |
814 sys.stderr.write('Input args:\n') | |
815 sys.stderr.write(' positive sequence file: ' + posf + '\n') | |
816 sys.stderr.write(' negative sequence file: ' + negf + '\n') | |
817 sys.stderr.write('\n') | |
818 | |
819 sys.stderr.write('numer of total positive seqs: ' + str(npos) + '\n') | |
820 sys.stderr.write('numer of total negative seqs: ' + str(nneg) + '\n') | |
821 sys.stderr.write('\n') | |
822 | |
823 #generate labels | |
824 labels = [1]*npos + [-1]*nneg | |
825 | |
826 if options.ktype == 1: | |
827 get_features = get_spectrum_features | |
828 get_kernel = get_spectrum_kernel | |
829 get_weights = get_sksvm_weights | |
830 save_weights = save_sksvm_weights | |
831 svm_classify = sksvm_classify | |
832 elif options.ktype == 2: | |
833 get_features = get_weighted_spectrum_features | |
834 get_kernel = get_weighted_spectrum_kernel | |
835 get_weights = get_wsksvm_weights | |
836 save_weights = save_wsksvm_weights | |
837 svm_classify = wsksvm_classify | |
838 else: | |
839 sys.stderr.write('..unknown kernel..\n') | |
840 sys.exit(0) | |
841 | |
842 A = B = 0 | |
843 if options.ncv > 0: | |
844 if options.quiet == False: | |
845 sys.stderr.write('..Cross-validation\n') | |
846 | |
847 cvlist = generate_cv_list(options.ncv, npos, nneg) | |
848 labels_cv = [] | |
849 preds_cv = [] | |
850 sids_cv = [] | |
851 indices_cv = [] | |
852 for icv in xrange(options.ncv): | |
853 #split data into training and test set | |
854 seqs_tr, seqs_te = split_cv_list(cvlist, icv, seqs) | |
855 labs_tr, labs_te = split_cv_list(cvlist, icv, labels) | |
856 sids_tr, sids_te = split_cv_list(cvlist, icv, sids) | |
857 indices_tr, indices_te = split_cv_list(cvlist, icv, range(len(seqs))) | |
858 | |
859 #train SVM | |
860 feats_tr = get_features(seqs_tr, options) | |
861 kernel_tr = get_kernel(feats_tr, options) | |
862 svm_cv = svm_learn(kernel_tr, labs_tr, options) | |
863 | |
864 preds_cv = preds_cv + svm_classify(seqs_te, svm_cv, kernel_tr, feats_tr, options) | |
865 | |
866 labels_cv = labels_cv + labs_te | |
867 sids_cv = sids_cv + sids_te | |
868 indices_cv = indices_cv + indices_te | |
869 | |
870 output_cvpred = options.outputname + "_cvpred.out" | |
871 prediction_results = sorted(zip(indices_cv, sids_cv, preds_cv, labels_cv), key=lambda p: p[0]) | |
872 save_predictions(output_cvpred, prediction_results, cvlist) | |
873 | |
874 if options.posteriorp: | |
875 A, B = LMAI(preds_cv, labels_cv, labels_cv.count(-1), labels_cv.count(1)) | |
876 | |
877 if options.quiet == False: | |
878 sys.stderr.write('Estimated Parameters:\n') | |
879 sys.stderr.write(' A: ' + str(A) + '\n') | |
880 sys.stderr.write(' B: ' + str(B) + '\n') | |
881 | |
882 if options.quiet == False: | |
883 sys.stderr.write('..SVM weights\n') | |
884 | |
885 feats = get_features(seqs, options) | |
886 kernel = get_kernel(feats, options) | |
887 svm = svm_learn(kernel, labels, options) | |
888 jj = get_feature_counts(svm, feats, options) | |
889 w = get_weights(svm, feats, options) | |
890 b = svm.get_bias() | |
891 | |
892 save_weights(w, b, A, B, options) | |
893 | |
894 if __name__=='__main__': main() |