comparison tools/myTools/bin/sfa/analysis/perturb.py @ 1:7e5c71b2e71f draft default tip

Uploaded
author laurenmarazzi
date Wed, 22 Dec 2021 16:00:34 +0000
parents
children
comparison
equal deleted inserted replaced
0:f24d4892aaed 1:7e5c71b2e71f
1 # -*- coding: utf-8 -*-
2
3 import numpy as np
4
5
6 def analyze_perturb(alg, data, targets, b=None, get_trj=False):
7 """Perform signal flow analysis under perturbations.
8
9 Parameters
10 ----------
11 alg : sfa.Algorithm
12 Algorithm object.
13
14 data : sfa.Data
15 Data object which has perturbation data.
16
17 targets : list
18 List of node names, which are the keys of data.n2i.
19
20 b : numpy.ndarray
21 Basic vector for signaling sources or basal activities.
22
23 get_trj : bool (optional)
24 Decide to get the trajectory of activity change.
25
26 Returns
27 -------
28 act : numpy.ndarray
29 Change in the activities. It is usually calculated
30 as x2 - x1, where x is
31 the a vector of activities at steady-state.
32
33 F : numpy.ndarray
34 A matrix of signal flows.
35 It is usually calculated as W2*x1 - W1*x1,
36 where W is weight matrix and
37 x is a vector of activities at steady-state.
38
39 trj : numpy.ndarray (optional)
40 Trajectory of activity change, which is returned
41 if get_trj is True.
42 """
43 N = data.A.shape[0]
44
45 if b is None:
46 b = np.zeros((N,), dtype=np.float)
47 elif b.size != N:
48 raise TypeError("The size of b should be equal to %d"%(N))
49
50 inds = []
51 vals = []
52 alg.apply_inputs(inds, vals)
53 b[inds] = vals
54
55 W_ctrl = alg.W.copy()
56 x_ctrl, trj_ctrl = alg.propagate_iterative(
57 W_ctrl,
58 b,
59 b,
60 alg.params.alpha,
61 get_trj=get_trj)
62
63 if data.has_link_perturb:
64 W_pert = W_ctrl.copy()
65 alg.apply_perturbations(targets, inds, vals, W_pert)
66 alg.W = W_pert
67 else:
68 W_pert = W_ctrl
69 alg.apply_perturbations(targets, inds, vals)
70
71 b[inds] = vals
72 x_pert, trj_pert = alg.propagate_iterative(
73 W_pert,
74 b,
75 b,
76 alg.params.alpha,
77 get_trj=get_trj)
78
79 act_change = x_pert - x_ctrl
80
81 if data.has_link_perturb:
82 F = W_pert*x_pert - W_ctrl*x_ctrl
83 else:
84 F = W_ctrl*act_change
85
86 ret = [act_change, F] # return objects
87 if get_trj:
88 if trj_pert.shape[0] != trj_ctrl.shape[0]:
89 trj_ctrl, trj_pert = resize_trj(trj_ctrl, trj_pert)
90
91 trj_change = trj_pert - trj_ctrl
92 ret.append(trj_change)
93
94 return tuple(ret)
95
96
97 def resize_trj(trj_ctrl, trj_pert):
98 # Prepare the comparison
99 trjs = [trj_pert, trj_ctrl]
100 ind_trjs = [0, 1]
101 func_key = lambda x: trjs[x].shape[0]
102
103 # Find smaller and bigger arrays.
104 ind_smaller = min(ind_trjs, key=func_key)
105 ind_bigger = max(ind_trjs, key=func_key)
106 smaller = trjs[ind_smaller]
107 bigger = trjs[ind_bigger]
108
109 # Resize the smaller one.
110 smaller_resized = np.zeros_like(bigger)
111 smaller_resized[:smaller.shape[0], :] = smaller
112 smaller_resized[smaller.shape[0]:, :] = smaller[-1, :]
113
114 if ind_smaller == 0:
115 trj_pert = smaller_resized
116 elif ind_smaller == 1:
117 trj_ctrl = smaller_resized
118 else:
119 err_msg = "Invalid index for trajectories: %d" % (ind_smaller)
120 raise IndexError(err_msg)
121
122 return trj_ctrl, trj_pert