Mercurial > repos > davidmurphy > codonlogo
comparison test_weblogo.py @ 0:c55bdc2fb9fa
Uploaded
author | davidmurphy |
---|---|
date | Thu, 27 Oct 2011 12:09:09 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:c55bdc2fb9fa |
---|---|
1 #!/usr/bin/env python | |
2 | |
3 # Copyright (c) 2006, The Regents of the University of California, through | |
4 # Lawrence Berkeley National Laboratory (subject to receipt of any required | |
5 # approvals from the U.S. Dept. of Energy). All rights reserved. | |
6 | |
7 # This software is distributed under the new BSD Open Source License. | |
8 # <http://www.opensource.org/licenses/bsd-license.html> | |
9 # | |
10 # Redistribution and use in source and binary forms, with or without | |
11 # modification, are permitted provided that the following conditions are met: | |
12 # | |
13 # (1) Redistributions of source code must retain the above copyright notice, | |
14 # this list of conditions and the following disclaimer. | |
15 # | |
16 # (2) Redistributions in binary form must reproduce the above copyright | |
17 # notice, this list of conditions and the following disclaimer in the | |
18 # documentation and or other materials provided with the distribution. | |
19 # | |
20 # (3) Neither the name of the University of California, Lawrence Berkeley | |
21 # National Laboratory, U.S. Dept. of Energy nor the names of its contributors | |
22 # may be used to endorse or promote products derived from this software | |
23 # without specific prior written permission. | |
24 # | |
25 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
26 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
27 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |
28 # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | |
29 # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |
30 # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |
31 # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |
32 # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |
33 # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |
34 # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |
35 # POSSIBILITY OF SUCH DAMAGE. | |
36 | |
37 | |
38 import unittest | |
39 | |
40 import weblogolib | |
41 from weblogolib import * | |
42 from weblogolib import parse_prior, GhostscriptAPI | |
43 from weblogolib.color import * | |
44 from weblogolib.colorscheme import * | |
45 from StringIO import StringIO | |
46 import sys | |
47 | |
48 from numpy import array, asarray, float64, ones, zeros, int32,all,any, shape | |
49 import numpy as na | |
50 | |
51 from corebio import seq_io | |
52 from corebio.seq import * | |
53 | |
54 # python2.3 compatability | |
55 from corebio._future.subprocess import * | |
56 from corebio._future import resource_stream | |
57 | |
58 from corebio.moremath import entropy | |
59 from math import log, sqrt | |
60 codon_alphabetU=['AAA', 'AAC', 'AAG', 'AAU', 'ACA', 'ACC', 'ACG', 'ACU', 'AGA', 'AGC', 'AGG', 'AGU', 'AUA', 'AUC', 'AUG', 'AUU', 'CAA', 'CAC', 'CAG', 'CAU', 'CCA', 'CCC', 'CCG', 'CCU', 'CGA', 'CGC', 'CGG', 'CGU', 'CUA', 'CUC', 'CUG', 'CUU', 'GAA', 'GAC', 'GAG', 'GAU', 'GCA', 'GCC', 'GCG', 'GCU', 'GGA', 'GGC', 'GGG', 'GGU', 'GUA', 'GUC', 'GUG', 'GUU', 'UAA', 'UAC', 'UAG', 'UAU', 'UCA', 'UCC', 'UCG', 'UCU', 'UGA', 'UGC', 'UGG', 'UGU', 'UUA', 'UUC', 'UUG', 'UUU'] | |
61 codon_alphabetT=['AAA', 'AAC', 'AAG', 'AAT', 'ACA', 'ACC', 'ACG', 'ACT', 'AGA', 'AGC', 'AGG', 'AGT', 'ATA', 'ATC', 'ATG', 'ATT', 'CAA', 'CAC', 'CAG', 'CAT', 'CCA', 'CCC', 'CCG', 'CCT', 'CGA', 'CGC', 'CGG', 'CGT', 'CTA', 'CTC', 'CTG', 'CTT', 'GAA', 'GAC', 'GAG', 'GAT', 'GCA', 'GCC', 'GCG', 'GCT', 'GGA', 'GGC', 'GGG', 'GGT', 'GTA', 'GTC', 'GTG', 'GTT', 'TAA', 'TAC', 'TAG', 'TAT', 'TCA', 'TCC', 'TCG', 'TCT', 'TGA', 'TGC', 'TGG', 'TGT', 'TTA', 'TTC', 'TTG', 'TTT'] | |
62 | |
63 | |
64 def testdata_stream( name ): | |
65 return resource_stream(__name__, 'tests/data/'+name, __file__) | |
66 | |
67 class test_logoformat(unittest.TestCase) : | |
68 | |
69 def test_options(self) : | |
70 options = LogoOptions() | |
71 | |
72 | |
73 class test_ghostscript(unittest.TestCase) : | |
74 def test_version(self) : | |
75 version = GhostscriptAPI().version | |
76 | |
77 | |
78 | |
79 class test_parse_prior(unittest.TestCase) : | |
80 def assertTrue(self, bool) : | |
81 self.assertEquals( bool, True) | |
82 | |
83 def test_parse_prior_none(self) : | |
84 self.assertEquals( None, | |
85 parse_prior(None, unambiguous_protein_alphabet ) ) | |
86 self.assertEquals( None, | |
87 parse_prior( 'none', unambiguous_protein_alphabet ) ) | |
88 self.assertEquals( None, | |
89 parse_prior( 'noNe', None) ) | |
90 | |
91 def test_parse_prior_equiprobable(self) : | |
92 self.assertTrue( all(20.*equiprobable_distribution(20) == | |
93 parse_prior( 'equiprobable', unambiguous_protein_alphabet ) ) ) | |
94 | |
95 self.assertTrue( | |
96 all( 1.2* equiprobable_distribution(3) | |
97 == parse_prior( ' equiprobablE ', Alphabet('123'), 1.2 ) ) ) | |
98 | |
99 def test_parse_prior_percentage(self) : | |
100 #print parse_prior( '50%', unambiguous_dna_alphabet, 1. ) | |
101 self.assertTrue( all( equiprobable_distribution(4) | |
102 == parse_prior( '50%', unambiguous_dna_alphabet, 1. ) ) ) | |
103 | |
104 self.assertTrue( all( equiprobable_distribution(4) | |
105 == parse_prior( ' 50.0 % ', unambiguous_dna_alphabet, 1. ) ) ) | |
106 | |
107 self.assertTrue( all( array( (0.3,0.2,0.2,0.3), float64) | |
108 == parse_prior( ' 40.0 % ', unambiguous_dna_alphabet, 1. ) ) ) | |
109 | |
110 def test_parse_prior_float(self) : | |
111 self.assertTrue( all( equiprobable_distribution(4) | |
112 == parse_prior( '0.5', unambiguous_dna_alphabet, 1. ) ) ) | |
113 | |
114 self.assertTrue( all( equiprobable_distribution(4) | |
115 == parse_prior( ' 0.500 ', unambiguous_dna_alphabet, 1. ) ) ) | |
116 | |
117 self.assertTrue( all( array( (0.3,0.2,0.2,0.3), float64) | |
118 == parse_prior( ' 0.40 ', unambiguous_dna_alphabet, 1. ) ) ) | |
119 | |
120 def test_auto(self) : | |
121 self.assertTrue( all(4.*equiprobable_distribution(4) == | |
122 parse_prior( 'auto', unambiguous_dna_alphabet ) ) ) | |
123 self.assertTrue( all(4.*equiprobable_distribution(4) == | |
124 parse_prior( 'automatic', unambiguous_dna_alphabet ) ) ) | |
125 | |
126 def test_weight(self) : | |
127 self.assertTrue( all(4.*equiprobable_distribution(4) == | |
128 parse_prior( 'automatic', unambiguous_dna_alphabet ) ) ) | |
129 self.assertTrue( all(123.123*equiprobable_distribution(4) == | |
130 parse_prior( 'auto', unambiguous_dna_alphabet , 123.123) ) ) | |
131 | |
132 def test_explicit(self) : | |
133 s = "{'A':10, 'C':40, 'G':40, 'T':10}" | |
134 p = array( (10, 40, 40,10), float64)*4./100. | |
135 self.assertTrue( all( | |
136 p == parse_prior( s, unambiguous_dna_alphabet ) ) ) | |
137 | |
138 | |
139 class test_logooptions(unittest.TestCase) : | |
140 def test_create(self) : | |
141 opt = LogoOptions() | |
142 opt.small_fontsize =10 | |
143 options = repr(opt) | |
144 | |
145 opt = LogoOptions(title="sometitle") | |
146 assert opt.title == "sometitle" | |
147 | |
148 class test_logosize(unittest.TestCase) : | |
149 def test_create(self) : | |
150 s = LogoSize(101.0,10.0) | |
151 assert s.stack_width == 101.0 | |
152 r = repr(s) | |
153 | |
154 | |
155 class test_seqlogo(unittest.TestCase) : | |
156 # FIXME: The version of python used by Popen may not be the | |
157 # same as that used to run this test. | |
158 def _exec(self, args, outputtext, returncode =0, stdin=None) : | |
159 if not stdin : | |
160 stdin = testdata_stream("cap.fa") | |
161 args = ["./weblogo"] + args | |
162 p = Popen(args,stdin=stdin,stdout=PIPE, stderr=PIPE) | |
163 (out, err) = p.communicate() | |
164 if returncode ==0 and p.returncode >0 : | |
165 print err | |
166 self.assertEquals(returncode, p.returncode) | |
167 if returncode == 0 : self.assertEquals( len(err), 0) | |
168 | |
169 for item in outputtext : | |
170 self.failUnless(item in out) | |
171 | |
172 | |
173 | |
174 def test_malformed_options(self) : | |
175 self._exec( ["--notarealoption"], [], 2) | |
176 self._exec( ["extrajunk"], [], 2) | |
177 self._exec( ["-I"], [], 2) | |
178 | |
179 def test_help_option(self) : | |
180 self._exec( ["-h"], ["options"]) | |
181 self._exec( ["--help"], ["options"]) | |
182 | |
183 def test_version_option(self) : | |
184 self._exec( ['--version'], weblogolib.__version__) | |
185 | |
186 | |
187 def test_default_build(self) : | |
188 self._exec( [], ["%%Title: Sequence Logo:"] ) | |
189 | |
190 | |
191 # Format options | |
192 def test_width(self) : | |
193 self._exec( ['-W','1234'], ["/stack_width 1234"] ) | |
194 self._exec( ['--stack-width','1234'], ["/stack_width 1234"] ) | |
195 | |
196 def test_height(self) : | |
197 self._exec( ['-H','1234'], ["/stack_height 1234"] ) | |
198 self._exec( ['--stack-height','1234'], ["/stack_height 1234"] ) | |
199 | |
200 | |
201 def test_stacks_per_line(self) : | |
202 self._exec( ['-n','7'], ["/stacks_per_line 7 def"] ) | |
203 self._exec( ['--stacks-per-line','7'], ["/stacks_per_line 7 def"] ) | |
204 | |
205 | |
206 def test_title(self) : | |
207 self._exec( ['-t', '3456'], ['/logo_title (3456) def', | |
208 '/show_title True def']) | |
209 self._exec( ['-t', ''], ['/logo_title () def', | |
210 '/show_title False def']) | |
211 self._exec( ['--title', '3456'], ['/logo_title (3456) def', | |
212 '/show_title True def']) | |
213 | |
214 | |
215 | |
216 | |
217 | |
218 class test_which(unittest.TestCase) : | |
219 def test_which(self): | |
220 tests = ( | |
221 (seq_io.read(testdata_stream('cap.fa')), codon_alphabetT), | |
222 (seq_io.read(testdata_stream('capu.fa')), codon_alphabetU), | |
223 | |
224 #(seq_io.read(testdata_stream('cox2.msf')), unambiguous_protein_alphabet), | |
225 #(seq_io.read(testdata_stream('Rv3829c.fasta')), unambiguous_protein_alphabet), | |
226 ) | |
227 for t in tests : | |
228 self.failUnlessEqual(which_alphabet(t[0]), t[1]) | |
229 | |
230 | |
231 | |
232 | |
233 class test_colorscheme(unittest.TestCase) : | |
234 | |
235 def test_colorgroup(self) : | |
236 cr = ColorGroup( "ABC", "black", "Because") | |
237 self.assertEquals( cr.description, "Because") | |
238 | |
239 def test_colorscheme(self) : | |
240 cs = ColorScheme([ | |
241 ColorGroup("G", "orange"), | |
242 ColorGroup("TU", "red"), | |
243 ColorGroup("C", "blue"), | |
244 ColorGroup("A", "green") | |
245 ], | |
246 title = "title", | |
247 description = "description", | |
248 ) | |
249 | |
250 self.assertEquals( cs.color('A'), Color.by_name("green")) | |
251 self.assertEquals( cs.color('X'), cs.default_color) | |
252 | |
253 | |
254 | |
255 class test_color(unittest.TestCase) : | |
256 # 2.3 Python compatibility | |
257 assertTrue = unittest.TestCase.failUnless | |
258 assertFalse = unittest.TestCase.failIf | |
259 | |
260 def test_color_names(self) : | |
261 names = Color.names() | |
262 self.failUnlessEqual( len(names), 147) | |
263 | |
264 for n in names: | |
265 c = Color.by_name(n) | |
266 self.assertTrue( c != None ) | |
267 | |
268 | |
269 def test_color_components(self) : | |
270 white = Color.by_name("white") | |
271 self.failUnlessEqual( 1.0, white.red) | |
272 self.failUnlessEqual( 1.0, white.green) | |
273 self.failUnlessEqual( 1.0, white.blue) | |
274 | |
275 | |
276 c = Color(0.3, 0.4, 0.2) | |
277 self.failUnlessEqual( 0.3, c.red) | |
278 self.failUnlessEqual( 0.4, c.green) | |
279 self.failUnlessEqual( 0.2, c.blue) | |
280 | |
281 c = Color(0,128,0) | |
282 self.failUnlessEqual( 0.0, c.red) | |
283 self.failUnlessEqual( 128./255., c.green) | |
284 self.failUnlessEqual( 0.0, c.blue) | |
285 | |
286 | |
287 def test_color_from_rgb(self) : | |
288 white = Color.by_name("white") | |
289 | |
290 self.failUnlessEqual(white, Color(1.,1.,1.) ) | |
291 self.failUnlessEqual(white, Color(255,255,255) ) | |
292 self.failUnlessEqual(white, Color.from_rgb(1.,1.,1.) ) | |
293 self.failUnlessEqual(white, Color.from_rgb(255,255,255) ) | |
294 | |
295 | |
296 def test_color_from_hsl(self) : | |
297 red = Color.by_name("red") | |
298 lime = Color.by_name("lime") | |
299 saddlebrown = Color.by_name("saddlebrown") | |
300 darkgreen = Color.by_name("darkgreen") | |
301 blue = Color.by_name("blue") | |
302 green = Color.by_name("green") | |
303 | |
304 self.failUnlessEqual(red, Color.from_hsl(0, 1.0,0.5) ) | |
305 self.failUnlessEqual(lime, Color.from_hsl(120, 1.0, 0.5) ) | |
306 self.failUnlessEqual(blue, Color.from_hsl(240, 1.0, 0.5) ) | |
307 self.failUnlessEqual(Color.by_name("gray"), Color.from_hsl(0,0,0.5) ) | |
308 | |
309 self.failUnlessEqual(saddlebrown, Color.from_hsl(25, 0.76, 0.31) ) | |
310 | |
311 self.failUnlessEqual(darkgreen, Color.from_hsl(120, 1.0, 0.197) ) | |
312 | |
313 | |
314 def test_color_by_name(self): | |
315 white = Color.by_name("white") | |
316 self.failUnlessEqual(white, Color.by_name("white")) | |
317 self.failUnlessEqual(white, Color.by_name("WHITE")) | |
318 self.failUnlessEqual(white, Color.by_name(" wHiTe \t\n\t")) | |
319 | |
320 | |
321 self.failUnlessEqual(Color(255,255,240), Color.by_name("ivory")) | |
322 self.failUnlessEqual(Color(70,130,180), Color.by_name("steelblue")) | |
323 | |
324 self.failUnlessEqual(Color(0,128,0), Color.by_name("green")) | |
325 | |
326 | |
327 def test_color_from_invalid_name(self): | |
328 self.failUnlessRaises( ValueError, Color.by_name, "not_a_color") | |
329 | |
330 | |
331 def test_color_clipping(self): | |
332 red = Color.by_name("red") | |
333 self.failUnlessEqual(red, Color(255,0,0) ) | |
334 self.failUnlessEqual(red, Color(260,-10,0) ) | |
335 self.failUnlessEqual(red, Color(1.1,-0.,-1.) ) | |
336 | |
337 self.failUnlessEqual( Color(1.0001, 213.0, 1.2).red, 1.0 ) | |
338 self.failUnlessEqual( Color(-0.001, -2183.0, -1.0).red, 0.0 ) | |
339 self.failUnlessEqual( Color(1.0001, 213.0, 1.2).green, 1.0 ) | |
340 self.failUnlessEqual( Color(-0.001, -2183.0, -1.0).green, 0.0 ) | |
341 self.failUnlessEqual( Color(1.0001, 213.0, 1.2).blue, 1.0 ) | |
342 self.failUnlessEqual( Color(-0.001, -2183.0, -1.0).blue, 0.0 ) | |
343 | |
344 | |
345 def test_color_fail_on_mixed_type(self): | |
346 self.failUnlessRaises( TypeError, Color.from_rgb, 1,1,1.0 ) | |
347 self.failUnlessRaises( TypeError, Color.from_rgb, 1.0,1,1.0 ) | |
348 | |
349 def test_color_red(self) : | |
350 # Check Usage comment in Color | |
351 red = Color.by_name("red") | |
352 self.failUnlessEqual( red , Color(255,0,0) ) | |
353 self.failUnlessEqual( red, Color(1., 0., 0.) ) | |
354 | |
355 self.failUnlessEqual( red , Color.from_rgb(1.,0.,0.) ) | |
356 self.failUnlessEqual( red , Color.from_rgb(255,0,0) ) | |
357 self.failUnlessEqual( red , Color.from_hsl(0.,1., 0.5) ) | |
358 | |
359 self.failUnlessEqual( red , Color.from_string("red") ) | |
360 self.failUnlessEqual( red , Color.from_string("RED") ) | |
361 self.failUnlessEqual( red , Color.from_string("#F00") ) | |
362 self.failUnlessEqual( red , Color.from_string("#FF0000") ) | |
363 self.failUnlessEqual( red , Color.from_string("rgb(255, 0, 0)") ) | |
364 self.failUnlessEqual( red , Color.from_string("rgb(100%, 0%, 0%)") ) | |
365 self.failUnlessEqual( red , Color.from_string("hsl(0, 100%, 50%)") ) | |
366 | |
367 | |
368 def test_color_from_string(self) : | |
369 purple = Color(128,0,128) | |
370 red = Color(255,0,0) | |
371 skyblue = Color(135,206,235) | |
372 | |
373 red_strings = ("red", | |
374 "ReD", | |
375 "RED", | |
376 " Red \t", | |
377 "#F00", | |
378 "#FF0000", | |
379 "rgb(255, 0, 0)", | |
380 "rgb(100%, 0%, 0%)", | |
381 "hsl(0, 100%, 50%)") | |
382 | |
383 for s in red_strings: | |
384 self.failUnlessEqual( red, Color.from_string(s) ) | |
385 | |
386 skyblue_strings = ("skyblue", | |
387 "SKYBLUE", | |
388 " \t\n SkyBlue \t", | |
389 "#87ceeb", | |
390 "rgb(135,206,235)" | |
391 ) | |
392 | |
393 for s in skyblue_strings: | |
394 self.failUnlessEqual( skyblue, Color.from_string(s) ) | |
395 | |
396 | |
397 | |
398 def test_color_equality(self): | |
399 c1 = Color(123,99,12) | |
400 c2 = Color(123,99,12) | |
401 | |
402 self.failUnlessEqual(c1,c2) | |
403 | |
404 | |
405 | |
406 | |
407 | |
408 | |
409 class test_Dirichlet(unittest.TestCase) : | |
410 # 2.3 Python compatibility | |
411 assertTrue = unittest.TestCase.failUnless | |
412 assertFalse = unittest.TestCase.failIf | |
413 | |
414 | |
415 def test_init(self) : | |
416 d = Dirichlet( ( 1,1,1,1,) ) | |
417 | |
418 | |
419 def test_random(self) : | |
420 | |
421 | |
422 def do_test( alpha, samples = 1000) : | |
423 ent = zeros( (samples,), float64) | |
424 #alpha = ones( ( K,), Float64 ) * A/K | |
425 | |
426 #pt = zeros( (len(alpha) ,), Float64) | |
427 d = Dirichlet(alpha) | |
428 for s in range(samples) : | |
429 p = d.sample() | |
430 #print p | |
431 #pt +=p | |
432 ent[s] = entropy(p) | |
433 | |
434 #print pt/samples | |
435 | |
436 m = mean(ent) | |
437 v = var(ent) | |
438 | |
439 dm = d.mean_entropy() | |
440 dv = d.variance_entropy() | |
441 | |
442 #print alpha, ':', m, v, dm, dv | |
443 error = 4. * sqrt(v/samples) | |
444 self.assertTrue( abs(m-dm) < error) | |
445 self.assertTrue( abs(v-dv) < error) # dodgy error estimate | |
446 | |
447 | |
448 do_test( (1., 1.) ) | |
449 do_test( (2., 1.) ) | |
450 do_test( (3., 1.) ) | |
451 do_test( (4., 1.) ) | |
452 do_test( (5., 1.) ) | |
453 do_test( (6., 1.) ) | |
454 | |
455 do_test( (1., 1.) ) | |
456 do_test( (20., 20.) ) | |
457 do_test( (1., 1., 1., 1., 1., 1., 1., 1., 1., 1.) ) | |
458 do_test( (.1, .1, .1, .1, .1, .1, .1, .1, .1, .1) ) | |
459 do_test( (.01, .01, .01, .01, .01, .01, .01, .01, .01, .01) ) | |
460 do_test( (2.0, 6.0, 1.0, 1.0) ) | |
461 | |
462 | |
463 def test_mean(self) : | |
464 alpha = ones( ( 10,), float64 ) * 23. | |
465 d = Dirichlet(alpha) | |
466 m = d.mean() | |
467 self.assertAlmostEqual( m[2], 1./10) | |
468 self.assertAlmostEqual( sum(m), 1.0) | |
469 | |
470 def test_covariance(self) : | |
471 alpha = ones( ( 4,), float64 ) | |
472 d = Dirichlet(alpha) | |
473 cv = d.covariance() | |
474 self.assertEqual( cv.shape, (4,4) ) | |
475 self.assertAlmostEqual( cv[0,0], 1.0 * (1.0 - 1./4.0)/ (4.0 * 5.0) ) | |
476 self.assertAlmostEqual( cv[0,1], - 1 / ( 4. * 4. * 5.) ) | |
477 | |
478 def test_mean_x(self) : | |
479 alpha = (1.0, 2.0, 3.0, 4.0) | |
480 xx = (2.0, 2.0, 2.0, 2.0) | |
481 m = Dirichlet(alpha).mean_x(xx) | |
482 self.assertEquals( m, 2.0) | |
483 | |
484 alpha = (1.0, 1.0, 1.0, 1.0) | |
485 xx = (2.0, 3.0, 4.0, 3.0) | |
486 m = Dirichlet(alpha).mean_x(xx) | |
487 self.assertEquals( m, 3.0) | |
488 | |
489 def test_variance_x(self) : | |
490 alpha = (1.0, 1.0, 1.0, 1.0) | |
491 xx = (2.0, 2.0, 2.0, 2.0) | |
492 v = Dirichlet(alpha).variance_x(xx) | |
493 self.assertAlmostEquals( v, 0.0) | |
494 | |
495 alpha = (1.0, 2.0, 3.0, 4.0) | |
496 xx = (2.0, 0.0, 1.0, 10.0) | |
497 v = Dirichlet(alpha).variance_x(xx) | |
498 #print v | |
499 # TODO: Don't actually know if this is correct | |
500 | |
501 def test_relative_entropy(self): | |
502 alpha = (2.0, 10.0, 1.0, 1.0) | |
503 d = Dirichlet(alpha) | |
504 pvec = (0.1, 0.2, 0.3, 0.4) | |
505 | |
506 rent = d.mean_relative_entropy(pvec) | |
507 vrent = d.variance_relative_entropy(pvec) | |
508 low, high = d.interval_relative_entropy(pvec, 0.95) | |
509 | |
510 #print | |
511 #print '> ', rent, vrent, low, high | |
512 | |
513 # This test can fail randomly, but the precision form a few | |
514 # thousand samples is low. Increasing samples, 1000->2000 | |
515 samples = 2000 | |
516 sent = zeros( (samples,), float64) | |
517 | |
518 for s in range(samples) : | |
519 post = d.sample() | |
520 e = -entropy(post) | |
521 for k in range(4) : | |
522 e += - post[k] * log(pvec[k]) | |
523 sent[s] = e | |
524 sent.sort() | |
525 self.assertTrue( abs(sent.mean() - rent) < 4.*sqrt(vrent) ) | |
526 self.assertAlmostEqual( sent.std(), sqrt(vrent), 1 ) | |
527 self.assertTrue( abs(low-sent[ int( samples *0.025)])<0.2 ) | |
528 self.assertTrue( abs(high-sent[ int( samples *0.975)])<0.2 ) | |
529 | |
530 #print '>>', mean(sent), var(sent), sent[ int( samples *0.025)] ,sent[ int( samples *0.975)] | |
531 | |
532 | |
533 | |
534 def mean( a) : | |
535 return sum(a)/ len(a) | |
536 | |
537 def var(a) : | |
538 return (sum(a*a) /len(a) ) - mean(a)**2 | |
539 | |
540 | |
541 | |
542 | |
543 if __name__ == '__main__': | |
544 unittest.main() |