Mercurial > repos > onnodg > blast_annotations_processor
diff blast_annotations_processor.py @ 0:a3989edf0a4a draft
planemo upload for repository https://github.com/Onnodg/Naturalis_NLOOR/tree/main/NLOOR_scripts/process_annotations_tool commit c944fd5685f295acba06679e85b67973c173b137
| author | onnodg |
|---|---|
| date | Tue, 14 Oct 2025 09:08:30 +0000 |
| parents | |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/blast_annotations_processor.py Tue Oct 14 09:08:30 2025 +0000 @@ -0,0 +1,913 @@ +"""Galaxy-compatible BLAST annotation processor. + +This script processes a single annotated BLAST file along with a FASTA file +containing unannotated reads. It generates multiple types of outputs for +integration with Galaxy workflows: + +- E-value distribution plots (PNG) +- Taxonomic composition reports (text) +- Circular taxonomy diagram data (JSON) +- Header annotations with merged and per-read information (Excel) +- Annotation statistics summary (text) + +Main workflow: +1. Parse command-line arguments. +2. Load annotated BLAST results and unannotated FASTA headers. +3. Group BLAST hits per query (q_id), filter by thresholds. +4. Resolve taxonomic conflicts with uncertainty rules. +5. Generate requested outputs. + +Notes: +- Headers in BLAST and FASTA should correspond. +- Uses matplotlib, pandas, and openpyxl for visualization and reporting. +""" + +import argparse +from collections import defaultdict +import json +import os +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +import openpyxl + +# Default taxonomic levels +TAXONOMIC_LEVELS = ["K", "P", "C", "O", "F", "G", "S"] + +def parse_arguments(arg_list=None): + """Parse command line arguments for cluster processing.""" + parser = argparse.ArgumentParser( + description='Process BLAST annotation results for Galaxy' + ) + # Required inputs + parser.add_argument('--input-anno', required=True, + help='Annotated BLAST output file (tabular format)') + parser.add_argument('--input-unanno', required=True, + help='Unannotated sequences file (FASTA format)') + # Optional outputs + parser.add_argument('--eval-plot', + help='Output path for E-value plot (PNG)') + parser.add_argument('--taxa-output', + help='Output path for taxa report (tabular)') + parser.add_argument('--circle-data', + help='Output path for circular taxonomy data (txt)') + parser.add_argument('--header-anno', + help='Output path for header annotations (tabular/xlsx)') + parser.add_argument('--anno-stats', + help='Output path for annotation statistics (txt)') + # Parameters + parser.add_argument('--uncertain-threshold', type=float, default=0.9, + help='Threshold for resolving taxonomic conflicts (default: 0.9)') + parser.add_argument('--eval-threshold', type=float, default=1e-10, + help='E-value threshold for filtering results (default: 1e-10)') + parser.add_argument('--use-counts', action='store_true', default=True, + help='Use read counts in circular data') + + return parser.parse_args(arg_list) + + +def list_to_string(x): + """ + Convert a list, pandas Series, or numpy array to a comma-separated string.""" + if isinstance(x, (list, pd.Series, np.ndarray)): + return ", ".join(map(str, x)) + return str(x) + +def make_eval_plot(e_val_sets, output_path): + """ + Create a bar plot of E-values per read. + + :param e_val_sets: List of sets containing E-values per read. + :type e_val_sets: list[set[str]] + :param output_path: Path to save the plot (PNG). + :type output_path: str + :return: None + :rtype: None + """ + if not e_val_sets: + print("No E-values to plot") + return None + + # Convert string E-values to floats and sort + processed_sets = sorted([sorted(float(e_val) for e_val in e_val_set) + for e_val_set in e_val_sets]) + + bar_positions = [] + bar_heights = [] + + for i, e_set in enumerate(processed_sets): + if len(e_set) == 1: + e_set.append(0) + for j, e_val in enumerate(e_set): + bar_positions.append(i + j * 0.29) + if e_val != 0: + bar_heights.append(1 / e_val) + else: + bar_heights.append(e_val) + + plt.figure(figsize=(21, 9)) + plt.bar(bar_positions, bar_heights, width=0.29, color=['blue', 'red']) + + plt.yscale('log') + plt.grid(axis='y', linestyle='--', color='gray', alpha=0.7) + plt.ylabel("e-values") + plt.xticks(ticks=range(len(processed_sets)), + labels=[f'{i + 1}' for i in range(len(processed_sets))]) + + plt.tight_layout() + plt.savefig(output_path, format='png') + plt.close() + return None + + +def calculate_annotation_stats(anno_count, unanno_file_path, unique_anno_count, total_unique_count): + """ + Compute annotation statistics for a dataset. + + This function calculates summary statistics for annotated and unannotated + sequences in a dataset. It counts the total number of sequences in the + unannotated file (FASTA-style, based on lines starting with '>'), and + computes the percentage of annotated sequences and unique annotated + sequences relative to their totals. + + :param anno_count: Total number of annotated sequences. + :type anno_count: int + :param unanno_file_path: Path to a FASTA file containing unannotated sequences. + :type unanno_file_path: str + :param unique_anno_count: Number of unique annotated sequences. + :type unique_anno_count: int + :param total_unique_count: Total number of unique sequences in the dataset. + :type total_unique_count: int + :return: Dictionary containing: + - 'percentage_annotated' (float): Percentage of annotated sequences. + - 'annotated_sequences' (int): Number of annotated sequences. + - 'total_sequences' (int): Total number of sequences (annotated + unannotated). + - 'percentage_unique_annotated' (float): Percentage of unique annotated sequences. + - 'unique_annotated' (int): Number of unique annotated sequences. + - 'total_unique' (int): Total number of unique sequences. + :rtype: dict[str, float | int] + """ + # Count total sequences in unannotated file + total_sequences = 0 + with open(unanno_file_path, 'r') as f: + total_sequences = sum(1 for line in f if line.startswith('>')) + + percentage_annotated = (anno_count / total_sequences * 100) if total_sequences > 0 else 0 + percentage_unique_annotated = (unique_anno_count / total_unique_count * 100) if total_unique_count > 0 else 0 + + return { + 'percentage_annotated': percentage_annotated, + 'annotated_sequences': anno_count, + 'total_sequences': total_sequences, + 'percentage_unique_annotated': percentage_unique_annotated, + 'unique_annotated': unique_anno_count, + 'total_unique': total_unique_count + } + +def _choose_taxon(level_counts: dict, threshold: float): + """ + Determine the most representative taxonomic level based on counts and a threshold. + + This function evaluates a dictionary of taxonomic level counts and returns the + level name that holds a strict majority. If no level reaches the threshold, or if + there is a tie for the highest count, the function returns ``None``. + + :param level_counts: A dictionary mapping taxonomic level names (e.g., "species", + "genus") to their respective counts. + :type level_counts: dict[str, int] + :param threshold: The minimum fraction (between 0 and 1) of the total counts that + the top taxonomic level must have to be considered dominant. + :type threshold: float + + :return: The name of the taxonomic level that meets or exceeds the threshold and + is not tied with another level. Returns ``None`` if no level qualifies or + if there is a tie. + :rtype: str | None + + :example: + _choose_taxon({"species": 6, "genus": 3, "family": 1}, 0.5) + 'species' + _choose_taxon({"species": 4, "genus": 4}, 0.6) + None + _choose_taxon({"species": 2, "genus": 3}, 0.6) + 'genus' + """ + total = sum(level_counts.values()) + if total == 0: + return None + items = sorted(level_counts.items(), key=lambda x: x[1], reverse=True) + top_name, top_count = items[0] + # tie -> no clear majority + if sum(1 for _, c in items if c == top_count) > 1: + return None + if top_count / total >= threshold: + return top_name + return None + + +def _resolve_conflicts_recursive(conflicting: list, level_idx: int, threshold: float, max_levels: int, prefix=None): + """ + Recursively resolve taxonomic classification conflicts across multiple levels. + + This function attempts to resolve conflicts between multiple taxonomic paths by + recursively evaluating each taxonomic level. At each level, it determines whether + there is a clear majority (based on a specified threshold) for a taxon name. If + a majority is found, the function continues to the next level using only the + taxa that match that name. If no majority exists, the function constructs + "uncertain" outputs to indicate ambiguous classification. + + :param conflicting: List of tuples containing full taxonomic paths and their + associated counts, e.g. ``[("Bacteria / Proteobacteria / Gammaproteobacteria", 10), + ("Bacteria / Proteobacteria / Alphaproteobacteria", 5)]``. + :type conflicting: list[tuple[str, int]] + :param level_idx: The current taxonomic level index being evaluated (0-based). + :type level_idx: int + :param threshold: The minimum fraction (between 0 and 1) required for a strict + majority decision at a given taxonomic level. + :type threshold: float + :param max_levels: The maximum number of taxonomic levels to evaluate before + stopping recursion. + :type max_levels: int + :param prefix: A list of already-resolved taxon names up to the previous level. + Defaults to ``None``, in which case an empty list is used. + :type prefix: list[str] | None + + :return: A tuple containing: + - **short_output** (str): A simplified taxonomic path with "Uncertain taxa" + appended if no consensus was reached at the current level. + - **full_output** (str): The full taxonomic path, filled with additional + "Uncertain taxa" entries up to the minimal conflicting level depth. + :rtype: tuple[str, str] + + :example: + conflicts = [ + ... ("Bacteria / Proteobacteria / Gammaproteobacteria", 10), + ... ("Bacteria / Proteobacteria / Alphaproteobacteria", 5) + ... ] + _resolve_conflicts_recursive(conflicts, level_idx=0, threshold=0.6, max_levels=5) + ('Bacteria / Proteobacteria / Uncertain taxa', + 'Bacteria / Proteobacteria / Uncertain taxa / Uncertain taxa') + """ + if prefix is None: + prefix = [] + + if not conflicting: + return "", "" + + # split paths (keeps sync) + conflicting_levels = [t.split(" / ") for t, _ in conflicting] + min_conflicting_level = min(len(p) for p in conflicting_levels) + + # If only one full path remains -> return it (both short & full) + if len(conflicting) == 1: + return conflicting[0][0], conflicting[0][0] + + # If we've reached deepest comparable level or max_levels -> pick most supported full path + if level_idx >= min_conflicting_level or level_idx >= max_levels: + full_counts = defaultdict(int) + for t, c in conflicting: + full_counts[t] += c + chosen_full = max(full_counts.items(), key=lambda x: x[1])[0] + return chosen_full, chosen_full + + # Count names at this level + level_counts = defaultdict(int) + for taxon, count in conflicting: + parts = taxon.split(" / ") + if level_idx < len(parts): + level_counts[parts[level_idx]] += count + + # If everyone agrees at this level, append that name and go deeper + if len(level_counts) == 1: + name = next(iter(level_counts)) + return _resolve_conflicts_recursive(conflicting, level_idx + 1, threshold, max_levels, prefix + [name]) + + # Check for a strict majority at this level + chosen_level_name = _choose_taxon(level_counts, threshold) + if chosen_level_name is not None: + # Keep only taxa that match the chosen name at this level and go deeper + filtered = [(t, c) for (t, c) in conflicting + if level_idx < len(t.split(" / ")) and t.split(" / ")[level_idx] == chosen_level_name] + return _resolve_conflicts_recursive(filtered, level_idx + 1, threshold, max_levels, prefix + [chosen_level_name]) + + # No majority at this level -> construct uncertain outputs. + # short: prefix + 'Uncertain taxa' (stop here) + short_path = prefix + ['Uncertain taxa'] + short_output = " / ".join(short_path) + + # full: prefix + 'Uncertain taxa' + fill until min_conflicting_level + full_path = short_path + ['Uncertain taxa'] * (min_conflicting_level - (level_idx + 1)) + full_output = " / ".join(full_path) + + return short_output, full_output + + +def resolve_taxon_majority(taxon_counts: dict, threshold: float = 0.90, max_levels: int = 7): + """ + Resolve the majority consensus taxonomic classification from a set of annotated paths. + + :param taxon_counts: A dictionary mapping full taxonomic paths to their respective + occurrence counts. For example: + ``{"Bacteria / Proteobacteria / Gammaproteobacteria": 10, + "Bacteria / Proteobacteria / Alphaproteobacteria": 5}``. + :type taxon_counts: dict[str, int] + :param threshold: Minimum fraction (between 0 and 1) required for a strict majority + at each taxonomic level. Defaults to ``0.90``. + :type threshold: float, optional + :param max_levels: Maximum number of taxonomic levels to evaluate before stopping + recursion. Defaults to ``7``. + :type max_levels: int, optional + + :return: A tuple containing: + - **short_output** (str): The consensus taxonomic path, possibly ending with + "Uncertain taxa" if ambiguity remains. + - **full_output** (str): The expanded taxonomic path including placeholder + "Uncertain taxa" labels up to the minimal conflicting depth. + :rtype: tuple[str, str] + """ + conflicting = list(taxon_counts.items()) # list of (path, count) keeps taxa<->count synced + return _resolve_conflicts_recursive(conflicting, level_idx=0, threshold=threshold, max_levels=max_levels) + + +def process_taxa_output(taxa_dicts, output_path, uncertain_threshold): + """ + Generate a tabular taxa report from a list of read taxonomies. + + This function processes taxonomic assignments for each read in a dataset + and produces a summary report similar in style to Kraken2 output. The + report includes: + + - The number of reads that were marked as "Uncertain taxa" at each + taxonomic rank. + - For all other taxa, the number of reads assigned and their percentage + relative to the sample total. + - The taxonomic rank for each assignment, indicated in the report with + indentation proportional to the rank depth. + + :param taxa_dicts: List of taxon count dictionaries, one per read, where + keys are taxon names and values are counts of how often + they were assigned to the read. + :type taxa_dicts: list[dict[str, int]] + :param output_path: Path to save the tabular taxa report. + :type output_path: str + :param uncertain_threshold: Threshold for resolving conflicting taxonomic + assignments (values below the threshold are + marked as "Uncertain taxa"). + :type uncertain_threshold: float + :return: None. The function writes the taxa report to ``output_path``. + :rtype: None + """ + uncertain_dict = {level: 0 for level in TAXONOMIC_LEVELS} + aggregated_counts = defaultdict(int) + + for read_taxa in taxa_dicts: + resolved_taxon, _ = resolve_taxon_majority(read_taxa, uncertain_threshold) + + # Add counts for resolved taxon + levels = resolved_taxon.split(" / ") + for i in range(1, len(levels) + 1): + aggregated_counts[" / ".join(levels[:i])] += 1 + + # Calculate total count + total_count = sum(value for key, value in aggregated_counts.items() + if len(key.split(" / ")) == 1) + + report_lines = [] + + for taxonomy, count in sorted(aggregated_counts.items(), key=lambda x: x[0]): + levels = taxonomy.split(" / ") + indent = " " * (len(levels) - 1) * 2 + taxon_name = levels[-1] + taxon_level_code = TAXONOMIC_LEVELS[len(levels) - 1] if len(levels) <= len(TAXONOMIC_LEVELS) else "U" + percentage = (count / total_count) * 100 if total_count > 0 else 0 + + report_lines.append( + f"{percentage:.2f}\t{count}\t{total_count}\t{taxon_level_code}\t{indent}{taxon_name}" + ) + + if taxon_name == 'Uncertain taxa': + uncertain_dict[taxon_level_code] += count + + # Write output + with open(output_path, 'w') as f: + f.write("Uncertain count per taxonomie level" + str(uncertain_dict) + '\n') + f.write('percentage_rooted\tnumber_rooted\ttotal_num\ttaxon_level\tindentificatie\n') + f.write("\n".join(report_lines)) + + +def process_header_annotations(taxa_dicts, headers, e_values, output_path, uncertain_threshold, + identity_list, coverage_list, source_list, bitscore_list): + """ + Generate an Excel report with per-read and aggregated taxonomic annotations. + + This function processes taxonomic assignments per read, along with associated + metrics (e-value, identity, coverage, bitscore, source), and generates an Excel + file with two sheets: + + 1. **Individual_Reads** – One row per sequence, including all metrics and + the resolved taxonomic path. + 2. **Merged_by_Taxa** – Aggregated results per unique taxonomic path, including + merged metrics and counts. + + Internally, the function creates a temporary TSV file, processes it into + a pandas DataFrame, expands the taxonomic path into separate columns + (e.g., kingdom, phylum, class, …), and then writes the results into a + multi-sheet Excel file. Column widths are automatically adjusted for readability. + + :param taxa_dicts: List of taxon count dictionaries, one per read, where keys + are taxa and values are their counts in the read. + :type taxa_dicts: list[dict[str, int]] + :param headers: Sequence headers corresponding to each read. Used to extract + read identifiers and counts. + :type headers: list[str] + :param e_values: List of e-values per read. Each element is a list of values + associated with that read. + :type e_values: list[list[float]] + :param output_path: Path to the Excel file to be created. + :type output_path: str + :param uncertain_threshold: Threshold for resolving conflicting taxonomic + assignments. Values below the threshold are + replaced with "Uncertain taxa". + :type uncertain_threshold: float + :param identity_list: Percent identity values per read. + :type identity_list: list[list[str]] + :param coverage_list: Coverage percentage values per read. + :type coverage_list: list[list[str]] + :param source_list: Source identifiers per read (e.g., database names). + :type source_list: list[list[str]] + :param bitscore_list: Bitscore values per read. + :type bitscore_list: list[list[str]] + :return: None. The function writes results to an Excel file at ``output_path``. + :rtype: None + :raises PermissionError: If the temporary TSV or output Excel file cannot be written + (e.g., file is open in another program). + :raises IndexError: If the input lists (headers, e_values, etc.) are not aligned + with ``taxa_dicts``. + :raises ValueError: If sequence headers are not formatted as expected for extracting counts. + """ + report_lines = [] + for i, read_taxa in enumerate(taxa_dicts): + _, resolved_taxon_long = resolve_taxon_majority(read_taxa, uncertain_threshold) + try: + e_val = e_values[i][0] if e_values[i] else "N/A" + identity = identity_list[i][0] if identity_list[i] else "N/A" + coverage = coverage_list[i][0] if coverage_list[i] else "N/A" + source = source_list[i][0] if source_list[i] else "N/A" + bitscore = bitscore_list[i][0] if bitscore_list[i] else "N/A" + except (IndexError, TypeError) as e: + print(f"Mismatch while extracting values: {e}. Check that your annotated and unannotated input files correspond correctly.") + continue + + header = headers[i] if i < len(headers) else f"Header missing" + + # Extract count + try: + header_base, count_str = header.rsplit("(", 1) + count = int(count_str.rstrip(")")) + except ValueError as e: + print(f'Failed extracting count: {e}') + header_base = header + count = 1 + + report_lines.append( + f'{header_base}\t{e_val}\t{identity}\t{coverage}\t{bitscore}\t{count}\t{source}\t{resolved_taxon_long}') + # Create temporary TSV for processing + temp_tsv_path = 'temp.tsv' + try: + with open(temp_tsv_path, 'w') as f: + f.write('header\te_value\tidentity percentage\tcoverage\tbitscore\tcount\tsource\ttaxa\n') + f.write("\n".join(report_lines)) + except PermissionError as e: + print(f"Unable to write to file, error: {e} file might be opened") + + # Process the data + df = pd.read_csv(temp_tsv_path, sep='\t', encoding="latin1") + if not df.empty: + # Split taxa into separate columns + taxa_split = df["taxa"].str.split(" / ", expand=True) + max_levels = taxa_split.shape[1] if taxa_split.shape[1] is not None else 7 + level_names = ["kingdom", "phylum", "class", "order", "family", "genus", "species"][:max_levels] + taxa_split.columns = level_names + + # Add taxa columns to main dataframe + df_individual = pd.concat([df, taxa_split], axis=1) + df_individual = df_individual.sort_values(['species', 'genus', 'family'], ascending=[True, True, True]) + # Create merged dataframe - first add taxa columns, then aggregate + df_for_merge = df_individual.copy() + + group_cols = ["taxa"] + agg_dict = { + "e_value": lambda x: f"{x.min()}–{x.max()}", + "identity percentage": lambda x: list_to_string(list(x.unique())), + "coverage": lambda x: list_to_string(list(x.unique())), + "bitscore": lambda x: list_to_string(list(x.unique())), + "count": "sum", + "source": "first" + } + + # Add aggregation for taxa levels that actually exist + for level in level_names: + if level in df_for_merge.columns: + agg_dict[level] = "first" + + df_merged = df_for_merge.groupby(group_cols, as_index=False).agg(agg_dict) + + sort_columns = [] + sort_ascending = [] + + # Add family, genus, species if they exist + for level in ['species', 'genus', 'family']: + if level in df_merged.columns: + sort_columns.append(level) + sort_ascending.append(True) + + df_merged = df_merged.sort_values(sort_columns, ascending=sort_ascending) + + # Write both datasets to single Excel file with multiple sheets + temp_path = output_path + ".xlsx" + os.makedirs(os.path.dirname(temp_path), exist_ok=True) + with pd.ExcelWriter(temp_path, engine='openpyxl', mode='w') as writer: + df_individual.to_excel(writer, sheet_name='Individual_Reads', index=False) + df_merged.to_excel(writer, sheet_name='Merged_by_Taxa', index=False) + + # Auto-adjust column widths for both sheets + for sheet_name in writer.sheets: + worksheet = writer.sheets[sheet_name] + for column in worksheet.columns: + max_length = 0 + column_letter = column[0].column_letter + for cell in column: + try: + if len(str(cell.value)) > max_length: + max_length = len(str(cell.value)) + except: + pass + adjusted_width = min(max_length + 2, 50) # Cap at 50 characters + worksheet.column_dimensions[column_letter].width = adjusted_width + os.replace(temp_path, output_path) + # Clean up temporary file + os.remove(temp_tsv_path) + else: + print("Dataframe empty, no annotation results") + + +def create_circle_diagram(taxa_dicts, output_path, use_counts, uncertain_threshold): + """ + Generate hierarchical JSON data for a circular taxonomy diagram. + + This function aggregates taxonomic classification results from multiple reads + and prepares hierarchical count data suitable for circular visualization. + + :param taxa_dicts: A list of dictionaries, where each dictionary represents a + mapping from full taxonomic paths to their corresponding counts for a single + read. Each key should be a taxonomic path in the format + :type taxa_dicts: list[dict[str, int]] + :param output_path: File path to write the generated JSON output containing + circle diagram data. + :type output_path: str + :param use_counts: Whether to aggregate counts by total read frequency + (``True``) or to count only unique taxa once (``False``). + :type use_counts: boolean + :param uncertain_threshold: Fraction (between 0 and 1) representing the minimum + majority threshold used for resolving conflicting taxonomic assignments. + :type uncertain_threshold: float + + :return: None. The resulting JSON file will be written to ``output_path``. The + file contains a list of dictionaries, one per taxonomic level, with: + - **labels** (*list[str]*): Names of taxa at that level. + - **sizes** (*list[int]*): Corresponding aggregated counts. + :rtype: None + + :note: + - Ambiguous classifications are labeled as "Uncertain taxa". + - The number of hierarchical levels is determined by the length of + ``TAXONOMIC_LEVELS``. + """ + aggregated_counts = defaultdict(int) + seen_taxa = set() + + for read_taxa in taxa_dicts: + _, resolved_taxon_long = resolve_taxon_majority(read_taxa, uncertain_threshold) + + levels = resolved_taxon_long.split(" / ") + + if use_counts: + for i in range(1, len(levels) + 1): + aggregated_counts[" / ".join(levels[:i])] += 1 + else: + if resolved_taxon_long not in seen_taxa: + for i in range(1, len(levels) + 1): + aggregated_counts[" / ".join(levels[:i])] += 1 + seen_taxa.add(resolved_taxon_long) + + # Prepare circle data + circle_data = [] + for level in range(len(TAXONOMIC_LEVELS)): + circle_data.append({"labels": [], "sizes": []}) + + for taxonomy, count in sorted(aggregated_counts.items(), key=lambda x: x[0]): + levels = taxonomy.split(" / ") + if len(levels) <= len(TAXONOMIC_LEVELS): + circle_data[len(levels) - 1]["labels"].append(levels[-1].strip()) + circle_data[len(levels) - 1]["sizes"].append(count) + with open(output_path, "w") as f: + json.dump(circle_data, f, indent=2) + + +def process_single_file(anno_file_path, unanno_file_path, args): + """ + Process a single annotated BLAST file and generate all requested analytical outputs. + + This function integrates multiple steps for analyzing annotated BLAST output + files and corresponding unannotated FASTA sequences. It organizes BLAST hits + per query, extracts the best-scoring matches, and produces various types of + output files (plots, tables, diagrams, and statistics) according to user + arguments. + + :param anno_file_path: Path to the annotated BLAST results file. + :type anno_file_path: str + :param unanno_file_path: Path to a FASTA file containing unannotated sequences. + Used to determine sequence order and total unique sequence counts. + :type unanno_file_path: str + :param args: Parsed command-line arguments containing the following attributes: + - **eval_threshold** (*float*): Maximum allowed E-value for valid hits. + - **eval_plot** (*str | None*): Output path for E-value distribution plot. + - **taxa_output** (*str | None*): Output path for taxonomic summary table. + - **header_anno** (*str | None*): Output path for annotated header table. + - **circle_data** (*str | None*): Output path for circular diagram JSON data. + - **anno_stats** (*str | None*): Output path for annotation statistics file. + - **use_counts** (*bool*): Whether to use read counts or unique taxa only. + - **uncertain_threshold** (*float*): Majority threshold for resolving taxonomic ambiguity. + :type args: argparse.Namespace + + :return: None. This function writes multiple output files to the paths + specified in ``args`` + :rtype: None + + :raises FileNotFoundError: If the annotated BLAST or unannotated FASTA file cannot be found. + :raises ValueError: If file content is malformed or missing required columns. + + :notes: + - The function first groups BLAST hits by query ID and filters out invalid or + low-confidence taxa. + - For each query, the best BLAST hit is determined based on bitscore and E-value. + - The total number of unique annotated sequences is inferred from sequence + headers containing a ``count=`` field. + - Progress and warnings are printed to standard output for traceability. + """ + # Core data structures - one per read + read_data = [] # List of dicts, one per read + headers = [] + seen_reads = set() + annotated_unique_count = 0 + total_unique_count = 0 + + # Current read processing + current_q_id = '' + current_read_info = None + + # Check if file exists and has content + if not os.path.exists(anno_file_path): + print(f"Error: Input file {anno_file_path} not found") + return + + unanno_headers_ordered = [] + if os.path.exists(unanno_file_path): + with open(unanno_file_path, 'r') as f: + for line in f: + if line.startswith('>'): + header = line.split()[0].strip('>') + unanno_headers_ordered.append(header) + print(f"Found {len(unanno_headers_ordered)} headers in unannotated file") + else: + print(f"Warning: Unannotated file {unanno_file_path} not found") + + # Counter to track which read we're currently expecting + unanno_set = set(unanno_headers_ordered) + read_counter = 0 + + # Count total unique sequences from unannotated file + if os.path.exists(unanno_file_path): + with open(unanno_file_path, 'r') as f: + for line in f: + if "count=" in line: + count = int(line.split("count=")[1].split(";")[0]) + total_unique_count += count + + # Process annotated file + from collections import OrderedDict + blast_groups = OrderedDict() # q_id -> list of hit dicts + + # Group BLAST hits per q_id + try: + with open(anno_file_path, 'r') as f: + for line in f: + if line.startswith("#"): + continue + parts = line.strip().split('\t') + if len(parts) < 7: + continue + + q_id = parts[0] + try: + e_val = float(parts[6]) + except Exception: + continue + if e_val > args.eval_threshold: + continue + valid_taxa = True + taxon = parts[-1].strip() + if is_invalid_taxon(taxon): + print(f"Skipping {q_id}: invalid taxon: {taxon}") + valid_taxa = False + + identity = parts[4] if len(parts) > 4 else '' + cov = parts[5] if len(parts) > 5 else '' + bitscore = float(parts[7] if len(parts) > 7 else '') + source = parts[8] if len(parts) > 8 else '' + if valid_taxa: + hit = { + 'e_val': e_val, + 'identity': identity, + 'cov': cov, + 'bitscore': bitscore, + 'source': source, + 'taxon': taxon + } + blast_groups.setdefault(q_id, []).append(hit) + except Exception as e: + print(f"Error reading BLAST file {anno_file_path}: {e}") + return + + # BLAST-reads not in fasta + extra_blast_qids = [q for q in blast_groups.keys() if q not in unanno_set] + if extra_blast_qids: + print(f"Note: {len(extra_blast_qids)} BLAST q_ids not in FASTA: {extra_blast_qids[:10]}...") + + # Process fasta order + for header in unanno_headers_ordered: + if header not in blast_groups: + print(f"Skipping {header}: no BLAST hits") + continue + current_read_info = { + 'taxa_dict': {}, + 'header': header, + 'e_values': set(), + 'identities': set(), + 'coverages': set(), + 'sources': set(), + 'bitscores': set(), + 'best_eval': float(), + 'best_identity': '', + 'best_coverage': '', + 'best_source': '', + 'best_bitscore': -1, + 'best_taxa_dict' : {} + } + + # process all hits for this header + for hit in blast_groups[header]: + e_val = hit['e_val'] + identity = hit['identity'] + cov = hit['cov'] + bitscore = hit['bitscore'] + source = hit['source'] + taxon = hit['taxon'] + + current_read_info['e_values'].add(str(e_val)) + current_read_info['identities'].add(identity) + current_read_info['coverages'].add(cov) + current_read_info['sources'].add(source) + current_read_info['bitscores'].add(bitscore) + + # Keep track of best hit + if (bitscore > current_read_info['best_bitscore'] or + (bitscore == current_read_info['best_bitscore'] and e_val < current_read_info['best_eval'])): + # New best set → reset + current_read_info['best_bitscore'] = bitscore + current_read_info['best_eval'] = e_val + current_read_info['best_identity'] = identity + current_read_info['best_coverage'] = cov + current_read_info['best_source'] = source + current_read_info['best_taxa_dict'] = {taxon: 1} + + # If hit is of same quality: add to existing dictionary + elif bitscore == current_read_info['best_bitscore'] and e_val == current_read_info['best_eval']: + td = current_read_info['best_taxa_dict'] + td[taxon] = td.get(taxon, 0) + 1 + # Might be unnecesary, best_taxa_dict seems to be the same as taxa_dict + # in output. Too scared to remove. + if len(current_read_info['e_values']) == 1: + taxa_dict = current_read_info['taxa_dict'] + taxa_dict[taxon] = taxa_dict.get(taxon, 0) + 1 + read_data.append(current_read_info) + + if header not in seen_reads: + seen_reads.add(header) + headers.append(header) + try: + count = int(header.split('(')[1].split(')')[0]) + except IndexError or ValueError or AttributeError: + count = 0 + annotated_unique_count += count + + # Extract data for different functions from unified structure + taxa_dicts = [read['taxa_dict'] for read in read_data] + td = [read['best_taxa_dict'] for read in read_data] + # Generate outputs based on arguments + if args.eval_plot: + e_val_sets = [read['e_values'] for read in read_data] + make_eval_plot(e_val_sets, args.eval_plot) + + if args.taxa_output: + process_taxa_output(td, args.taxa_output, args.uncertain_threshold) + + if args.header_anno: + # Extract best values efficiently - single list comprehension per metric + e_values_for_headers = [[read['best_eval']] for read in read_data] + identity_for_headers = [[read['best_identity']] for read in read_data] + coverage_for_headers = [[read['best_coverage']] for read in read_data] + source_for_headers = [[read['best_source']] for read in read_data] + bitscore_for_headers= [[read['best_bitscore']] for read in read_data] + process_header_annotations(td, headers, e_values_for_headers, + args.header_anno, args.uncertain_threshold, + identity_for_headers, coverage_for_headers, source_for_headers, bitscore_for_headers) + + if args.circle_data: + create_circle_diagram(td, args.circle_data, args.use_counts, + args.uncertain_threshold) + + if args.anno_stats: + stats = calculate_annotation_stats(len(taxa_dicts), unanno_file_path, + annotated_unique_count, total_unique_count) + with open(args.anno_stats, 'w') as f: + f.write('metric\tvalue\n') + for key, value in stats.items(): + f.write(f'{key}\t{value}\n') + +def is_invalid_taxon(taxon: str) -> bool: + """ + Determine whether a given taxonomic path should be considered invalid and excluded from analysis. + + This function identifies and filters out taxonomic strings that contain unreliable, + incomplete, or placeholder taxonomic information. It is used to clean BLAST or + classification outputs before aggregation and visualization. + + :param taxon: Full taxonomic path as a string + :type taxon: str + :return: ``True`` if the taxon is considered invalid according to the criteria below; + otherwise ``False``. + :rtype: bool + + :criteria: + A taxon is considered invalid if: + - It contains the substring ``"invalid taxid"`` or ``"GenBank invalid taxon"``. + - It includes the term ``"environmental sample"`` (non-specific environmental sequence). + - It has any hierarchical level labeled as ``"unknown"`` followed by a more + specific level (e.g., *unknown genus / Homo sapiens*). + + :note: + This function performs a **case-insensitive** check and assumes the + input taxonomic path uses ``" / "`` as a level delimiter. + """ + taxon_lower = taxon.lower() + + if "invalid taxid" in taxon_lower or "genbank invalid taxon" in taxon_lower: + return True + + if "environmental sample" in taxon_lower: + return True + + # check for "unknown" followed by something deeper + parts = taxon.split(" / ") + for i, part in enumerate(parts): + if "unknown" in part.lower() and i < len(parts) - 1: + # there is something more specific after an unknown level + return True + return False + + +def main(arg_list=None): + """ + Entry point for Galaxy-compatible BLAST annotation processing. + + :param arg_list: Optional list of command-line arguments to override + ``sys.argv``. Primarily used for testing or programmatic execution. + If ``None``, arguments are read directly from the command line. + :type arg_list: list[str] | None + :return: None + :rtype: None + + Notes: + Calls `process_single_file` with parsed arguments. + """ + args = parse_arguments(arg_list) + for output_file in [args.eval_plot, args.header_anno, args.taxa_output, args.circle_data, args.anno_stats]: + if output_file: + os.makedirs(os.path.dirname(output_file), exist_ok=True) + process_single_file(args.input_anno, args.input_unanno, args) + print("Processing completed successfully") + + +if __name__ == "__main__": + main() \ No newline at end of file
