0
|
1 # script for a lifelines ToolFactory KM/CPH tool for Galaxy
|
|
2 # km models for https://github.com/galaxyproject/tools-iuc/issues/5393
|
|
3 # test as
|
|
4 # python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin"
|
|
5
|
|
6 import argparse
|
|
7 import os
|
|
8 import sys
|
|
9
|
|
10 import lifelines
|
|
11
|
|
12 from matplotlib import pyplot as plt
|
|
13
|
|
14 import pandas as pd
|
|
15
|
|
16 # Ross Lazarus July 2023
|
|
17
|
|
18
|
|
19 kmf = lifelines.KaplanMeierFitter()
|
|
20 cph = lifelines.CoxPHFitter()
|
|
21
|
|
22 parser = argparse.ArgumentParser()
|
|
23 a = parser.add_argument
|
|
24 a('--input_tab', default='', required=True)
|
|
25 a('--header', default='')
|
|
26 a('--htmlout', default="test_run.html")
|
|
27 a('--group', default='')
|
|
28 a('--time', default='', required=True)
|
|
29 a('--status',default='', required=True)
|
|
30 a('--cphcols',default='')
|
|
31 a('--title', default='Default plot title')
|
|
32 a('--image_type', default='png')
|
|
33 a('--image_dir', default='images')
|
|
34 a('--readme', default='run_log.txt')
|
|
35 args = parser.parse_args()
|
|
36 sys.stdout = open(args.readme, 'w')
|
|
37 df = pd.read_csv(args.input_tab, sep='\t')
|
|
38 NCOLS = df.columns.size
|
|
39 NROWS = len(df.index)
|
|
40 defaultcols = ['col%d' % (x+1) for x in range(NCOLS)]
|
|
41 testcols = df.columns
|
|
42 if len(args.header.strip()) > 0:
|
|
43 newcols = args.header.split(',')
|
|
44 if len(newcols) == NCOLS:
|
|
45 if (args.time in newcols) and (args.status in newcols):
|
|
46 df.columns = newcols
|
|
47 else:
|
|
48 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and/or status %s not found in supplied header parameter %s' % (args.time, args.status, args.header))
|
|
49 sys.exit(4)
|
|
50 else:
|
|
51 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): Supplied header %s has %d comma delimited header names - does not match the input tabular file %d columns' % (args.header, len(newcols), NCOLS))
|
|
52 sys.exit(5)
|
|
53 else: # no header supplied - check for a real one that matches the x and y axis column names
|
|
54 colsok = (args.time in testcols) and (args.status in testcols) # if they match, probably ok...should use more code and logic..
|
|
55 if colsok:
|
|
56 df.columns = testcols # use actual header
|
|
57 else:
|
|
58 colsok = (args.time in defaultcols) and (args.status in defaultcols)
|
|
59 if colsok:
|
|
60 sys.stderr.write('replacing first row of data derived header %s with %s' % (testcols, defaultcols))
|
|
61 df.columns = defaultcols
|
|
62 else:
|
|
63 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and status %s do not match anything in the file header, supplied header or automatic default column names %s' % (args.time, args.status, defaultcols))
|
|
64 print('## Lifelines tool starting.\nUsing data header =', df.columns, 'time column =', args.time, 'status column =', args.status)
|
|
65 os.makedirs(args.image_dir, exist_ok=True)
|
|
66 fig, ax = plt.subplots()
|
|
67 if args.group > '':
|
|
68 names = []
|
|
69 times = []
|
|
70 events = []
|
|
71 rmst = []
|
|
72 for name, grouped_df in df.groupby(args.group):
|
|
73 T = grouped_df[args.time]
|
|
74 E = grouped_df[args.status]
|
|
75 gfit = kmf.fit(T, E, label=name)
|
|
76 kmf.plot_survival_function(ax=ax)
|
|
77 rst = lifelines.utils.restricted_mean_survival_time(gfit)
|
|
78 rmst.append(rst)
|
|
79 names.append(str(name))
|
|
80 times.append(T)
|
|
81 events.append(E)
|
|
82 ngroup = len(names)
|
|
83 if ngroup == 2: # run logrank test if 2 groups
|
|
84 results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99)
|
|
85 print(' vs '.join(names), results)
|
|
86 results.print_summary()
|
|
87 elif ngroup > 1:
|
|
88 fig, ax = plt.subplots(nrows=ngroup, ncols=1, sharex=True)
|
|
89 for i, rst in rmst:
|
|
90 lifelines.plotting.rmst_plot(rst, ax=ax)
|
|
91 fig.savefig(os.path.join(args.image_dir,'RMST_%s.png' % args.title))
|
|
92 else:
|
|
93 kmf.fit(df[args.time], df[args.status])
|
|
94 kmf.plot_survival_function(ax=ax)
|
|
95 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title))
|
|
96 if len(args.cphcols) > 0:
|
|
97 fig, ax = plt.subplots()
|
|
98 cphcols = args.cphcols.strip().split(',')
|
|
99 cphcols = [x.strip() for x in cphcols]
|
|
100 notfound = sum([(x not in df.columns) for x in cphcols])
|
|
101 if notfound > 0:
|
|
102 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns))
|
|
103 sys.exit(6)
|
|
104 print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title))
|
|
105 cphcols += [args.time, args.status]
|
|
106 cphdf = df[cphcols]
|
|
107 cph.fit(cphdf, duration_col=args.time, event_col=args.status)
|
|
108 cph.print_summary()
|
|
109 cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True)
|
|
110 for i, ax in enumerate(cphaxes):
|
|
111 figr = ax[0].get_figure()
|
|
112 titl = figr._suptitle.get_text().replace(' ','_').replace("'","")
|
|
113 oname = os.path.join(args.image_dir,'CPH%s.%s' % (titl, args.image_type))
|
|
114 figr.savefig(oname)
|
|
115
|
|
116
|
|
117
|
|
118
|