comparison Marea/marea_cluster.py @ 33:abf0bfe01c78 draft

Uploaded
author bimib
date Wed, 16 Oct 2019 16:25:56 -0400
parents 944e15aa970a
children 1a97d1537623
comparison
equal deleted inserted replaced
32:b795e3e163e0 33:abf0bfe01c78
70 type = str, 70 type = str,
71 required = True, 71 required = True,
72 help = 'your tool directory') 72 help = 'your tool directory')
73 73
74 parser.add_argument('-ms', '--min_samples', 74 parser.add_argument('-ms', '--min_samples',
75 type = int, 75 type = float,
76 help = 'min samples for dbscan (optional)') 76 help = 'min samples for dbscan (optional)')
77 77
78 parser.add_argument('-ep', '--eps', 78 parser.add_argument('-ep', '--eps',
79 type = int, 79 type = float,
80 help = 'eps for dbscan (optional)') 80 help = 'eps for dbscan (optional)')
81 81
82 parser.add_argument('-bc', '--best_cluster', 82 parser.add_argument('-bc', '--best_cluster',
83 type = str, 83 type = str,
84 help = 'output of best cluster tsv') 84 help = 'output of best cluster tsv')
308 308
309 plt.savefig(path, bbox_inches='tight') 309 plt.savefig(path, bbox_inches='tight')
310 310
311 ######################## dbscan ############################################## 311 ######################## dbscan ##############################################
312 312
313 def dbscan(dataset, eps, min_samples): 313 def dbscan(dataset, eps, min_samples, best_cluster):
314 if not os.path.exists('clustering'): 314 if not os.path.exists('clustering'):
315 os.makedirs('clustering') 315 os.makedirs('clustering')
316 316
317 if eps is not None: 317 if eps is not None:
318 clusterer = DBSCAN(eps = eps, min_samples = min_samples) 318 clusterer = DBSCAN(eps = eps, min_samples = min_samples)
329 n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) 329 n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
330 330
331 331
332 ##TODO: PLOT SU DBSCAN (no centers) e HIERARCHICAL 332 ##TODO: PLOT SU DBSCAN (no centers) e HIERARCHICAL
333 333
334 334 labels = labels
335 write_to_csv(dataset, labels, 'clustering/dbscan_results.tsv') 335 predict = [x+1 for x in labels]
336 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
337 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
338
336 339
337 ########################## hierachical ####################################### 340 ########################## hierachical #######################################
338 341
339 def hierachical_agglomerative(dataset, k_min, k_max): 342 def hierachical_agglomerative(dataset, k_min, k_max, best_cluster):
340 343
341 if not os.path.exists('clustering'): 344 if not os.path.exists('clustering'):
342 os.makedirs('clustering') 345 os.makedirs('clustering')
343 346
344 plt.figure(figsize=(10, 7)) 347 plt.figure(figsize=(10, 7))
347 fig = plt.gcf() 350 fig = plt.gcf()
348 fig.savefig('clustering/dendogram.png', dpi=200) 351 fig.savefig('clustering/dendogram.png', dpi=200)
349 352
350 range_n_clusters = [i for i in range(k_min, k_max+1)] 353 range_n_clusters = [i for i in range(k_min, k_max+1)]
351 354
352 for n_clusters in range_n_clusters: 355 scores = []
353 356 labels = []
357 for n_clusters in range_n_clusters:
354 cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward') 358 cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward')
355 cluster.fit_predict(dataset) 359 cluster.fit_predict(dataset)
356 cluster_labels = cluster.labels_ 360 cluster_labels = cluster.labels_
357 361 labels.append(cluster_labels)
358 silhouette_avg = silhouette_score(dataset, cluster_labels) 362 silhouette_avg = silhouette_score(dataset, cluster_labels)
359 write_to_csv(dataset, cluster_labels, 'clustering/hierarchical_with_' + str(n_clusters) + '_clusters.tsv') 363 write_to_csv(dataset, cluster_labels, 'clustering/hierarchical_with_' + str(n_clusters) + '_clusters.tsv')
364 scores.append(silhouette_avg)
360 #warning("For n_clusters =", n_clusters, 365 #warning("For n_clusters =", n_clusters,
361 #"The average silhouette_score is :", silhouette_avg) 366 #"The average silhouette_score is :", silhouette_avg)
367
368 best = max_index(scores) + k_min
369
370 for i in range(len(labels)):
371 if (i + k_min == best):
372 labels = labels[i]
373 predict = [x+1 for x in labels]
374 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
375 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
376
362 377
363 378
364 379
365 380
366 381
388 403
389 if args.cluster_type == 'kmeans': 404 if args.cluster_type == 'kmeans':
390 kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies, args.best_cluster) 405 kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies, args.best_cluster)
391 406
392 if args.cluster_type == 'dbscan': 407 if args.cluster_type == 'dbscan':
393 dbscan(X, args.eps, args.min_samples) 408 dbscan(X, args.eps, args.min_samples, args.best_cluster)
394 409
395 if args.cluster_type == 'hierarchy': 410 if args.cluster_type == 'hierarchy':
396 hierachical_agglomerative(X, args.k_min, args.k_max) 411 hierachical_agglomerative(X, args.k_min, args.k_max, args.best_cluster)
397 412
398 ############################################################################## 413 ##############################################################################
399 414
400 if __name__ == "__main__": 415 if __name__ == "__main__":
401 main() 416 main()