Mercurial > repos > fubar > lifelines_km_cph_tool
diff lifelines_tool/plotlykm.py @ 0:dd49a7040643 draft
Initial commit
author | fubar |
---|---|
date | Wed, 09 Aug 2023 11:12:16 +0000 |
parents | |
children | 232b874046a7 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/lifelines_tool/plotlykm.py Wed Aug 09 11:12:16 2023 +0000 @@ -0,0 +1,118 @@ +# script for a lifelines ToolFactory KM/CPH tool for Galaxy +# km models for https://github.com/galaxyproject/tools-iuc/issues/5393 +# test as +# python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin" + +import argparse +import os +import sys + +import lifelines + +from matplotlib import pyplot as plt + +import pandas as pd + +# Ross Lazarus July 2023 + + +kmf = lifelines.KaplanMeierFitter() +cph = lifelines.CoxPHFitter() + +parser = argparse.ArgumentParser() +a = parser.add_argument +a('--input_tab', default='', required=True) +a('--header', default='') +a('--htmlout', default="test_run.html") +a('--group', default='') +a('--time', default='', required=True) +a('--status',default='', required=True) +a('--cphcols',default='') +a('--title', default='Default plot title') +a('--image_type', default='png') +a('--image_dir', default='images') +a('--readme', default='run_log.txt') +args = parser.parse_args() +sys.stdout = open(args.readme, 'w') +df = pd.read_csv(args.input_tab, sep='\t') +NCOLS = df.columns.size +NROWS = len(df.index) +defaultcols = ['col%d' % (x+1) for x in range(NCOLS)] +testcols = df.columns +if len(args.header.strip()) > 0: + newcols = args.header.split(',') + if len(newcols) == NCOLS: + if (args.time in newcols) and (args.status in newcols): + df.columns = newcols + else: + sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and/or status %s not found in supplied header parameter %s' % (args.time, args.status, args.header)) + sys.exit(4) + else: + sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): Supplied header %s has %d comma delimited header names - does not match the input tabular file %d columns' % (args.header, len(newcols), NCOLS)) + sys.exit(5) +else: # no header supplied - check for a real one that matches the x and y axis column names + colsok = (args.time in testcols) and (args.status in testcols) # if they match, probably ok...should use more code and logic.. + if colsok: + df.columns = testcols # use actual header + else: + colsok = (args.time in defaultcols) and (args.status in defaultcols) + if colsok: + sys.stderr.write('replacing first row of data derived header %s with %s' % (testcols, defaultcols)) + df.columns = defaultcols + else: + sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and status %s do not match anything in the file header, supplied header or automatic default column names %s' % (args.time, args.status, defaultcols)) +print('## Lifelines tool starting.\nUsing data header =', df.columns, 'time column =', args.time, 'status column =', args.status) +os.makedirs(args.image_dir, exist_ok=True) +fig, ax = plt.subplots() +if args.group > '': + names = [] + times = [] + events = [] + rmst = [] + for name, grouped_df in df.groupby(args.group): + T = grouped_df[args.time] + E = grouped_df[args.status] + gfit = kmf.fit(T, E, label=name) + kmf.plot_survival_function(ax=ax) + rst = lifelines.utils.restricted_mean_survival_time(gfit) + rmst.append(rst) + names.append(str(name)) + times.append(T) + events.append(E) + ngroup = len(names) + if ngroup == 2: # run logrank test if 2 groups + results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99) + print(' vs '.join(names), results) + results.print_summary() + elif ngroup > 1: + fig, ax = plt.subplots(nrows=ngroup, ncols=1, sharex=True) + for i, rst in rmst: + lifelines.plotting.rmst_plot(rst, ax=ax) + fig.savefig(os.path.join(args.image_dir,'RMST_%s.png' % args.title)) +else: + kmf.fit(df[args.time], df[args.status]) + kmf.plot_survival_function(ax=ax) +fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) +if len(args.cphcols) > 0: + fig, ax = plt.subplots() + cphcols = args.cphcols.strip().split(',') + cphcols = [x.strip() for x in cphcols] + notfound = sum([(x not in df.columns) for x in cphcols]) + if notfound > 0: + sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns)) + sys.exit(6) + print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title)) + cphcols += [args.time, args.status] + cphdf = df[cphcols] + cph.fit(cphdf, duration_col=args.time, event_col=args.status) + cph.print_summary() + cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True) + for i, ax in enumerate(cphaxes): + figr = ax[0].get_figure() + titl = figr._suptitle.get_text().replace(' ','_').replace("'","") + oname = os.path.join(args.image_dir,'CPH%s.%s' % (titl, args.image_type)) + figr.savefig(oname) + + + +