Mercurial > repos > laurenmarazzi > netisce_test
diff tools/myTools/bin/sfa/plot/table_hierarchical_clustering.py @ 1:7e5c71b2e71f draft default tip
Uploaded
author | laurenmarazzi |
---|---|
date | Wed, 22 Dec 2021 16:00:34 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tools/myTools/bin/sfa/plot/table_hierarchical_clustering.py Wed Dec 22 16:00:34 2021 +0000 @@ -0,0 +1,286 @@ + +import numpy as np +import scipy.spatial.distance as distance +import scipy.cluster.hierarchy as sch + +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec + +import seaborn as sns + +from .table_condition import ConditionTable + + +class HierarchicalClusteringTable(ConditionTable): + + def __init__(self, conds, samples, *args, **kwargs): + # Set references for data objects + self._dfs = samples # DataFrame of samples to be clustered. + super().__init__(conds, *args, **kwargs) + self._create_colorbar() + self.column_tick_fontsize = self._table_tick_fontsize + # end of def __init__ + + def _parse_kwargs(self, **kwargs): + """Parse the keyword arguments. + """ + + self._vmin = kwargs.get('vmin', None) + self._vmax = kwargs.get('vmax', None) + + colors_blend = ['red', 'white', np.array([0, 1, 0, 1])] + default_cmap = sns.blend_palette(colors_blend, + n_colors=100, + as_cmap=True) + self._cmap = kwargs.get('cmap', default_cmap) + + self._dim = kwargs.get('dim', (2, 5)) + self._wspace = kwargs.get('wspace', 0.005) + self._hspace = kwargs.get('hspace', 0.005) + + default_width_ratios = [self._dfc.shape[1], + 0.25, + self._dfs.shape[1], + 0.5*self._dfs.shape[1], + 0.05*self._dfs.shape[1]] + + default_height_ratios = [self._dfs.shape[1], + self._dfc.shape[0]] + + self._width_ratios = kwargs.get('width_ratios', + default_width_ratios) + self._height_ratios = kwargs.get('height_ratios', + default_height_ratios) + + default_position = {'condition': np.array([1, 0]), + 'heatmap': np.array([1, 2]), + 'row_dendrogram': np.array([1, 3]), + 'col_dendrogram': np.array([0, 2]), + 'colorbar': np.array([1, 4])} + + self._axes_position = kwargs.get('axes_position', + default_position) + + self._row_cluster = kwargs.get('row_cluster', True) + self._col_cluster = kwargs.get('col_cluster', True) + + if self._row_cluster: + self._row_method = kwargs.get('row_method', 'single') + self._row_metric = kwargs.get('row_metric', 'cityblock') + self._row_dend_linewidth = kwargs.get('row_dend_linewidth', 0.5) + + if self._col_cluster: + self._col_method = kwargs.get('col_method', 'single') + self._col_metric = kwargs.get('col_metric', 'cityblock') + self._col_dend_linewidth = kwargs.get('col_dend_linewidth', 0.5) + + self._table_linewidth = kwargs.get('table_linewidth', 0.5) + self._table_tick_fontsize = kwargs.get('table_tick_fontsize', 5) + self._colorbar_tick_fontsize = kwargs.get('colorbar_tick_fontsize', 5) + + def _create_axes(self): + self._axes = {} + # pos = self._axes_position['heatmap'] + # ax_heatmap = self._fig.add_subplot(self._gridspec[pos[0], pos[1]]) + # self._axes.append(ax_heatmap) + + pos = self._axes_position['condition'] + ax_conds = self._fig.add_subplot(self._gridspec[pos[0], pos[1]]) + self._axes['condition'] = ax_conds + ax_conds.grid(b=False) + ax_conds.set_frame_on(False) + ax_conds.invert_yaxis() + ax_conds.xaxis.tick_bottom() + + self._perform_clustering() + + def _create_tables(self): + super()._create_tables() + ax_heatmap = self._axes['heatmap'] + + # Draw lines on table and heatmap + self.tables[0].linewidth = self._table_linewidth + for x in range(self._dfs.shape[1]+1): + ax_heatmap.axvline(x-0.5, + linewidth=self._table_linewidth, + color='k', zorder=10) + + for y in range(self._dfs.shape[0]+1): + ax_heatmap.axhline(y-0.5, + linewidth=self._table_linewidth, + color='k', zorder=10) + + def _perform_clustering(self): + sch.set_link_color_palette(['black']) + if self._row_cluster: + + row_pairwise_dists = distance.pdist(self._dfs, + metric=self._row_metric) + row_clusters = sch.linkage(row_pairwise_dists, + metric=self._row_metric, + method=self._row_method) + + with plt.rc_context({'lines.linewidth': self._row_dend_linewidth}): + # Dendrogram for row clustering + pos = self._axes_position['row_dendrogram'] + subgs = self._gridspec[pos[0], pos[1]] + ax_row_den = self._fig.add_subplot(subgs) + row_den = sch.dendrogram(row_clusters, + color_threshold=np.inf, + orientation='right') + + ax_row_den.set_facecolor("white") + self._clean_axis(ax_row_den) + self._axes['row_dendrogram'] = ax_row_den + + ind_row = row_den['leaves'] + # Rearrange the DataFrame for condition according to + # the clustering result. + self._dfc = self._dfc.iloc[ind_row, :] + else: + ind_row = range(self._dfs.index.size) #self._dfs.index.ravel() + + if self._col_cluster: + col_pairwise_dists = distance.pdist(self._dfs.T, + metric=self._col_metric) + col_clusters = sch.linkage(col_pairwise_dists, + metric=self._col_metric, + method=self._col_method) + + with plt.rc_context({'lines.linewidth': self._col_dend_linewidth}): + # Dendrogram for column clustering + pos = self._axes_position['col_dendrogram'] + ax_col_den = self._fig.add_subplot(self._gridspec[pos[0], pos[1]]) + col_den = sch.dendrogram(col_clusters, + color_threshold=np.inf, + orientation='top') + ax_col_den.set_facecolor("white") + self._clean_axis(ax_col_den) + self._axes['col_dendrogram'] = ax_col_den + ind_col = col_den['leaves'] + else: + # ind_col = self._dfs.columns.ravel() + ind_col = range(self._dfs.columns.size) + + # Heatmap + pos = self._axes_position['heatmap'] + subgs = self._gridspec[pos[0], pos[1]] + ax_heatmap = self._fig.add_subplot(subgs) + self._heatmap = ax_heatmap.matshow(self._dfs.iloc[ind_row, ind_col], + vmin=self._vmin, + vmax=self._vmax, + interpolation='nearest', + aspect='auto', + #origin='lower', + cmap=self._cmap) + + ax_heatmap.grid(b=False) + ax_heatmap.set_frame_on(True) + ax_heatmap.xaxis.tick_bottom() + self._axes['heatmap'] = ax_heatmap + self._clean_axis(ax_heatmap) + + # Remove the y-labels of condition table + #ax_conds = self._axes['condition'] + + # Add column labels + ax_heatmap.set_xticks(np.arange(0, self._dfs.shape[1], 1)) + ax_heatmap.set_xticklabels(np.array(self._dfs.columns[ind_col]), + rotation=90, minor=False) + + ax_heatmap.tick_params(axis='x', which='major', pad=-2) + + # Remove the tick lines + for line in ax_heatmap.get_xticklines(): + line.set_markersize(0) + + for line in ax_heatmap.get_yticklines(): + line.set_markersize(0) + + def _create_colorbar(self): + pos = self._axes_position['colorbar'] + subgs = self._gridspec[pos[0], pos[1]] + ax_colorbar = self._fig.add_subplot(subgs) + + cb = self._fig.colorbar(self._heatmap, ax_colorbar, + drawedges=False) #True) + self._colorbar = cb + #cb.ax.yaxis.set_ticks_position('right') + #self._clean_axis(cb.ax) + # for sp in cb.ax.spines.values(): + # sp.set_visible(False) + cb.ax.yaxis.set_ticks_position('none') + cb.ax.yaxis.set_tick_params(pad=-2) + cb.ax.yaxis.set_label_position('right') + cb.outline.set_edgecolor('black') + cb.outline.set_linewidth(self._table_linewidth) + self.colorbar_fontsize = self._colorbar_tick_fontsize + + def _clean_axis(self, ax): + """Remove ticks, tick labels, and frame from axis + """ + ax.xaxis.set_ticks_position('none') + ax.yaxis.set_ticks_position('none') + ax.xaxis.set_ticks([]) + ax.yaxis.set_ticks([]) + for sp in ax.spines.values(): + sp.set_visible(False) + + def _add_labels(self): + """Add only column labels for condition table. + """ + tb = self._tables[0] + tb.add_column_labels() + + @property + def table_linewidth(self): + return self._table_linewidth + + @table_linewidth.setter + def table_linewidth(self, val): + self._table_linewidth = val + + @property + def colorbar(self): + return self._colorbar + + @property + def colorbar_fontsize(self): + return self._colorbar_tick_fontsize + + @colorbar_fontsize.setter + def colorbar_fontsize(self, val): + self._colorbar_tick_fontsize = val + ticks = self._colorbar.ax.yaxis.get_ticklabels() + for t in ticks: + t.set_fontsize(self._colorbar_tick_fontsize) + + # @column_tick_fontsize.setter + # def column_tick_fontsize(self, val): + # self._column_tick_fontsize = val + # for ax in self._axes: + # ax.tick_params(axis='x', which='major', + # labelsize=self._column_tick_fontsize) + + # def _set_colors(self, colors): + # super()._set_colors(colors) + # + # def _add_labels(self): + # super()._add_labels() # Add labels for condition table + + + # def _add_column_labels(self): + # """Add column labels using x-axis + # """ + # xlabels = list(self._dfs.columns) + # ax_heatmap.set_xticks(np.arange(self._dfs.shape[0])) + # ax_heatmap.set_xticklabels(xlabels, + # rotation=90, minor=False) + # ax_heatmap.tick_params(axis='x', which='major', pad=3) + # + # # Hide the small bars of ticks + # for tick in self._ax.xaxis.get_major_ticks(): + # tick.tick1On = False + # tick.tick2On = False + # # # end of def \ No newline at end of file