Mercurial > repos > fubar > lifelines_km_cph_tool
comparison lifelines_tool/plotlykm.py @ 1:232b874046a7 draft
Uploaded
author | fubar |
---|---|
date | Thu, 10 Aug 2023 07:15:22 +0000 |
parents | dd49a7040643 |
children | dd5e65893cb8 |
comparison
equal
deleted
inserted
replaced
0:dd49a7040643 | 1:232b874046a7 |
---|---|
1 # script for a lifelines ToolFactory KM/CPH tool for Galaxy | 1 # script for a lifelines ToolFactory KM/CPH tool for Galaxy |
2 # km models for https://github.com/galaxyproject/tools-iuc/issues/5393 | 2 # km models for https://github.com/galaxyproject/tools-iuc/issues/5393 |
3 # test as | 3 # test as |
4 # python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin" | 4 # python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin" |
5 # Ross Lazarus July 2023 | |
6 import argparse | |
5 | 7 |
6 import argparse | |
7 import os | 8 import os |
8 import sys | 9 import sys |
9 | 10 |
10 import lifelines | 11 import lifelines |
11 | 12 |
12 from matplotlib import pyplot as plt | 13 from matplotlib import pyplot as plt |
13 | 14 |
14 import pandas as pd | 15 import pandas as pd |
15 | 16 |
16 # Ross Lazarus July 2023 | |
17 | 17 |
18 def trimlegend(v): | |
19 """ | |
20 for int64 quintiles - must be ints - otherwise get silly legends with long float values | |
21 """ | |
22 for i, av in enumerate(v): | |
23 x = int(av) | |
24 v[i] = str(x) | |
25 return v | |
18 | 26 |
19 kmf = lifelines.KaplanMeierFitter() | 27 kmf = lifelines.KaplanMeierFitter() |
20 cph = lifelines.CoxPHFitter() | 28 cph = lifelines.CoxPHFitter() |
21 | 29 |
22 parser = argparse.ArgumentParser() | 30 parser = argparse.ArgumentParser() |
23 a = parser.add_argument | 31 a = parser.add_argument |
24 a('--input_tab', default='', required=True) | 32 a('--input_tab', default='rossi.tab', required=True) |
25 a('--header', default='') | 33 a('--header', default='') |
26 a('--htmlout', default="test_run.html") | 34 a('--htmlout', default="test_run.html") |
27 a('--group', default='') | 35 a('--group', default='') |
28 a('--time', default='', required=True) | 36 a('--time', default='', required=True) |
29 a('--status',default='', required=True) | 37 a('--status',default='', required=True) |
35 args = parser.parse_args() | 43 args = parser.parse_args() |
36 sys.stdout = open(args.readme, 'w') | 44 sys.stdout = open(args.readme, 'w') |
37 df = pd.read_csv(args.input_tab, sep='\t') | 45 df = pd.read_csv(args.input_tab, sep='\t') |
38 NCOLS = df.columns.size | 46 NCOLS = df.columns.size |
39 NROWS = len(df.index) | 47 NROWS = len(df.index) |
48 QVALS = [.2, .4, .6, .8] # for partial cox ph plots | |
40 defaultcols = ['col%d' % (x+1) for x in range(NCOLS)] | 49 defaultcols = ['col%d' % (x+1) for x in range(NCOLS)] |
41 testcols = df.columns | 50 testcols = df.columns |
42 if len(args.header.strip()) > 0: | 51 if len(args.header.strip()) > 0: |
43 newcols = args.header.split(',') | 52 newcols = args.header.split(',') |
44 if len(newcols) == NCOLS: | 53 if len(newcols) == NCOLS: |
55 if colsok: | 64 if colsok: |
56 df.columns = testcols # use actual header | 65 df.columns = testcols # use actual header |
57 else: | 66 else: |
58 colsok = (args.time in defaultcols) and (args.status in defaultcols) | 67 colsok = (args.time in defaultcols) and (args.status in defaultcols) |
59 if colsok: | 68 if colsok: |
60 sys.stderr.write('replacing first row of data derived header %s with %s' % (testcols, defaultcols)) | 69 print('Replacing first row of data derived header %s with %s' % (testcols, defaultcols)) |
61 df.columns = defaultcols | 70 df.columns = defaultcols |
62 else: | 71 else: |
63 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and status %s do not match anything in the file header, supplied header or automatic default column names %s' % (args.time, args.status, defaultcols)) | 72 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and status %s do not match anything in the file header, supplied header or automatic default column names %s' % (args.time, args.status, defaultcols)) |
64 print('## Lifelines tool starting.\nUsing data header =', df.columns, 'time column =', args.time, 'status column =', args.status) | 73 print('## Lifelines tool\nInput data header =', df.columns, 'time column =', args.time, 'status column =', args.status) |
65 os.makedirs(args.image_dir, exist_ok=True) | 74 os.makedirs(args.image_dir, exist_ok=True) |
66 fig, ax = plt.subplots() | 75 fig, ax = plt.subplots() |
67 if args.group > '': | 76 if args.group > '': |
68 names = [] | 77 names = [] |
69 times = [] | 78 times = [] |
70 events = [] | 79 events = [] |
71 rmst = [] | |
72 for name, grouped_df in df.groupby(args.group): | 80 for name, grouped_df in df.groupby(args.group): |
73 T = grouped_df[args.time] | 81 T = grouped_df[args.time] |
74 E = grouped_df[args.status] | 82 E = grouped_df[args.status] |
75 gfit = kmf.fit(T, E, label=name) | 83 gfit = kmf.fit(T, E, label=name) |
76 kmf.plot_survival_function(ax=ax) | 84 kmf.plot_survival_function(ax=ax) |
77 rst = lifelines.utils.restricted_mean_survival_time(gfit) | |
78 rmst.append(rst) | |
79 names.append(str(name)) | 85 names.append(str(name)) |
80 times.append(T) | 86 times.append(T) |
81 events.append(E) | 87 events.append(E) |
88 ax.set_title(args.title) | |
89 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) | |
82 ngroup = len(names) | 90 ngroup = len(names) |
83 if ngroup == 2: # run logrank test if 2 groups | 91 if ngroup == 2: # run logrank test if 2 groups |
84 results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99) | 92 results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99) |
85 print(' vs '.join(names), results) | 93 print('Logrank test for %s - %s vs %s\n' % (args.group, names[0], names[1])) |
86 results.print_summary() | 94 results.print_summary() |
87 elif ngroup > 1: | |
88 fig, ax = plt.subplots(nrows=ngroup, ncols=1, sharex=True) | |
89 for i, rst in rmst: | |
90 lifelines.plotting.rmst_plot(rst, ax=ax) | |
91 fig.savefig(os.path.join(args.image_dir,'RMST_%s.png' % args.title)) | |
92 else: | 95 else: |
93 kmf.fit(df[args.time], df[args.status]) | 96 kmf.fit(df[args.time], df[args.status]) |
94 kmf.plot_survival_function(ax=ax) | 97 kmf.plot_survival_function(ax=ax) |
95 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) | 98 ax.set_title(args.title) |
99 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) | |
100 print('#### No grouping variable, so no log rank or other Kaplan-Meier statistical output is available') | |
96 if len(args.cphcols) > 0: | 101 if len(args.cphcols) > 0: |
97 fig, ax = plt.subplots() | 102 fig, ax = plt.subplots() |
103 ax.set_title('Cox-PH model: %s' % args.title) | |
98 cphcols = args.cphcols.strip().split(',') | 104 cphcols = args.cphcols.strip().split(',') |
99 cphcols = [x.strip() for x in cphcols] | 105 cphcols = [x.strip() for x in cphcols] |
100 notfound = sum([(x not in df.columns) for x in cphcols]) | 106 notfound = sum([(x not in df.columns) for x in cphcols]) |
101 if notfound > 0: | 107 if notfound > 0: |
102 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns)) | 108 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns)) |
103 sys.exit(6) | 109 sys.exit(6) |
110 colsdf = df[cphcols] | |
104 print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title)) | 111 print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title)) |
105 cphcols += [args.time, args.status] | 112 cutcphcols = [args.time, args.status] + cphcols |
106 cphdf = df[cphcols] | 113 cphdf = df[cutcphcols] |
114 ucolcounts = colsdf.nunique(axis=0) | |
107 cph.fit(cphdf, duration_col=args.time, event_col=args.status) | 115 cph.fit(cphdf, duration_col=args.time, event_col=args.status) |
108 cph.print_summary() | 116 cph.print_summary() |
117 for i, cov in enumerate(colsdf.columns): | |
118 if ucolcounts[i] > 10: # a hack - assume categories are sparse - if not imaginary quintiles will have to do | |
119 v = pd.Series.tolist(cphdf[cov].quantile(QVALS)) | |
120 vdt = df.dtypes[cov] | |
121 if vdt == 'int64': | |
122 v = trimlegend(v) | |
123 axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v) | |
124 axp.set_title('Cox-PH %s quintile partials: %s' % (cov,args.title)) | |
125 figr = axp.get_figure() | |
126 oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type)) | |
127 figr.savefig(oname) | |
128 else: | |
129 v = pd.unique(cphdf[cov]) | |
130 v = [str(x) for x in v] | |
131 try: | |
132 axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v) | |
133 axp.set_title('Cox-PH %s partials: %s' % (cov,args.title)) | |
134 figr = axp.get_figure() | |
135 oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type)) | |
136 figr.savefig(oname) | |
137 except: | |
138 pass | |
109 cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True) | 139 cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True) |
110 for i, ax in enumerate(cphaxes): | 140 for i, ax in enumerate(cphaxes): |
111 figr = ax[0].get_figure() | 141 figr = ax[0].get_figure() |
112 titl = figr._suptitle.get_text().replace(' ','_').replace("'","") | 142 titl = figr._suptitle.get_text().replace(' ','_').replace("'","") |
113 oname = os.path.join(args.image_dir,'CPH%s.%s' % (titl, args.image_type)) | 143 oname = os.path.join(args.image_dir,'CPH%s.%s' % (titl, args.image_type)) |
114 figr.savefig(oname) | 144 figr.savefig(oname) |
115 | |
116 | |
117 | |
118 |