annotate tools/myTools/bin/sfa/control/influence.py @ 1:7e5c71b2e71f draft default tip

Uploaded
author laurenmarazzi
date Wed, 22 Dec 2021 16:00:34 +0000
parents
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
1
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
1 from collections import Counter
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
2
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
3 import numpy as np
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
4 import scipy as sp
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
5 import pandas as pd
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
6
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
7
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
8 def compute_influence(W,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
9 alpha=0.9,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
10 beta=0.1,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
11 S=None,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
12 rtype='df',
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
13 outputs=None,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
14 n2i=None,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
15 max_iter=1000,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
16 tol=1e-7,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
17 get_iter=False,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
18 device="cpu",
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
19 sparse=False):
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
20 r"""Compute the influence.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
21 It estimates the effects of a node to the other nodes,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
22 by calculating partial derivative with respect to source nodes,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
23 based on a simple iterative method.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
24
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
25 Based on the below difference equation,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
26
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
27 x(t+1) = alpha*W.dot(x(t)) + (1-alpha)*b
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
28
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
29 The influence matrix, S, is computed using chain rule of
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
30 partial derivative as follows.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
31
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
32 \begin{align}
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
33 S_{ij} &= \frac{\partial{x_i}}{\partial{x_j}} \\
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
34 &= (I + \alpha W + \alpha^2 W^2 + ... + \alpha^{\infty}W^{\infty})_{ij} \\
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
35 &\approx (I + \alpha W + \alpha^2 W^2 + ... + \alpha^{l}W^{l})_{ij} \\
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
36 \end{align}
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
37
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
38 This is the summation of the weight multiplications along all paths
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
39 including cycles. $S_{ij}$ denotes the influence of node (j) on node (i).
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
40
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
41 An iterative method for an approximated solution is as follows.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
42
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
43 S(t+1) = \alpha WS(t) + I,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
44
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
45 where $S(0) = \beta I$ and $S(1) = \beta(I + \alpha W)$ $(t>1)$.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
46
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
47 The iteration continues until $||S(t+1) - S(t)|| \leq tol$.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
48
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
49
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
50 Parameters
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
51 ----------
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
52 W : numpy.ndarray
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
53 Weight matrix.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
54 alpha : float, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
55 Hyperparameter for adjusting the effect of signal flow.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
56 beta : float, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
57 Hyperparameter for adjusting the effect of basal activity.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
58 S : numpy.ndarray, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
59 Initial influence matrix.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
60 rtype: str (optional)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
61 Return object type: 'df' or 'array'.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
62 outputs: list (or iterable) of str, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
63 Names of output nodes, which is necessary for 'df' rtype.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
64 n2i: dict, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
65 Name to index dict, which is necessary for 'df' rtype.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
66 max_iter : int, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
67 The maximum iteration number for the estimation.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
68 tol : float, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
69 Tolerance for terminating the iteration.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
70 get_iter : bool, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
71 Determine whether the actual iteration number is returned.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
72 device : str, optional, {'CPU', 'GPU:0', 'GPU:1', ...}
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
73 Select which device to use. 'CPU' is default.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
74 sparse : bool, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
75 Use sparse matrices for the computation.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
76
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
77 Returns
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
78 -------
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
79 S : numpy.ndarray, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
80 2D array of influence.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
81 df : pd.DataFrame, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
82 Influences for each output in DataFrame.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
83 num_iter : int, optional
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
84 The actual number of iteration.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
85 """
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
86 # TODO: Test rendering the above mathematical expressions in LaTeX form.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
87
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
88 if max_iter < 2:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
89 raise ValueError("max_iter should be greater than 2.")
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
90
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
91 device = device.lower()
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
92
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
93 if 'cpu' in device:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
94 if sparse:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
95 ret = _compute_influence_cpu_sparse(W, alpha, beta, S,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
96 max_iter, tol, get_iter)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
97 else:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
98 ret = _compute_influence_cpu(W, alpha, beta, S,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
99 max_iter, tol, get_iter)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
100 elif 'gpu'in device:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
101 _, id_device = device.split(':')
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
102 ret = _compute_influence_gpu(W, alpha, beta, S,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
103 max_iter, tol, get_iter, id_device)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
104
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
105 if rtype == 'df':
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
106 import cupy as cp
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
107 if isinstance(ret, cp.core.core.ndarray):
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
108 ret = cp.asnumpy(ret)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
109
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
110 if get_iter:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
111 S_ret, num_iter = ret
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
112 else:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
113 S_ret = ret
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
114
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
115 if rtype == 'array':
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
116 return ret
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
117 elif rtype == 'df':
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
118 if not outputs:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
119 err_msg = "outputs should be designated for 'df' return type."
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
120 raise ValueError(err_msg)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
121
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
122 df = pd.DataFrame(columns=outputs)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
123
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
124 for trg in outputs:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
125 for src in n2i:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
126 if src == trg:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
127 df.loc[src, trg] = np.inf
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
128
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
129 idx_src = n2i[src]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
130 idx_trg = n2i[trg]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
131 df.loc[src, trg] = S_ret[idx_trg, idx_src]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
132
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
133 if get_iter:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
134 return df, num_iter
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
135 else:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
136 return df
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
137 else:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
138 raise ValueError("Unknown return type: %s"%(rtype))
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
139
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
140
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
141 def _compute_influence_cpu(W, alpha=0.5, beta=0.5, S=None,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
142 max_iter=1000, tol=1e-6, get_iter=False):
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
143 N = W.shape[0]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
144 if S is not None:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
145 S1 = S
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
146 else:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
147 S1 = np.eye(N, dtype=np.float)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
148
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
149 I = np.eye(N, dtype=np.float)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
150 S2 = np.zeros_like(W)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
151 aW = alpha * W
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
152 for cnt in range(max_iter):
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
153 S2[:, :] = S1.dot(aW) + I
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
154 norm = np.linalg.norm(S2 - S1)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
155 if norm < tol:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
156 break
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
157 # end of if
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
158 S1[:, :] = S2
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
159 # end of for
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
160
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
161 S_fin = beta * S2
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
162 if get_iter:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
163 return S_fin, cnt
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
164 else:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
165 return S_fin
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
166
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
167
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
168 def _compute_influence_cpu_sparse(W, alpha, beta, S,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
169 max_iter, tol, get_iter):
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
170 N = W.shape[0]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
171 if S is not None:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
172 S1 = S
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
173 else:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
174 S1 = sp.sparse.lil_matrix(sp.sparse.eye(N, dtype=np.float))
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
175
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
176
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
177 I = sp.sparse.eye(N, dtype=np.float)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
178 S2 = sp.sparse.lil_matrix((N,N), dtype=np.float)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
179 aW = sp.sparse.csc_matrix(alpha * W)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
180 for cnt in range(max_iter):
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
181 S2[:, :] = S1.dot(aW) + I
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
182 norm = sp.sparse.linalg.norm(S2 - S1)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
183 if norm < tol:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
184 break
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
185 # end of if
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
186 S1[:, :] = S2
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
187 # end of for
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
188
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
189 S_fin = beta * S2
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
190 if get_iter:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
191 return S_fin, cnt
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
192 else:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
193 return S_fin
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
194
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
195
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
196 def _compute_influence_gpu(W, alpha=0.5, beta=0.5, S=None,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
197 max_iter=1000, tol=1e-6, get_iter=False,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
198 id_device=0):
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
199 import cupy as cp
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
200 cp.cuda.Device(id_device).use()
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
201 N = W.shape[0]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
202 I = cp.eye(N, dtype=cp.float32) #np.eye(N, N, dtype=np.float)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
203 if S is not None:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
204 S1 = cp.array(S, dtype=cp.float32)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
205 else:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
206 S1 = cp.eye(N, dtype=cp.float32)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
207
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
208 S2 = cp.zeros((N,N), dtype=cp.float32)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
209 aW = alpha * cp.array(W, dtype=cp.float32)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
210
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
211 tol_gpu = cp.array(tol)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
212
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
213 for cnt in range(max_iter):
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
214 S2[:, :] = cp.dot(S1, aW) + I
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
215 mat_norm = cp.linalg.norm(S2 - S1)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
216 if mat_norm < tol_gpu:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
217 break
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
218 # end of if
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
219 S1[:, :] = S2
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
220 # end of for
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
221
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
222 S_fin = beta*S2
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
223 if get_iter:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
224 return S_fin, cnt
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
225 else:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
226 return S_fin
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
227
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
228
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
229
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
230 def arrange_si(
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
231 df_splo,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
232 df_inf,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
233 output,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
234 min_splo=None,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
235 max_splo=None,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
236 thr_inf=1e-10,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
237 ascending=True):
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
238
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
239 # SPLO-Influence data
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
240 if not min_splo:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
241 min_splo = df_splo.min()
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
242
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
243 if not max_splo:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
244 max_splo = df_splo.max()
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
245
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
246 mask_splo = (min_splo <= df_splo) & (df_splo <= max_splo)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
247 df_splo = df_splo[mask_splo]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
248
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
249 df_splo = pd.DataFrame(df_splo)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
250 df_splo.columns = ['SPLO']
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
251
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
252 if output in df_splo.index:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
253 df_splo.drop(output, inplace=True)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
254
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
255 index_common = df_splo.index.intersection(df_inf.index)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
256 df_inf = pd.DataFrame(df_inf.loc[index_common])
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
257
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
258 mark_drop = df_inf[output].abs() <= thr_inf
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
259 df_inf.drop(df_inf.loc[mark_drop, output].index,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
260 inplace=True)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
261
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
262
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
263 df_si = df_inf.join(df_splo.loc[index_common])
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
264 df_si.index.name = 'Source'
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
265 df_si.reset_index(inplace=True)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
266
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
267 cnt_splo = Counter(df_si['SPLO'])
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
268 splos = sorted(cnt_splo.keys())
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
269
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
270 si = {}
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
271 for i, splo in enumerate(splos):
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
272 df_sub = df_si[df_si['SPLO'] == splo]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
273 df_sub = df_sub.sort_values(by=output,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
274 ascending=ascending)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
275 #num_items = df_sub[output].count()
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
276 #influence = np.zeros((cnt_max,)) # Influence
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
277 #num_empty = cnt_max - num_items
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
278 #influence[num_empty:] = df_sub[output]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
279 #names = df_sub['Source'].tolist()
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
280 si[splo] = df_sub #[output]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
281
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
282 return si
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
283
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
284
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
285 def prioritize(df_splo,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
286 df_inf,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
287 output,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
288 dac,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
289 thr_rank=3,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
290 min_group_size=0,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
291 min_splo=None,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
292 max_splo=None,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
293 thr_inf=1e-10,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
294 ):
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
295 """Prioritize target candiates.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
296
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
297 Parameters
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
298 ----------
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
299 df_splo : pandas.DataFrame
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
300 Dataframe for SPLO information.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
301 df_inf : pandas.DataFrame
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
302 Dataframe for influence information.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
303 output : str
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
304 Names of output node, which is necessary for 'df_inf'.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
305 dac : int
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
306 Direction of activity change (DAC) of the output.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
307 thr_rank : int or float
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
308 Rank to filter out the entities.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
309 The entities whose ranks are greater than thr_rank survive.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
310 min_group_size : int
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
311 Minimum group size to be satisfied.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
312 """
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
313 ascending = True if dac < 0 else False
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
314
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
315 df_inf_dac = df_inf[np.sign(df_inf[output]) == dac]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
316 si = arrange_si(df_splo,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
317 df_inf_dac,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
318 output,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
319 min_splo,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
320 max_splo,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
321 thr_inf,
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
322 ascending)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
323 targets = []
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
324 for splo in si:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
325 # Get the group of this SPLO.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
326 df_sub = si[splo]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
327
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
328 if df_sub.shape[0] < min_group_size:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
329 continue
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
330
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
331 # Get the entities that have the designated dac.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
332 df_sub = df_sub[np.sign(df_sub[output]) == dac]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
333
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
334 # Get the enetities whose rank exceeds the threshods.
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
335 if 0 < thr_rank < 1:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
336 ix_max_rank = int(thr_rank * df_sub.shape[0])
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
337 if ix_max_rank == 0:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
338 ix_max_rank = df_sub.shape[0]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
339 else:
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
340 ix_max_rank = thr_rank
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
341
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
342 #print(ix_max_rank)
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
343 df_top = df_sub.iloc[:ix_max_rank, :]
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
344
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
345 targets.extend(df_top['Source'].tolist())
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
346 # end of for
7e5c71b2e71f Uploaded
laurenmarazzi
parents:
diff changeset
347 return targets