1
|
1 # -*- coding: utf-8 -*-
|
|
2
|
|
3 import matplotlib
|
|
4 import matplotlib.pyplot as plt
|
|
5 import seaborn as sns
|
|
6
|
|
7 from .base import BaseGridPlot
|
|
8
|
|
9
|
|
10 class Heatmap(BaseGridPlot):
|
|
11
|
|
12 def __init__(self, df, *args, fmt='.3f',
|
|
13 cmap=None, vmin=0.0, vmax=1.0,
|
|
14 annot=True, **kwargs):
|
|
15
|
|
16 super().__init__(*args, **kwargs)
|
|
17
|
|
18 # Set references for data objects
|
|
19 self._df = df # A dataFrame
|
|
20 sns.heatmap(self._df,
|
|
21 ax=self._axes['base'],
|
|
22 annot=annot,
|
|
23 fmt=fmt,
|
|
24 annot_kws={"size": 10},
|
|
25 linecolor=self._colors['table_edge_color'],
|
|
26 vmin=vmin, vmax=vmax,
|
|
27 cmap=cmap,
|
|
28 cbar_kws={"orientation": "vertical",
|
|
29 "pad": 0.02, })
|
|
30
|
|
31
|
|
32 self._qm = self._axes['heatmap'].collections[0]
|
|
33 #self._qm.set_edgecolor(self._colors['table_edge_color'])
|
|
34
|
|
35 # Change the labelsize
|
|
36 cb = self._qm.colorbar
|
|
37 cb.outline.set_linewidth(0.5)
|
|
38 cb.outline.set_edgecolor(self._colors['table_edge_color'])
|
|
39 cb_ax_asp = cb.ax.get_aspect()
|
|
40 cb.ax.set_aspect(cb_ax_asp * 2.0)
|
|
41
|
|
42 # Remove inner lines
|
|
43 children = cb.ax.get_children()
|
|
44 for child in children:
|
|
45 if isinstance(child, matplotlib.collections.LineCollection):
|
|
46 child.set_linewidth(0)
|
|
47
|
|
48 self._axes['heatmap'].xaxis.tick_top()
|
|
49 plt.xticks(rotation=90)
|
|
50 plt.yticks(rotation=0)
|
|
51
|
|
52 self._axes['heatmap'].tick_params(axis='x', which='major', pad=-2)
|
|
53 self._axes['heatmap'].tick_params(axis='y', which='major', pad=3)
|
|
54
|
|
55 # Hide axis labels
|
|
56 self._axes['heatmap'].set_xlabel('')
|
|
57 self._axes['heatmap'].set_ylabel('')
|
|
58
|
|
59 # Text element of the heatmap object
|
|
60 self._texts = []
|
|
61 ch = self._axes['heatmap'].get_children()
|
|
62 for child in ch:
|
|
63 if isinstance(child, matplotlib.text.Text):
|
|
64 if child.get_text() != '':
|
|
65 self._texts.append(child)
|
|
66
|
|
67 # Set default values using properties
|
|
68 self.row_tick_fontsize = 10
|
|
69 self.column_tick_fontsize = 10
|
|
70 self.colorbar_label_fontsize = 10
|
|
71 self.linewidth = 0.5
|
|
72
|
|
73 # end of __init__
|
|
74
|
|
75 def _set_colors(self, colors):
|
|
76 """Assign default color values for heatmap and colorbar
|
|
77 """
|
|
78 self._set_default_color('table_edge_color', 'black')
|
|
79 self._set_default_color('colorbar_edge_color', 'black')
|
|
80
|
|
81 def _create_axes(self):
|
|
82 super()._create_axes()
|
|
83 ax = self._axes['base']
|
|
84 self._axes['heatmap'] = ax
|
|
85 #del self._axes['base']
|
|
86
|
|
87 # Properties
|
|
88 @property
|
|
89 def text_fontsize(self):
|
|
90 return self._text_fontsize
|
|
91
|
|
92 @text_fontsize.setter
|
|
93 def text_fontsize(self, val):
|
|
94 """Resize text fonts
|
|
95 """
|
|
96 self._text_fontsize = val
|
|
97 for t in self._texts:
|
|
98 t.set_fontsize(val)
|
|
99
|
|
100 @property
|
|
101 def column_tick_fontsize(self):
|
|
102 return self._column_tick_fontsize
|
|
103
|
|
104 @column_tick_fontsize.setter
|
|
105 def column_tick_fontsize(self, val):
|
|
106 self._column_tick_fontsize = val
|
|
107 self._axes['heatmap'].tick_params(
|
|
108 axis='x',
|
|
109 which='major',
|
|
110 labelsize=self._column_tick_fontsize)
|
|
111
|
|
112 @property
|
|
113 def row_tick_fontsize(self):
|
|
114 return self._row_tick_fontsize
|
|
115
|
|
116 @row_tick_fontsize.setter
|
|
117 def row_tick_fontsize(self, val):
|
|
118 self._row_tick_fontsize = val
|
|
119 self._axes['heatmap'].tick_params(
|
|
120 axis='y',
|
|
121 which='major',
|
|
122 labelsize=self._row_tick_fontsize)
|
|
123
|
|
124 @property
|
|
125 def colorbar_tick_fontsize(self):
|
|
126 return self._colorbar_tick_fontsize
|
|
127
|
|
128 @colorbar_tick_fontsize.setter
|
|
129 def colorbar_tick_fontsize(self, val):
|
|
130 self._colorbar_tick_fontsize = val
|
|
131 self._qm.colorbar.ax.tick_params(axis='y', labelsize=val)
|
|
132
|
|
133 @property
|
|
134 def linewidth(self):
|
|
135 return self._linewidth
|
|
136
|
|
137 @linewidth.setter
|
|
138 def linewidth(self, val):
|
|
139 """Adjust the width of table lines
|
|
140 """
|
|
141 self._linewidth = val
|
|
142 self._qm.set_linewidth(self._linewidth)
|
|
143
|
|
144 # Read-only properties
|
|
145 @property
|
|
146 def colorbar(self):
|
|
147 return self._qm.colorbar |