| 1 | 1 | 
|  | 2 from collections import Counter | 
|  | 3 | 
|  | 4 import numpy as np | 
|  | 5 import pandas as pd | 
|  | 6 import matplotlib.gridspec as gridspec | 
|  | 7 import matplotlib.pyplot as plt | 
|  | 8 from matplotlib.patches import Rectangle | 
|  | 9 from matplotlib.ticker import FormatStrFormatter | 
|  | 10 from matplotlib import rcParams | 
|  | 11 | 
|  | 12 rcParams['font.family'] = 'sans-serif' | 
|  | 13 rcParams['font.sans-serif'] = ['Arial'] | 
|  | 14 | 
|  | 15 def siplot(df_splo, | 
|  | 16            df_inf, | 
|  | 17            output, | 
|  | 18            min_splo=None, | 
|  | 19            max_splo=None, | 
|  | 20            thr_inf=1e-10, | 
|  | 21            fmt_inf='%f', | 
|  | 22            fig=None, | 
|  | 23            cnt_max=None, | 
|  | 24            ncol=4, | 
|  | 25            designated=None, | 
|  | 26            color='silver', | 
|  | 27            dcolor='red', | 
|  | 28            zcolor='red', | 
|  | 29            alpha=0.7, | 
|  | 30            xfontsize=8, | 
|  | 31            yfontsize=8): | 
|  | 32 | 
|  | 33     # SPLO-Influence Data | 
|  | 34     if not min_splo: | 
|  | 35         min_splo = df_splo.min() | 
|  | 36 | 
|  | 37     if not max_splo: | 
|  | 38         max_splo = df_splo.max() | 
|  | 39 | 
|  | 40     mask_splo = (min_splo <= df_splo) & (df_splo <= max_splo) | 
|  | 41     df_splo = df_splo[mask_splo] | 
|  | 42 | 
|  | 43     df_splo = pd.DataFrame(df_splo) | 
|  | 44     df_splo.columns = ['SPLO'] | 
|  | 45 | 
|  | 46 | 
|  | 47     if output in df_splo.index: | 
|  | 48         df_splo.drop(output, inplace=True) | 
|  | 49 | 
|  | 50     index_common = df_splo.index.intersection(df_inf.index) | 
|  | 51     df_inf = pd.DataFrame(df_inf.loc[index_common]) | 
|  | 52 | 
|  | 53     mark_drop = df_inf[output].abs() <= thr_inf | 
|  | 54     df_inf.drop(df_inf.loc[mark_drop, output].index, | 
|  | 55                 inplace=True) | 
|  | 56 | 
|  | 57 | 
|  | 58     df_si = df_inf.join(df_splo.loc[index_common]) | 
|  | 59     df_si.index.name = 'Source' | 
|  | 60     df_si.reset_index(inplace=True) | 
|  | 61 | 
|  | 62     cnt_splo = Counter(df_si['SPLO']) | 
|  | 63     if not cnt_max: | 
|  | 64         cnt_max = max(cnt_splo.values()) | 
|  | 65 | 
|  | 66     splos = sorted(cnt_splo.keys()) | 
|  | 67     nrow = int(np.ceil(len(splos)/ncol)) | 
|  | 68 | 
|  | 69     # Plot | 
|  | 70     if not fig: | 
|  | 71         fig = plt.figure() | 
|  | 72 | 
|  | 73     gs = gridspec.GridSpec(nrow, ncol) | 
|  | 74 | 
|  | 75     yvals = np.arange(1, cnt_max +1) | 
|  | 76     for i, splo in enumerate(splos): | 
|  | 77         idx_row = int(i / ncol) | 
|  | 78         idx_col = int(i % ncol) | 
|  | 79         ax = fig.add_subplot(gs[idx_row, idx_col]) | 
|  | 80         df_sub = df_si[df_si['SPLO'] == splo] | 
|  | 81         df_sub = df_sub.sort_values(by=output) | 
|  | 82         num_items = df_sub[output].count() | 
|  | 83 | 
|  | 84         influence = np.zeros((cnt_max,))  # Influence | 
|  | 85         num_empty = cnt_max - num_items | 
|  | 86         influence[num_empty:] = df_sub[output] | 
|  | 87         names = df_sub['Source'].tolist() | 
|  | 88         names = ['' ] *(num_empty) + names | 
|  | 89 | 
|  | 90         # Plot bars | 
|  | 91         plt.barh(yvals, influence, align='center', | 
|  | 92                  alpha=alpha) | 
|  | 93 | 
|  | 94         ax.set_title('SPLO=%d'%(splo)) | 
|  | 95         ax.set_xlabel('') | 
|  | 96 | 
|  | 97         ax.xaxis.set_major_formatter(FormatStrFormatter(fmt_inf)) | 
|  | 98         ax.tick_params(axis='x', | 
|  | 99                        which='major', | 
|  | 100                        labelsize=xfontsize) | 
|  | 101 | 
|  | 102         ax.set_ylabel('') | 
|  | 103         ax.yaxis.set_ticks_position('right') | 
|  | 104         ax.tick_params(axis='y', | 
|  | 105                        which='major', | 
|  | 106                        labelsize=yfontsize) | 
|  | 107 | 
|  | 108         plt.yticks(yvals, names) | 
|  | 109 | 
|  | 110         # Draw zero line. | 
|  | 111         if not((influence <= 0).all() or (influence >= 0).all()): | 
|  | 112             ax.vlines(x=0.0, ymin=0, ymax=yvals[-1]+1, color=zcolor) | 
|  | 113 | 
|  | 114         # Set limitations | 
|  | 115         ax.set_ylim(0, cnt_max +1) | 
|  | 116 | 
|  | 117         if designated: | 
|  | 118             # Filter bar graphics. | 
|  | 119             bars = [] | 
|  | 120             cnt_bars = 0 | 
|  | 121             for obj in ax.get_children(): | 
|  | 122                 if cnt_bars == cnt_max: | 
|  | 123                     break | 
|  | 124                 if isinstance(obj, Rectangle): | 
|  | 125                     bars.append(obj) | 
|  | 126                     obj.set_color(color) | 
|  | 127                     cnt_bars += 1 | 
|  | 128             # end of for | 
|  | 129 | 
|  | 130             # Change the bars of the designated names. | 
|  | 131             for i, name in enumerate(names): | 
|  | 132                 if name in designated: | 
|  | 133                     bars[i].set_color(dcolor) | 
|  | 134             # end of for | 
|  | 135 | 
|  | 136             # Change the text colors of the designated names. | 
|  | 137             for obj in ax.get_yticklabels(): | 
|  | 138                 name = obj.get_text() | 
|  | 139                 if name in designated: | 
|  | 140                     obj.set_color(dcolor) | 
|  | 141             # end of for | 
|  | 142     # end of for | 
|  | 143 | 
|  | 144     # Make zero notation more simple. | 
|  | 145     fig.canvas.draw() | 
|  | 146     for ax in fig.axes: | 
|  | 147         labels = [] | 
|  | 148         for obj in ax.get_xticklabels(): | 
|  | 149             try: | 
|  | 150                 text = obj.get_text() | 
|  | 151                 num = float(text) | 
|  | 152             except ValueError: | 
|  | 153                 labels.append(text) | 
|  | 154                 continue | 
|  | 155 | 
|  | 156             if num == 0: | 
|  | 157                 labels.append('0') | 
|  | 158             else: | 
|  | 159                 labels.append(text) | 
|  | 160         # end of for | 
|  | 161         ax.set_xticklabels(labels) | 
|  | 162         # end of for | 
|  | 163     # end of for | 
|  | 164 | 
|  | 165     return fig |