diff lifelines_tool/plotlykm.py @ 1:232b874046a7 draft

Uploaded
author fubar
date Thu, 10 Aug 2023 07:15:22 +0000
parents dd49a7040643
children dd5e65893cb8
line wrap: on
line diff
--- a/lifelines_tool/plotlykm.py	Wed Aug 09 11:12:16 2023 +0000
+++ b/lifelines_tool/plotlykm.py	Thu Aug 10 07:15:22 2023 +0000
@@ -2,8 +2,9 @@
 # km models for https://github.com/galaxyproject/tools-iuc/issues/5393
 # test as
 # python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin"
+# Ross Lazarus July 2023
+import argparse
 
-import argparse
 import os
 import sys
 
@@ -13,15 +14,22 @@
 
 import pandas as pd
 
-# Ross Lazarus July 2023
 
+def trimlegend(v):
+    """
+    for int64 quintiles - must be ints - otherwise get silly legends with long float values
+    """
+    for i, av in enumerate(v):
+        x = int(av)
+        v[i] = str(x)
+    return v
 
 kmf = lifelines.KaplanMeierFitter()
 cph = lifelines.CoxPHFitter()
 
 parser = argparse.ArgumentParser()
 a = parser.add_argument
-a('--input_tab', default='', required=True)
+a('--input_tab', default='rossi.tab', required=True)
 a('--header', default='')
 a('--htmlout', default="test_run.html")
 a('--group', default='')
@@ -37,6 +45,7 @@
 df = pd.read_csv(args.input_tab, sep='\t')
 NCOLS = df.columns.size
 NROWS = len(df.index)
+QVALS = [.2, .4, .6, .8] # for partial cox ph plots
 defaultcols = ['col%d' % (x+1) for x in range(NCOLS)]
 testcols = df.columns
 if len(args.header.strip()) > 0:
@@ -57,62 +66,79 @@
     else:
         colsok = (args.time in defaultcols) and (args.status in defaultcols)
         if colsok:
-            sys.stderr.write('replacing first row of data derived header %s with %s' % (testcols, defaultcols))
+            print('Replacing first row of data derived header %s with %s' % (testcols, defaultcols))
             df.columns = defaultcols
         else:
             sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and status %s do not match anything in the file header, supplied header or automatic default column names %s' % (args.time, args.status, defaultcols))
-print('## Lifelines tool starting.\nUsing data header =', df.columns, 'time column =', args.time, 'status column =', args.status)
+print('## Lifelines tool\nInput data header =', df.columns, 'time column =', args.time, 'status column =', args.status)
 os.makedirs(args.image_dir, exist_ok=True)
 fig, ax = plt.subplots()
 if args.group > '':
     names = []
     times = []
     events = []
-    rmst = []
     for name, grouped_df in df.groupby(args.group):
         T = grouped_df[args.time]
         E = grouped_df[args.status]
         gfit = kmf.fit(T, E, label=name)
         kmf.plot_survival_function(ax=ax)
-        rst = lifelines.utils.restricted_mean_survival_time(gfit)
-        rmst.append(rst)
         names.append(str(name))
         times.append(T)
         events.append(E)
+    ax.set_title(args.title)
+    fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title))
     ngroup = len(names)
     if  ngroup == 2: # run logrank test if 2 groups
         results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99)
-        print(' vs '.join(names), results)
+        print('Logrank test for %s - %s vs %s\n' % (args.group, names[0], names[1]))
         results.print_summary()
-    elif ngroup > 1:
-        fig, ax = plt.subplots(nrows=ngroup, ncols=1, sharex=True)
-        for i, rst in rmst:
-            lifelines.plotting.rmst_plot(rst, ax=ax)
-        fig.savefig(os.path.join(args.image_dir,'RMST_%s.png' % args.title))
 else:
     kmf.fit(df[args.time], df[args.status])
     kmf.plot_survival_function(ax=ax)
-fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title))
+    ax.set_title(args.title)
+    fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title))
+    print('#### No grouping variable, so no log rank or other Kaplan-Meier statistical output is available')
 if len(args.cphcols) > 0:
     fig, ax = plt.subplots()
+    ax.set_title('Cox-PH model: %s' % args.title)
     cphcols = args.cphcols.strip().split(',')
     cphcols = [x.strip() for x in cphcols]
     notfound = sum([(x not in df.columns) for x in cphcols])
     if notfound > 0:
         sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns))
         sys.exit(6)
+    colsdf = df[cphcols]
     print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title))
-    cphcols += [args.time, args.status]
-    cphdf = df[cphcols]
+    cutcphcols = [args.time, args.status] + cphcols
+    cphdf = df[cutcphcols]
+    ucolcounts = colsdf.nunique(axis=0)
     cph.fit(cphdf, duration_col=args.time, event_col=args.status)
     cph.print_summary()
+    for i, cov in enumerate(colsdf.columns):
+         if ucolcounts[i] > 10: # a hack - assume categories are sparse - if not imaginary quintiles will have to do
+             v = pd.Series.tolist(cphdf[cov].quantile(QVALS))
+             vdt = df.dtypes[cov]
+             if vdt == 'int64':
+                 v = trimlegend(v)
+             axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v)
+             axp.set_title('Cox-PH %s quintile partials: %s' % (cov,args.title))
+             figr = axp.get_figure()
+             oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type))
+             figr.savefig(oname)
+         else:
+             v = pd.unique(cphdf[cov])
+             v = [str(x) for x in v]
+             try:
+                 axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v)
+                 axp.set_title('Cox-PH %s partials: %s' % (cov,args.title))
+                 figr = axp.get_figure()
+                 oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type))
+                 figr.savefig(oname)
+             except:
+                 pass
     cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True)
     for i, ax in enumerate(cphaxes):
         figr = ax[0].get_figure()
         titl = figr._suptitle.get_text().replace(' ','_').replace("'","")
         oname = os.path.join(args.image_dir,'CPH%s.%s' % (titl, args.image_type))
         figr.savefig(oname)
-
-
-
-