0
|
1 #!/usr/bin/env python
|
|
2
|
|
3 # Copyright (c) 2006, The Regents of the University of California, through
|
|
4 # Lawrence Berkeley National Laboratory (subject to receipt of any required
|
|
5 # approvals from the U.S. Dept. of Energy). All rights reserved.
|
|
6
|
|
7 # This software is distributed under the new BSD Open Source License.
|
|
8 # <http://www.opensource.org/licenses/bsd-license.html>
|
|
9 #
|
|
10 # Redistribution and use in source and binary forms, with or without
|
|
11 # modification, are permitted provided that the following conditions are met:
|
|
12 #
|
|
13 # (1) Redistributions of source code must retain the above copyright notice,
|
|
14 # this list of conditions and the following disclaimer.
|
|
15 #
|
|
16 # (2) Redistributions in binary form must reproduce the above copyright
|
|
17 # notice, this list of conditions and the following disclaimer in the
|
|
18 # documentation and or other materials provided with the distribution.
|
|
19 #
|
|
20 # (3) Neither the name of the University of California, Lawrence Berkeley
|
|
21 # National Laboratory, U.S. Dept. of Energy nor the names of its contributors
|
|
22 # may be used to endorse or promote products derived from this software
|
|
23 # without specific prior written permission.
|
|
24 #
|
|
25 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
26 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
27 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
28 # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
|
29 # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
|
30 # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
|
31 # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
|
32 # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
|
33 # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
|
34 # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
|
35 # POSSIBILITY OF SUCH DAMAGE.
|
|
36
|
|
37
|
|
38 import unittest
|
|
39
|
|
40 import weblogolib
|
|
41 from weblogolib import *
|
|
42 from weblogolib import parse_prior, GhostscriptAPI
|
|
43 from weblogolib.color import *
|
|
44 from weblogolib.colorscheme import *
|
|
45 from StringIO import StringIO
|
|
46 import sys
|
|
47
|
|
48 from numpy import array, asarray, float64, ones, zeros, int32,all,any, shape
|
|
49 import numpy as na
|
|
50
|
|
51 from corebio import seq_io
|
|
52 from corebio.seq import *
|
|
53
|
|
54 # python2.3 compatability
|
|
55 from corebio._future.subprocess import *
|
|
56 from corebio._future import resource_stream
|
|
57
|
|
58 from corebio.moremath import entropy
|
|
59 from math import log, sqrt
|
|
60 codon_alphabetU=['AAA', 'AAC', 'AAG', 'AAU', 'ACA', 'ACC', 'ACG', 'ACU', 'AGA', 'AGC', 'AGG', 'AGU', 'AUA', 'AUC', 'AUG', 'AUU', 'CAA', 'CAC', 'CAG', 'CAU', 'CCA', 'CCC', 'CCG', 'CCU', 'CGA', 'CGC', 'CGG', 'CGU', 'CUA', 'CUC', 'CUG', 'CUU', 'GAA', 'GAC', 'GAG', 'GAU', 'GCA', 'GCC', 'GCG', 'GCU', 'GGA', 'GGC', 'GGG', 'GGU', 'GUA', 'GUC', 'GUG', 'GUU', 'UAA', 'UAC', 'UAG', 'UAU', 'UCA', 'UCC', 'UCG', 'UCU', 'UGA', 'UGC', 'UGG', 'UGU', 'UUA', 'UUC', 'UUG', 'UUU']
|
|
61 codon_alphabetT=['AAA', 'AAC', 'AAG', 'AAT', 'ACA', 'ACC', 'ACG', 'ACT', 'AGA', 'AGC', 'AGG', 'AGT', 'ATA', 'ATC', 'ATG', 'ATT', 'CAA', 'CAC', 'CAG', 'CAT', 'CCA', 'CCC', 'CCG', 'CCT', 'CGA', 'CGC', 'CGG', 'CGT', 'CTA', 'CTC', 'CTG', 'CTT', 'GAA', 'GAC', 'GAG', 'GAT', 'GCA', 'GCC', 'GCG', 'GCT', 'GGA', 'GGC', 'GGG', 'GGT', 'GTA', 'GTC', 'GTG', 'GTT', 'TAA', 'TAC', 'TAG', 'TAT', 'TCA', 'TCC', 'TCG', 'TCT', 'TGA', 'TGC', 'TGG', 'TGT', 'TTA', 'TTC', 'TTG', 'TTT']
|
|
62
|
|
63
|
|
64 def testdata_stream( name ):
|
|
65 return resource_stream(__name__, 'tests/data/'+name, __file__)
|
|
66
|
|
67 class test_logoformat(unittest.TestCase) :
|
|
68
|
|
69 def test_options(self) :
|
|
70 options = LogoOptions()
|
|
71
|
|
72
|
|
73 class test_ghostscript(unittest.TestCase) :
|
|
74 def test_version(self) :
|
|
75 version = GhostscriptAPI().version
|
|
76
|
|
77
|
|
78
|
|
79 class test_parse_prior(unittest.TestCase) :
|
|
80 def assertTrue(self, bool) :
|
|
81 self.assertEquals( bool, True)
|
|
82
|
|
83 def test_parse_prior_none(self) :
|
|
84 self.assertEquals( None,
|
|
85 parse_prior(None, unambiguous_protein_alphabet ) )
|
|
86 self.assertEquals( None,
|
|
87 parse_prior( 'none', unambiguous_protein_alphabet ) )
|
|
88 self.assertEquals( None,
|
|
89 parse_prior( 'noNe', None) )
|
|
90
|
|
91 def test_parse_prior_equiprobable(self) :
|
|
92 self.assertTrue( all(20.*equiprobable_distribution(20) ==
|
|
93 parse_prior( 'equiprobable', unambiguous_protein_alphabet ) ) )
|
|
94
|
|
95 self.assertTrue(
|
|
96 all( 1.2* equiprobable_distribution(3)
|
|
97 == parse_prior( ' equiprobablE ', Alphabet('123'), 1.2 ) ) )
|
|
98
|
|
99 def test_parse_prior_percentage(self) :
|
|
100 #print parse_prior( '50%', unambiguous_dna_alphabet, 1. )
|
|
101 self.assertTrue( all( equiprobable_distribution(4)
|
|
102 == parse_prior( '50%', unambiguous_dna_alphabet, 1. ) ) )
|
|
103
|
|
104 self.assertTrue( all( equiprobable_distribution(4)
|
|
105 == parse_prior( ' 50.0 % ', unambiguous_dna_alphabet, 1. ) ) )
|
|
106
|
|
107 self.assertTrue( all( array( (0.3,0.2,0.2,0.3), float64)
|
|
108 == parse_prior( ' 40.0 % ', unambiguous_dna_alphabet, 1. ) ) )
|
|
109
|
|
110 def test_parse_prior_float(self) :
|
|
111 self.assertTrue( all( equiprobable_distribution(4)
|
|
112 == parse_prior( '0.5', unambiguous_dna_alphabet, 1. ) ) )
|
|
113
|
|
114 self.assertTrue( all( equiprobable_distribution(4)
|
|
115 == parse_prior( ' 0.500 ', unambiguous_dna_alphabet, 1. ) ) )
|
|
116
|
|
117 self.assertTrue( all( array( (0.3,0.2,0.2,0.3), float64)
|
|
118 == parse_prior( ' 0.40 ', unambiguous_dna_alphabet, 1. ) ) )
|
|
119
|
|
120 def test_auto(self) :
|
|
121 self.assertTrue( all(4.*equiprobable_distribution(4) ==
|
|
122 parse_prior( 'auto', unambiguous_dna_alphabet ) ) )
|
|
123 self.assertTrue( all(4.*equiprobable_distribution(4) ==
|
|
124 parse_prior( 'automatic', unambiguous_dna_alphabet ) ) )
|
|
125
|
|
126 def test_weight(self) :
|
|
127 self.assertTrue( all(4.*equiprobable_distribution(4) ==
|
|
128 parse_prior( 'automatic', unambiguous_dna_alphabet ) ) )
|
|
129 self.assertTrue( all(123.123*equiprobable_distribution(4) ==
|
|
130 parse_prior( 'auto', unambiguous_dna_alphabet , 123.123) ) )
|
|
131
|
|
132 def test_explicit(self) :
|
|
133 s = "{'A':10, 'C':40, 'G':40, 'T':10}"
|
|
134 p = array( (10, 40, 40,10), float64)*4./100.
|
|
135 self.assertTrue( all(
|
|
136 p == parse_prior( s, unambiguous_dna_alphabet ) ) )
|
|
137
|
|
138
|
|
139 class test_logooptions(unittest.TestCase) :
|
|
140 def test_create(self) :
|
|
141 opt = LogoOptions()
|
|
142 opt.small_fontsize =10
|
|
143 options = repr(opt)
|
|
144
|
|
145 opt = LogoOptions(title="sometitle")
|
|
146 assert opt.title == "sometitle"
|
|
147
|
|
148 class test_logosize(unittest.TestCase) :
|
|
149 def test_create(self) :
|
|
150 s = LogoSize(101.0,10.0)
|
|
151 assert s.stack_width == 101.0
|
|
152 r = repr(s)
|
|
153
|
|
154
|
|
155 class test_seqlogo(unittest.TestCase) :
|
|
156 # FIXME: The version of python used by Popen may not be the
|
|
157 # same as that used to run this test.
|
|
158 def _exec(self, args, outputtext, returncode =0, stdin=None) :
|
|
159 if not stdin :
|
|
160 stdin = testdata_stream("cap.fa")
|
|
161 args = ["./weblogo"] + args
|
|
162 p = Popen(args,stdin=stdin,stdout=PIPE, stderr=PIPE)
|
|
163 (out, err) = p.communicate()
|
|
164 if returncode ==0 and p.returncode >0 :
|
|
165 print err
|
|
166 self.assertEquals(returncode, p.returncode)
|
|
167 if returncode == 0 : self.assertEquals( len(err), 0)
|
|
168
|
|
169 for item in outputtext :
|
|
170 self.failUnless(item in out)
|
|
171
|
|
172
|
|
173
|
|
174 def test_malformed_options(self) :
|
|
175 self._exec( ["--notarealoption"], [], 2)
|
|
176 self._exec( ["extrajunk"], [], 2)
|
|
177 self._exec( ["-I"], [], 2)
|
|
178
|
|
179 def test_help_option(self) :
|
|
180 self._exec( ["-h"], ["options"])
|
|
181 self._exec( ["--help"], ["options"])
|
|
182
|
|
183 def test_version_option(self) :
|
|
184 self._exec( ['--version'], weblogolib.__version__)
|
|
185
|
|
186
|
|
187 def test_default_build(self) :
|
|
188 self._exec( [], ["%%Title: Sequence Logo:"] )
|
|
189
|
|
190
|
|
191 # Format options
|
|
192 def test_width(self) :
|
|
193 self._exec( ['-W','1234'], ["/stack_width 1234"] )
|
|
194 self._exec( ['--stack-width','1234'], ["/stack_width 1234"] )
|
|
195
|
|
196 def test_height(self) :
|
|
197 self._exec( ['-H','1234'], ["/stack_height 1234"] )
|
|
198 self._exec( ['--stack-height','1234'], ["/stack_height 1234"] )
|
|
199
|
|
200
|
|
201 def test_stacks_per_line(self) :
|
|
202 self._exec( ['-n','7'], ["/stacks_per_line 7 def"] )
|
|
203 self._exec( ['--stacks-per-line','7'], ["/stacks_per_line 7 def"] )
|
|
204
|
|
205
|
|
206 def test_title(self) :
|
|
207 self._exec( ['-t', '3456'], ['/logo_title (3456) def',
|
|
208 '/show_title True def'])
|
|
209 self._exec( ['-t', ''], ['/logo_title () def',
|
|
210 '/show_title False def'])
|
|
211 self._exec( ['--title', '3456'], ['/logo_title (3456) def',
|
|
212 '/show_title True def'])
|
|
213
|
|
214
|
|
215
|
|
216
|
|
217
|
|
218 class test_which(unittest.TestCase) :
|
|
219 def test_which(self):
|
|
220 tests = (
|
|
221 (seq_io.read(testdata_stream('cap.fa')), codon_alphabetT),
|
|
222 (seq_io.read(testdata_stream('capu.fa')), codon_alphabetU),
|
|
223
|
|
224 #(seq_io.read(testdata_stream('cox2.msf')), unambiguous_protein_alphabet),
|
|
225 #(seq_io.read(testdata_stream('Rv3829c.fasta')), unambiguous_protein_alphabet),
|
|
226 )
|
|
227 for t in tests :
|
|
228 self.failUnlessEqual(which_alphabet(t[0]), t[1])
|
|
229
|
|
230
|
|
231
|
|
232
|
|
233 class test_colorscheme(unittest.TestCase) :
|
|
234
|
|
235 def test_colorgroup(self) :
|
|
236 cr = ColorGroup( "ABC", "black", "Because")
|
|
237 self.assertEquals( cr.description, "Because")
|
|
238
|
|
239 def test_colorscheme(self) :
|
|
240 cs = ColorScheme([
|
|
241 ColorGroup("G", "orange"),
|
|
242 ColorGroup("TU", "red"),
|
|
243 ColorGroup("C", "blue"),
|
|
244 ColorGroup("A", "green")
|
|
245 ],
|
|
246 title = "title",
|
|
247 description = "description",
|
|
248 )
|
|
249
|
|
250 self.assertEquals( cs.color('A'), Color.by_name("green"))
|
|
251 self.assertEquals( cs.color('X'), cs.default_color)
|
|
252
|
|
253
|
|
254
|
|
255 class test_color(unittest.TestCase) :
|
|
256 # 2.3 Python compatibility
|
|
257 assertTrue = unittest.TestCase.failUnless
|
|
258 assertFalse = unittest.TestCase.failIf
|
|
259
|
|
260 def test_color_names(self) :
|
|
261 names = Color.names()
|
|
262 self.failUnlessEqual( len(names), 147)
|
|
263
|
|
264 for n in names:
|
|
265 c = Color.by_name(n)
|
|
266 self.assertTrue( c != None )
|
|
267
|
|
268
|
|
269 def test_color_components(self) :
|
|
270 white = Color.by_name("white")
|
|
271 self.failUnlessEqual( 1.0, white.red)
|
|
272 self.failUnlessEqual( 1.0, white.green)
|
|
273 self.failUnlessEqual( 1.0, white.blue)
|
|
274
|
|
275
|
|
276 c = Color(0.3, 0.4, 0.2)
|
|
277 self.failUnlessEqual( 0.3, c.red)
|
|
278 self.failUnlessEqual( 0.4, c.green)
|
|
279 self.failUnlessEqual( 0.2, c.blue)
|
|
280
|
|
281 c = Color(0,128,0)
|
|
282 self.failUnlessEqual( 0.0, c.red)
|
|
283 self.failUnlessEqual( 128./255., c.green)
|
|
284 self.failUnlessEqual( 0.0, c.blue)
|
|
285
|
|
286
|
|
287 def test_color_from_rgb(self) :
|
|
288 white = Color.by_name("white")
|
|
289
|
|
290 self.failUnlessEqual(white, Color(1.,1.,1.) )
|
|
291 self.failUnlessEqual(white, Color(255,255,255) )
|
|
292 self.failUnlessEqual(white, Color.from_rgb(1.,1.,1.) )
|
|
293 self.failUnlessEqual(white, Color.from_rgb(255,255,255) )
|
|
294
|
|
295
|
|
296 def test_color_from_hsl(self) :
|
|
297 red = Color.by_name("red")
|
|
298 lime = Color.by_name("lime")
|
|
299 saddlebrown = Color.by_name("saddlebrown")
|
|
300 darkgreen = Color.by_name("darkgreen")
|
|
301 blue = Color.by_name("blue")
|
|
302 green = Color.by_name("green")
|
|
303
|
|
304 self.failUnlessEqual(red, Color.from_hsl(0, 1.0,0.5) )
|
|
305 self.failUnlessEqual(lime, Color.from_hsl(120, 1.0, 0.5) )
|
|
306 self.failUnlessEqual(blue, Color.from_hsl(240, 1.0, 0.5) )
|
|
307 self.failUnlessEqual(Color.by_name("gray"), Color.from_hsl(0,0,0.5) )
|
|
308
|
|
309 self.failUnlessEqual(saddlebrown, Color.from_hsl(25, 0.76, 0.31) )
|
|
310
|
|
311 self.failUnlessEqual(darkgreen, Color.from_hsl(120, 1.0, 0.197) )
|
|
312
|
|
313
|
|
314 def test_color_by_name(self):
|
|
315 white = Color.by_name("white")
|
|
316 self.failUnlessEqual(white, Color.by_name("white"))
|
|
317 self.failUnlessEqual(white, Color.by_name("WHITE"))
|
|
318 self.failUnlessEqual(white, Color.by_name(" wHiTe \t\n\t"))
|
|
319
|
|
320
|
|
321 self.failUnlessEqual(Color(255,255,240), Color.by_name("ivory"))
|
|
322 self.failUnlessEqual(Color(70,130,180), Color.by_name("steelblue"))
|
|
323
|
|
324 self.failUnlessEqual(Color(0,128,0), Color.by_name("green"))
|
|
325
|
|
326
|
|
327 def test_color_from_invalid_name(self):
|
|
328 self.failUnlessRaises( ValueError, Color.by_name, "not_a_color")
|
|
329
|
|
330
|
|
331 def test_color_clipping(self):
|
|
332 red = Color.by_name("red")
|
|
333 self.failUnlessEqual(red, Color(255,0,0) )
|
|
334 self.failUnlessEqual(red, Color(260,-10,0) )
|
|
335 self.failUnlessEqual(red, Color(1.1,-0.,-1.) )
|
|
336
|
|
337 self.failUnlessEqual( Color(1.0001, 213.0, 1.2).red, 1.0 )
|
|
338 self.failUnlessEqual( Color(-0.001, -2183.0, -1.0).red, 0.0 )
|
|
339 self.failUnlessEqual( Color(1.0001, 213.0, 1.2).green, 1.0 )
|
|
340 self.failUnlessEqual( Color(-0.001, -2183.0, -1.0).green, 0.0 )
|
|
341 self.failUnlessEqual( Color(1.0001, 213.0, 1.2).blue, 1.0 )
|
|
342 self.failUnlessEqual( Color(-0.001, -2183.0, -1.0).blue, 0.0 )
|
|
343
|
|
344
|
|
345 def test_color_fail_on_mixed_type(self):
|
|
346 self.failUnlessRaises( TypeError, Color.from_rgb, 1,1,1.0 )
|
|
347 self.failUnlessRaises( TypeError, Color.from_rgb, 1.0,1,1.0 )
|
|
348
|
|
349 def test_color_red(self) :
|
|
350 # Check Usage comment in Color
|
|
351 red = Color.by_name("red")
|
|
352 self.failUnlessEqual( red , Color(255,0,0) )
|
|
353 self.failUnlessEqual( red, Color(1., 0., 0.) )
|
|
354
|
|
355 self.failUnlessEqual( red , Color.from_rgb(1.,0.,0.) )
|
|
356 self.failUnlessEqual( red , Color.from_rgb(255,0,0) )
|
|
357 self.failUnlessEqual( red , Color.from_hsl(0.,1., 0.5) )
|
|
358
|
|
359 self.failUnlessEqual( red , Color.from_string("red") )
|
|
360 self.failUnlessEqual( red , Color.from_string("RED") )
|
|
361 self.failUnlessEqual( red , Color.from_string("#F00") )
|
|
362 self.failUnlessEqual( red , Color.from_string("#FF0000") )
|
|
363 self.failUnlessEqual( red , Color.from_string("rgb(255, 0, 0)") )
|
|
364 self.failUnlessEqual( red , Color.from_string("rgb(100%, 0%, 0%)") )
|
|
365 self.failUnlessEqual( red , Color.from_string("hsl(0, 100%, 50%)") )
|
|
366
|
|
367
|
|
368 def test_color_from_string(self) :
|
|
369 purple = Color(128,0,128)
|
|
370 red = Color(255,0,0)
|
|
371 skyblue = Color(135,206,235)
|
|
372
|
|
373 red_strings = ("red",
|
|
374 "ReD",
|
|
375 "RED",
|
|
376 " Red \t",
|
|
377 "#F00",
|
|
378 "#FF0000",
|
|
379 "rgb(255, 0, 0)",
|
|
380 "rgb(100%, 0%, 0%)",
|
|
381 "hsl(0, 100%, 50%)")
|
|
382
|
|
383 for s in red_strings:
|
|
384 self.failUnlessEqual( red, Color.from_string(s) )
|
|
385
|
|
386 skyblue_strings = ("skyblue",
|
|
387 "SKYBLUE",
|
|
388 " \t\n SkyBlue \t",
|
|
389 "#87ceeb",
|
|
390 "rgb(135,206,235)"
|
|
391 )
|
|
392
|
|
393 for s in skyblue_strings:
|
|
394 self.failUnlessEqual( skyblue, Color.from_string(s) )
|
|
395
|
|
396
|
|
397
|
|
398 def test_color_equality(self):
|
|
399 c1 = Color(123,99,12)
|
|
400 c2 = Color(123,99,12)
|
|
401
|
|
402 self.failUnlessEqual(c1,c2)
|
|
403
|
|
404
|
|
405
|
|
406
|
|
407
|
|
408
|
|
409 class test_Dirichlet(unittest.TestCase) :
|
|
410 # 2.3 Python compatibility
|
|
411 assertTrue = unittest.TestCase.failUnless
|
|
412 assertFalse = unittest.TestCase.failIf
|
|
413
|
|
414
|
|
415 def test_init(self) :
|
|
416 d = Dirichlet( ( 1,1,1,1,) )
|
|
417
|
|
418
|
|
419 def test_random(self) :
|
|
420
|
|
421
|
|
422 def do_test( alpha, samples = 1000) :
|
|
423 ent = zeros( (samples,), float64)
|
|
424 #alpha = ones( ( K,), Float64 ) * A/K
|
|
425
|
|
426 #pt = zeros( (len(alpha) ,), Float64)
|
|
427 d = Dirichlet(alpha)
|
|
428 for s in range(samples) :
|
|
429 p = d.sample()
|
|
430 #print p
|
|
431 #pt +=p
|
|
432 ent[s] = entropy(p)
|
|
433
|
|
434 #print pt/samples
|
|
435
|
|
436 m = mean(ent)
|
|
437 v = var(ent)
|
|
438
|
|
439 dm = d.mean_entropy()
|
|
440 dv = d.variance_entropy()
|
|
441
|
|
442 #print alpha, ':', m, v, dm, dv
|
|
443 error = 4. * sqrt(v/samples)
|
|
444 self.assertTrue( abs(m-dm) < error)
|
|
445 self.assertTrue( abs(v-dv) < error) # dodgy error estimate
|
|
446
|
|
447
|
|
448 do_test( (1., 1.) )
|
|
449 do_test( (2., 1.) )
|
|
450 do_test( (3., 1.) )
|
|
451 do_test( (4., 1.) )
|
|
452 do_test( (5., 1.) )
|
|
453 do_test( (6., 1.) )
|
|
454
|
|
455 do_test( (1., 1.) )
|
|
456 do_test( (20., 20.) )
|
|
457 do_test( (1., 1., 1., 1., 1., 1., 1., 1., 1., 1.) )
|
|
458 do_test( (.1, .1, .1, .1, .1, .1, .1, .1, .1, .1) )
|
|
459 do_test( (.01, .01, .01, .01, .01, .01, .01, .01, .01, .01) )
|
|
460 do_test( (2.0, 6.0, 1.0, 1.0) )
|
|
461
|
|
462
|
|
463 def test_mean(self) :
|
|
464 alpha = ones( ( 10,), float64 ) * 23.
|
|
465 d = Dirichlet(alpha)
|
|
466 m = d.mean()
|
|
467 self.assertAlmostEqual( m[2], 1./10)
|
|
468 self.assertAlmostEqual( sum(m), 1.0)
|
|
469
|
|
470 def test_covariance(self) :
|
|
471 alpha = ones( ( 4,), float64 )
|
|
472 d = Dirichlet(alpha)
|
|
473 cv = d.covariance()
|
|
474 self.assertEqual( cv.shape, (4,4) )
|
|
475 self.assertAlmostEqual( cv[0,0], 1.0 * (1.0 - 1./4.0)/ (4.0 * 5.0) )
|
|
476 self.assertAlmostEqual( cv[0,1], - 1 / ( 4. * 4. * 5.) )
|
|
477
|
|
478 def test_mean_x(self) :
|
|
479 alpha = (1.0, 2.0, 3.0, 4.0)
|
|
480 xx = (2.0, 2.0, 2.0, 2.0)
|
|
481 m = Dirichlet(alpha).mean_x(xx)
|
|
482 self.assertEquals( m, 2.0)
|
|
483
|
|
484 alpha = (1.0, 1.0, 1.0, 1.0)
|
|
485 xx = (2.0, 3.0, 4.0, 3.0)
|
|
486 m = Dirichlet(alpha).mean_x(xx)
|
|
487 self.assertEquals( m, 3.0)
|
|
488
|
|
489 def test_variance_x(self) :
|
|
490 alpha = (1.0, 1.0, 1.0, 1.0)
|
|
491 xx = (2.0, 2.0, 2.0, 2.0)
|
|
492 v = Dirichlet(alpha).variance_x(xx)
|
|
493 self.assertAlmostEquals( v, 0.0)
|
|
494
|
|
495 alpha = (1.0, 2.0, 3.0, 4.0)
|
|
496 xx = (2.0, 0.0, 1.0, 10.0)
|
|
497 v = Dirichlet(alpha).variance_x(xx)
|
|
498 #print v
|
|
499 # TODO: Don't actually know if this is correct
|
|
500
|
|
501 def test_relative_entropy(self):
|
|
502 alpha = (2.0, 10.0, 1.0, 1.0)
|
|
503 d = Dirichlet(alpha)
|
|
504 pvec = (0.1, 0.2, 0.3, 0.4)
|
|
505
|
|
506 rent = d.mean_relative_entropy(pvec)
|
|
507 vrent = d.variance_relative_entropy(pvec)
|
|
508 low, high = d.interval_relative_entropy(pvec, 0.95)
|
|
509
|
|
510 #print
|
|
511 #print '> ', rent, vrent, low, high
|
|
512
|
|
513 # This test can fail randomly, but the precision form a few
|
|
514 # thousand samples is low. Increasing samples, 1000->2000
|
|
515 samples = 2000
|
|
516 sent = zeros( (samples,), float64)
|
|
517
|
|
518 for s in range(samples) :
|
|
519 post = d.sample()
|
|
520 e = -entropy(post)
|
|
521 for k in range(4) :
|
|
522 e += - post[k] * log(pvec[k])
|
|
523 sent[s] = e
|
|
524 sent.sort()
|
|
525 self.assertTrue( abs(sent.mean() - rent) < 4.*sqrt(vrent) )
|
|
526 self.assertAlmostEqual( sent.std(), sqrt(vrent), 1 )
|
|
527 self.assertTrue( abs(low-sent[ int( samples *0.025)])<0.2 )
|
|
528 self.assertTrue( abs(high-sent[ int( samples *0.975)])<0.2 )
|
|
529
|
|
530 #print '>>', mean(sent), var(sent), sent[ int( samples *0.025)] ,sent[ int( samples *0.975)]
|
|
531
|
|
532
|
|
533
|
|
534 def mean( a) :
|
|
535 return sum(a)/ len(a)
|
|
536
|
|
537 def var(a) :
|
|
538 return (sum(a*a) /len(a) ) - mean(a)**2
|
|
539
|
|
540
|
|
541
|
|
542
|
|
543 if __name__ == '__main__':
|
|
544 unittest.main()
|