Mercurial > repos > laurenmarazzi > netisce_test
diff tools/myTools/bin/class_and_consensus.py @ 1:7e5c71b2e71f draft default tip
Uploaded
author | laurenmarazzi |
---|---|
date | Wed, 22 Dec 2021 16:00:34 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tools/myTools/bin/class_and_consensus.py Wed Dec 22 16:00:34 2021 +0000 @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +import pandas as pd +from sklearn.naive_bayes import GaussianNB +from sklearn.ensemble import RandomForestClassifier +from sklearn.svm import SVC +from collections import Counter +import sys +def main(): + train=sys.argv[1].split(',') #train attractors + df_attr=pd.DataFrame() + for j in train: + dfj=pd.read_csv(j, delim_whitespace=True,index_col = ["name"]) + df_attr=pd.concat([df_attr,dfj],axis=0) + #df_attr=df_attr.drop_duplicates() + df_perturb=pd.read_csv(sys.argv[2], delim_whitespace=True,index_col=[0,1]) + + df_labels=pd.read_csv(sys.argv[3], delim_whitespace=True,index_col = ["name"]) # kmeans results + labels=df_labels['clusters'].tolist() + + gnb = GaussianNB() + gnb.fit(df_attr,labels) #do knn with attractor landscape + perturb_lab=gnb.predict(df_perturb) # predict clusters for perturbations + #create dataframe of perturabtion and clusters + NB_label=pd.DataFrame(index=df_perturb.index) + NB_label['clusters']=perturb_lab + + regressor = RandomForestClassifier(n_estimators=100, random_state=1) + regressor.fit(df_attr,labels) #do knn with attractor landscape + perturb_lab=regressor.predict(df_perturb) # predict clusters for perturbations + #create dataframe of perturabtion and clusters + RF_label=pd.DataFrame(index=df_perturb.index) + RF_label['clusters']=perturb_lab + + + svm=SVC(gamma='auto',random_state=4) #intialize knn + svm.fit(df_attr,labels) #do knn with attractor landscape + perturb_lab=svm.predict(df_perturb) # predict clusters for perturbations + #create dataframe of perturabtion and clusters + SVM_label=pd.DataFrame(index=df_perturb.index) + SVM_label['clusters']=perturb_lab + + consensus=[] + + for df in (SVM_label,NB_label,RF_label): + df=df.where(df==0, None) + df['count']=df.apply(lambda x: x.count(), axis=1) + if df.index.nlevels>1: + consensus.append(df.loc[df['count']>=1].index.get_level_values('perturbation').to_list()) + else: + consensus.append(df.loc[df['count']>1].index.to_list()) + cencount=Counter(x for sublist in consensus for x in sublist) + + if df.index.nlevels>1: + keys = [k for k, v in cencount.items() if v >=len(df.index.unique('replicate').tolist())*2] + else: + keys = [k for k, v in cencount.items()] + + with open('crit1perts.txt','w') as f: + for item in keys: + f.write("%s\n" % item) + +main() \ No newline at end of file