1
|
1
|
|
2 from collections import Counter
|
|
3
|
|
4 import numpy as np
|
|
5 import pandas as pd
|
|
6 import matplotlib.gridspec as gridspec
|
|
7 import matplotlib.pyplot as plt
|
|
8 from matplotlib.patches import Rectangle
|
|
9 from matplotlib.ticker import FormatStrFormatter
|
|
10 from matplotlib import rcParams
|
|
11
|
|
12 rcParams['font.family'] = 'sans-serif'
|
|
13 rcParams['font.sans-serif'] = ['Arial']
|
|
14
|
|
15 def siplot(df_splo,
|
|
16 df_inf,
|
|
17 output,
|
|
18 min_splo=None,
|
|
19 max_splo=None,
|
|
20 thr_inf=1e-10,
|
|
21 fmt_inf='%f',
|
|
22 fig=None,
|
|
23 cnt_max=None,
|
|
24 ncol=4,
|
|
25 designated=None,
|
|
26 color='silver',
|
|
27 dcolor='red',
|
|
28 zcolor='red',
|
|
29 alpha=0.7,
|
|
30 xfontsize=8,
|
|
31 yfontsize=8):
|
|
32
|
|
33 # SPLO-Influence Data
|
|
34 if not min_splo:
|
|
35 min_splo = df_splo.min()
|
|
36
|
|
37 if not max_splo:
|
|
38 max_splo = df_splo.max()
|
|
39
|
|
40 mask_splo = (min_splo <= df_splo) & (df_splo <= max_splo)
|
|
41 df_splo = df_splo[mask_splo]
|
|
42
|
|
43 df_splo = pd.DataFrame(df_splo)
|
|
44 df_splo.columns = ['SPLO']
|
|
45
|
|
46
|
|
47 if output in df_splo.index:
|
|
48 df_splo.drop(output, inplace=True)
|
|
49
|
|
50 index_common = df_splo.index.intersection(df_inf.index)
|
|
51 df_inf = pd.DataFrame(df_inf.loc[index_common])
|
|
52
|
|
53 mark_drop = df_inf[output].abs() <= thr_inf
|
|
54 df_inf.drop(df_inf.loc[mark_drop, output].index,
|
|
55 inplace=True)
|
|
56
|
|
57
|
|
58 df_si = df_inf.join(df_splo.loc[index_common])
|
|
59 df_si.index.name = 'Source'
|
|
60 df_si.reset_index(inplace=True)
|
|
61
|
|
62 cnt_splo = Counter(df_si['SPLO'])
|
|
63 if not cnt_max:
|
|
64 cnt_max = max(cnt_splo.values())
|
|
65
|
|
66 splos = sorted(cnt_splo.keys())
|
|
67 nrow = int(np.ceil(len(splos)/ncol))
|
|
68
|
|
69 # Plot
|
|
70 if not fig:
|
|
71 fig = plt.figure()
|
|
72
|
|
73 gs = gridspec.GridSpec(nrow, ncol)
|
|
74
|
|
75 yvals = np.arange(1, cnt_max +1)
|
|
76 for i, splo in enumerate(splos):
|
|
77 idx_row = int(i / ncol)
|
|
78 idx_col = int(i % ncol)
|
|
79 ax = fig.add_subplot(gs[idx_row, idx_col])
|
|
80 df_sub = df_si[df_si['SPLO'] == splo]
|
|
81 df_sub = df_sub.sort_values(by=output)
|
|
82 num_items = df_sub[output].count()
|
|
83
|
|
84 influence = np.zeros((cnt_max,)) # Influence
|
|
85 num_empty = cnt_max - num_items
|
|
86 influence[num_empty:] = df_sub[output]
|
|
87 names = df_sub['Source'].tolist()
|
|
88 names = ['' ] *(num_empty) + names
|
|
89
|
|
90 # Plot bars
|
|
91 plt.barh(yvals, influence, align='center',
|
|
92 alpha=alpha)
|
|
93
|
|
94 ax.set_title('SPLO=%d'%(splo))
|
|
95 ax.set_xlabel('')
|
|
96
|
|
97 ax.xaxis.set_major_formatter(FormatStrFormatter(fmt_inf))
|
|
98 ax.tick_params(axis='x',
|
|
99 which='major',
|
|
100 labelsize=xfontsize)
|
|
101
|
|
102 ax.set_ylabel('')
|
|
103 ax.yaxis.set_ticks_position('right')
|
|
104 ax.tick_params(axis='y',
|
|
105 which='major',
|
|
106 labelsize=yfontsize)
|
|
107
|
|
108 plt.yticks(yvals, names)
|
|
109
|
|
110 # Draw zero line.
|
|
111 if not((influence <= 0).all() or (influence >= 0).all()):
|
|
112 ax.vlines(x=0.0, ymin=0, ymax=yvals[-1]+1, color=zcolor)
|
|
113
|
|
114 # Set limitations
|
|
115 ax.set_ylim(0, cnt_max +1)
|
|
116
|
|
117 if designated:
|
|
118 # Filter bar graphics.
|
|
119 bars = []
|
|
120 cnt_bars = 0
|
|
121 for obj in ax.get_children():
|
|
122 if cnt_bars == cnt_max:
|
|
123 break
|
|
124 if isinstance(obj, Rectangle):
|
|
125 bars.append(obj)
|
|
126 obj.set_color(color)
|
|
127 cnt_bars += 1
|
|
128 # end of for
|
|
129
|
|
130 # Change the bars of the designated names.
|
|
131 for i, name in enumerate(names):
|
|
132 if name in designated:
|
|
133 bars[i].set_color(dcolor)
|
|
134 # end of for
|
|
135
|
|
136 # Change the text colors of the designated names.
|
|
137 for obj in ax.get_yticklabels():
|
|
138 name = obj.get_text()
|
|
139 if name in designated:
|
|
140 obj.set_color(dcolor)
|
|
141 # end of for
|
|
142 # end of for
|
|
143
|
|
144 # Make zero notation more simple.
|
|
145 fig.canvas.draw()
|
|
146 for ax in fig.axes:
|
|
147 labels = []
|
|
148 for obj in ax.get_xticklabels():
|
|
149 try:
|
|
150 text = obj.get_text()
|
|
151 num = float(text)
|
|
152 except ValueError:
|
|
153 labels.append(text)
|
|
154 continue
|
|
155
|
|
156 if num == 0:
|
|
157 labels.append('0')
|
|
158 else:
|
|
159 labels.append(text)
|
|
160 # end of for
|
|
161 ax.set_xticklabels(labels)
|
|
162 # end of for
|
|
163 # end of for
|
|
164
|
|
165 return fig
|