Mercurial > repos > fubar > lifelines_km_cph_tool
view lifelines_tool/plotlykm.py @ 2:dd5e65893cb8 draft default tip
add survival and collapsed life table outputs suggested by Wolfgang
author | fubar |
---|---|
date | Thu, 10 Aug 2023 22:52:45 +0000 (19 months ago) |
parents | 232b874046a7 |
children |
line wrap: on
line source
# script for a lifelines ToolFactory KM/CPH tool for Galaxy # km models for https://github.com/galaxyproject/tools-iuc/issues/5393 # test as # python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin" # Ross Lazarus July 2023 import argparse import os import sys import lifelines from matplotlib import pyplot as plt import pandas as pd def trimlegend(v): """ for int64 quintiles - must be ints - otherwise get silly legends with long float values """ for i, av in enumerate(v): x = int(av) v[i] = str(x) return v kmf = lifelines.KaplanMeierFitter() cph = lifelines.CoxPHFitter() parser = argparse.ArgumentParser() a = parser.add_argument a('--input_tab', default='rossi.tab', required=True) a('--header', default='') a('--htmlout', default="test_run.html") a('--group', default='') a('--time', default='', required=True) a('--status',default='', required=True) a('--cphcols',default='') a('--title', default='Default plot title') a('--image_type', default='png') a('--image_dir', default='images') a('--readme', default='run_log.txt') args = parser.parse_args() sys.stdout = open(args.readme, 'w') df = pd.read_csv(args.input_tab, sep='\t') NCOLS = df.columns.size NROWS = len(df.index) QVALS = [.2, .4, .6, .8] # for partial cox ph plots defaultcols = ['col%d' % (x+1) for x in range(NCOLS)] testcols = df.columns if len(args.header.strip()) > 0: newcols = args.header.split(',') if len(newcols) == NCOLS: if (args.time in newcols) and (args.status in newcols): df.columns = newcols else: sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and/or status %s not found in supplied header parameter %s' % (args.time, args.status, args.header)) sys.exit(4) else: sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): Supplied header %s has %d comma delimited header names - does not match the input tabular file %d columns' % (args.header, len(newcols), NCOLS)) sys.exit(5) else: # no header supplied - check for a real one that matches the x and y axis column names colsok = (args.time in testcols) and (args.status in testcols) # if they match, probably ok...should use more code and logic.. if colsok: df.columns = testcols # use actual header else: colsok = (args.time in defaultcols) and (args.status in defaultcols) if colsok: print('Replacing first row of data derived header %s with %s' % (testcols, defaultcols)) df.columns = defaultcols else: sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and status %s do not match anything in the file header, supplied header or automatic default column names %s' % (args.time, args.status, defaultcols)) print('## Lifelines tool\nInput data header =', df.columns, 'time column =', args.time, 'status column =', args.status) os.makedirs(args.image_dir, exist_ok=True) fig, ax = plt.subplots() if args.group > '': names = [] times = [] events = [] for name, grouped_df in df.groupby(args.group): T = grouped_df[args.time] E = grouped_df[args.status] gfit = kmf.fit(T, E, label=name) kmf.plot_survival_function(ax=ax) names.append(str(name)) times.append(T) events.append(E) ax.set_title(args.title) fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) ngroup = len(names) if ngroup == 2: # run logrank test if 2 groups results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99) print('Logrank test for %s - %s vs %s\n' % (args.group, names[0], names[1])) results.print_summary() else: kmf.fit(df[args.time], df[args.status]) kmf.plot_survival_function(ax=ax) ax.set_title(args.title) fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) print('#### No grouping variable, so no log rank or other Kaplan-Meier statistical output is available') survdf = lifelines.utils.survival_table_from_events(df[args.time], df[args.status]) lifedf = lifelines.utils.survival_table_from_events(df[args.time], df[args.status], collapse=True) print("Survival table using time %s and event %s" % (args.time, args.status)) with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.precision', 3, ): print(survdf) print("Life table using time %s and event %s" % (args.time, args.status)) with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.precision', 3, ): print(lifedf) outpath = os.path.join(args.image_dir,'survival_table.tabular') survdf.to_csv(outpath, sep='\t') outpath = os.path.join(args.image_dir,'life_table.tabular') lifedf.to_csv(outpath, sep='\t') if len(args.cphcols) > 0: fig, ax = plt.subplots() ax.set_title('Cox-PH model: %s' % args.title) cphcols = args.cphcols.strip().split(',') cphcols = [x.strip() for x in cphcols] notfound = sum([(x not in df.columns) for x in cphcols]) if notfound > 0: sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns)) sys.exit(6) colsdf = df[cphcols] print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title)) cutcphcols = [args.time, args.status] + cphcols cphdf = df[cutcphcols] ucolcounts = colsdf.nunique(axis=0) cph.fit(cphdf, duration_col=args.time, event_col=args.status) cph.print_summary() for i, cov in enumerate(colsdf.columns): if ucolcounts[i] > 10: # a hack - assume categories are sparse - if not imaginary quintiles will have to do v = pd.Series.tolist(cphdf[cov].quantile(QVALS)) vdt = df.dtypes[cov] if vdt == 'int64': v = trimlegend(v) axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v) axp.set_title('Cox-PH %s quintile partials: %s' % (cov,args.title)) figr = axp.get_figure() oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type)) figr.savefig(oname) else: v = pd.unique(cphdf[cov]) v = [str(x) for x in v] try: axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v) axp.set_title('Cox-PH %s partials: %s' % (cov,args.title)) figr = axp.get_figure() oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type)) figr.savefig(oname) except: pass cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True) for i, ax in enumerate(cphaxes): figr = ax[0].get_figure() titl = figr._suptitle.get_text().replace(' ','_').replace("'","") oname = os.path.join(args.image_dir,'CPH%s.%s' % (titl, args.image_type)) figr.savefig(oname)