view tools/myTools/bin/sfa/plot/si.py @ 1:7e5c71b2e71f draft default tip

Uploaded
author laurenmarazzi
date Wed, 22 Dec 2021 16:00:34 +0000
parents
children
line wrap: on
line source


from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.ticker import FormatStrFormatter
from matplotlib import rcParams

rcParams['font.family'] = 'sans-serif'
rcParams['font.sans-serif'] = ['Arial']

def siplot(df_splo,
           df_inf,
           output,
           min_splo=None,
           max_splo=None,
           thr_inf=1e-10,
           fmt_inf='%f',
           fig=None,
           cnt_max=None,
           ncol=4,
           designated=None,
           color='silver',
           dcolor='red',
           zcolor='red',
           alpha=0.7,
           xfontsize=8,
           yfontsize=8):

    # SPLO-Influence Data
    if not min_splo:
        min_splo = df_splo.min()

    if not max_splo:
        max_splo = df_splo.max()

    mask_splo = (min_splo <= df_splo) & (df_splo <= max_splo)
    df_splo = df_splo[mask_splo]

    df_splo = pd.DataFrame(df_splo)
    df_splo.columns = ['SPLO']


    if output in df_splo.index:
        df_splo.drop(output, inplace=True)

    index_common = df_splo.index.intersection(df_inf.index)
    df_inf = pd.DataFrame(df_inf.loc[index_common])

    mark_drop = df_inf[output].abs() <= thr_inf
    df_inf.drop(df_inf.loc[mark_drop, output].index,
                inplace=True)


    df_si = df_inf.join(df_splo.loc[index_common])
    df_si.index.name = 'Source'
    df_si.reset_index(inplace=True)

    cnt_splo = Counter(df_si['SPLO'])
    if not cnt_max:
        cnt_max = max(cnt_splo.values())

    splos = sorted(cnt_splo.keys())
    nrow = int(np.ceil(len(splos)/ncol))

    # Plot
    if not fig:
        fig = plt.figure()

    gs = gridspec.GridSpec(nrow, ncol)

    yvals = np.arange(1, cnt_max +1)
    for i, splo in enumerate(splos):
        idx_row = int(i / ncol)
        idx_col = int(i % ncol)
        ax = fig.add_subplot(gs[idx_row, idx_col])
        df_sub = df_si[df_si['SPLO'] == splo]
        df_sub = df_sub.sort_values(by=output)
        num_items = df_sub[output].count()

        influence = np.zeros((cnt_max,))  # Influence
        num_empty = cnt_max - num_items
        influence[num_empty:] = df_sub[output]
        names = df_sub['Source'].tolist()
        names = ['' ] *(num_empty) + names

        # Plot bars
        plt.barh(yvals, influence, align='center',
                 alpha=alpha)

        ax.set_title('SPLO=%d'%(splo))
        ax.set_xlabel('')

        ax.xaxis.set_major_formatter(FormatStrFormatter(fmt_inf))
        ax.tick_params(axis='x',
                       which='major',
                       labelsize=xfontsize)

        ax.set_ylabel('')
        ax.yaxis.set_ticks_position('right')
        ax.tick_params(axis='y',
                       which='major',
                       labelsize=yfontsize)

        plt.yticks(yvals, names)

        # Draw zero line.
        if not((influence <= 0).all() or (influence >= 0).all()):
            ax.vlines(x=0.0, ymin=0, ymax=yvals[-1]+1, color=zcolor)

        # Set limitations
        ax.set_ylim(0, cnt_max +1)

        if designated:
            # Filter bar graphics.
            bars = []
            cnt_bars = 0
            for obj in ax.get_children():
                if cnt_bars == cnt_max:
                    break
                if isinstance(obj, Rectangle):
                    bars.append(obj)
                    obj.set_color(color)
                    cnt_bars += 1
            # end of for

            # Change the bars of the designated names.
            for i, name in enumerate(names):
                if name in designated:
                    bars[i].set_color(dcolor)
            # end of for

            # Change the text colors of the designated names.
            for obj in ax.get_yticklabels():
                name = obj.get_text()
                if name in designated:
                    obj.set_color(dcolor)
            # end of for
    # end of for

    # Make zero notation more simple.
    fig.canvas.draw()
    for ax in fig.axes:
        labels = []
        for obj in ax.get_xticklabels():
            try:
                text = obj.get_text()
                num = float(text)
            except ValueError:
                labels.append(text)
                continue

            if num == 0:
                labels.append('0')
            else:
                labels.append(text)
        # end of for
        ax.set_xticklabels(labels)
        # end of for
    # end of for

    return fig