Mercurial > repos > laurenmarazzi > netisce_test
view tools/myTools/bin/sfa/plot/si.py @ 1:7e5c71b2e71f draft default tip
Uploaded
author | laurenmarazzi |
---|---|
date | Wed, 22 Dec 2021 16:00:34 +0000 |
parents | |
children |
line wrap: on
line source
from collections import Counter import numpy as np import pandas as pd import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from matplotlib.ticker import FormatStrFormatter from matplotlib import rcParams rcParams['font.family'] = 'sans-serif' rcParams['font.sans-serif'] = ['Arial'] def siplot(df_splo, df_inf, output, min_splo=None, max_splo=None, thr_inf=1e-10, fmt_inf='%f', fig=None, cnt_max=None, ncol=4, designated=None, color='silver', dcolor='red', zcolor='red', alpha=0.7, xfontsize=8, yfontsize=8): # SPLO-Influence Data if not min_splo: min_splo = df_splo.min() if not max_splo: max_splo = df_splo.max() mask_splo = (min_splo <= df_splo) & (df_splo <= max_splo) df_splo = df_splo[mask_splo] df_splo = pd.DataFrame(df_splo) df_splo.columns = ['SPLO'] if output in df_splo.index: df_splo.drop(output, inplace=True) index_common = df_splo.index.intersection(df_inf.index) df_inf = pd.DataFrame(df_inf.loc[index_common]) mark_drop = df_inf[output].abs() <= thr_inf df_inf.drop(df_inf.loc[mark_drop, output].index, inplace=True) df_si = df_inf.join(df_splo.loc[index_common]) df_si.index.name = 'Source' df_si.reset_index(inplace=True) cnt_splo = Counter(df_si['SPLO']) if not cnt_max: cnt_max = max(cnt_splo.values()) splos = sorted(cnt_splo.keys()) nrow = int(np.ceil(len(splos)/ncol)) # Plot if not fig: fig = plt.figure() gs = gridspec.GridSpec(nrow, ncol) yvals = np.arange(1, cnt_max +1) for i, splo in enumerate(splos): idx_row = int(i / ncol) idx_col = int(i % ncol) ax = fig.add_subplot(gs[idx_row, idx_col]) df_sub = df_si[df_si['SPLO'] == splo] df_sub = df_sub.sort_values(by=output) num_items = df_sub[output].count() influence = np.zeros((cnt_max,)) # Influence num_empty = cnt_max - num_items influence[num_empty:] = df_sub[output] names = df_sub['Source'].tolist() names = ['' ] *(num_empty) + names # Plot bars plt.barh(yvals, influence, align='center', alpha=alpha) ax.set_title('SPLO=%d'%(splo)) ax.set_xlabel('') ax.xaxis.set_major_formatter(FormatStrFormatter(fmt_inf)) ax.tick_params(axis='x', which='major', labelsize=xfontsize) ax.set_ylabel('') ax.yaxis.set_ticks_position('right') ax.tick_params(axis='y', which='major', labelsize=yfontsize) plt.yticks(yvals, names) # Draw zero line. if not((influence <= 0).all() or (influence >= 0).all()): ax.vlines(x=0.0, ymin=0, ymax=yvals[-1]+1, color=zcolor) # Set limitations ax.set_ylim(0, cnt_max +1) if designated: # Filter bar graphics. bars = [] cnt_bars = 0 for obj in ax.get_children(): if cnt_bars == cnt_max: break if isinstance(obj, Rectangle): bars.append(obj) obj.set_color(color) cnt_bars += 1 # end of for # Change the bars of the designated names. for i, name in enumerate(names): if name in designated: bars[i].set_color(dcolor) # end of for # Change the text colors of the designated names. for obj in ax.get_yticklabels(): name = obj.get_text() if name in designated: obj.set_color(dcolor) # end of for # end of for # Make zero notation more simple. fig.canvas.draw() for ax in fig.axes: labels = [] for obj in ax.get_xticklabels(): try: text = obj.get_text() num = float(text) except ValueError: labels.append(text) continue if num == 0: labels.append('0') else: labels.append(text) # end of for ax.set_xticklabels(labels) # end of for # end of for return fig