diff association_rules.py @ 7:c16818ce0424 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:22:12 +0000
parents 24c1cc2dd4a4
children
line wrap: on
line diff
--- a/association_rules.py	Thu Aug 11 09:59:00 2022 +0000
+++ b/association_rules.py	Wed Aug 09 13:22:12 2023 +0000
@@ -7,7 +7,16 @@
 from mlxtend.preprocessing import TransactionEncoder
 
 
-def main(inputs, infile, outfile, min_support=0.5, min_confidence=0.5, min_lift=1.0, min_conviction=1.0, max_length=None):
+def main(
+    inputs,
+    infile,
+    outfile,
+    min_support=0.5,
+    min_confidence=0.5,
+    min_lift=1.0,
+    min_conviction=1.0,
+    max_length=None,
+):
     """
     Parameter
     ---------
@@ -36,13 +45,13 @@
         Maximum length
 
     """
-    warnings.simplefilter('ignore')
+    warnings.simplefilter("ignore")
 
-    with open(inputs, 'r') as param_handler:
+    with open(inputs, "r") as param_handler:
         params = json.load(param_handler)
 
-    input_header = params['header0']
-    header = 'infer' if input_header else None
+    input_header = params["header0"]
+    header = "infer" if input_header else None
 
     with open(infile) as fp:
         lines = fp.read().splitlines()
@@ -65,41 +74,45 @@
 
     # Extract frequent itemsets for association rule mining
     # use_colnames: Use DataFrames' column names in the returned DataFrame instead of column indices
-    frequent_itemsets = fpgrowth(df, min_support=min_support, use_colnames=True, max_len=max_length)
+    frequent_itemsets = fpgrowth(
+        df, min_support=min_support, use_colnames=True, max_len=max_length
+    )
 
     # Get association rules, with confidence larger than min_confidence
-    rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=min_confidence)
+    rules = association_rules(
+        frequent_itemsets, metric="confidence", min_threshold=min_confidence
+    )
 
     # Filter association rules, keeping rules with lift and conviction larger than min_liftand and min_conviction
-    rules = rules[(rules['lift'] >= min_lift) & (rules['conviction'] >= min_conviction)]
+    rules = rules[(rules["lift"] >= min_lift) & (rules["conviction"] >= min_conviction)]
 
     # Convert columns from frozenset to list (more readable)
-    rules['antecedents'] = rules['antecedents'].apply(list)
-    rules['consequents'] = rules['consequents'].apply(list)
+    rules["antecedents"] = rules["antecedents"].apply(list)
+    rules["consequents"] = rules["consequents"].apply(list)
 
     # The next 3 steps are intended to fix the order of the association
     # rules generated, so tests that rely on diff'ing a desired output
     # with an expected output can pass
 
     # 1) Sort entry in every row/column for columns 'antecedents' and 'consequents'
-    rules['antecedents'] = rules['antecedents'].apply(lambda row: sorted(row))
-    rules['consequents'] = rules['consequents'].apply(lambda row: sorted(row))
+    rules["antecedents"] = rules["antecedents"].apply(lambda row: sorted(row))
+    rules["consequents"] = rules["consequents"].apply(lambda row: sorted(row))
 
     # 2) Create two temporary string columns to sort on
-    rules['ant_str'] = rules['antecedents'].apply(lambda row: " ".join(row))
-    rules['con_str'] = rules['consequents'].apply(lambda row: " ".join(row))
+    rules["ant_str"] = rules["antecedents"].apply(lambda row: " ".join(row))
+    rules["con_str"] = rules["consequents"].apply(lambda row: " ".join(row))
 
     # 3) Sort results so they are re-producable
-    rules.sort_values(by=['ant_str', 'con_str'], inplace=True)
-    del rules['ant_str']
-    del rules['con_str']
+    rules.sort_values(by=["ant_str", "con_str"], inplace=True)
+    del rules["ant_str"]
+    del rules["con_str"]
     rules.reset_index(drop=True, inplace=True)
 
     # Write association rules and metrics to file
     rules.to_csv(outfile, sep="\t", index=False)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     aparser = argparse.ArgumentParser()
     aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
     aparser.add_argument("-y", "--infile", dest="infile", required=True)
@@ -111,6 +124,13 @@
     aparser.add_argument("-t", "--length", dest="length", default=5)
     args = aparser.parse_args()
 
-    main(args.inputs, args.infile, args.outfile,
-         min_support=float(args.support), min_confidence=float(args.confidence),
-         min_lift=float(args.lift), min_conviction=float(args.conviction), max_length=int(args.length))
+    main(
+        args.inputs,
+        args.infile,
+        args.outfile,
+        min_support=float(args.support),
+        min_confidence=float(args.confidence),
+        min_lift=float(args.lift),
+        min_conviction=float(args.conviction),
+        max_length=int(args.length),
+    )