Mercurial > repos > onnodg > cdhit_analysis
diff cdhit_analysis.py @ 4:e64af72e1b8f draft default tip
planemo upload for repository https://github.com/Onnodg/Naturalis_NLOOR/tree/main/NLOOR_scripts/process_clusters_tool commit 4017d38cf327c48a6252e488ba792527dae97a70-dirty
| author | onnodg |
|---|---|
| date | Mon, 15 Dec 2025 16:44:40 +0000 |
| parents | c6981ea453ae |
| children |
line wrap: on
line diff
--- a/cdhit_analysis.py Fri Oct 24 09:38:24 2025 +0000 +++ b/cdhit_analysis.py Mon Dec 15 16:44:40 2025 +0000 @@ -1,13 +1,12 @@ """ This script processes cluster output files from cd-hit-est for use in Galaxy. -It extracts cluster information, associates taxa and e-values from annotation files, +It extracts cluster information, associates taxa from annotation files, performs statistical calculations, and generates text and plot outputs summarizing similarity and taxonomic distributions. - Main steps: -1. Parse cd-hit-est cluster file and (optional) annotation file. -2. Process each cluster to extract similarity, taxa, and e-value information. +1. Parse cd-hit-est cluster file and (optional) annotation Excel file. +2. Process each cluster to extract similarity and taxa information. 3. Aggregate results across clusters. 4. Generate requested outputs: text summaries, plots, and Excel reports. """ @@ -16,65 +15,48 @@ from collections import Counter, defaultdict import os import re +from math import sqrt + import matplotlib.pyplot as plt import pandas as pd -from math import sqrt import openpyxl +def log_message(messages, text: str): + """Append a message to the log list.""" + messages.append(text) def parse_arguments(args_list=None): """Parse command-line arguments for the script.""" - parser = argparse.ArgumentParser( - description='Create taxa analysis from cd-hit cluster files') - parser.add_argument('--input_cluster', type=str, required=True, - help='Input cluster file (.clstr)') - parser.add_argument('--input_annotation', type=str, required=False, - help='Input annotation file (.out)') + parser = argparse.ArgumentParser(description="Create taxa analysis from cd-hit cluster files") + parser.add_argument("--input_cluster", type=str, required=True, help="Input cluster file (.clstr)") + parser.add_argument("--input_annotation", type=str, required=False, help="Input annotation Excel file (header annotations)") + + parser.add_argument("--output_similarity_txt", type=str, help="Similarity text output file") + parser.add_argument("--output_similarity_plot",type=str, help="Similarity plot output file (PNG)") + parser.add_argument("--output_count", type=str, help="Count summary output file") + parser.add_argument("--output_excel", type=str, help="Count summary output file") + parser.add_argument("--output_taxa_clusters", action="store_true", default=False, help="Excel output: include raw taxa per cluster sheet") + parser.add_argument("--output_taxa_processed", action="store_true", default=False, help="Excel output: include processed/LCA taxa per cluster sheet") + parser.add_argument("--log_file", type=str, help="Optional log file with run statistics") - # Galaxy output files - parser.add_argument('--output_similarity_txt', type=str, - help='Similarity text output file') - parser.add_argument('--output_similarity_plot', type=str, - help='Similarity plot output file') - parser.add_argument('--output_evalue_txt', type=str, - help='E-value text output file') - parser.add_argument('--output_evalue_plot', type=str, - help='E-value plot output file') - parser.add_argument('--output_count', type=str, - help='Count summary output file') - parser.add_argument('--output_taxa_clusters', type=str, - help='Taxa per cluster output file') - parser.add_argument('--output_taxa_processed', type=str, - help='Processed taxa output file') - # Plot parameters - parser.add_argument('--simi_plot_y_min', type=float, default=95.0, - help='Minimum value of the y-axis in the similarity plot') - parser.add_argument('--simi_plot_y_max', type=float, default=100.0, - help='Maximum value of the y-axis in the similarity plot') + parser.add_argument("--simi_plot_y_min", type=float, default=95.0, help="Minimum value of the y-axis in the similarity plot") + parser.add_argument("--simi_plot_y_max", type=float, default=100.0, help="Maximum value of the y-axis in the similarity plot") - # Uncertain taxa configuration - parser.add_argument('--uncertain_taxa_use_ratio', type=float, default=0.5, - help='Ratio at which uncertain taxa count toward the correct taxa') - parser.add_argument('--min_to_split', type=float, default=0.45, - help='Minimum percentage for taxonomic split') - parser.add_argument('--min_count_to_split', type=int, default=10, - help='Minimum count for taxonomic split') + parser.add_argument("--uncertain_taxa_use_ratio", type=float, default=0.5, help="Ratio at which uncertain taxa count toward the dominant taxa") + parser.add_argument("--min_to_split", type=float, default=0.45, help=("Minimum fraction (0–0.5) for splitting a cluster taxonomically. " + "ratio = count(second) / (count(first) + count(second))")) + parser.add_argument("--min_count_to_split", type=int, default=10, help=("Minimum combined count of the two top taxa to avoid being forced " + "into an 'uncertain' split at this level.")) - # Processing options - parser.add_argument('--show_unannotated_clusters', action='store_true', default=False, - help='Show unannotated clusters in output') - parser.add_argument('--make_taxa_in_cluster_split', action='store_true', default=False, - help='Split clusters with multiple taxa') - parser.add_argument('--print_empty_files', action='store_true', default=False, - help='Print messages about empty annotation files') + parser.add_argument("--min_cluster_support", type=int, default=1, help=("Minimum number of annotated reads required for a cluster to be " + "included in the processed/LCA taxa sheet. Unannotated reads do not count toward this threshold.")) return parser.parse_args(args_list) -# Color map for plots COLORMAP = [ -# List of RGBA tuples for bar coloring in plots + # List of RGBA tuples for bar coloring in plots (0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0), (1.0, 0.4980392156862745, 0.054901960784313725, 1.0), (0.17254901960784313, 0.6274509803921569, 0.17254901960784313, 1.0), @@ -84,175 +66,216 @@ (0.8901960784313725, 0.4666666666666667, 0.7607843137254902, 1.0), (0.4980392156862745, 0.4980392156862745, 0.4980392156862745, 1.0), (0.7372549019607844, 0.7411764705882353, 0.13333333333333333, 1.0), - (0.09019607843137255, 0.7450980392156863, 0.8117647058823529, 1.0) + (0.09019607843137255, 0.7450980392156863, 0.8117647058823529, 1.0), ] -def parse_cluster_file(cluster_file, annotation_file=None, print_empty=False, raise_on_error=False): - """ - Parse the cd-hit-est cluster file (.clstr) and (optionally) an Excel annotation file. +def parse_cluster_file(cluster_file, annotation_file, raise_on_error=False, log_messages=None): + """Parse CD-HIT cluster output and attach taxonomic annotations. - It extracts cluster information (header, read count, similarity) and associates - taxonomic information and E-values from the annotation file based on the read header. + Cluster entries are parsed for read headers, counts, and similarities. + If an annotation Excel file is supplied, taxa, seq_id and source fields + are mapped per read from sheet ``Individual_Reads``. - :param cluster_file: Path to cd-hit cluster file (.clstr). + :param cluster_file: Path to the CD-HIT cluster file. :type cluster_file: str - :param annotation_file: Path to Excel annotation file with taxa and e-values. - :type annotation_file: str, optional - :param print_empty: Print a message if the annotation file is not found or empty. - :type print_empty: bool - :param raise_on_error: Raise parsing errors instead of printing warnings. + :param annotation_file: Optional Excel annotation file. + :type annotation_file: str or None + :param raise_on_error: Whether to raise parsing errors directly. :type raise_on_error: bool - :return: List of clusters, where each cluster is a dict mapping read header to a dict of read info. - :rtype: list[dict[str, dict]] - :raises ValueError: If similarity cannot be parsed from a cluster line. - :raises UnboundLocalError: If an error occurs during annotation processing. + :param log_messages: Optional message collector for logging. + :type log_messages: list[str] or None + :return: List of clusters, where each cluster is a dict keyed by header. + :rtype: list[dict] """ clusters = [] current_cluster = {} - # Load annotations if provided annotations = {} if annotation_file and os.path.exists(annotation_file): - # Lees het Excel-bestand - df = pd.read_excel(annotation_file, sheet_name='Individual_Reads', engine='openpyxl') + try: + df = pd.read_excel( + annotation_file, + sheet_name="Individual_Reads", + engine="openpyxl", + ) + required_cols = {"header", "seq_id", "source", "taxa"} + missing = required_cols - set(df.columns) + if missing: + msg = ( + f"Annotation file missing columns {missing}, " + "continuing without annotations." + ) + if log_messages is not None: + log_message(log_messages, msg) + else: + for _, row in df.iterrows(): + header = str(row["header"]) + seq_id = str(row["seq_id"]) + source = str(row["source"]) + taxa = str(row["taxa"]) + annotations[header] = { + "seq_id": seq_id, + "source": source, + "taxa": taxa, + } + if log_messages is not None: + log_message( + log_messages, + f"Loaded {len(annotations)} annotated headers", + ) + except Exception as exc: + msg = ( + "Failed to read annotation file; proceeding without annotations. " + f"Details: {exc}" + ) + if log_messages is not None: + log_message(log_messages, msg) + else: + if log_messages is not None: + if annotation_file: + log_message( + log_messages, + "Annotation file not found; proceeding without annotations.", + ) + else: + log_message( + log_messages, + "No annotation file provided; proceeding without annotations.", + ) - # Itereer over de rijen - for _, row in df.iterrows(): - header = row['header'] # kolomnaam zoals in Excel - evalue = row['e_value'] # of de kolomnaam die je wilt gebruiken - taxa = row['taxa'] # afhankelijk van hoe je taxa opslaat - - annotations[header] = {'evalue': evalue, 'taxa': taxa} - elif annotation_file and print_empty: - print(f"Annotation file {annotation_file} not found or empty") - with open(cluster_file, 'r') as f: + with open(cluster_file, "r") as f: for line in f: line = line.strip() - if line.startswith('>Cluster'): + if not line: + continue + + if line.startswith(">Cluster"): # Start of new cluster, save previous if exists if current_cluster: clusters.append(current_cluster) current_cluster = {} else: - # Parse sequence line parts = line.split() - if len(parts) >= 2: - # Extract header and count - header_part = parts[2].strip('>.') - header_parts = header_part.split('(') - if len(header_parts) > 1: - last_part = header_parts[-1].strip(')') - header = header_parts[0] - if last_part: + if len(parts) < 3: + continue + + header_part = parts[2].strip(">.") + header_parts = header_part.split("(") + header = header_parts[0] + count = 0 + if len(header_parts) > 1: + last_part = header_parts[-1].strip(")") + if last_part: + try: count = int(last_part) - else: - print('no count') + except ValueError: count = 0 - header = '' - # Extract similarity - similarity_part = parts[-1].strip() - if '*' in similarity_part: - similarity = 100.0 # Representative sequence + + similarity_part = parts[-1].strip() + if "*" in similarity_part: + similarity = 100.0 + else: + matches = re.findall(r"[\d.]+", similarity_part) + if matches: + similarity = float(matches[-1]) else: - matches = re.findall(r'[\d.]+', similarity_part) - if matches: - similarity = float(matches[-1]) - else: - raise ValueError(f"Could not parse similarity from: '{similarity_part}'") - # Get annotation info - try: - if header in annotations: - taxa = annotations[header]['taxa'] - evalue = annotations[header]['evalue'] - else: - taxa = 'Unannotated read' - evalue = 'Unannotated read' + msg = f"Could not parse similarity from: '{similarity_part}'" + if raise_on_error: + raise ValueError(msg) + if log_messages is not None: + log_message(log_messages, f"WARNING: {msg}") + similarity = 0.0 - current_cluster[header] = { - 'count': count, - 'similarity': similarity, - 'taxa': taxa, - 'evalue': evalue - } - except UnboundLocalError as e: - if raise_on_error: - raise UnboundLocalError(str(e)) - print(f"Error: {e}, No annotations found") + if header in annotations: + taxa = annotations[header]["taxa"] + seq_id = annotations[header]["seq_id"] + source = annotations[header]["source"] + else: + taxa = "Unannotated read" + seq_id = "Unannotated read" + source = "Unannotated read" - # Add the last cluster + current_cluster[header] = { + "count": count, + "similarity": similarity, + "taxa": taxa, + "seq_id": seq_id, + "source": source, + } + if current_cluster: clusters.append(current_cluster) + if log_messages is not None: + log_message(log_messages, f"Parsed {len(clusters)} clusters") + return clusters -def process_cluster_data(cluster): - """ - Process a single cluster to extract E-value, similarity, and taxa data for all reads. +def process_cluster_data(cluster: dict): + """Convert a raw cluster into similarity lists and aggregated taxa counts. - Aggregates information from all reads in the cluster, storing read counts, - E-values, similarities, and taxa in lists and a dictionary. + Similarity values are expanded based on read count. Taxa are grouped + by unique taxonomic labels, tracking supporting seq_ids and sources. - :param cluster: Cluster data mapping read headers to read info. + :param cluster: Mapping of read header to cluster metadata. :type cluster: dict - :return: A tuple containing: (list of E-values, list of similarity values, dict of taxa -> counts). - The first element of the E-value list ([0]) stores the count of unannotated reads. - :rtype: tuple[list[float | int], list[float], dict[str, int]] + :return: similarity_list, taxa_map, annotated_count, unannotated_count + :rtype: tuple[list[float], dict, int, int] """ - eval_list = [0.0] # First element for unannotated count - simi_list = [] - taxa_dict = {'Unannotated read': 0} + similarity_list: list[float] = [] + taxa_map: dict[str, dict] = {} + annotated_count = 0 + unannotated_count = 0 - for header, info in cluster.items(): - count = info['count'] - similarity = info['similarity'] - taxa = info['taxa'] - evalue = info['evalue'] + for _header, info in cluster.items(): + count = info["count"] + similarity = info["similarity"] + taxa = info["taxa"] + seq_id = info["seq_id"] + source = info["source"] + + similarity_list.extend([similarity] * count) - if evalue == 'Unannotated read': - eval_list[0] += count - taxa_dict['Unannotated read'] += count + key = taxa if seq_id != "Unannotated read" else "Unannotated read" + if key not in taxa_map: + taxa_map[key] = { + "count": 0, + "seq_ids": set(), + "sources": set(), + } + + taxa_map[key]["count"] += count + taxa_map[key]["seq_ids"].add(seq_id) + taxa_map[key]["sources"].add(source) + + if seq_id == "Unannotated read": + unannotated_count += count else: - try: - eval_val = float(evalue) - for _ in range(count): - eval_list.append(eval_val) - except ValueError: - eval_list[0] += count + annotated_count += count - if taxa not in taxa_dict: - taxa_dict[taxa] = 0 - taxa_dict[taxa] += count + return similarity_list, taxa_map, annotated_count, unannotated_count - # Add similarity values - for _ in range(count): - simi_list.append(similarity) - return eval_list, simi_list, taxa_dict def calculate_cluster_taxa(taxa_dict, args): - """ - Calculate the most likely taxa for a cluster based on read counts. + """Resolve cluster-level taxa using weighted recursive LCA splitting. - This function applies the 'uncertain taxa use ratio' for unannotated reads - and uses a recursive approach to potentially split a cluster into sub-clusters - if taxonomic dominance is not strong enough (based on ``min_to_split`` - and ``min_count_to_split``). + Unannotated reads are treated as fully uncertain taxa but contribute + partially via ``uncertain_taxa_use_ratio`` when determining dominant levels. - :param taxa_dict: Mapping of taxa (full string) -> read counts. + :param taxa_dict: Mapping of taxa string to total read count. :type taxa_dict: dict[str, int] - :param args: Parsed script arguments, including parameters for taxa calculation. + :param args: Argument namespace containing resolution parameters. :type args: argparse.Namespace - :return: A list of refined taxa assignments (dictionaries), where each dictionary - represents a potentially split sub-cluster. - :rtype: list[dict[str, float | int]] + :return: List of taxonomic groups after recursive resolution. + :rtype: list[dict[str, int]] """ - # Replace 'Unannotated read' with uncertain taxa format processed_dict = {} for taxa, count in taxa_dict.items(): - if taxa == 'Unannotated read': - uncertain_taxa = ' / '.join(['Uncertain taxa'] * 7) + if taxa == "Unannotated read": + uncertain_taxa = " / ".join(["Uncertain taxa"] * 7) processed_dict[uncertain_taxa] = count else: processed_dict[taxa] = count @@ -261,290 +284,325 @@ def _recursive_taxa_calculation(taxa_dict, args, level): - """ - Recursive helper to calculate and potentially split taxa at each taxonomic level. + """Recursive helper performing level-by-level weighted LCA resolution. - :param taxa_dict: Taxa counts at the current level (or sub-group). - :type taxa_dict: dict[str, float | int] - :param args: Parsed script arguments. + At each taxonomic rank, taxa are grouped and compared. If multiple + taxa exceed the configured split thresholds, the cluster is divided + and recursion continues deeper. + + :param taxa_dict: Mapping of taxa to count for this recursion stage. + :type taxa_dict: dict[str, int] + :param args: Parameter namespace controlling uncertainty and split rules. :type args: argparse.Namespace - :param level: Index of the current taxonomic level (0=kingdom, max 6=species). + :param level: Current taxonomy depth (0–6). :type level: int - :return: List of refined taxa dictionaries. - :rtype: list[dict[str, float | int]] + :return: One or more resolved taxon groups at deeper levels. + :rtype: list[dict[str, int]] """ - if level >= 7: # Max 7 taxonomic levels + if level >= 7: return [taxa_dict] level_dict = defaultdict(float) - - # Group by taxonomic level for taxa, count in taxa_dict.items(): - taxa_parts = taxa.split(' / ') + taxa_parts = taxa.split(" / ") if level < len(taxa_parts): level_taxon = taxa_parts[level] - if level_taxon == 'Uncertain taxa': + if level_taxon == "Uncertain taxa": level_dict[level_taxon] += count * args.uncertain_taxa_use_ratio else: level_dict[level_taxon] += count if len(level_dict) <= 1: - # Only one taxon at this level, continue to next level return _recursive_taxa_calculation(taxa_dict, args, level + 1) - # Sort by abundance sorted_taxa = sorted(level_dict.items(), key=lambda x: x[1], reverse=True) - result = [] dominant_taxon = sorted_taxa[0][0] - # Check if we should split for i in range(1, len(sorted_taxa)): secondary_taxon = sorted_taxa[i][0] total_count = sorted_taxa[0][1] + sorted_taxa[i][1] ratio = sorted_taxa[i][1] / total_count if ratio >= args.min_to_split or total_count <= args.min_count_to_split: - # Split off this taxon - split_dict = {taxa: count for taxa, count in taxa_dict.items() - if taxa.split(' / ')[level] == secondary_taxon} + split_dict = { + taxa: count + for taxa, count in taxa_dict.items() + if taxa.split(" / ")[level] == secondary_taxon + } result.extend(_recursive_taxa_calculation(split_dict, args, level + 1)) - # Process the dominant group - dominant_dict = {taxa: count for taxa, count in taxa_dict.items() - if taxa.split(' / ')[level] == dominant_taxon} + dominant_dict = { + taxa: count + for taxa, count in taxa_dict.items() + if taxa.split(" / ")[level] == dominant_taxon + } result.extend(_recursive_taxa_calculation(dominant_dict, args, level + 1)) return result -def write_similarity_output(all_simi_data, output_file): - """ - Write the similarity text output, including the mean and standard deviation, - and a count per similarity percentage. + +def write_similarity_output(cluster_data_list, output_file, log_messages=None): + """Write similarity statistics and per-cluster distributions. + + Output includes mean, standard deviation, and counts of each + similarity value per cluster. - :param all_simi_data: List of all similarity percentages from all reads. - :type all_simi_data: list[float] - :param output_file: Path to the output file. + :param cluster_data_list: List of processed cluster dictionaries. + :type cluster_data_list: list[dict] + :param output_file: Destination path for the TSV summary. :type output_file: str + :param log_messages: Optional log collector. + :type log_messages: list[str] or None :return: None :rtype: None """ - if not all_simi_data: + all_simi = [] + pair_counter = Counter() + + for cluster_index, cluster_data in enumerate(cluster_data_list): + simi_list = cluster_data["similarities"] + if not simi_list: + continue + all_simi.extend(simi_list) + for s in simi_list: + pair_counter[(cluster_index, s)] += 1 + + if not all_simi: + if log_messages is not None: + log_message( + log_messages, + "No similarity data found; similarity text file not written.", + ) return - avg = sum(all_simi_data) / len(all_simi_data) - variance = sum((s - avg) ** 2 for s in all_simi_data) / len(all_simi_data) + avg = sum(all_simi) / len(all_simi) + variance = sum((s - avg) ** 2 for s in all_simi) / len(all_simi) st_dev = sqrt(variance) - simi_counter = Counter(all_simi_data) - simi_sorted = sorted(simi_counter.items(), key=lambda x: -x[0]) + with open(output_file, "w") as f: + f.write(f"# Average similarity (all reads): {avg:.2f}\n") + f.write(f"# Standard deviation (all reads): {st_dev:.2f}\n") + f.write("cluster\tsimilarity\tcount\n") - with open(output_file, 'w') as f: - f.write(f"# Average similarity: {avg:.2f}\n") - f.write(f"# Standard deviation: {st_dev:.2f}\n") - f.write("similarity\tcount\n") - for similarity, count in simi_sorted: - f.write(f"{similarity}\t{count}\n") + for (cluster_index, similarity), count in sorted( + pair_counter.items(), key=lambda x: (x[0][0], -x[0][1]) + ): + f.write(f"{cluster_index}\t{similarity}\t{count}\n") + + if log_messages is not None: + log_message(log_messages, "Similarity text summary written succesfully") + log_message( + log_messages, f"Similarity mean = {avg:.2f}, std = {st_dev:.2f}" + ) -def write_evalue_output(all_eval_data, output_file): - """ - Write the E-value text output, including the count of unannotated reads - and a count per E-value. +def write_count_output(cluster_data_list, output_file, log_messages=None): + """Write counts of annotated and unannotated reads per cluster. + + A summary row containing total and percentage values is appended. - :param all_eval_data: List of E-values from all reads. The first element ([0]) is the count of unannotated reads. - :type all_eval_data: list[float | int] - :param output_file: Path to the output file. + :param cluster_data_list: List of processed clusters. + :type cluster_data_list: list[dict] + :param output_file: Path for the count summary text file. :type output_file: str + :param log_messages: Optional logging accumulator. + :type log_messages: list[str] or None :return: None :rtype: None """ - unanno_count = all_eval_data[0] - eval_counter = Counter(all_eval_data[1:]) + total_annotated = 0 + total_unannotated = 0 - with open(output_file, 'w') as f: - f.write("evalue\tcount\n") - if unanno_count > 0: - f.write(f"unannotated\t{unanno_count}\n") - - eval_sorted = sorted(eval_counter.items(), - key=lambda x: (-x[1], float(x[0]) if isinstance(x[0], (int, float)) else float('inf'))) - for value, count in eval_sorted: - f.write(f"{value}\t{count}\n") - + with open(output_file, "w") as f: + f.write( + "cluster\tunannotated\tannotated\ttotal\tperc_unannotated\tperc_annotated\n" + ) -def write_count_output(all_eval_data, cluster_data_list, output_file): - """ - Write a summary of annotated and unannotated read counts per cluster - and for the total sample. + for cluster_index, cluster_data in enumerate(cluster_data_list): + unannotated = cluster_data["unannotated"] + annotated = cluster_data["annotated"] + total = unannotated + annotated - :param all_eval_data: List of E-values from all reads for the total count summary. - :type all_eval_data: list[float | int] - :param cluster_data_list: List of tuples (eval_list, simi_list, taxa_dict) per cluster. - :type cluster_data_list: list[tuple] - :param output_file: Path to the output file. - :type output_file: str - :return: None - :rtype: None - """ - with open(output_file, 'w') as f: - f.write("cluster\tunannotated\tannotated\ttotal\tperc_unannotated\tperc_annotated\n") - - for cluster_num, (eval_list, _, _) in enumerate(cluster_data_list): - unannotated = eval_list[0] - annotated = len(eval_list[1:]) - total = unannotated + annotated + total_annotated += annotated + total_unannotated += unannotated if total > 0: perc_annotated = (annotated / total) * 100 perc_unannotated = (unannotated / total) * 100 else: - perc_annotated = perc_unannotated = 0 + perc_annotated = perc_unannotated = 0.0 f.write( - f"{cluster_num}\t{unannotated}\t{annotated}\t{total}\t{perc_unannotated:.2f}\t{perc_annotated:.2f}\n") + f"{cluster_index}\t{unannotated}\t{annotated}\t{total}" + f"\t{perc_unannotated:.2f}\t{perc_annotated:.2f}\n" + ) - # Add full sample summary - total_unannotated = all_eval_data[0] - total_annotated = len(all_eval_data[1:]) - grand_total = total_unannotated + total_annotated - + grand_total = total_annotated + total_unannotated if grand_total > 0: perc_annotated = (total_annotated / grand_total) * 100 perc_unannotated = (total_unannotated / grand_total) * 100 else: - perc_annotated = perc_unannotated = 0 + perc_annotated = perc_unannotated = 0.0 f.write( - f"TOTAL\t{total_unannotated}\t{total_annotated}\t{grand_total}\t{perc_unannotated:.2f}\t{perc_annotated:.2f}\n") + f"TOTAL\t{total_unannotated}\t{total_annotated}\t{grand_total}" + f"\t{perc_unannotated:.2f}\t{perc_annotated:.2f}\n" + ) + + if log_messages is not None: + log_message(log_messages, "Count summary written succesfully") + log_message( + log_messages, + f"TOTAL annotated={total_annotated}, unannotated={total_unannotated}, total={grand_total}", + ) -def write_taxa_clusters_output(cluster_data_list, output_file): - """ - Write raw taxa information per cluster to an Excel file. +def write_taxa_excel(cluster_data_list, args, output_file, write_raw: bool, write_processed: bool, log_messages=None): + """Write raw and processed taxa data into a combined Excel report. + + Generates up to three sheets: + + - ``Raw_Taxa_Clusters``: Full taxa per read cluster (if enabled). + - ``Processed_Taxa_Clusters``: LCA-resolved taxa (if enabled). + - ``Settings``: Parameters used for taxonomic resolution. + + seq_id and source tracking are kept only in the raw sheet. - Each row contains the cluster number, read count, the full taxa string, - and the separate taxonomic levels (kingdom through species). - - :param cluster_data_list: List of tuples (eval_list, simi_list, taxa_dict) per cluster. - :type cluster_data_list: list[tuple] - :param output_file: Path to the output Excel file. + :param cluster_data_list: List containing similarity/taxa data per cluster. + :type cluster_data_list: list[dict] + :param args: Parsed arguments containing LCA configuration. + :type args: argparse.Namespace + :param output_file: Path to the combined Excel file. :type output_file: str + :param write_raw: Whether to include the raw taxa sheet. + :type write_raw: bool + :param write_processed: Whether to include the processed/LCA sheet. + :type write_processed: bool + :param log_messages: Optional log collector. + :type log_messages: list[str] or None :return: None :rtype: None """ - # Create main dataframe - data = [] - for cluster_num, (_, _, taxa_dict) in enumerate(cluster_data_list): - for taxa, count in taxa_dict.items(): - if count > 0: - # Split taxa into taxonomic levels - taxa_levels = taxa.split(' / ') if taxa else [] - taxa_levels += ['Unannotated read'] * (7 - len(taxa_levels)) - try: - data.append({ - 'cluster': cluster_num, - 'count': count, - 'taxa_full': taxa, - 'kingdom': taxa_levels[0], - 'phylum': taxa_levels[1], - 'class': taxa_levels[2], - 'order': taxa_levels[3], - 'family': taxa_levels[4], - 'genus': taxa_levels[5], - 'species': taxa_levels[6] - }) - except IndexError as e: - # Skip entries with incomplete taxonomic data - print(f"Skipped entry in cluster {cluster_num}: incomplete taxonomic data for '{taxa}, error: {e}'") + raw_rows = [] + processed_rows = [] + + if write_raw: + for cluster_index, cluster_data in enumerate(cluster_data_list): + taxa_map = cluster_data["taxa_map"] + for taxa, info in taxa_map.items(): + count = info["count"] + if count <= 0: continue - df = pd.DataFrame(data) - # Write to Excel - temp_path = output_file + ".xlsx" - os.makedirs(os.path.dirname(temp_path), exist_ok=True) - with pd.ExcelWriter(temp_path, engine='openpyxl') as writer: - df.to_excel(writer, sheet_name='Raw_Taxa_Clusters', index=False, engine='openpyxl') - os.replace(temp_path, output_file) - -def write_taxa_processed_output(cluster_data_list, args, output_file): - """ - Write processed (potentially split) taxa information to an Excel file. + taxa_levels = taxa.split(" / ") if taxa else [] + while len(taxa_levels) < 7: + taxa_levels.append("Unannotated") - This file contains the resulting sub-clusters from the taxonomic dominance - analysis and a separate sheet documenting the parameters used. - - :param cluster_data_list: List of tuples (eval_list, simi_list, taxa_dict) per cluster. - :type cluster_data_list: list[tuple] - :param args: Parsed script arguments, used for taxa calculation and settings documentation. - :type args: argparse.Namespace - :param output_file: Path to the output Excel file. - :type output_file: str - :return: None - :rtype: None - """ - # Create main dataframe - data = [] - for cluster_num, (_, _, taxa_dict) in enumerate(cluster_data_list): - processed_taxa = calculate_cluster_taxa(taxa_dict, args) + seq_ids_str = ",".join(sorted(info["seq_ids"])) + sources_str = ",".join(sorted(info["sources"])) - for taxa_group in processed_taxa: - for taxa, count in taxa_group.items(): - if 'Uncertain taxa / Uncertain taxa / Uncertain taxa' in taxa: - if args.show_unannotated_clusters: - data.append({ - 'cluster': cluster_num, - 'count': count, - 'taxa_full': 'Unannotated read', - 'kingdom': 'Unannotated', - 'phylum': 'Unannotated', - 'class': 'Unannotated', - 'order': 'Unannotated', - 'family': 'Unannotated', - 'genus': 'Unannotated', - 'species': 'Unannotated' - }) - else: - # Split taxa into taxonomic levels - taxa_levels = taxa.split(' / ') if taxa else [] + raw_rows.append( + { + "cluster": cluster_index, + "count": count, + "seq_id": seq_ids_str, + "source": sources_str, + "taxa_full": taxa, + "kingdom": taxa_levels[0], + "phylum": taxa_levels[1], + "class": taxa_levels[2], + "order": taxa_levels[3], + "family": taxa_levels[4], + "genus": taxa_levels[5], + "species": taxa_levels[6], + } + ) - try: - data.append({ - 'cluster': cluster_num, - 'count': count, - 'taxa_full': taxa, - 'kingdom': taxa_levels[0], - 'phylum': taxa_levels[1], - 'class': taxa_levels[2], - 'order': taxa_levels[3], - 'family': taxa_levels[4], - 'genus': taxa_levels[5], - 'species': taxa_levels[6] - }) - except IndexError: - # Skip entries with incomplete taxonomic data - print(f"Skipped entry in cluster {cluster_num}: incomplete taxonomic data for '{taxa}'") + if write_processed: + for cluster_index, cluster_data in enumerate(cluster_data_list): + taxa_map = cluster_data["taxa_map"] + + annotated_support = sum( + info["count"] + for taxa, info in taxa_map.items() + if taxa != "Unannotated read" + ) + + if annotated_support < args.min_cluster_support: + continue + + taxa_counts = {taxa: info["count"] for taxa, info in taxa_map.items()} + processed_groups = calculate_cluster_taxa(taxa_counts, args) + + for group in processed_groups: + for taxa, count in group.items(): + if count <= 0: continue - df = pd.DataFrame(data) + if "Uncertain taxa / Uncertain taxa / Uncertain taxa" in taxa: + continue + else: + info = taxa_map.get(taxa) + seq_ids_set = info["seq_ids"] + sources_set = info["sources"] + taxa_full = taxa + taxa_levels = taxa.split(" / ") if taxa else [] + while len(taxa_levels) < 7: + taxa_levels.append("Unannotated") - # Create settings dataframe + seq_ids_str = ",".join(sorted(seq_ids_set)) + sources_str = ",".join(sorted(sources_set)) + processed_rows.append( + { + "cluster": cluster_index, + "count": count, + "seq_id": seq_ids_str, + "source": sources_str, + "taxa_full": taxa_full, + "kingdom": taxa_levels[0], + "phylum": taxa_levels[1], + "class": taxa_levels[2], + "order": taxa_levels[3], + "family": taxa_levels[4], + "genus": taxa_levels[5], + "species": taxa_levels[6], + } + ) + + if not raw_rows and not processed_rows: + if log_messages is not None: + log_message( + log_messages, + "No taxa data to write; taxa Excel file not created.", + ) + return + + raw_df = pd.DataFrame(raw_rows) if raw_rows else None + processed_df = pd.DataFrame(processed_rows) if processed_rows else None + settings_data = [ - ['uncertain_taxa_use_ratio', args.uncertain_taxa_use_ratio], - ['min_to_split', args.min_to_split], - ['min_count_to_split', args.min_count_to_split] + ["uncertain_taxa_use_ratio", args.uncertain_taxa_use_ratio], + ["min_to_split", args.min_to_split], + ["min_count_to_split", args.min_count_to_split], + ["min_cluster_support", args.min_cluster_support] ] - settings_df = pd.DataFrame(settings_data, columns=['Parameter', 'Value']) + settings_df = pd.DataFrame(settings_data, columns=["Parameter", "Value"]) - # Write to Excel with multiple sheets - temp_path = output_file + ".xlsx" + temp_path = output_file + ".tmp.xlsx" os.makedirs(os.path.dirname(temp_path), exist_ok=True) - with pd.ExcelWriter(temp_path, engine='openpyxl') as writer: - df.to_excel(writer, sheet_name='Processed_Taxa_Clusters', index=False, engine='openpyxl') - settings_df.to_excel(writer, sheet_name='Settings', index=False, engine='openpyxl') + with pd.ExcelWriter(temp_path, engine="openpyxl") as writer: + if raw_df is not None: + raw_df.to_excel( + writer, sheet_name="Raw_Taxa_Clusters", index=False + ) + if processed_df is not None: + processed_df.to_excel( + writer, sheet_name="Processed_Taxa_Clusters", index=False + ) + settings_df.to_excel(writer, sheet_name="Settings", index=False) - # Auto-adjust column widths for better readability for sheet_name in writer.sheets: worksheet = writer.sheets[sheet_name] for column in worksheet.columns: @@ -552,38 +610,52 @@ column_letter = column[0].column_letter for cell in column: try: - if len(str(cell.value)) > max_length: - max_length = len(str(cell.value)) - except: + cell_len = len(str(cell.value)) + if cell_len > max_length: + max_length = cell_len + except AttributeError: pass - adjusted_width = min(max_length + 2, 50) # Cap at 50 characters + adjusted_width = min(max_length + 2, 50) worksheet.column_dimensions[column_letter].width = adjusted_width + os.replace(temp_path, output_file) -def create_similarity_plot(all_simi_data, cluster_simi_lengths, args, output_file): - """ - Create a bar plot showing the distribution of intra-cluster similarity values. + if log_messages is not None: + if raw_df is not None: + log_message(log_messages, "Raw taxa per cluster written succesfully") + if processed_df is not None: + log_message( + log_messages, + "Processed taxa (split clusters) written succesfully", + ) - The plot uses different colors to distinguish reads belonging to different clusters. + +def create_similarity_plot(all_simi_data, cluster_simi_lengths, args, output_file, log_messages=None): + """Create and save a similarity distribution bar plot. - :param all_simi_data: List of all similarity values, sorted descending. + Bars are colored per cluster using a fixed repeating colormap. + + :param all_simi_data: All similarity values across clusters. :type all_simi_data: list[float] - :param cluster_simi_lengths: List of lengths of similarity data per cluster, used for coloring. + :param cluster_simi_lengths: Per-cluster similarity list lengths. :type cluster_simi_lengths: list[int] - :param args: Parsed script arguments, used for plot y-limits. + :param args: Namespace containing y-axis plot limits. :type args: argparse.Namespace - :param output_file: Path to the output plot file (e.g., .png). + :param output_file: Output PNG file path. :type output_file: str + :param log_messages: Optional log message list. + :type log_messages: list[str] or None :return: None :rtype: None """ if not all_simi_data: + if log_messages is not None: + log_message(log_messages, "No similarity data; similarity plot not created.") return sorted_simi_list = sorted(all_simi_data, reverse=True) bar_positions = list(range(len(sorted_simi_list))) - # Create colormap for different clusters colormap_full = [] for i, length in enumerate(cluster_simi_lengths): color = COLORMAP[i % len(COLORMAP)] @@ -591,191 +663,196 @@ plt.figure(figsize=(12, 6)) plt.bar(bar_positions, sorted_simi_list, width=1, color=colormap_full) - plt.grid(axis='y', linestyle='--', color='gray', alpha=0.7) + plt.grid(axis="y", linestyle="--", color="gray", alpha=0.7) plt.ylabel("Similarity (%)") plt.xlabel("Reads") plt.title("Intra-cluster Similarity Distribution") plt.ylim(ymin=args.simi_plot_y_min, ymax=args.simi_plot_y_max) plt.tight_layout() - plt.savefig(output_file, format='png', dpi=300, bbox_inches='tight') + plt.savefig(output_file, format="png", dpi=300, bbox_inches="tight") plt.close() + if log_messages is not None: + log_message(log_messages, "Similarity plot written succesfully") + -def create_evalue_plot(all_eval_data, cluster_eval_lengths, output_file): - """ - Create a bar plot showing the distribution of E-values. +def summarize_and_log(cluster_data_list, all_simi_data, log_messages, args): + """Compute global summary statistics and append them to the log. - The y-axis is log-scaled and displays ``1/E-values``. Reads are ordered - by E-value (ascending). The plot uses different colors to distinguish reads - belonging to different clusters. + Summary includes: + - total clusters + - total annotated/unannotated reads + - per-cluster annotation presence + - top taxa distribution + - similarity mean and standard deviation - :param all_eval_data: List of E-values from all reads. Assumes E-values start at index 1. - :type all_eval_data: list[float | int] - :param cluster_eval_lengths: List of lengths of annotated E-value data per cluster, used for coloring. - :type cluster_eval_lengths: list[int] - :param output_file: Path to the output plot file (e.g., .png). - :type output_file: str + :param cluster_data_list: List of processed cluster descriptors. + :type cluster_data_list: list[dict] + :param all_simi_data: All similarity values from all clusters. + :type all_simi_data: list[float] + :param log_messages: List collecting log output lines. + :type log_messages: list[str] + :param args: Argument namespace with configuration parameters. + :type args: argparse.Namespace :return: None :rtype: None """ - if len(all_eval_data) <= 1: # Only unannotated reads - return - - sorted_eval_list = sorted(all_eval_data[1:]) # Skip unannotated count - - if not sorted_eval_list: - return - - bar_positions = list(range(len(sorted_eval_list))) - bar_heights = [1 / e if e > 0 else 0 for e in sorted_eval_list] - - # Create colormap for different clusters - colormap_full = [] - for i, length in enumerate(cluster_eval_lengths): - color = COLORMAP[i % len(COLORMAP)] - colormap_full.extend([color] * length) + total_clusters = len(cluster_data_list) + total_annotated = sum(c["annotated"] for c in cluster_data_list) + total_unannotated = sum(c["unannotated"] for c in cluster_data_list) + grand_total = total_annotated + total_unannotated - plt.figure(figsize=(12, 6)) - plt.bar(bar_positions, bar_heights, width=1, color=colormap_full) - plt.yscale('log') - plt.grid(axis='y', linestyle='--', color='gray', alpha=0.7) - plt.ylabel("1/E-values") - plt.xlabel("Reads") - plt.title("E-value Distribution") - plt.tight_layout() + clusters_with_annotations = sum( + 1 for c in cluster_data_list if c["annotated"] > 0 + ) + clusters_unannotated_only = total_clusters - clusters_with_annotations - plt.savefig(output_file, format='png', dpi=300, bbox_inches='tight') - plt.close() - -def prepare_evalue_histogram(evalue_list, unannotated_list): - """ - Generate histogram data for E-value distributions. - - This function processes a list of E-values from BLAST or similar search - results, filters out invalid or zero entries, and computes histogram data - suitable for plotting. The histogram represents the frequency distribution - of E-values across all annotated hits. + log_message(log_messages, "=== SUMMARY ===") + log_message(log_messages, f"Clusters parsed: {total_clusters}") + log_message(log_messages, f"Total reads: {grand_total}") + log_message(log_messages, f"Annotated reads: {total_annotated}") + log_message(log_messages, f"Unannotated reads: {total_unannotated}") + log_message(log_messages, f"Clusters with annotations: {clusters_with_annotations}") + log_message(log_messages, f"Clusters fully unannotated: {clusters_unannotated_only}") + log_message(log_messages, f"Minimum cluster support for processed taxa: {args.min_cluster_support}") - :param evalue_list: List of E-values from BLAST hits - :type evalue_list: list[float | int] - :param unannotated_list: List of unannotated E-values - :type unannotated_list: list - :return: Tuple containing: - - **counts** (*numpy.ndarray*): Number of entries per histogram bin. - - **bins** (*numpy.ndarray*): Bin edges corresponding to the histogram. - Returns ``(None, None)`` if no valid data is available. - :rtype: tuple[numpy.ndarray, numpy.ndarray] | tuple[None, None] - :note: - - Only positive numeric E-values are included in the histogram. - - Uses 50 bins in the range (0, 1) for visualization consistency. - """ - data = [ev for ev in evalue_list if isinstance(ev, (int, float)) and ev > 0] - if not data: - return None, None - - counts, bins, _ = plt.hist(data, bins=50, range=(0, 1)) - plt.close() - return counts, bins - -def create_evalue_plot_test(evalue_list, unannotated_list, output_file): - """ - Create and save an E-value distribution plot, returning the computed histogram data. + taxa_counter = Counter() + for c in cluster_data_list: + for taxa, info in c["taxa_map"].items(): + if taxa == "Unannotated read": + continue + taxa_counter[taxa] += info["count"] - This function visualizes the frequency distribution of E-values from BLAST or - annotation results. It saves the plot to the specified output file and returns - the histogram data (counts and bins) for testing with pytests. - - :param evalue_list: List of numeric E-values to plot - :type evalue_list: list[float | int] - :param unannotated_list: Optional list of E-values for unannotated sequences. - :type unannotated_list: list - :param output_file: Path where the histogram image will be saved. - :type output_file: str + if taxa_counter: + log_message(log_messages, "=== TAXA SUMMARY (top 10) ===") + for taxa, count in taxa_counter.most_common(10): + log_message(log_messages, f"{taxa}: {count} reads") + log_message(log_messages, f"Total unique taxa: {len(taxa_counter)}") + else: + log_message(log_messages, "No annotated taxa found.") - :return: Tuple containing: - - **counts** (*numpy.ndarray*): Frequency counts per histogram bin. - - **bins** (*numpy.ndarray*): Histogram bin edges. - Returns ``(None, None)`` if no valid data was available for plotting. - :rtype: tuple[numpy.ndarray, numpy.ndarray] | tuple[None, None] - """ - counts, bins = prepare_evalue_histogram(evalue_list, unannotated_list) - if counts is None: - return None, None + if all_simi_data: + avg = sum(all_simi_data) / len(all_simi_data) + variance = sum((s - avg) ** 2 for s in all_simi_data) / len(all_simi_data) + st_dev = sqrt(variance) + log_message(log_messages, "=== SIMILARITY SUMMARY ===") + log_message(log_messages, f"Mean similarity: {avg:.2f}") + log_message(log_messages, f"Std similarity: {st_dev:.2f}") + else: + log_message(log_messages, "No similarity values available for summary.") - plt.hist([ev for ev in evalue_list if isinstance(ev, (int, float)) and ev > 0], - bins=50, range=(0, 1)) - plt.xlabel("E-value") - plt.ylabel("Frequency") - plt.title("E-value Distribution") - plt.savefig(output_file) - plt.close() - return counts, bins def main(arg_list=None): - """ - Main entry point of the script. + """Main entry point for CD-HIT cluster processing. - Parses arguments, processes cd-hit cluster data, aggregates results, - and generates requested outputs (text summaries, plots, and Excel reports). + Orchestrates parsing, cluster processing, statistic aggregation, + visualization, and generation of all requested outputs. - :param arg_list: List of arguments for testing purposes. - :type arg_list: list, optional + :param arg_list: Optional list of arguments (used by tests). + :type arg_list: list[str] or None :return: None :rtype: None """ args = parse_arguments(arg_list) - # Parse cluster file + log_messages: list[str] = [] + log_message(log_messages, "=== CD-HIT cluster processing started ===") + + log_message(log_messages, "cluster_file: provided") + log_message( + log_messages, + "annotation_file: provided" if args.input_annotation else "annotation_file: none", + ) + + log_message(log_messages, "=== PARAMETERS ===") + skip_keys = { + "output_similarity_txt", + "output_similarity_plot", + "output_count", + "output_excel", + "input_cluster", + "input_annotation", + "log_file", + } + for key, value in vars(args).items(): + if key in skip_keys: + continue + log_message(log_messages, f"{key}: {value}") + log_message(log_messages, "=== PARAMETERS END===") + clusters = parse_cluster_file( args.input_cluster, args.input_annotation, - args.print_empty_files + raise_on_error=False, + log_messages=log_messages, ) - # Process each cluster - all_eval_data = [0] # For full sample statistics - all_simi_data = [] - cluster_eval_lengths = [] - cluster_simi_lengths = [] + cluster_data_list = [] + all_simi_data: list[float] = [] + cluster_simi_lengths: list[int] = [] for cluster in clusters: - eval_list, simi_list, taxa_dict = process_cluster_data(cluster) - cluster_data_list.append((eval_list, simi_list, taxa_dict)) - # Collect data for full sample plots - all_eval_data[0] += eval_list[0] - if len(eval_list) > 1: - all_eval_data.extend(sorted(eval_list[1:])) - cluster_eval_lengths.append(len(eval_list[1:])) + simi_list, taxa_map, annotated_count, unannotated_count = process_cluster_data( + cluster + ) + cluster_data = { + "similarities": simi_list, + "taxa_map": taxa_map, + "annotated": annotated_count, + "unannotated": unannotated_count, + } + cluster_data_list.append(cluster_data) if simi_list: - all_simi_data.extend(sorted(simi_list, reverse=True)) + all_simi_data.extend(simi_list) cluster_simi_lengths.append(len(simi_list)) - # Generate outputs based on what was requested if args.output_similarity_txt: - write_similarity_output(all_simi_data, args.output_similarity_txt) + write_similarity_output( + cluster_data_list, args.output_similarity_txt, log_messages) - if args.output_similarity_plot and all_simi_data: - create_similarity_plot(all_simi_data, cluster_simi_lengths, args, args.output_similarity_plot) - - if args.output_evalue_txt: - write_evalue_output(all_eval_data, args.output_evalue_txt) - - if args.output_evalue_plot and len(all_eval_data) > 1: - create_evalue_plot(all_eval_data, cluster_eval_lengths, args.output_evalue_plot) + if args.output_similarity_plot: + create_similarity_plot( + all_simi_data, + cluster_simi_lengths, + args, + args.output_similarity_plot, + log_messages) if args.output_count: - write_count_output(all_eval_data, cluster_data_list, args.output_count) + write_count_output(cluster_data_list, args.output_count, log_messages) - if args.output_taxa_clusters: - write_taxa_clusters_output(cluster_data_list, args.output_taxa_clusters) + if args.output_excel: + write_raw = bool(args.output_taxa_clusters) + write_processed = bool(args.output_taxa_processed) - if args.output_taxa_processed: - write_taxa_processed_output(cluster_data_list, args, args.output_taxa_processed) + if not write_raw and not write_processed: + if log_messages is not None: + log_message(log_messages, "output_excel provided but no taxa sheet flags set; no Excel file written.") + else: + write_taxa_excel(cluster_data_list, + args, + args.output_excel, + write_raw=write_raw, + write_processed=write_processed, + log_messages=log_messages) + else: + if args.output_taxa_clusters or args.output_taxa_processed: + if log_messages is not None: + log_message( + log_messages, + "WARNING: Raw/processed taxa output flags set but no --output_excel path provided; skipping Excel output." + ) - print(f"Processing complete. Processed {len(clusters)} clusters.") + summarize_and_log(cluster_data_list, all_simi_data, log_messages, args) + log_message(log_messages, "=== CD-HIT cluster processing finished ===") + + if args.log_file: + os.makedirs(os.path.dirname(args.log_file), exist_ok=True) + with open(args.log_file, "w") as f: + f.write("\n".join(log_messages)) if __name__ == "__main__": - main() \ No newline at end of file + main()
