Mercurial > repos > laurenmarazzi > netisce_test
comparison tools/myTools/bin/sfa/analysis/perturb.py @ 1:7e5c71b2e71f draft default tip
Uploaded
author | laurenmarazzi |
---|---|
date | Wed, 22 Dec 2021 16:00:34 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:f24d4892aaed | 1:7e5c71b2e71f |
---|---|
1 # -*- coding: utf-8 -*- | |
2 | |
3 import numpy as np | |
4 | |
5 | |
6 def analyze_perturb(alg, data, targets, b=None, get_trj=False): | |
7 """Perform signal flow analysis under perturbations. | |
8 | |
9 Parameters | |
10 ---------- | |
11 alg : sfa.Algorithm | |
12 Algorithm object. | |
13 | |
14 data : sfa.Data | |
15 Data object which has perturbation data. | |
16 | |
17 targets : list | |
18 List of node names, which are the keys of data.n2i. | |
19 | |
20 b : numpy.ndarray | |
21 Basic vector for signaling sources or basal activities. | |
22 | |
23 get_trj : bool (optional) | |
24 Decide to get the trajectory of activity change. | |
25 | |
26 Returns | |
27 ------- | |
28 act : numpy.ndarray | |
29 Change in the activities. It is usually calculated | |
30 as x2 - x1, where x is | |
31 the a vector of activities at steady-state. | |
32 | |
33 F : numpy.ndarray | |
34 A matrix of signal flows. | |
35 It is usually calculated as W2*x1 - W1*x1, | |
36 where W is weight matrix and | |
37 x is a vector of activities at steady-state. | |
38 | |
39 trj : numpy.ndarray (optional) | |
40 Trajectory of activity change, which is returned | |
41 if get_trj is True. | |
42 """ | |
43 N = data.A.shape[0] | |
44 | |
45 if b is None: | |
46 b = np.zeros((N,), dtype=np.float) | |
47 elif b.size != N: | |
48 raise TypeError("The size of b should be equal to %d"%(N)) | |
49 | |
50 inds = [] | |
51 vals = [] | |
52 alg.apply_inputs(inds, vals) | |
53 b[inds] = vals | |
54 | |
55 W_ctrl = alg.W.copy() | |
56 x_ctrl, trj_ctrl = alg.propagate_iterative( | |
57 W_ctrl, | |
58 b, | |
59 b, | |
60 alg.params.alpha, | |
61 get_trj=get_trj) | |
62 | |
63 if data.has_link_perturb: | |
64 W_pert = W_ctrl.copy() | |
65 alg.apply_perturbations(targets, inds, vals, W_pert) | |
66 alg.W = W_pert | |
67 else: | |
68 W_pert = W_ctrl | |
69 alg.apply_perturbations(targets, inds, vals) | |
70 | |
71 b[inds] = vals | |
72 x_pert, trj_pert = alg.propagate_iterative( | |
73 W_pert, | |
74 b, | |
75 b, | |
76 alg.params.alpha, | |
77 get_trj=get_trj) | |
78 | |
79 act_change = x_pert - x_ctrl | |
80 | |
81 if data.has_link_perturb: | |
82 F = W_pert*x_pert - W_ctrl*x_ctrl | |
83 else: | |
84 F = W_ctrl*act_change | |
85 | |
86 ret = [act_change, F] # return objects | |
87 if get_trj: | |
88 if trj_pert.shape[0] != trj_ctrl.shape[0]: | |
89 trj_ctrl, trj_pert = resize_trj(trj_ctrl, trj_pert) | |
90 | |
91 trj_change = trj_pert - trj_ctrl | |
92 ret.append(trj_change) | |
93 | |
94 return tuple(ret) | |
95 | |
96 | |
97 def resize_trj(trj_ctrl, trj_pert): | |
98 # Prepare the comparison | |
99 trjs = [trj_pert, trj_ctrl] | |
100 ind_trjs = [0, 1] | |
101 func_key = lambda x: trjs[x].shape[0] | |
102 | |
103 # Find smaller and bigger arrays. | |
104 ind_smaller = min(ind_trjs, key=func_key) | |
105 ind_bigger = max(ind_trjs, key=func_key) | |
106 smaller = trjs[ind_smaller] | |
107 bigger = trjs[ind_bigger] | |
108 | |
109 # Resize the smaller one. | |
110 smaller_resized = np.zeros_like(bigger) | |
111 smaller_resized[:smaller.shape[0], :] = smaller | |
112 smaller_resized[smaller.shape[0]:, :] = smaller[-1, :] | |
113 | |
114 if ind_smaller == 0: | |
115 trj_pert = smaller_resized | |
116 elif ind_smaller == 1: | |
117 trj_ctrl = smaller_resized | |
118 else: | |
119 err_msg = "Invalid index for trajectories: %d" % (ind_smaller) | |
120 raise IndexError(err_msg) | |
121 | |
122 return trj_ctrl, trj_pert |