comparison tools/myTools/bin/sfa/analysis/random/base.py @ 1:7e5c71b2e71f draft default tip

Uploaded
author laurenmarazzi
date Wed, 22 Dec 2021 16:00:34 +0000
parents
children
comparison
equal deleted inserted replaced
0:f24d4892aaed 1:7e5c71b2e71f
1 # -*- coding: utf-8 -*-
2
3 from multiprocessing import Pool
4
5 import numpy as np
6 import pandas as pd
7
8 import sfa
9
10
11 class BaseRandomSimulator(object):
12
13 def __init__(self):
14 # Consider RuntimeWarning from NumPy as an error
15 np.seterr(all='raise')
16
17 def _initialize(self, alg):
18 self._S = np.sign(alg.data.A) # Sign matrix
19 self._ir, self._ic = alg.data.A.nonzero()
20 self._num_links = self._ir.size
21 self._A = np.array(alg.data.A)
22 self._W = np.zeros_like(self._A, dtype=np.float)
23
24 def _randomize(self):
25 raise NotImplementedError()
26
27 def _apply_norm(self, alg, use_norm):
28 """
29 Apply normalization
30 """
31 if use_norm:
32 alg.W = sfa.normalize(self._W)
33 else:
34 alg.W = self._W
35
36 # end of def
37
38 def _need_to_print(self, use_print, freq_print, cnt):
39 return use_print and (cnt % freq_print == 0)
40
41 def simulate_single(*args, **kwargs):
42 raise NotImplementedError()
43
44 def simulate_multiple(*args, **kwargs):
45 raise NotImplementedError()
46
47 """
48 def _simulate_single(self, args):
49 num_samp = args[0]
50 alg = args[1]
51 data = args[2]
52 use_norm = args[3]
53 use_print = args[4]
54 freq_print = args[5]
55
56 alg.data = data
57 alg.initialize(network=False)
58
59 results = np.zeros((num_samp,), dtype=np.float)
60 cnt = 0
61
62 if self._need_to_print(use_print, freq_print, cnt):
63 print("%s simulation for %s starts..." % (alg.abbr, data.abbr))
64
65 while cnt < num_samp:
66 self._randomize()
67 self._apply_norm(alg, use_norm)
68 try:
69 alg.compute_batch()
70 acc = sfa.calc_accuracy(self._alg.result.df_sim,
71 self._alg.data.df_exp)
72 except FloatingPointError as pe:
73 # Skip this condition
74 if self._need_to_print(use_print, freq_print, cnt):
75 print("%s: skipped..." % (pe))
76 continue
77 except RuntimeWarning as rw:
78 # Skip these weights
79 if self._need_to_print(use_print, freq_print, cnt):
80 print("%s: skipped..." % (rw))
81 continue
82
83 # Skip these weights assuming acc cannot be exactly 0.
84 if acc == 0:
85 if self._need_to_print(use_print, freq_print, cnt):
86 print("Zero accuracy: skipped...")
87 continue
88
89 results[cnt] = acc
90 cnt += 1
91 if self._need_to_print(use_print, freq_print, cnt):
92 print("[Iteration #%d] acc: %f" % (cnt, acc))
93 # end of loop
94 df = pd.DataFrame(results)
95 df.index = range(1, num_samp + 1)
96 df.columns = [data.abbr]
97 return df
98 # end of def
99
100 def simulate_single(self, num_samp, alg, data, use_norm=False,
101 use_print=False, freq_print=100):
102 self._alg = alg
103 self._alg.data = data
104 self._alg.initialize()
105 self._initialize(alg)
106
107 df = self._simulate_single((num_samp, alg, data, use_norm,
108 use_print, freq_print))
109 return df
110
111 # end of def
112
113 def simulate_multiple(self, num_samp, alg, mdata, use_norm=False,
114 use_print=False, freq_print=100,
115 max_workers=1):
116 self._alg = alg
117
118 # Initialize network information only
119 self._alg.data = sfa.get_avalue(mdata)
120 self._alg.initialize()
121 self._initialize(alg)
122
123 if isinstance(mdata, list):
124 list_data = [(data.abbr, data) for data in mdata]
125 elif isinstance(mdata, dict):
126 list_data = [(abbr, mdata[abbr]) for abbr in mdata]
127
128 dfs = []
129 if max_workers == 1:
130 for (abbr, data) in list_data:
131 df = self._simulate_single(num_samp, alg, data, use_norm,
132 use_print, freq_print)
133 dfs.append(df)
134 # end of for
135 elif max_workers > 1:
136 args = ((num_samp, alg, data, use_norm, use_print, freq_print)
137 for (abbr, data) in list_data)
138 pool = Pool(processes=max_workers)
139 dfs = list(pool.map(self._simulate_single, args))
140 pool.close()
141 pool.join()
142 else:
143 raise ValueError("max_workers should be a positive integer.")
144
145 df_res = pd.concat(dfs, axis=1)
146 return df_res
147 # end of def
148 """
149
150 # end of class
151
152
153 class BaseRandomBatchSimulator(BaseRandomSimulator):
154 def __init__(self):
155 super().__init__()
156
157 def _randomize(self):
158 raise NotImplementedError()
159
160 def _simulate_single(self, args):
161 num_samp = args[0]
162 alg = args[1]
163 data = args[2]
164 use_norm = args[3]
165 use_print = args[4]
166 freq_print = args[5]
167
168 alg.data = data
169 alg.initialize(network=False)
170
171 results = np.zeros((num_samp,), dtype=np.float)
172 cnt = 0
173
174 if self._need_to_print(use_print, freq_print, cnt):
175 print("%s simulation for %s starts..." % (alg.abbr, data.abbr))
176
177 while cnt < num_samp:
178 self._randomize()
179 self._apply_norm(alg, use_norm)
180 try:
181 alg.compute_batch()
182 acc = sfa.calc_accuracy(self._alg.result.df_sim,
183 self._alg.data.df_exp)
184 except FloatingPointError as pe:
185 # Skip this condition
186 if self._need_to_print(use_print, freq_print, cnt):
187 print("%s: skipped..." % (pe))
188 continue
189 except RuntimeWarning as rw:
190 # Skip these weights
191 if self._need_to_print(use_print, freq_print, cnt):
192 print("%s: skipped..." % (rw))
193 continue
194
195 # Skip these weights assuming acc cannot be exactly 0.
196 if acc == 0:
197 if self._need_to_print(use_print, freq_print, cnt):
198 print("Zero accuracy: skipped...")
199 continue
200
201 results[cnt] = acc
202 cnt += 1
203 if self._need_to_print(use_print, freq_print, cnt):
204 print("[Iteration #%d] acc: %f" % (cnt, acc))
205 # end of loop
206 df = pd.DataFrame(results)
207 df.index = range(1, num_samp + 1)
208 df.columns = [data.abbr]
209 return df
210 # end of def
211
212 def simulate_single(self, num_samp, alg, data, use_norm=False,
213 use_print=False, freq_print=100):
214 self._alg = alg
215 self._alg.data = data
216 self._alg.initialize()
217 self._initialize(alg)
218
219 df = self._simulate_single((num_samp, alg, data, use_norm,
220 use_print, freq_print))
221 return df
222
223 # end of def
224
225 def simulate_multiple(self, num_samp, alg, mdata, use_norm=False,
226 use_print=False, freq_print=100,
227 max_workers=1):
228 self._alg = alg
229
230 # Initialize network information only
231 self._alg.data = sfa.get_avalue(mdata)
232 self._alg.initialize()
233 self._initialize(alg)
234
235 if isinstance(mdata, list):
236 list_data = [(data.abbr, data) for data in mdata]
237 elif isinstance(mdata, dict):
238 list_data = [(abbr, mdata[abbr]) for abbr in mdata]
239
240 dfs = []
241 if max_workers == 1:
242 for (abbr, data) in list_data:
243 df = self._simulate_single(num_samp, alg, data, use_norm,
244 use_print, freq_print)
245 dfs.append(df)
246 # end of for
247 elif max_workers > 1:
248 args = ((num_samp, alg, data, use_norm, use_print, freq_print)
249 for (abbr, data) in list_data)
250 pool = Pool(processes=max_workers)
251 dfs = list(pool.map(self._simulate_single, args))
252 pool.close()
253 pool.join()
254 else:
255 raise ValueError("max_workers should be a positive integer.")
256
257 df_res = pd.concat(dfs, axis=1)
258 return df_res
259 # end of def
260 # end of class