diff tools/myTools/bin/sfa/plot/si.py @ 1:7e5c71b2e71f draft default tip

Uploaded
author laurenmarazzi
date Wed, 22 Dec 2021 16:00:34 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tools/myTools/bin/sfa/plot/si.py	Wed Dec 22 16:00:34 2021 +0000
@@ -0,0 +1,165 @@
+
+from collections import Counter
+
+import numpy as np
+import pandas as pd
+import matplotlib.gridspec as gridspec
+import matplotlib.pyplot as plt
+from matplotlib.patches import Rectangle
+from matplotlib.ticker import FormatStrFormatter
+from matplotlib import rcParams
+
+rcParams['font.family'] = 'sans-serif'
+rcParams['font.sans-serif'] = ['Arial']
+
+def siplot(df_splo,
+           df_inf,
+           output,
+           min_splo=None,
+           max_splo=None,
+           thr_inf=1e-10,
+           fmt_inf='%f',
+           fig=None,
+           cnt_max=None,
+           ncol=4,
+           designated=None,
+           color='silver',
+           dcolor='red',
+           zcolor='red',
+           alpha=0.7,
+           xfontsize=8,
+           yfontsize=8):
+
+    # SPLO-Influence Data
+    if not min_splo:
+        min_splo = df_splo.min()
+
+    if not max_splo:
+        max_splo = df_splo.max()
+
+    mask_splo = (min_splo <= df_splo) & (df_splo <= max_splo)
+    df_splo = df_splo[mask_splo]
+
+    df_splo = pd.DataFrame(df_splo)
+    df_splo.columns = ['SPLO']
+
+
+    if output in df_splo.index:
+        df_splo.drop(output, inplace=True)
+
+    index_common = df_splo.index.intersection(df_inf.index)
+    df_inf = pd.DataFrame(df_inf.loc[index_common])
+
+    mark_drop = df_inf[output].abs() <= thr_inf
+    df_inf.drop(df_inf.loc[mark_drop, output].index,
+                inplace=True)
+
+
+    df_si = df_inf.join(df_splo.loc[index_common])
+    df_si.index.name = 'Source'
+    df_si.reset_index(inplace=True)
+
+    cnt_splo = Counter(df_si['SPLO'])
+    if not cnt_max:
+        cnt_max = max(cnt_splo.values())
+
+    splos = sorted(cnt_splo.keys())
+    nrow = int(np.ceil(len(splos)/ncol))
+
+    # Plot
+    if not fig:
+        fig = plt.figure()
+
+    gs = gridspec.GridSpec(nrow, ncol)
+
+    yvals = np.arange(1, cnt_max +1)
+    for i, splo in enumerate(splos):
+        idx_row = int(i / ncol)
+        idx_col = int(i % ncol)
+        ax = fig.add_subplot(gs[idx_row, idx_col])
+        df_sub = df_si[df_si['SPLO'] == splo]
+        df_sub = df_sub.sort_values(by=output)
+        num_items = df_sub[output].count()
+
+        influence = np.zeros((cnt_max,))  # Influence
+        num_empty = cnt_max - num_items
+        influence[num_empty:] = df_sub[output]
+        names = df_sub['Source'].tolist()
+        names = ['' ] *(num_empty) + names
+
+        # Plot bars
+        plt.barh(yvals, influence, align='center',
+                 alpha=alpha)
+
+        ax.set_title('SPLO=%d'%(splo))
+        ax.set_xlabel('')
+
+        ax.xaxis.set_major_formatter(FormatStrFormatter(fmt_inf))
+        ax.tick_params(axis='x',
+                       which='major',
+                       labelsize=xfontsize)
+
+        ax.set_ylabel('')
+        ax.yaxis.set_ticks_position('right')
+        ax.tick_params(axis='y',
+                       which='major',
+                       labelsize=yfontsize)
+
+        plt.yticks(yvals, names)
+
+        # Draw zero line.
+        if not((influence <= 0).all() or (influence >= 0).all()):
+            ax.vlines(x=0.0, ymin=0, ymax=yvals[-1]+1, color=zcolor)
+
+        # Set limitations
+        ax.set_ylim(0, cnt_max +1)
+
+        if designated:
+            # Filter bar graphics.
+            bars = []
+            cnt_bars = 0
+            for obj in ax.get_children():
+                if cnt_bars == cnt_max:
+                    break
+                if isinstance(obj, Rectangle):
+                    bars.append(obj)
+                    obj.set_color(color)
+                    cnt_bars += 1
+            # end of for
+
+            # Change the bars of the designated names.
+            for i, name in enumerate(names):
+                if name in designated:
+                    bars[i].set_color(dcolor)
+            # end of for
+
+            # Change the text colors of the designated names.
+            for obj in ax.get_yticklabels():
+                name = obj.get_text()
+                if name in designated:
+                    obj.set_color(dcolor)
+            # end of for
+    # end of for
+
+    # Make zero notation more simple.
+    fig.canvas.draw()
+    for ax in fig.axes:
+        labels = []
+        for obj in ax.get_xticklabels():
+            try:
+                text = obj.get_text()
+                num = float(text)
+            except ValueError:
+                labels.append(text)
+                continue
+
+            if num == 0:
+                labels.append('0')
+            else:
+                labels.append(text)
+        # end of for
+        ax.set_xticklabels(labels)
+        # end of for
+    # end of for
+
+    return fig