Mercurial > repos > laurenmarazzi > netisce_test
diff tools/myTools/bin/sfa/analysis/perturb.py @ 1:7e5c71b2e71f draft default tip
Uploaded
author | laurenmarazzi |
---|---|
date | Wed, 22 Dec 2021 16:00:34 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tools/myTools/bin/sfa/analysis/perturb.py Wed Dec 22 16:00:34 2021 +0000 @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +import numpy as np + + +def analyze_perturb(alg, data, targets, b=None, get_trj=False): + """Perform signal flow analysis under perturbations. + + Parameters + ---------- + alg : sfa.Algorithm + Algorithm object. + + data : sfa.Data + Data object which has perturbation data. + + targets : list + List of node names, which are the keys of data.n2i. + + b : numpy.ndarray + Basic vector for signaling sources or basal activities. + + get_trj : bool (optional) + Decide to get the trajectory of activity change. + + Returns + ------- + act : numpy.ndarray + Change in the activities. It is usually calculated + as x2 - x1, where x is + the a vector of activities at steady-state. + + F : numpy.ndarray + A matrix of signal flows. + It is usually calculated as W2*x1 - W1*x1, + where W is weight matrix and + x is a vector of activities at steady-state. + + trj : numpy.ndarray (optional) + Trajectory of activity change, which is returned + if get_trj is True. + """ + N = data.A.shape[0] + + if b is None: + b = np.zeros((N,), dtype=np.float) + elif b.size != N: + raise TypeError("The size of b should be equal to %d"%(N)) + + inds = [] + vals = [] + alg.apply_inputs(inds, vals) + b[inds] = vals + + W_ctrl = alg.W.copy() + x_ctrl, trj_ctrl = alg.propagate_iterative( + W_ctrl, + b, + b, + alg.params.alpha, + get_trj=get_trj) + + if data.has_link_perturb: + W_pert = W_ctrl.copy() + alg.apply_perturbations(targets, inds, vals, W_pert) + alg.W = W_pert + else: + W_pert = W_ctrl + alg.apply_perturbations(targets, inds, vals) + + b[inds] = vals + x_pert, trj_pert = alg.propagate_iterative( + W_pert, + b, + b, + alg.params.alpha, + get_trj=get_trj) + + act_change = x_pert - x_ctrl + + if data.has_link_perturb: + F = W_pert*x_pert - W_ctrl*x_ctrl + else: + F = W_ctrl*act_change + + ret = [act_change, F] # return objects + if get_trj: + if trj_pert.shape[0] != trj_ctrl.shape[0]: + trj_ctrl, trj_pert = resize_trj(trj_ctrl, trj_pert) + + trj_change = trj_pert - trj_ctrl + ret.append(trj_change) + + return tuple(ret) + + +def resize_trj(trj_ctrl, trj_pert): + # Prepare the comparison + trjs = [trj_pert, trj_ctrl] + ind_trjs = [0, 1] + func_key = lambda x: trjs[x].shape[0] + + # Find smaller and bigger arrays. + ind_smaller = min(ind_trjs, key=func_key) + ind_bigger = max(ind_trjs, key=func_key) + smaller = trjs[ind_smaller] + bigger = trjs[ind_bigger] + + # Resize the smaller one. + smaller_resized = np.zeros_like(bigger) + smaller_resized[:smaller.shape[0], :] = smaller + smaller_resized[smaller.shape[0]:, :] = smaller[-1, :] + + if ind_smaller == 0: + trj_pert = smaller_resized + elif ind_smaller == 1: + trj_ctrl = smaller_resized + else: + err_msg = "Invalid index for trajectories: %d" % (ind_smaller) + raise IndexError(err_msg) + + return trj_ctrl, trj_pert \ No newline at end of file