comparison tools/myTools/bin/class_and_consensus.py @ 1:7e5c71b2e71f draft default tip

Uploaded
author laurenmarazzi
date Wed, 22 Dec 2021 16:00:34 +0000
parents
children
comparison
equal deleted inserted replaced
0:f24d4892aaed 1:7e5c71b2e71f
1 #!/usr/bin/env python3
2 import pandas as pd
3 from sklearn.naive_bayes import GaussianNB
4 from sklearn.ensemble import RandomForestClassifier
5 from sklearn.svm import SVC
6 from collections import Counter
7 import sys
8 def main():
9 train=sys.argv[1].split(',') #train attractors
10 df_attr=pd.DataFrame()
11 for j in train:
12 dfj=pd.read_csv(j, delim_whitespace=True,index_col = ["name"])
13 df_attr=pd.concat([df_attr,dfj],axis=0)
14 #df_attr=df_attr.drop_duplicates()
15 df_perturb=pd.read_csv(sys.argv[2], delim_whitespace=True,index_col=[0,1])
16
17 df_labels=pd.read_csv(sys.argv[3], delim_whitespace=True,index_col = ["name"]) # kmeans results
18 labels=df_labels['clusters'].tolist()
19
20 gnb = GaussianNB()
21 gnb.fit(df_attr,labels) #do knn with attractor landscape
22 perturb_lab=gnb.predict(df_perturb) # predict clusters for perturbations
23 #create dataframe of perturabtion and clusters
24 NB_label=pd.DataFrame(index=df_perturb.index)
25 NB_label['clusters']=perturb_lab
26
27 regressor = RandomForestClassifier(n_estimators=100, random_state=1)
28 regressor.fit(df_attr,labels) #do knn with attractor landscape
29 perturb_lab=regressor.predict(df_perturb) # predict clusters for perturbations
30 #create dataframe of perturabtion and clusters
31 RF_label=pd.DataFrame(index=df_perturb.index)
32 RF_label['clusters']=perturb_lab
33
34
35 svm=SVC(gamma='auto',random_state=4) #intialize knn
36 svm.fit(df_attr,labels) #do knn with attractor landscape
37 perturb_lab=svm.predict(df_perturb) # predict clusters for perturbations
38 #create dataframe of perturabtion and clusters
39 SVM_label=pd.DataFrame(index=df_perturb.index)
40 SVM_label['clusters']=perturb_lab
41
42 consensus=[]
43
44 for df in (SVM_label,NB_label,RF_label):
45 df=df.where(df==0, None)
46 df['count']=df.apply(lambda x: x.count(), axis=1)
47 if df.index.nlevels>1:
48 consensus.append(df.loc[df['count']>=1].index.get_level_values('perturbation').to_list())
49 else:
50 consensus.append(df.loc[df['count']>1].index.to_list())
51 cencount=Counter(x for sublist in consensus for x in sublist)
52
53 if df.index.nlevels>1:
54 keys = [k for k, v in cencount.items() if v >=len(df.index.unique('replicate').tolist())*2]
55 else:
56 keys = [k for k, v in cencount.items()]
57
58 with open('crit1perts.txt','w') as f:
59 for item in keys:
60 f.write("%s\n" % item)
61
62 main()