diff lifelines_tool/lifelineskmcph.xml @ 1:232b874046a7 draft

Uploaded
author fubar
date Thu, 10 Aug 2023 07:15:22 +0000
parents dd49a7040643
children dd5e65893cb8
line wrap: on
line diff
--- a/lifelines_tool/lifelineskmcph.xml	Wed Aug 09 11:12:16 2023 +0000
+++ b/lifelines_tool/lifelineskmcph.xml	Thu Aug 10 07:15:22 2023 +0000
@@ -1,6 +1,6 @@
 <tool name="lifelineskmcph" id="lifelineskmcph" version="0.01">
   <!--Source in git at: https://github.com/fubar2/galaxy_tf_overlay-->
-  <!--Created by toolfactory@galaxy.org at 09/08/2023 17:43:16 using the Galaxy Tool Factory.-->
+  <!--Created by toolfactory@galaxy.org at 10/08/2023 15:48:43 using the Galaxy Tool Factory.-->
   <description>Lifelines KM and optional Cox PH models</description>
   <requirements>
     <requirement version="1.5.3" type="package">pandas</requirement>
@@ -40,8 +40,9 @@
 # km models for https://github.com/galaxyproject/tools-iuc/issues/5393
 # test as
 # python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin"
+# Ross Lazarus July 2023
+import argparse
 
-import argparse
 import os
 import sys
 
@@ -51,15 +52,22 @@
 
 import pandas as pd
 
-# Ross Lazarus July 2023
 
+def trimlegend(v):
+    """
+    for int64 quintiles - must be ints - otherwise get silly legends with long float values
+    """
+    for i, av in enumerate(v):
+        x = int(av)
+        v[i] = str(x)
+    return v
 
 kmf = lifelines.KaplanMeierFitter()
 cph = lifelines.CoxPHFitter()
 
 parser = argparse.ArgumentParser()
 a = parser.add_argument
-a('--input_tab', default='', required=True)
+a('--input_tab', default='rossi.tab', required=True)
 a('--header', default='')
 a('--htmlout', default="test_run.html")
 a('--group', default='')
@@ -75,6 +83,7 @@
 df = pd.read_csv(args.input_tab, sep='\t')
 NCOLS = df.columns.size
 NROWS = len(df.index)
+QVALS = [.2, .4, .6, .8] # for partial cox ph plots
 defaultcols = ['col%d' % (x+1) for x in range(NCOLS)]
 testcols = df.columns
 if len(args.header.strip()) > 0:
@@ -106,14 +115,11 @@
     names = []
     times = []
     events = []
-    rmst = []
     for name, grouped_df in df.groupby(args.group):
         T = grouped_df[args.time]
         E = grouped_df[args.status]
         gfit = kmf.fit(T, E, label=name)
         kmf.plot_survival_function(ax=ax)
-        rst = lifelines.utils.restricted_mean_survival_time(gfit)
-        rmst.append(rst)
         names.append(str(name))
         times.append(T)
         events.append(E)
@@ -124,30 +130,50 @@
         results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99)
         print('Logrank test for %s - %s vs %s\n' % (args.group, names[0], names[1]))
         results.print_summary()
-    elif ngroup > 1:
-        fig, ax = plt.subplots(nrows=ngroup, ncols=1, sharex=True)
-        for i, rst in rmst:
-            lifelines.plotting.rmst_plot(rst, ax=ax)
-        fig.savefig(os.path.join(args.image_dir,'RMST_%s.png' % args.title))
 else:
     kmf.fit(df[args.time], df[args.status])
     kmf.plot_survival_function(ax=ax)
     ax.set_title(args.title)
     fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title))
+    print('#### No grouping variable, so no log rank or other Kaplan-Meier statistical output is available')
 if len(args.cphcols) > 0:
     fig, ax = plt.subplots()
-    ax.set_title('Cox PH model: %s' % args.title)
+    ax.set_title('Cox-PH model: %s' % args.title)
     cphcols = args.cphcols.strip().split(',')
     cphcols = [x.strip() for x in cphcols]
     notfound = sum([(x not in df.columns) for x in cphcols])
     if notfound > 0:
         sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns))
         sys.exit(6)
+    colsdf = df[cphcols]
     print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title))
-    cphcols += [args.time, args.status]
-    cphdf = df[cphcols]
+    cutcphcols = [args.time, args.status] + cphcols
+    cphdf = df[cutcphcols]
+    ucolcounts = colsdf.nunique(axis=0)
     cph.fit(cphdf, duration_col=args.time, event_col=args.status)
     cph.print_summary()
+    for i, cov in enumerate(colsdf.columns):
+         if ucolcounts[i] > 10:
+             v = pd.Series.tolist(cphdf[cov].quantile(QVALS))
+             vdt = df.dtypes[cov]
+             if vdt == 'int64':
+                 v = trimlegend(v)
+             axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v)
+             axp.set_title('Cox-PH %s quintile partials: %s' % (cov,args.title))
+             figr = axp.get_figure()
+             oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type))
+             figr.savefig(oname)
+         else:
+             v = pd.unique(cphdf[cov])
+             v = [str(x) for x in v]
+             try:
+                 axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v)
+                 axp.set_title('Cox-PH %s partials: %s' % (cov,args.title))
+                 figr = axp.get_figure()
+                 oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type))
+                 figr.savefig(oname)
+             except:
+                 pass
     cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True)
     for i, ax in enumerate(cphaxes):
         figr = ax[0].get_figure()