Mercurial > repos > laurenmarazzi > netisce_test
comparison tools/myTools/bin/sfa/plot/table_hierarchical_clustering.py @ 1:7e5c71b2e71f draft default tip
Uploaded
author | laurenmarazzi |
---|---|
date | Wed, 22 Dec 2021 16:00:34 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:f24d4892aaed | 1:7e5c71b2e71f |
---|---|
1 | |
2 import numpy as np | |
3 import scipy.spatial.distance as distance | |
4 import scipy.cluster.hierarchy as sch | |
5 | |
6 import matplotlib | |
7 import matplotlib.pyplot as plt | |
8 import matplotlib.gridspec as gridspec | |
9 | |
10 import seaborn as sns | |
11 | |
12 from .table_condition import ConditionTable | |
13 | |
14 | |
15 class HierarchicalClusteringTable(ConditionTable): | |
16 | |
17 def __init__(self, conds, samples, *args, **kwargs): | |
18 # Set references for data objects | |
19 self._dfs = samples # DataFrame of samples to be clustered. | |
20 super().__init__(conds, *args, **kwargs) | |
21 self._create_colorbar() | |
22 self.column_tick_fontsize = self._table_tick_fontsize | |
23 # end of def __init__ | |
24 | |
25 def _parse_kwargs(self, **kwargs): | |
26 """Parse the keyword arguments. | |
27 """ | |
28 | |
29 self._vmin = kwargs.get('vmin', None) | |
30 self._vmax = kwargs.get('vmax', None) | |
31 | |
32 colors_blend = ['red', 'white', np.array([0, 1, 0, 1])] | |
33 default_cmap = sns.blend_palette(colors_blend, | |
34 n_colors=100, | |
35 as_cmap=True) | |
36 self._cmap = kwargs.get('cmap', default_cmap) | |
37 | |
38 self._dim = kwargs.get('dim', (2, 5)) | |
39 self._wspace = kwargs.get('wspace', 0.005) | |
40 self._hspace = kwargs.get('hspace', 0.005) | |
41 | |
42 default_width_ratios = [self._dfc.shape[1], | |
43 0.25, | |
44 self._dfs.shape[1], | |
45 0.5*self._dfs.shape[1], | |
46 0.05*self._dfs.shape[1]] | |
47 | |
48 default_height_ratios = [self._dfs.shape[1], | |
49 self._dfc.shape[0]] | |
50 | |
51 self._width_ratios = kwargs.get('width_ratios', | |
52 default_width_ratios) | |
53 self._height_ratios = kwargs.get('height_ratios', | |
54 default_height_ratios) | |
55 | |
56 default_position = {'condition': np.array([1, 0]), | |
57 'heatmap': np.array([1, 2]), | |
58 'row_dendrogram': np.array([1, 3]), | |
59 'col_dendrogram': np.array([0, 2]), | |
60 'colorbar': np.array([1, 4])} | |
61 | |
62 self._axes_position = kwargs.get('axes_position', | |
63 default_position) | |
64 | |
65 self._row_cluster = kwargs.get('row_cluster', True) | |
66 self._col_cluster = kwargs.get('col_cluster', True) | |
67 | |
68 if self._row_cluster: | |
69 self._row_method = kwargs.get('row_method', 'single') | |
70 self._row_metric = kwargs.get('row_metric', 'cityblock') | |
71 self._row_dend_linewidth = kwargs.get('row_dend_linewidth', 0.5) | |
72 | |
73 if self._col_cluster: | |
74 self._col_method = kwargs.get('col_method', 'single') | |
75 self._col_metric = kwargs.get('col_metric', 'cityblock') | |
76 self._col_dend_linewidth = kwargs.get('col_dend_linewidth', 0.5) | |
77 | |
78 self._table_linewidth = kwargs.get('table_linewidth', 0.5) | |
79 self._table_tick_fontsize = kwargs.get('table_tick_fontsize', 5) | |
80 self._colorbar_tick_fontsize = kwargs.get('colorbar_tick_fontsize', 5) | |
81 | |
82 def _create_axes(self): | |
83 self._axes = {} | |
84 # pos = self._axes_position['heatmap'] | |
85 # ax_heatmap = self._fig.add_subplot(self._gridspec[pos[0], pos[1]]) | |
86 # self._axes.append(ax_heatmap) | |
87 | |
88 pos = self._axes_position['condition'] | |
89 ax_conds = self._fig.add_subplot(self._gridspec[pos[0], pos[1]]) | |
90 self._axes['condition'] = ax_conds | |
91 ax_conds.grid(b=False) | |
92 ax_conds.set_frame_on(False) | |
93 ax_conds.invert_yaxis() | |
94 ax_conds.xaxis.tick_bottom() | |
95 | |
96 self._perform_clustering() | |
97 | |
98 def _create_tables(self): | |
99 super()._create_tables() | |
100 ax_heatmap = self._axes['heatmap'] | |
101 | |
102 # Draw lines on table and heatmap | |
103 self.tables[0].linewidth = self._table_linewidth | |
104 for x in range(self._dfs.shape[1]+1): | |
105 ax_heatmap.axvline(x-0.5, | |
106 linewidth=self._table_linewidth, | |
107 color='k', zorder=10) | |
108 | |
109 for y in range(self._dfs.shape[0]+1): | |
110 ax_heatmap.axhline(y-0.5, | |
111 linewidth=self._table_linewidth, | |
112 color='k', zorder=10) | |
113 | |
114 def _perform_clustering(self): | |
115 sch.set_link_color_palette(['black']) | |
116 if self._row_cluster: | |
117 | |
118 row_pairwise_dists = distance.pdist(self._dfs, | |
119 metric=self._row_metric) | |
120 row_clusters = sch.linkage(row_pairwise_dists, | |
121 metric=self._row_metric, | |
122 method=self._row_method) | |
123 | |
124 with plt.rc_context({'lines.linewidth': self._row_dend_linewidth}): | |
125 # Dendrogram for row clustering | |
126 pos = self._axes_position['row_dendrogram'] | |
127 subgs = self._gridspec[pos[0], pos[1]] | |
128 ax_row_den = self._fig.add_subplot(subgs) | |
129 row_den = sch.dendrogram(row_clusters, | |
130 color_threshold=np.inf, | |
131 orientation='right') | |
132 | |
133 ax_row_den.set_facecolor("white") | |
134 self._clean_axis(ax_row_den) | |
135 self._axes['row_dendrogram'] = ax_row_den | |
136 | |
137 ind_row = row_den['leaves'] | |
138 # Rearrange the DataFrame for condition according to | |
139 # the clustering result. | |
140 self._dfc = self._dfc.iloc[ind_row, :] | |
141 else: | |
142 ind_row = range(self._dfs.index.size) #self._dfs.index.ravel() | |
143 | |
144 if self._col_cluster: | |
145 col_pairwise_dists = distance.pdist(self._dfs.T, | |
146 metric=self._col_metric) | |
147 col_clusters = sch.linkage(col_pairwise_dists, | |
148 metric=self._col_metric, | |
149 method=self._col_method) | |
150 | |
151 with plt.rc_context({'lines.linewidth': self._col_dend_linewidth}): | |
152 # Dendrogram for column clustering | |
153 pos = self._axes_position['col_dendrogram'] | |
154 ax_col_den = self._fig.add_subplot(self._gridspec[pos[0], pos[1]]) | |
155 col_den = sch.dendrogram(col_clusters, | |
156 color_threshold=np.inf, | |
157 orientation='top') | |
158 ax_col_den.set_facecolor("white") | |
159 self._clean_axis(ax_col_den) | |
160 self._axes['col_dendrogram'] = ax_col_den | |
161 ind_col = col_den['leaves'] | |
162 else: | |
163 # ind_col = self._dfs.columns.ravel() | |
164 ind_col = range(self._dfs.columns.size) | |
165 | |
166 # Heatmap | |
167 pos = self._axes_position['heatmap'] | |
168 subgs = self._gridspec[pos[0], pos[1]] | |
169 ax_heatmap = self._fig.add_subplot(subgs) | |
170 self._heatmap = ax_heatmap.matshow(self._dfs.iloc[ind_row, ind_col], | |
171 vmin=self._vmin, | |
172 vmax=self._vmax, | |
173 interpolation='nearest', | |
174 aspect='auto', | |
175 #origin='lower', | |
176 cmap=self._cmap) | |
177 | |
178 ax_heatmap.grid(b=False) | |
179 ax_heatmap.set_frame_on(True) | |
180 ax_heatmap.xaxis.tick_bottom() | |
181 self._axes['heatmap'] = ax_heatmap | |
182 self._clean_axis(ax_heatmap) | |
183 | |
184 # Remove the y-labels of condition table | |
185 #ax_conds = self._axes['condition'] | |
186 | |
187 # Add column labels | |
188 ax_heatmap.set_xticks(np.arange(0, self._dfs.shape[1], 1)) | |
189 ax_heatmap.set_xticklabels(np.array(self._dfs.columns[ind_col]), | |
190 rotation=90, minor=False) | |
191 | |
192 ax_heatmap.tick_params(axis='x', which='major', pad=-2) | |
193 | |
194 # Remove the tick lines | |
195 for line in ax_heatmap.get_xticklines(): | |
196 line.set_markersize(0) | |
197 | |
198 for line in ax_heatmap.get_yticklines(): | |
199 line.set_markersize(0) | |
200 | |
201 def _create_colorbar(self): | |
202 pos = self._axes_position['colorbar'] | |
203 subgs = self._gridspec[pos[0], pos[1]] | |
204 ax_colorbar = self._fig.add_subplot(subgs) | |
205 | |
206 cb = self._fig.colorbar(self._heatmap, ax_colorbar, | |
207 drawedges=False) #True) | |
208 self._colorbar = cb | |
209 #cb.ax.yaxis.set_ticks_position('right') | |
210 #self._clean_axis(cb.ax) | |
211 # for sp in cb.ax.spines.values(): | |
212 # sp.set_visible(False) | |
213 cb.ax.yaxis.set_ticks_position('none') | |
214 cb.ax.yaxis.set_tick_params(pad=-2) | |
215 cb.ax.yaxis.set_label_position('right') | |
216 cb.outline.set_edgecolor('black') | |
217 cb.outline.set_linewidth(self._table_linewidth) | |
218 self.colorbar_fontsize = self._colorbar_tick_fontsize | |
219 | |
220 def _clean_axis(self, ax): | |
221 """Remove ticks, tick labels, and frame from axis | |
222 """ | |
223 ax.xaxis.set_ticks_position('none') | |
224 ax.yaxis.set_ticks_position('none') | |
225 ax.xaxis.set_ticks([]) | |
226 ax.yaxis.set_ticks([]) | |
227 for sp in ax.spines.values(): | |
228 sp.set_visible(False) | |
229 | |
230 def _add_labels(self): | |
231 """Add only column labels for condition table. | |
232 """ | |
233 tb = self._tables[0] | |
234 tb.add_column_labels() | |
235 | |
236 @property | |
237 def table_linewidth(self): | |
238 return self._table_linewidth | |
239 | |
240 @table_linewidth.setter | |
241 def table_linewidth(self, val): | |
242 self._table_linewidth = val | |
243 | |
244 @property | |
245 def colorbar(self): | |
246 return self._colorbar | |
247 | |
248 @property | |
249 def colorbar_fontsize(self): | |
250 return self._colorbar_tick_fontsize | |
251 | |
252 @colorbar_fontsize.setter | |
253 def colorbar_fontsize(self, val): | |
254 self._colorbar_tick_fontsize = val | |
255 ticks = self._colorbar.ax.yaxis.get_ticklabels() | |
256 for t in ticks: | |
257 t.set_fontsize(self._colorbar_tick_fontsize) | |
258 | |
259 # @column_tick_fontsize.setter | |
260 # def column_tick_fontsize(self, val): | |
261 # self._column_tick_fontsize = val | |
262 # for ax in self._axes: | |
263 # ax.tick_params(axis='x', which='major', | |
264 # labelsize=self._column_tick_fontsize) | |
265 | |
266 # def _set_colors(self, colors): | |
267 # super()._set_colors(colors) | |
268 # | |
269 # def _add_labels(self): | |
270 # super()._add_labels() # Add labels for condition table | |
271 | |
272 | |
273 # def _add_column_labels(self): | |
274 # """Add column labels using x-axis | |
275 # """ | |
276 # xlabels = list(self._dfs.columns) | |
277 # ax_heatmap.set_xticks(np.arange(self._dfs.shape[0])) | |
278 # ax_heatmap.set_xticklabels(xlabels, | |
279 # rotation=90, minor=False) | |
280 # ax_heatmap.tick_params(axis='x', which='major', pad=3) | |
281 # | |
282 # # Hide the small bars of ticks | |
283 # for tick in self._ax.xaxis.get_major_ticks(): | |
284 # tick.tick1On = False | |
285 # tick.tick2On = False | |
286 # # # end of def |