comparison env/lib/python3.9/site-packages/networkx/utils/tests/test_decorators.py @ 0:4f3585e2f14b draft default tip

"planemo upload commit 60cee0fc7c0cda8592644e1aad72851dec82c959"
author shellac
date Mon, 22 Mar 2021 18:12:50 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:4f3585e2f14b
1 import tempfile
2 import os
3 import pathlib
4 import random
5
6 import pytest
7
8 import networkx as nx
9 from networkx.utils.decorators import open_file, not_implemented_for
10 from networkx.utils.decorators import (
11 preserve_random_state,
12 py_random_state,
13 np_random_state,
14 random_state,
15 )
16 from networkx.utils.misc import PythonRandomInterface
17
18
19 def test_not_implemented_decorator():
20 @not_implemented_for("directed")
21 def test1(G):
22 pass
23
24 test1(nx.Graph())
25
26
27 def test_not_implemented_decorator_key():
28 with pytest.raises(KeyError):
29
30 @not_implemented_for("foo")
31 def test1(G):
32 pass
33
34 test1(nx.Graph())
35
36
37 def test_not_implemented_decorator_raise():
38 with pytest.raises(nx.NetworkXNotImplemented):
39
40 @not_implemented_for("graph")
41 def test1(G):
42 pass
43
44 test1(nx.Graph())
45
46
47 class TestOpenFileDecorator:
48 def setup_method(self):
49 self.text = ["Blah... ", "BLAH ", "BLAH!!!!"]
50 self.fobj = tempfile.NamedTemporaryFile("wb+", delete=False)
51 self.name = self.fobj.name
52
53 def teardown_method(self):
54 self.fobj.close()
55 os.unlink(self.name)
56
57 def write(self, path):
58 for text in self.text:
59 path.write(text.encode("ascii"))
60
61 @open_file(1, "r")
62 def read(self, path):
63 return path.readlines()[0]
64
65 @staticmethod
66 @open_file(0, "wb")
67 def writer_arg0(path):
68 path.write(b"demo")
69
70 @open_file(1, "wb+")
71 def writer_arg1(self, path):
72 self.write(path)
73
74 @open_file(2, "wb")
75 def writer_arg2default(self, x, path=None):
76 if path is None:
77 with tempfile.NamedTemporaryFile("wb+") as fh:
78 self.write(fh)
79 else:
80 self.write(path)
81
82 @open_file(4, "wb")
83 def writer_arg4default(self, x, y, other="hello", path=None, **kwargs):
84 if path is None:
85 with tempfile.NamedTemporaryFile("wb+") as fh:
86 self.write(fh)
87 else:
88 self.write(path)
89
90 @open_file("path", "wb")
91 def writer_kwarg(self, **kwargs):
92 path = kwargs.get("path", None)
93 if path is None:
94 with tempfile.NamedTemporaryFile("wb+") as fh:
95 self.write(fh)
96 else:
97 self.write(path)
98
99 def test_writer_arg0_str(self):
100 self.writer_arg0(self.name)
101
102 def test_writer_arg0_fobj(self):
103 self.writer_arg0(self.fobj)
104
105 def test_writer_arg0_pathlib(self):
106 self.writer_arg0(pathlib.Path(self.name))
107
108 def test_writer_arg1_str(self):
109 self.writer_arg1(self.name)
110 assert self.read(self.name) == "".join(self.text)
111
112 def test_writer_arg1_fobj(self):
113 self.writer_arg1(self.fobj)
114 assert not self.fobj.closed
115 self.fobj.close()
116 assert self.read(self.name) == "".join(self.text)
117
118 def test_writer_arg2default_str(self):
119 self.writer_arg2default(0, path=None)
120 self.writer_arg2default(0, path=self.name)
121 assert self.read(self.name) == "".join(self.text)
122
123 def test_writer_arg2default_fobj(self):
124 self.writer_arg2default(0, path=self.fobj)
125 assert not self.fobj.closed
126 self.fobj.close()
127 assert self.read(self.name) == "".join(self.text)
128
129 def test_writer_arg2default_fobj_path_none(self):
130 self.writer_arg2default(0, path=None)
131
132 def test_writer_arg4default_fobj(self):
133 self.writer_arg4default(0, 1, dog="dog", other="other")
134 self.writer_arg4default(0, 1, dog="dog", other="other", path=self.name)
135 assert self.read(self.name) == "".join(self.text)
136
137 def test_writer_kwarg_str(self):
138 self.writer_kwarg(path=self.name)
139 assert self.read(self.name) == "".join(self.text)
140
141 def test_writer_kwarg_fobj(self):
142 self.writer_kwarg(path=self.fobj)
143 self.fobj.close()
144 assert self.read(self.name) == "".join(self.text)
145
146 def test_writer_kwarg_path_none(self):
147 self.writer_kwarg(path=None)
148
149
150 @preserve_random_state
151 def test_preserve_random_state():
152 try:
153 import numpy.random
154
155 r = numpy.random.random()
156 except ImportError:
157 return
158 assert abs(r - 0.61879477158568) < 1e-16
159
160
161 class TestRandomState:
162 @classmethod
163 def setup_class(cls):
164 global np
165 np = pytest.importorskip("numpy")
166
167 @random_state(1)
168 def instantiate_random_state(self, random_state):
169 assert isinstance(random_state, np.random.RandomState)
170 return random_state.random_sample()
171
172 @np_random_state(1)
173 def instantiate_np_random_state(self, random_state):
174 assert isinstance(random_state, np.random.RandomState)
175 return random_state.random_sample()
176
177 @py_random_state(1)
178 def instantiate_py_random_state(self, random_state):
179 assert isinstance(random_state, random.Random) or isinstance(
180 random_state, PythonRandomInterface
181 )
182 return random_state.random()
183
184 def test_random_state_None(self):
185 np.random.seed(42)
186 rv = np.random.random_sample()
187 np.random.seed(42)
188 assert rv == self.instantiate_random_state(None)
189 np.random.seed(42)
190 assert rv == self.instantiate_np_random_state(None)
191
192 random.seed(42)
193 rv = random.random()
194 random.seed(42)
195 assert rv == self.instantiate_py_random_state(None)
196
197 def test_random_state_np_random(self):
198 np.random.seed(42)
199 rv = np.random.random_sample()
200 np.random.seed(42)
201 assert rv == self.instantiate_random_state(np.random)
202 np.random.seed(42)
203 assert rv == self.instantiate_np_random_state(np.random)
204 np.random.seed(42)
205 assert rv == self.instantiate_py_random_state(np.random)
206
207 def test_random_state_int(self):
208 np.random.seed(42)
209 np_rv = np.random.random_sample()
210 random.seed(42)
211 py_rv = random.random()
212
213 np.random.seed(42)
214 seed = 1
215 rval = self.instantiate_random_state(seed)
216 rval_expected = np.random.RandomState(seed).rand()
217 assert rval, rval_expected
218
219 rval = self.instantiate_np_random_state(seed)
220 rval_expected = np.random.RandomState(seed).rand()
221 assert rval, rval_expected
222 # test that global seed wasn't changed in function
223 assert np_rv == np.random.random_sample()
224
225 random.seed(42)
226 rval = self.instantiate_py_random_state(seed)
227 rval_expected = random.Random(seed).random()
228 assert rval, rval_expected
229 # test that global seed wasn't changed in function
230 assert py_rv == random.random()
231
232 def test_random_state_np_random_RandomState(self):
233 np.random.seed(42)
234 np_rv = np.random.random_sample()
235
236 np.random.seed(42)
237 seed = 1
238 rng = np.random.RandomState(seed)
239 rval = self.instantiate_random_state(rng)
240 rval_expected = np.random.RandomState(seed).rand()
241 assert rval, rval_expected
242
243 rval = self.instantiate_np_random_state(seed)
244 rval_expected = np.random.RandomState(seed).rand()
245 assert rval, rval_expected
246
247 rval = self.instantiate_py_random_state(seed)
248 rval_expected = np.random.RandomState(seed).rand()
249 assert rval, rval_expected
250 # test that global seed wasn't changed in function
251 assert np_rv == np.random.random_sample()
252
253 def test_random_state_py_random(self):
254 seed = 1
255 rng = random.Random(seed)
256 rv = self.instantiate_py_random_state(rng)
257 assert rv, random.Random(seed).random()
258
259 pytest.raises(ValueError, self.instantiate_random_state, rng)
260 pytest.raises(ValueError, self.instantiate_np_random_state, rng)
261
262
263 def test_random_state_string_arg_index():
264 with pytest.raises(nx.NetworkXError):
265
266 @random_state("a")
267 def make_random_state(rs):
268 pass
269
270 rstate = make_random_state(1)
271
272
273 def test_py_random_state_string_arg_index():
274 with pytest.raises(nx.NetworkXError):
275
276 @py_random_state("a")
277 def make_random_state(rs):
278 pass
279
280 rstate = make_random_state(1)
281
282
283 def test_random_state_invalid_arg_index():
284 with pytest.raises(nx.NetworkXError):
285
286 @random_state(2)
287 def make_random_state(rs):
288 pass
289
290 rstate = make_random_state(1)
291
292
293 def test_py_random_state_invalid_arg_index():
294 with pytest.raises(nx.NetworkXError):
295
296 @py_random_state(2)
297 def make_random_state(rs):
298 pass
299
300 rstate = make_random_state(1)