comparison tools/myTools/bin/SFA_insilico.py @ 1:7e5c71b2e71f draft default tip

Uploaded
author laurenmarazzi
date Wed, 22 Dec 2021 16:00:34 +0000
parents
children
comparison
equal deleted inserted replaced
0:f24d4892aaed 1:7e5c71b2e71f
1 #!/usr/bin/env python3
2
3
4 import os
5 import numpy as np
6 import pandas as pd
7 import networkx as nx
8 import random
9 import sfa
10 import csv
11 import sys
12 import statistics
13
14
15
16 fpath = os.path.join(sys.argv[1]) #location of networkfile
17 qdata_all=pd.read_csv(sys.argv[2],index_col = 0)
18 all_samples=pd.read_csv(sys.argv[3],delim_whitespace=True, index_col = False)
19 npert=int(sys.argv[4])
20 phenotypes = all_samples.phenotype.unique()
21 class ThreeNodeCascade(sfa.base.Data):
22 def __init__(self):
23 super().__init__()
24 self._abbr = "TNC"
25 self._name = "A simple three node cascade"
26
27 signs = {'activates':1, 'inhibits':-1}
28 A, n2i, dg = sfa.read_sif(fpath, signs=signs, as_nx=True)
29 self._A = A
30 self._n2i = n2i
31 self._dg = dg
32 self._i2n = {idx: name for name, idx in n2i.items()}
33
34 # end of def __init__
35 # end of def class
36
37
38 def gen_basal_states(npert,nnode,pre,opts):
39 numpert=int(npert) # number of perturbations to generate
40 numnodes=int(nnode) # number of nodes to generate them for
41 prefix = pre # prefix for column labels
42 options=opts.split(',')
43 df1=pd.DataFrame()
44 if numpert > len(options)**numnodes:
45 numpert = len(options)**numnodes
46 while len(df1.index)<numpert:
47 temp=np.random.choice(a=options,size=[500000,numnodes]) # set to 200,000 to ensure unique random combinations in a timely manner (or 500,000 to run faster)
48 df1=pd.DataFrame(temp)
49 df1=df1.drop_duplicates()
50 if len(df1.index)>numpert:
51 df1=df1.iloc[0:numpert,]
52 l1=[]
53 for i in range(0,numpert):
54 l1.append(prefix+'_' + str(i+1).rjust(len(str(numpert)), '0'))
55 df1.index=l1
56
57 return df1
58
59
60 if __name__ == "__main__":
61 ## initalize parameters from SFA
62 data = ThreeNodeCascade()
63 algs = sfa.AlgorithmSet()
64 alg = algs.create('SP')
65 alg.data = data
66 alg.params.apply_weight_norm = True
67 alg.initialize()
68 alg.params.exsol_forbidden=True
69 alg.params.alpha=0.9
70
71
72
73 netnodes=list(data.dg.nodes)
74 expnodes=list(set(netnodes) & set(qdata_all.index))
75 samples=gen_basal_states(npert,len(expnodes),'attr','0,-1,1')
76 samples.columns=expnodes
77 n = data.dg.number_of_nodes() #the number of nodes
78 b = np.zeros((n,))
79 switch_num = int(len(samples)/len(phenotypes)) + (len(phenotypes) % len(samples) > 0)
80 l=0
81 logss=pd.DataFrame(index=samples.index,columns=netnodes,copy=True)
82 for i in range(len(phenotypes)):
83 m=switch_num*i
84 qs = all_samples[all_samples.isin([phenotypes[i]]).any(axis=1)]['name'].tolist()
85 qdata=qdata_all.loc[:,qs]
86 pi=[]
87 samples2=samples.iloc[l:switch_num+m]
88 minv=pd.Series(index = qdata.index, data = [np.amin(qdata.loc[node,]) for node in qdata.index])
89 maxv=pd.Series(index = qdata.index, data = [np.amax(qdata.loc[node,]) for node in qdata.index])
90 q1=pd.Series(index = qdata.index, data = [np.quantile(qdata.loc[node,],.33) for node in qdata.index])
91 q2=pd.Series(index = qdata.index, data = [np.quantile(qdata.loc[node,],.66) for node in qdata.index])
92
93 for name, item in samples2.iterrows(): #for each simulated initial condition
94 enodes=item.index.tolist()
95 for node in enodes: # set initial state to simulated value
96 if item.loc[node]==1: # if 1
97 number=np.random.uniform(low=q2[node], high=maxv[node]) #generate a random value for the node in the upper quartile
98 elif item.loc[node]==-1: # if -1
99 number=np.random.uniform(low=minv[node], high=q1[node]) #generate a random value for the node in the lower quartile
100 else: #item.loc[node]==0
101 number=np.random.uniform(low=q1[node], high=q2[node]) #generate a random value for the node in the middle
102 b[data.n2i[node]]=number
103 x = alg.compute(b,pi) # Run SFA calculation
104 logss.loc[name,netnodes]=x[0]
105 l=switch_num+m
106 logss=logss.astype(float).round(3)
107 logss.to_csv('attrs_insilico.tsv', sep=' ',float_format='%.3f',index_label="name",chunksize=10000)