diff tools/myTools/bin/sfa/analysis/perturb.py @ 1:7e5c71b2e71f draft default tip

Uploaded
author laurenmarazzi
date Wed, 22 Dec 2021 16:00:34 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tools/myTools/bin/sfa/analysis/perturb.py	Wed Dec 22 16:00:34 2021 +0000
@@ -0,0 +1,122 @@
+# -*- coding: utf-8 -*-
+
+import numpy as np
+
+
+def analyze_perturb(alg, data, targets, b=None, get_trj=False):
+    """Perform signal flow analysis under perturbations.
+
+    Parameters
+    ----------
+    alg : sfa.Algorithm
+        Algorithm object.
+
+    data : sfa.Data
+        Data object which has perturbation data.
+
+    targets : list
+        List of node names, which are the keys of data.n2i.
+
+    b : numpy.ndarray
+        Basic vector for signaling sources or basal activities.
+
+    get_trj : bool (optional)
+        Decide to get the trajectory of activity change.
+        
+    Returns
+    -------
+    act : numpy.ndarray
+        Change in the activities. It is usually calculated
+        as x2 - x1, where x is
+        the a vector of activities at steady-state.
+
+    F : numpy.ndarray
+        A matrix of signal flows.        
+        It is usually calculated as W2*x1 - W1*x1,
+        where W is weight matrix and
+        x is a vector of activities at steady-state. 
+
+    trj : numpy.ndarray (optional)
+        Trajectory of activity change, which is returned
+        if get_trj is True.
+    """
+    N = data.A.shape[0]
+
+    if b is None:
+        b = np.zeros((N,), dtype=np.float)
+    elif b.size != N:
+        raise TypeError("The size of b should be equal to %d"%(N))
+
+    inds = []
+    vals = []
+    alg.apply_inputs(inds, vals)
+    b[inds] = vals
+
+    W_ctrl = alg.W.copy()
+    x_ctrl, trj_ctrl = alg.propagate_iterative(
+                                W_ctrl,
+                                b,
+                                b,
+                                alg.params.alpha,
+                                get_trj=get_trj)
+
+    if data.has_link_perturb:
+        W_pert = W_ctrl.copy()
+        alg.apply_perturbations(targets, inds, vals, W_pert)
+        alg.W = W_pert
+    else:
+        W_pert = W_ctrl
+        alg.apply_perturbations(targets, inds, vals)
+
+    b[inds] = vals
+    x_pert, trj_pert = alg.propagate_iterative(
+                                W_pert,
+                                b,
+                                b,
+                                alg.params.alpha,
+                                get_trj=get_trj)
+
+    act_change = x_pert - x_ctrl
+
+    if data.has_link_perturb:
+        F = W_pert*x_pert - W_ctrl*x_ctrl
+    else:
+        F = W_ctrl*act_change
+
+    ret = [act_change, F]  # return objects
+    if get_trj:
+        if trj_pert.shape[0] != trj_ctrl.shape[0]:
+            trj_ctrl, trj_pert = resize_trj(trj_ctrl, trj_pert)
+    
+        trj_change = trj_pert - trj_ctrl
+        ret.append(trj_change)
+
+    return tuple(ret)
+
+
+def resize_trj(trj_ctrl, trj_pert):
+    # Prepare the comparison
+    trjs = [trj_pert, trj_ctrl]
+    ind_trjs = [0, 1]
+    func_key = lambda x: trjs[x].shape[0]
+
+    # Find smaller and bigger arrays.
+    ind_smaller = min(ind_trjs, key=func_key)
+    ind_bigger = max(ind_trjs, key=func_key)
+    smaller = trjs[ind_smaller]
+    bigger = trjs[ind_bigger]
+
+    # Resize the smaller one.
+    smaller_resized = np.zeros_like(bigger)
+    smaller_resized[:smaller.shape[0], :] = smaller
+    smaller_resized[smaller.shape[0]:, :] = smaller[-1, :]
+
+    if ind_smaller == 0:
+        trj_pert = smaller_resized
+    elif ind_smaller == 1:
+        trj_ctrl = smaller_resized
+    else:
+        err_msg = "Invalid index for trajectories: %d" % (ind_smaller)
+        raise IndexError(err_msg)
+
+    return trj_ctrl, trj_pert
\ No newline at end of file