Mercurial > repos > fubar > lifelines_km_cph_tool
comparison lifelines_tool/lifelineskmcph.xml @ 1:232b874046a7 draft
Uploaded
author | fubar |
---|---|
date | Thu, 10 Aug 2023 07:15:22 +0000 |
parents | dd49a7040643 |
children | dd5e65893cb8 |
comparison
equal
deleted
inserted
replaced
0:dd49a7040643 | 1:232b874046a7 |
---|---|
1 <tool name="lifelineskmcph" id="lifelineskmcph" version="0.01"> | 1 <tool name="lifelineskmcph" id="lifelineskmcph" version="0.01"> |
2 <!--Source in git at: https://github.com/fubar2/galaxy_tf_overlay--> | 2 <!--Source in git at: https://github.com/fubar2/galaxy_tf_overlay--> |
3 <!--Created by toolfactory@galaxy.org at 09/08/2023 17:43:16 using the Galaxy Tool Factory.--> | 3 <!--Created by toolfactory@galaxy.org at 10/08/2023 15:48:43 using the Galaxy Tool Factory.--> |
4 <description>Lifelines KM and optional Cox PH models</description> | 4 <description>Lifelines KM and optional Cox PH models</description> |
5 <requirements> | 5 <requirements> |
6 <requirement version="1.5.3" type="package">pandas</requirement> | 6 <requirement version="1.5.3" type="package">pandas</requirement> |
7 <requirement version="3.7.2" type="package">matplotlib</requirement> | 7 <requirement version="3.7.2" type="package">matplotlib</requirement> |
8 <requirement version="0.27.7" type="package">lifelines</requirement> | 8 <requirement version="0.27.7" type="package">lifelines</requirement> |
38 | 38 |
39 # script for a lifelines ToolFactory KM/CPH tool for Galaxy | 39 # script for a lifelines ToolFactory KM/CPH tool for Galaxy |
40 # km models for https://github.com/galaxyproject/tools-iuc/issues/5393 | 40 # km models for https://github.com/galaxyproject/tools-iuc/issues/5393 |
41 # test as | 41 # test as |
42 # python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin" | 42 # python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin" |
43 | 43 # Ross Lazarus July 2023 |
44 import argparse | 44 import argparse |
45 | |
45 import os | 46 import os |
46 import sys | 47 import sys |
47 | 48 |
48 import lifelines | 49 import lifelines |
49 | 50 |
50 from matplotlib import pyplot as plt | 51 from matplotlib import pyplot as plt |
51 | 52 |
52 import pandas as pd | 53 import pandas as pd |
53 | 54 |
54 # Ross Lazarus July 2023 | 55 |
55 | 56 def trimlegend(v): |
57 """ | |
58 for int64 quintiles - must be ints - otherwise get silly legends with long float values | |
59 """ | |
60 for i, av in enumerate(v): | |
61 x = int(av) | |
62 v[i] = str(x) | |
63 return v | |
56 | 64 |
57 kmf = lifelines.KaplanMeierFitter() | 65 kmf = lifelines.KaplanMeierFitter() |
58 cph = lifelines.CoxPHFitter() | 66 cph = lifelines.CoxPHFitter() |
59 | 67 |
60 parser = argparse.ArgumentParser() | 68 parser = argparse.ArgumentParser() |
61 a = parser.add_argument | 69 a = parser.add_argument |
62 a('--input_tab', default='', required=True) | 70 a('--input_tab', default='rossi.tab', required=True) |
63 a('--header', default='') | 71 a('--header', default='') |
64 a('--htmlout', default="test_run.html") | 72 a('--htmlout', default="test_run.html") |
65 a('--group', default='') | 73 a('--group', default='') |
66 a('--time', default='', required=True) | 74 a('--time', default='', required=True) |
67 a('--status',default='', required=True) | 75 a('--status',default='', required=True) |
73 args = parser.parse_args() | 81 args = parser.parse_args() |
74 sys.stdout = open(args.readme, 'w') | 82 sys.stdout = open(args.readme, 'w') |
75 df = pd.read_csv(args.input_tab, sep='\t') | 83 df = pd.read_csv(args.input_tab, sep='\t') |
76 NCOLS = df.columns.size | 84 NCOLS = df.columns.size |
77 NROWS = len(df.index) | 85 NROWS = len(df.index) |
86 QVALS = [.2, .4, .6, .8] # for partial cox ph plots | |
78 defaultcols = ['col%d' % (x+1) for x in range(NCOLS)] | 87 defaultcols = ['col%d' % (x+1) for x in range(NCOLS)] |
79 testcols = df.columns | 88 testcols = df.columns |
80 if len(args.header.strip()) > 0: | 89 if len(args.header.strip()) > 0: |
81 newcols = args.header.split(',') | 90 newcols = args.header.split(',') |
82 if len(newcols) == NCOLS: | 91 if len(newcols) == NCOLS: |
104 fig, ax = plt.subplots() | 113 fig, ax = plt.subplots() |
105 if args.group > '': | 114 if args.group > '': |
106 names = [] | 115 names = [] |
107 times = [] | 116 times = [] |
108 events = [] | 117 events = [] |
109 rmst = [] | |
110 for name, grouped_df in df.groupby(args.group): | 118 for name, grouped_df in df.groupby(args.group): |
111 T = grouped_df[args.time] | 119 T = grouped_df[args.time] |
112 E = grouped_df[args.status] | 120 E = grouped_df[args.status] |
113 gfit = kmf.fit(T, E, label=name) | 121 gfit = kmf.fit(T, E, label=name) |
114 kmf.plot_survival_function(ax=ax) | 122 kmf.plot_survival_function(ax=ax) |
115 rst = lifelines.utils.restricted_mean_survival_time(gfit) | |
116 rmst.append(rst) | |
117 names.append(str(name)) | 123 names.append(str(name)) |
118 times.append(T) | 124 times.append(T) |
119 events.append(E) | 125 events.append(E) |
120 ax.set_title(args.title) | 126 ax.set_title(args.title) |
121 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) | 127 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) |
122 ngroup = len(names) | 128 ngroup = len(names) |
123 if ngroup == 2: # run logrank test if 2 groups | 129 if ngroup == 2: # run logrank test if 2 groups |
124 results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99) | 130 results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99) |
125 print('Logrank test for %s - %s vs %s\n' % (args.group, names[0], names[1])) | 131 print('Logrank test for %s - %s vs %s\n' % (args.group, names[0], names[1])) |
126 results.print_summary() | 132 results.print_summary() |
127 elif ngroup > 1: | |
128 fig, ax = plt.subplots(nrows=ngroup, ncols=1, sharex=True) | |
129 for i, rst in rmst: | |
130 lifelines.plotting.rmst_plot(rst, ax=ax) | |
131 fig.savefig(os.path.join(args.image_dir,'RMST_%s.png' % args.title)) | |
132 else: | 133 else: |
133 kmf.fit(df[args.time], df[args.status]) | 134 kmf.fit(df[args.time], df[args.status]) |
134 kmf.plot_survival_function(ax=ax) | 135 kmf.plot_survival_function(ax=ax) |
135 ax.set_title(args.title) | 136 ax.set_title(args.title) |
136 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) | 137 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) |
138 print('#### No grouping variable, so no log rank or other Kaplan-Meier statistical output is available') | |
137 if len(args.cphcols) > 0: | 139 if len(args.cphcols) > 0: |
138 fig, ax = plt.subplots() | 140 fig, ax = plt.subplots() |
139 ax.set_title('Cox PH model: %s' % args.title) | 141 ax.set_title('Cox-PH model: %s' % args.title) |
140 cphcols = args.cphcols.strip().split(',') | 142 cphcols = args.cphcols.strip().split(',') |
141 cphcols = [x.strip() for x in cphcols] | 143 cphcols = [x.strip() for x in cphcols] |
142 notfound = sum([(x not in df.columns) for x in cphcols]) | 144 notfound = sum([(x not in df.columns) for x in cphcols]) |
143 if notfound > 0: | 145 if notfound > 0: |
144 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns)) | 146 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns)) |
145 sys.exit(6) | 147 sys.exit(6) |
148 colsdf = df[cphcols] | |
146 print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title)) | 149 print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title)) |
147 cphcols += [args.time, args.status] | 150 cutcphcols = [args.time, args.status] + cphcols |
148 cphdf = df[cphcols] | 151 cphdf = df[cutcphcols] |
152 ucolcounts = colsdf.nunique(axis=0) | |
149 cph.fit(cphdf, duration_col=args.time, event_col=args.status) | 153 cph.fit(cphdf, duration_col=args.time, event_col=args.status) |
150 cph.print_summary() | 154 cph.print_summary() |
155 for i, cov in enumerate(colsdf.columns): | |
156 if ucolcounts[i] > 10: | |
157 v = pd.Series.tolist(cphdf[cov].quantile(QVALS)) | |
158 vdt = df.dtypes[cov] | |
159 if vdt == 'int64': | |
160 v = trimlegend(v) | |
161 axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v) | |
162 axp.set_title('Cox-PH %s quintile partials: %s' % (cov,args.title)) | |
163 figr = axp.get_figure() | |
164 oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type)) | |
165 figr.savefig(oname) | |
166 else: | |
167 v = pd.unique(cphdf[cov]) | |
168 v = [str(x) for x in v] | |
169 try: | |
170 axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v) | |
171 axp.set_title('Cox-PH %s partials: %s' % (cov,args.title)) | |
172 figr = axp.get_figure() | |
173 oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type)) | |
174 figr.savefig(oname) | |
175 except: | |
176 pass | |
151 cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True) | 177 cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True) |
152 for i, ax in enumerate(cphaxes): | 178 for i, ax in enumerate(cphaxes): |
153 figr = ax[0].get_figure() | 179 figr = ax[0].get_figure() |
154 titl = figr._suptitle.get_text().replace(' ','_').replace("'","") | 180 titl = figr._suptitle.get_text().replace(' ','_').replace("'","") |
155 oname = os.path.join(args.image_dir,'CPH%s.%s' % (titl, args.image_type)) | 181 oname = os.path.join(args.image_dir,'CPH%s.%s' % (titl, args.image_type)) |