Mercurial > repos > laurenmarazzi > netisce_test
comparison tools/myTools/bin/class_and_consensus.py @ 1:7e5c71b2e71f draft default tip
Uploaded
| author | laurenmarazzi |
|---|---|
| date | Wed, 22 Dec 2021 16:00:34 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 0:f24d4892aaed | 1:7e5c71b2e71f |
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 import pandas as pd | |
| 3 from sklearn.naive_bayes import GaussianNB | |
| 4 from sklearn.ensemble import RandomForestClassifier | |
| 5 from sklearn.svm import SVC | |
| 6 from collections import Counter | |
| 7 import sys | |
| 8 def main(): | |
| 9 train=sys.argv[1].split(',') #train attractors | |
| 10 df_attr=pd.DataFrame() | |
| 11 for j in train: | |
| 12 dfj=pd.read_csv(j, delim_whitespace=True,index_col = ["name"]) | |
| 13 df_attr=pd.concat([df_attr,dfj],axis=0) | |
| 14 #df_attr=df_attr.drop_duplicates() | |
| 15 df_perturb=pd.read_csv(sys.argv[2], delim_whitespace=True,index_col=[0,1]) | |
| 16 | |
| 17 df_labels=pd.read_csv(sys.argv[3], delim_whitespace=True,index_col = ["name"]) # kmeans results | |
| 18 labels=df_labels['clusters'].tolist() | |
| 19 | |
| 20 gnb = GaussianNB() | |
| 21 gnb.fit(df_attr,labels) #do knn with attractor landscape | |
| 22 perturb_lab=gnb.predict(df_perturb) # predict clusters for perturbations | |
| 23 #create dataframe of perturabtion and clusters | |
| 24 NB_label=pd.DataFrame(index=df_perturb.index) | |
| 25 NB_label['clusters']=perturb_lab | |
| 26 | |
| 27 regressor = RandomForestClassifier(n_estimators=100, random_state=1) | |
| 28 regressor.fit(df_attr,labels) #do knn with attractor landscape | |
| 29 perturb_lab=regressor.predict(df_perturb) # predict clusters for perturbations | |
| 30 #create dataframe of perturabtion and clusters | |
| 31 RF_label=pd.DataFrame(index=df_perturb.index) | |
| 32 RF_label['clusters']=perturb_lab | |
| 33 | |
| 34 | |
| 35 svm=SVC(gamma='auto',random_state=4) #intialize knn | |
| 36 svm.fit(df_attr,labels) #do knn with attractor landscape | |
| 37 perturb_lab=svm.predict(df_perturb) # predict clusters for perturbations | |
| 38 #create dataframe of perturabtion and clusters | |
| 39 SVM_label=pd.DataFrame(index=df_perturb.index) | |
| 40 SVM_label['clusters']=perturb_lab | |
| 41 | |
| 42 consensus=[] | |
| 43 | |
| 44 for df in (SVM_label,NB_label,RF_label): | |
| 45 df=df.where(df==0, None) | |
| 46 df['count']=df.apply(lambda x: x.count(), axis=1) | |
| 47 if df.index.nlevels>1: | |
| 48 consensus.append(df.loc[df['count']>=1].index.get_level_values('perturbation').to_list()) | |
| 49 else: | |
| 50 consensus.append(df.loc[df['count']>1].index.to_list()) | |
| 51 cencount=Counter(x for sublist in consensus for x in sublist) | |
| 52 | |
| 53 if df.index.nlevels>1: | |
| 54 keys = [k for k, v in cencount.items() if v >=len(df.index.unique('replicate').tolist())*2] | |
| 55 else: | |
| 56 keys = [k for k, v in cencount.items()] | |
| 57 | |
| 58 with open('crit1perts.txt','w') as f: | |
| 59 for item in keys: | |
| 60 f.write("%s\n" % item) | |
| 61 | |
| 62 main() |
