Mercurial > repos > laurenmarazzi > netisce_test
comparison tools/myTools/bin/class_and_consensus.py @ 1:7e5c71b2e71f draft default tip
Uploaded
author | laurenmarazzi |
---|---|
date | Wed, 22 Dec 2021 16:00:34 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:f24d4892aaed | 1:7e5c71b2e71f |
---|---|
1 #!/usr/bin/env python3 | |
2 import pandas as pd | |
3 from sklearn.naive_bayes import GaussianNB | |
4 from sklearn.ensemble import RandomForestClassifier | |
5 from sklearn.svm import SVC | |
6 from collections import Counter | |
7 import sys | |
8 def main(): | |
9 train=sys.argv[1].split(',') #train attractors | |
10 df_attr=pd.DataFrame() | |
11 for j in train: | |
12 dfj=pd.read_csv(j, delim_whitespace=True,index_col = ["name"]) | |
13 df_attr=pd.concat([df_attr,dfj],axis=0) | |
14 #df_attr=df_attr.drop_duplicates() | |
15 df_perturb=pd.read_csv(sys.argv[2], delim_whitespace=True,index_col=[0,1]) | |
16 | |
17 df_labels=pd.read_csv(sys.argv[3], delim_whitespace=True,index_col = ["name"]) # kmeans results | |
18 labels=df_labels['clusters'].tolist() | |
19 | |
20 gnb = GaussianNB() | |
21 gnb.fit(df_attr,labels) #do knn with attractor landscape | |
22 perturb_lab=gnb.predict(df_perturb) # predict clusters for perturbations | |
23 #create dataframe of perturabtion and clusters | |
24 NB_label=pd.DataFrame(index=df_perturb.index) | |
25 NB_label['clusters']=perturb_lab | |
26 | |
27 regressor = RandomForestClassifier(n_estimators=100, random_state=1) | |
28 regressor.fit(df_attr,labels) #do knn with attractor landscape | |
29 perturb_lab=regressor.predict(df_perturb) # predict clusters for perturbations | |
30 #create dataframe of perturabtion and clusters | |
31 RF_label=pd.DataFrame(index=df_perturb.index) | |
32 RF_label['clusters']=perturb_lab | |
33 | |
34 | |
35 svm=SVC(gamma='auto',random_state=4) #intialize knn | |
36 svm.fit(df_attr,labels) #do knn with attractor landscape | |
37 perturb_lab=svm.predict(df_perturb) # predict clusters for perturbations | |
38 #create dataframe of perturabtion and clusters | |
39 SVM_label=pd.DataFrame(index=df_perturb.index) | |
40 SVM_label['clusters']=perturb_lab | |
41 | |
42 consensus=[] | |
43 | |
44 for df in (SVM_label,NB_label,RF_label): | |
45 df=df.where(df==0, None) | |
46 df['count']=df.apply(lambda x: x.count(), axis=1) | |
47 if df.index.nlevels>1: | |
48 consensus.append(df.loc[df['count']>=1].index.get_level_values('perturbation').to_list()) | |
49 else: | |
50 consensus.append(df.loc[df['count']>1].index.to_list()) | |
51 cencount=Counter(x for sublist in consensus for x in sublist) | |
52 | |
53 if df.index.nlevels>1: | |
54 keys = [k for k, v in cencount.items() if v >=len(df.index.unique('replicate').tolist())*2] | |
55 else: | |
56 keys = [k for k, v in cencount.items()] | |
57 | |
58 with open('crit1perts.txt','w') as f: | |
59 for item in keys: | |
60 f.write("%s\n" % item) | |
61 | |
62 main() |