1
|
1 # -*- coding: utf-8 -*-
|
|
2
|
|
3 import numpy as np
|
|
4
|
|
5
|
|
6 def analyze_perturb(alg, data, targets, b=None, get_trj=False):
|
|
7 """Perform signal flow analysis under perturbations.
|
|
8
|
|
9 Parameters
|
|
10 ----------
|
|
11 alg : sfa.Algorithm
|
|
12 Algorithm object.
|
|
13
|
|
14 data : sfa.Data
|
|
15 Data object which has perturbation data.
|
|
16
|
|
17 targets : list
|
|
18 List of node names, which are the keys of data.n2i.
|
|
19
|
|
20 b : numpy.ndarray
|
|
21 Basic vector for signaling sources or basal activities.
|
|
22
|
|
23 get_trj : bool (optional)
|
|
24 Decide to get the trajectory of activity change.
|
|
25
|
|
26 Returns
|
|
27 -------
|
|
28 act : numpy.ndarray
|
|
29 Change in the activities. It is usually calculated
|
|
30 as x2 - x1, where x is
|
|
31 the a vector of activities at steady-state.
|
|
32
|
|
33 F : numpy.ndarray
|
|
34 A matrix of signal flows.
|
|
35 It is usually calculated as W2*x1 - W1*x1,
|
|
36 where W is weight matrix and
|
|
37 x is a vector of activities at steady-state.
|
|
38
|
|
39 trj : numpy.ndarray (optional)
|
|
40 Trajectory of activity change, which is returned
|
|
41 if get_trj is True.
|
|
42 """
|
|
43 N = data.A.shape[0]
|
|
44
|
|
45 if b is None:
|
|
46 b = np.zeros((N,), dtype=np.float)
|
|
47 elif b.size != N:
|
|
48 raise TypeError("The size of b should be equal to %d"%(N))
|
|
49
|
|
50 inds = []
|
|
51 vals = []
|
|
52 alg.apply_inputs(inds, vals)
|
|
53 b[inds] = vals
|
|
54
|
|
55 W_ctrl = alg.W.copy()
|
|
56 x_ctrl, trj_ctrl = alg.propagate_iterative(
|
|
57 W_ctrl,
|
|
58 b,
|
|
59 b,
|
|
60 alg.params.alpha,
|
|
61 get_trj=get_trj)
|
|
62
|
|
63 if data.has_link_perturb:
|
|
64 W_pert = W_ctrl.copy()
|
|
65 alg.apply_perturbations(targets, inds, vals, W_pert)
|
|
66 alg.W = W_pert
|
|
67 else:
|
|
68 W_pert = W_ctrl
|
|
69 alg.apply_perturbations(targets, inds, vals)
|
|
70
|
|
71 b[inds] = vals
|
|
72 x_pert, trj_pert = alg.propagate_iterative(
|
|
73 W_pert,
|
|
74 b,
|
|
75 b,
|
|
76 alg.params.alpha,
|
|
77 get_trj=get_trj)
|
|
78
|
|
79 act_change = x_pert - x_ctrl
|
|
80
|
|
81 if data.has_link_perturb:
|
|
82 F = W_pert*x_pert - W_ctrl*x_ctrl
|
|
83 else:
|
|
84 F = W_ctrl*act_change
|
|
85
|
|
86 ret = [act_change, F] # return objects
|
|
87 if get_trj:
|
|
88 if trj_pert.shape[0] != trj_ctrl.shape[0]:
|
|
89 trj_ctrl, trj_pert = resize_trj(trj_ctrl, trj_pert)
|
|
90
|
|
91 trj_change = trj_pert - trj_ctrl
|
|
92 ret.append(trj_change)
|
|
93
|
|
94 return tuple(ret)
|
|
95
|
|
96
|
|
97 def resize_trj(trj_ctrl, trj_pert):
|
|
98 # Prepare the comparison
|
|
99 trjs = [trj_pert, trj_ctrl]
|
|
100 ind_trjs = [0, 1]
|
|
101 func_key = lambda x: trjs[x].shape[0]
|
|
102
|
|
103 # Find smaller and bigger arrays.
|
|
104 ind_smaller = min(ind_trjs, key=func_key)
|
|
105 ind_bigger = max(ind_trjs, key=func_key)
|
|
106 smaller = trjs[ind_smaller]
|
|
107 bigger = trjs[ind_bigger]
|
|
108
|
|
109 # Resize the smaller one.
|
|
110 smaller_resized = np.zeros_like(bigger)
|
|
111 smaller_resized[:smaller.shape[0], :] = smaller
|
|
112 smaller_resized[smaller.shape[0]:, :] = smaller[-1, :]
|
|
113
|
|
114 if ind_smaller == 0:
|
|
115 trj_pert = smaller_resized
|
|
116 elif ind_smaller == 1:
|
|
117 trj_ctrl = smaller_resized
|
|
118 else:
|
|
119 err_msg = "Invalid index for trajectories: %d" % (ind_smaller)
|
|
120 raise IndexError(err_msg)
|
|
121
|
|
122 return trj_ctrl, trj_pert |