comparison lifelines_tool/plotlykm.py @ 1:232b874046a7 draft

Uploaded
author fubar
date Thu, 10 Aug 2023 07:15:22 +0000
parents dd49a7040643
children dd5e65893cb8
comparison
equal deleted inserted replaced
0:dd49a7040643 1:232b874046a7
1 # script for a lifelines ToolFactory KM/CPH tool for Galaxy 1 # script for a lifelines ToolFactory KM/CPH tool for Galaxy
2 # km models for https://github.com/galaxyproject/tools-iuc/issues/5393 2 # km models for https://github.com/galaxyproject/tools-iuc/issues/5393
3 # test as 3 # test as
4 # python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin" 4 # python plotlykm.py --input_tab rossi.tab --htmlout "testfoo" --time "week" --status "arrest" --title "test" --image_dir images --cphcol="prio,age,race,paro,mar,fin"
5 # Ross Lazarus July 2023
6 import argparse
5 7
6 import argparse
7 import os 8 import os
8 import sys 9 import sys
9 10
10 import lifelines 11 import lifelines
11 12
12 from matplotlib import pyplot as plt 13 from matplotlib import pyplot as plt
13 14
14 import pandas as pd 15 import pandas as pd
15 16
16 # Ross Lazarus July 2023
17 17
18 def trimlegend(v):
19 """
20 for int64 quintiles - must be ints - otherwise get silly legends with long float values
21 """
22 for i, av in enumerate(v):
23 x = int(av)
24 v[i] = str(x)
25 return v
18 26
19 kmf = lifelines.KaplanMeierFitter() 27 kmf = lifelines.KaplanMeierFitter()
20 cph = lifelines.CoxPHFitter() 28 cph = lifelines.CoxPHFitter()
21 29
22 parser = argparse.ArgumentParser() 30 parser = argparse.ArgumentParser()
23 a = parser.add_argument 31 a = parser.add_argument
24 a('--input_tab', default='', required=True) 32 a('--input_tab', default='rossi.tab', required=True)
25 a('--header', default='') 33 a('--header', default='')
26 a('--htmlout', default="test_run.html") 34 a('--htmlout', default="test_run.html")
27 a('--group', default='') 35 a('--group', default='')
28 a('--time', default='', required=True) 36 a('--time', default='', required=True)
29 a('--status',default='', required=True) 37 a('--status',default='', required=True)
35 args = parser.parse_args() 43 args = parser.parse_args()
36 sys.stdout = open(args.readme, 'w') 44 sys.stdout = open(args.readme, 'w')
37 df = pd.read_csv(args.input_tab, sep='\t') 45 df = pd.read_csv(args.input_tab, sep='\t')
38 NCOLS = df.columns.size 46 NCOLS = df.columns.size
39 NROWS = len(df.index) 47 NROWS = len(df.index)
48 QVALS = [.2, .4, .6, .8] # for partial cox ph plots
40 defaultcols = ['col%d' % (x+1) for x in range(NCOLS)] 49 defaultcols = ['col%d' % (x+1) for x in range(NCOLS)]
41 testcols = df.columns 50 testcols = df.columns
42 if len(args.header.strip()) > 0: 51 if len(args.header.strip()) > 0:
43 newcols = args.header.split(',') 52 newcols = args.header.split(',')
44 if len(newcols) == NCOLS: 53 if len(newcols) == NCOLS:
55 if colsok: 64 if colsok:
56 df.columns = testcols # use actual header 65 df.columns = testcols # use actual header
57 else: 66 else:
58 colsok = (args.time in defaultcols) and (args.status in defaultcols) 67 colsok = (args.time in defaultcols) and (args.status in defaultcols)
59 if colsok: 68 if colsok:
60 sys.stderr.write('replacing first row of data derived header %s with %s' % (testcols, defaultcols)) 69 print('Replacing first row of data derived header %s with %s' % (testcols, defaultcols))
61 df.columns = defaultcols 70 df.columns = defaultcols
62 else: 71 else:
63 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and status %s do not match anything in the file header, supplied header or automatic default column names %s' % (args.time, args.status, defaultcols)) 72 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): time %s and status %s do not match anything in the file header, supplied header or automatic default column names %s' % (args.time, args.status, defaultcols))
64 print('## Lifelines tool starting.\nUsing data header =', df.columns, 'time column =', args.time, 'status column =', args.status) 73 print('## Lifelines tool\nInput data header =', df.columns, 'time column =', args.time, 'status column =', args.status)
65 os.makedirs(args.image_dir, exist_ok=True) 74 os.makedirs(args.image_dir, exist_ok=True)
66 fig, ax = plt.subplots() 75 fig, ax = plt.subplots()
67 if args.group > '': 76 if args.group > '':
68 names = [] 77 names = []
69 times = [] 78 times = []
70 events = [] 79 events = []
71 rmst = []
72 for name, grouped_df in df.groupby(args.group): 80 for name, grouped_df in df.groupby(args.group):
73 T = grouped_df[args.time] 81 T = grouped_df[args.time]
74 E = grouped_df[args.status] 82 E = grouped_df[args.status]
75 gfit = kmf.fit(T, E, label=name) 83 gfit = kmf.fit(T, E, label=name)
76 kmf.plot_survival_function(ax=ax) 84 kmf.plot_survival_function(ax=ax)
77 rst = lifelines.utils.restricted_mean_survival_time(gfit)
78 rmst.append(rst)
79 names.append(str(name)) 85 names.append(str(name))
80 times.append(T) 86 times.append(T)
81 events.append(E) 87 events.append(E)
88 ax.set_title(args.title)
89 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title))
82 ngroup = len(names) 90 ngroup = len(names)
83 if ngroup == 2: # run logrank test if 2 groups 91 if ngroup == 2: # run logrank test if 2 groups
84 results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99) 92 results = lifelines.statistics.logrank_test(times[0], times[1], events[0], events[1], alpha=.99)
85 print(' vs '.join(names), results) 93 print('Logrank test for %s - %s vs %s\n' % (args.group, names[0], names[1]))
86 results.print_summary() 94 results.print_summary()
87 elif ngroup > 1:
88 fig, ax = plt.subplots(nrows=ngroup, ncols=1, sharex=True)
89 for i, rst in rmst:
90 lifelines.plotting.rmst_plot(rst, ax=ax)
91 fig.savefig(os.path.join(args.image_dir,'RMST_%s.png' % args.title))
92 else: 95 else:
93 kmf.fit(df[args.time], df[args.status]) 96 kmf.fit(df[args.time], df[args.status])
94 kmf.plot_survival_function(ax=ax) 97 kmf.plot_survival_function(ax=ax)
95 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title)) 98 ax.set_title(args.title)
99 fig.savefig(os.path.join(args.image_dir,'KM_%s.png' % args.title))
100 print('#### No grouping variable, so no log rank or other Kaplan-Meier statistical output is available')
96 if len(args.cphcols) > 0: 101 if len(args.cphcols) > 0:
97 fig, ax = plt.subplots() 102 fig, ax = plt.subplots()
103 ax.set_title('Cox-PH model: %s' % args.title)
98 cphcols = args.cphcols.strip().split(',') 104 cphcols = args.cphcols.strip().split(',')
99 cphcols = [x.strip() for x in cphcols] 105 cphcols = [x.strip() for x in cphcols]
100 notfound = sum([(x not in df.columns) for x in cphcols]) 106 notfound = sum([(x not in df.columns) for x in cphcols])
101 if notfound > 0: 107 if notfound > 0:
102 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns)) 108 sys.stderr.write('## CRITICAL USAGE ERROR (not a bug!): One or more requested Cox PH columns %s not found in supplied column header %s' % (args.cphcols, df.columns))
103 sys.exit(6) 109 sys.exit(6)
110 colsdf = df[cphcols]
104 print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title)) 111 print('### Lifelines test of Proportional Hazards results with %s as covariates on %s' % (', '.join(cphcols), args.title))
105 cphcols += [args.time, args.status] 112 cutcphcols = [args.time, args.status] + cphcols
106 cphdf = df[cphcols] 113 cphdf = df[cutcphcols]
114 ucolcounts = colsdf.nunique(axis=0)
107 cph.fit(cphdf, duration_col=args.time, event_col=args.status) 115 cph.fit(cphdf, duration_col=args.time, event_col=args.status)
108 cph.print_summary() 116 cph.print_summary()
117 for i, cov in enumerate(colsdf.columns):
118 if ucolcounts[i] > 10: # a hack - assume categories are sparse - if not imaginary quintiles will have to do
119 v = pd.Series.tolist(cphdf[cov].quantile(QVALS))
120 vdt = df.dtypes[cov]
121 if vdt == 'int64':
122 v = trimlegend(v)
123 axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v)
124 axp.set_title('Cox-PH %s quintile partials: %s' % (cov,args.title))
125 figr = axp.get_figure()
126 oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type))
127 figr.savefig(oname)
128 else:
129 v = pd.unique(cphdf[cov])
130 v = [str(x) for x in v]
131 try:
132 axp = cph.plot_partial_effects_on_outcome(cov, cmap='coolwarm', values=v)
133 axp.set_title('Cox-PH %s partials: %s' % (cov,args.title))
134 figr = axp.get_figure()
135 oname = os.path.join(args.image_dir,'%s_CoxPH_%s.%s' % (args.title, cov, args.image_type))
136 figr.savefig(oname)
137 except:
138 pass
109 cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True) 139 cphaxes = cph.check_assumptions(cphdf, p_value_threshold=0.01, show_plots=True)
110 for i, ax in enumerate(cphaxes): 140 for i, ax in enumerate(cphaxes):
111 figr = ax[0].get_figure() 141 figr = ax[0].get_figure()
112 titl = figr._suptitle.get_text().replace(' ','_').replace("'","") 142 titl = figr._suptitle.get_text().replace(' ','_').replace("'","")
113 oname = os.path.join(args.image_dir,'CPH%s.%s' % (titl, args.image_type)) 143 oname = os.path.join(args.image_dir,'CPH%s.%s' % (titl, args.image_type))
114 figr.savefig(oname) 144 figr.savefig(oname)
115
116
117
118