Mercurial > repos > laurenmarazzi > netisce_test
comparison tools/myTools/bin/sfa/plot/si.py @ 1:7e5c71b2e71f draft default tip
Uploaded
author | laurenmarazzi |
---|---|
date | Wed, 22 Dec 2021 16:00:34 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:f24d4892aaed | 1:7e5c71b2e71f |
---|---|
1 | |
2 from collections import Counter | |
3 | |
4 import numpy as np | |
5 import pandas as pd | |
6 import matplotlib.gridspec as gridspec | |
7 import matplotlib.pyplot as plt | |
8 from matplotlib.patches import Rectangle | |
9 from matplotlib.ticker import FormatStrFormatter | |
10 from matplotlib import rcParams | |
11 | |
12 rcParams['font.family'] = 'sans-serif' | |
13 rcParams['font.sans-serif'] = ['Arial'] | |
14 | |
15 def siplot(df_splo, | |
16 df_inf, | |
17 output, | |
18 min_splo=None, | |
19 max_splo=None, | |
20 thr_inf=1e-10, | |
21 fmt_inf='%f', | |
22 fig=None, | |
23 cnt_max=None, | |
24 ncol=4, | |
25 designated=None, | |
26 color='silver', | |
27 dcolor='red', | |
28 zcolor='red', | |
29 alpha=0.7, | |
30 xfontsize=8, | |
31 yfontsize=8): | |
32 | |
33 # SPLO-Influence Data | |
34 if not min_splo: | |
35 min_splo = df_splo.min() | |
36 | |
37 if not max_splo: | |
38 max_splo = df_splo.max() | |
39 | |
40 mask_splo = (min_splo <= df_splo) & (df_splo <= max_splo) | |
41 df_splo = df_splo[mask_splo] | |
42 | |
43 df_splo = pd.DataFrame(df_splo) | |
44 df_splo.columns = ['SPLO'] | |
45 | |
46 | |
47 if output in df_splo.index: | |
48 df_splo.drop(output, inplace=True) | |
49 | |
50 index_common = df_splo.index.intersection(df_inf.index) | |
51 df_inf = pd.DataFrame(df_inf.loc[index_common]) | |
52 | |
53 mark_drop = df_inf[output].abs() <= thr_inf | |
54 df_inf.drop(df_inf.loc[mark_drop, output].index, | |
55 inplace=True) | |
56 | |
57 | |
58 df_si = df_inf.join(df_splo.loc[index_common]) | |
59 df_si.index.name = 'Source' | |
60 df_si.reset_index(inplace=True) | |
61 | |
62 cnt_splo = Counter(df_si['SPLO']) | |
63 if not cnt_max: | |
64 cnt_max = max(cnt_splo.values()) | |
65 | |
66 splos = sorted(cnt_splo.keys()) | |
67 nrow = int(np.ceil(len(splos)/ncol)) | |
68 | |
69 # Plot | |
70 if not fig: | |
71 fig = plt.figure() | |
72 | |
73 gs = gridspec.GridSpec(nrow, ncol) | |
74 | |
75 yvals = np.arange(1, cnt_max +1) | |
76 for i, splo in enumerate(splos): | |
77 idx_row = int(i / ncol) | |
78 idx_col = int(i % ncol) | |
79 ax = fig.add_subplot(gs[idx_row, idx_col]) | |
80 df_sub = df_si[df_si['SPLO'] == splo] | |
81 df_sub = df_sub.sort_values(by=output) | |
82 num_items = df_sub[output].count() | |
83 | |
84 influence = np.zeros((cnt_max,)) # Influence | |
85 num_empty = cnt_max - num_items | |
86 influence[num_empty:] = df_sub[output] | |
87 names = df_sub['Source'].tolist() | |
88 names = ['' ] *(num_empty) + names | |
89 | |
90 # Plot bars | |
91 plt.barh(yvals, influence, align='center', | |
92 alpha=alpha) | |
93 | |
94 ax.set_title('SPLO=%d'%(splo)) | |
95 ax.set_xlabel('') | |
96 | |
97 ax.xaxis.set_major_formatter(FormatStrFormatter(fmt_inf)) | |
98 ax.tick_params(axis='x', | |
99 which='major', | |
100 labelsize=xfontsize) | |
101 | |
102 ax.set_ylabel('') | |
103 ax.yaxis.set_ticks_position('right') | |
104 ax.tick_params(axis='y', | |
105 which='major', | |
106 labelsize=yfontsize) | |
107 | |
108 plt.yticks(yvals, names) | |
109 | |
110 # Draw zero line. | |
111 if not((influence <= 0).all() or (influence >= 0).all()): | |
112 ax.vlines(x=0.0, ymin=0, ymax=yvals[-1]+1, color=zcolor) | |
113 | |
114 # Set limitations | |
115 ax.set_ylim(0, cnt_max +1) | |
116 | |
117 if designated: | |
118 # Filter bar graphics. | |
119 bars = [] | |
120 cnt_bars = 0 | |
121 for obj in ax.get_children(): | |
122 if cnt_bars == cnt_max: | |
123 break | |
124 if isinstance(obj, Rectangle): | |
125 bars.append(obj) | |
126 obj.set_color(color) | |
127 cnt_bars += 1 | |
128 # end of for | |
129 | |
130 # Change the bars of the designated names. | |
131 for i, name in enumerate(names): | |
132 if name in designated: | |
133 bars[i].set_color(dcolor) | |
134 # end of for | |
135 | |
136 # Change the text colors of the designated names. | |
137 for obj in ax.get_yticklabels(): | |
138 name = obj.get_text() | |
139 if name in designated: | |
140 obj.set_color(dcolor) | |
141 # end of for | |
142 # end of for | |
143 | |
144 # Make zero notation more simple. | |
145 fig.canvas.draw() | |
146 for ax in fig.axes: | |
147 labels = [] | |
148 for obj in ax.get_xticklabels(): | |
149 try: | |
150 text = obj.get_text() | |
151 num = float(text) | |
152 except ValueError: | |
153 labels.append(text) | |
154 continue | |
155 | |
156 if num == 0: | |
157 labels.append('0') | |
158 else: | |
159 labels.append(text) | |
160 # end of for | |
161 ax.set_xticklabels(labels) | |
162 # end of for | |
163 # end of for | |
164 | |
165 return fig |