Mercurial > repos > muon-spectroscopy-computational-project > larch_criteria_report
diff larch_criteria_report.py @ 0:aa9cb2b42741 draft default tip
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_criteria_report commit 5be486890442dedfb327289d597e1c8110240735
author | muon-spectroscopy-computational-project |
---|---|
date | Tue, 14 Nov 2023 15:34:55 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/larch_criteria_report.py Tue Nov 14 15:34:55 2023 +0000 @@ -0,0 +1,103 @@ +import csv +import json +import os +import sys +from typing import Iterable + +import matplotlib.pyplot as plt + +import numpy as np + + +def plot(variable: str, column: Iterable[float]): + variable_stripped = variable.strip() + path = f"plots/{variable_stripped}.png" + plt.figure(figsize=(8, 4)) + plt.plot(column) + plt.xlim((0, len(column))) + ticks, _ = plt.xticks() + plt.xticks(np.array(ticks).astype("int")) + plt.xlabel("Dataset number") + plt.ylabel(variable_stripped) + plt.savefig(path, format="png") + + +def load(filepath: str) -> "list[list[str|float]]": + with open(filepath) as f: + reader = csv.reader(f) + header = next(reader) + columns = [[h] for h in header] + for row in reader: + for i, value in enumerate(row): + columns[i].append(float(value)) + + return columns + + +def parse_reports(input_data: str) -> "dict[str, list[float]]": + # Need to scrape variables from individual files + report_criteria = input_values["format"]["report_criteria"] + data = {c["variable"]: [] for c in report_criteria} + headers = list(data.keys()) + with open("criteria_report.csv", "w") as f_out: + writer = csv.writer(f_out) + writer.writerow([f"{h:>12s}" for h in headers]) + + if os.path.isdir(input_data): + input_files = [ + os.path.join(input_data, f) for f in os.listdir(input_data) + ] + input_files.sort() + else: + input_files = input_data.split(",") + + for input_file in input_files: + row = parse_row(data, headers, input_file) + writer.writerow(row) + + return data + + +def parse_row( + data: "dict[str, list[float]]", headers: "list[str]", input_file: str +) -> "list[str]": + row = [None] * len(headers) + with open(input_file) as f_in: + line = f_in.readline() + while line: + words = line.split() + try: + variable = words[0] + value = words[2] + if variable in headers: + row[headers.index(variable)] = f"{value:>12s}" + data[variable].append(float(value)) + if all(row): + return row + except IndexError: + # Not all lines will have potential variables/values + # so just pass + pass + + line = f_in.readline() + + # Only reach here if we run out of lines without finding a value for each + # variable + raise RuntimeError( + f"One or more criteria missing, was looking for {headers} but found " + f"{row}" + ) + + +if __name__ == "__main__": + input_data = sys.argv[1] + input_values = json.load(open(sys.argv[2], "r", encoding="utf-8")) + + if "report_criteria" in input_values["format"]: + data = parse_reports(input_data) + for variable, column in data.items(): + plot(variable, column) + else: + columns = load(input_data) + for column in columns: + plot(column[0], column[1:])