view blast_annotations_processor.py @ 0:a3989edf0a4a draft

planemo upload for repository https://github.com/Onnodg/Naturalis_NLOOR/tree/main/NLOOR_scripts/process_annotations_tool commit c944fd5685f295acba06679e85b67973c173b137
author onnodg
date Tue, 14 Oct 2025 09:08:30 +0000
parents
children
line wrap: on
line source

"""Galaxy-compatible BLAST annotation processor.

This script processes a single annotated BLAST file along with a FASTA file
containing unannotated reads. It generates multiple types of outputs for
integration with Galaxy workflows:

- E-value distribution plots (PNG)
- Taxonomic composition reports (text)
- Circular taxonomy diagram data (JSON)
- Header annotations with merged and per-read information (Excel)
- Annotation statistics summary (text)

Main workflow:
1. Parse command-line arguments.
2. Load annotated BLAST results and unannotated FASTA headers.
3. Group BLAST hits per query (q_id), filter by thresholds.
4. Resolve taxonomic conflicts with uncertainty rules.
5. Generate requested outputs.

Notes:
- Headers in BLAST and FASTA should correspond.
- Uses matplotlib, pandas, and openpyxl for visualization and reporting.
"""

import argparse
from collections import defaultdict
import json
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import openpyxl

# Default taxonomic levels
TAXONOMIC_LEVELS = ["K", "P", "C", "O", "F", "G", "S"]

def parse_arguments(arg_list=None):
    """Parse command line arguments for cluster processing."""
    parser = argparse.ArgumentParser(
        description='Process BLAST annotation results for Galaxy'
    )
    # Required inputs
    parser.add_argument('--input-anno', required=True,
                        help='Annotated BLAST output file (tabular format)')
    parser.add_argument('--input-unanno', required=True,
                        help='Unannotated sequences file (FASTA format)')
    # Optional outputs
    parser.add_argument('--eval-plot',
                        help='Output path for E-value plot (PNG)')
    parser.add_argument('--taxa-output',
                        help='Output path for taxa report (tabular)')
    parser.add_argument('--circle-data',
                        help='Output path for circular taxonomy data (txt)')
    parser.add_argument('--header-anno',
                        help='Output path for header annotations (tabular/xlsx)')
    parser.add_argument('--anno-stats',
                        help='Output path for annotation statistics (txt)')
    # Parameters
    parser.add_argument('--uncertain-threshold', type=float, default=0.9,
                        help='Threshold for resolving taxonomic conflicts (default: 0.9)')
    parser.add_argument('--eval-threshold', type=float, default=1e-10,
                        help='E-value threshold for filtering results (default: 1e-10)')
    parser.add_argument('--use-counts', action='store_true', default=True,
                        help='Use read counts in circular data')

    return parser.parse_args(arg_list)


def list_to_string(x):
    """
    Convert a list, pandas Series, or numpy array to a comma-separated string."""
    if isinstance(x, (list, pd.Series, np.ndarray)):
        return ", ".join(map(str, x))
    return str(x)

def make_eval_plot(e_val_sets, output_path):
    """
    Create a bar plot of E-values per read.

    :param e_val_sets: List of sets containing E-values per read.
    :type e_val_sets: list[set[str]]
    :param output_path: Path to save the plot (PNG).
    :type output_path: str
    :return: None
    :rtype: None
    """
    if not e_val_sets:
        print("No E-values to plot")
        return None

    # Convert string E-values to floats and sort
    processed_sets = sorted([sorted(float(e_val) for e_val in e_val_set)
                             for e_val_set in e_val_sets])

    bar_positions = []
    bar_heights = []

    for i, e_set in enumerate(processed_sets):
        if len(e_set) == 1:
            e_set.append(0)
        for j, e_val in enumerate(e_set):
            bar_positions.append(i + j * 0.29)
            if e_val != 0:
                bar_heights.append(1 / e_val)
            else:
                bar_heights.append(e_val)

    plt.figure(figsize=(21, 9))
    plt.bar(bar_positions, bar_heights, width=0.29, color=['blue', 'red'])

    plt.yscale('log')
    plt.grid(axis='y', linestyle='--', color='gray', alpha=0.7)
    plt.ylabel("e-values")
    plt.xticks(ticks=range(len(processed_sets)),
               labels=[f'{i + 1}' for i in range(len(processed_sets))])

    plt.tight_layout()
    plt.savefig(output_path, format='png')
    plt.close()
    return None


def calculate_annotation_stats(anno_count, unanno_file_path, unique_anno_count, total_unique_count):
    """
    Compute annotation statistics for a dataset.

    This function calculates summary statistics for annotated and unannotated
    sequences in a dataset. It counts the total number of sequences in the
    unannotated file (FASTA-style, based on lines starting with '>'), and
    computes the percentage of annotated sequences and unique annotated
    sequences relative to their totals.

    :param anno_count: Total number of annotated sequences.
    :type anno_count: int
    :param unanno_file_path: Path to a FASTA file containing unannotated sequences.
    :type unanno_file_path: str
    :param unique_anno_count: Number of unique annotated sequences.
    :type unique_anno_count: int
    :param total_unique_count: Total number of unique sequences in the dataset.
    :type total_unique_count: int
    :return: Dictionary containing:
        - 'percentage_annotated' (float): Percentage of annotated sequences.
        - 'annotated_sequences' (int): Number of annotated sequences.
        - 'total_sequences' (int): Total number of sequences (annotated + unannotated).
        - 'percentage_unique_annotated' (float): Percentage of unique annotated sequences.
        - 'unique_annotated' (int): Number of unique annotated sequences.
        - 'total_unique' (int): Total number of unique sequences.
    :rtype: dict[str, float | int]
    """
    # Count total sequences in unannotated file
    total_sequences = 0
    with open(unanno_file_path, 'r') as f:
        total_sequences = sum(1 for line in f if line.startswith('>'))

    percentage_annotated = (anno_count / total_sequences * 100) if total_sequences > 0 else 0
    percentage_unique_annotated = (unique_anno_count / total_unique_count * 100) if total_unique_count > 0 else 0

    return {
        'percentage_annotated': percentage_annotated,
        'annotated_sequences': anno_count,
        'total_sequences': total_sequences,
        'percentage_unique_annotated': percentage_unique_annotated,
        'unique_annotated': unique_anno_count,
        'total_unique': total_unique_count
    }

def _choose_taxon(level_counts: dict, threshold: float):
    """
    Determine the most representative taxonomic level based on counts and a threshold.

    This function evaluates a dictionary of taxonomic level counts and returns the
    level name that holds a strict majority. If no level reaches the threshold, or if
    there is a tie for the highest count, the function returns ``None``.

    :param level_counts: A dictionary mapping taxonomic level names (e.g., "species",
        "genus") to their respective counts.
    :type level_counts: dict[str, int]
    :param threshold: The minimum fraction (between 0 and 1) of the total counts that
        the top taxonomic level must have to be considered dominant.
    :type threshold: float

    :return: The name of the taxonomic level that meets or exceeds the threshold and
        is not tied with another level. Returns ``None`` if no level qualifies or
        if there is a tie.
    :rtype: str | None

    :example:
        _choose_taxon({"species": 6, "genus": 3, "family": 1}, 0.5)
        'species'
        _choose_taxon({"species": 4, "genus": 4}, 0.6)
        None
        _choose_taxon({"species": 2, "genus": 3}, 0.6)
        'genus'
    """
    total = sum(level_counts.values())
    if total == 0:
        return None
    items = sorted(level_counts.items(), key=lambda x: x[1], reverse=True)
    top_name, top_count = items[0]
    # tie -> no clear majority
    if sum(1 for _, c in items if c == top_count) > 1:
        return None
    if top_count / total >= threshold:
        return top_name
    return None


def _resolve_conflicts_recursive(conflicting: list, level_idx: int, threshold: float, max_levels: int, prefix=None):
    """
    Recursively resolve taxonomic classification conflicts across multiple levels.

    This function attempts to resolve conflicts between multiple taxonomic paths by
    recursively evaluating each taxonomic level. At each level, it determines whether
    there is a clear majority (based on a specified threshold) for a taxon name. If
    a majority is found, the function continues to the next level using only the
    taxa that match that name. If no majority exists, the function constructs
    "uncertain" outputs to indicate ambiguous classification.

    :param conflicting: List of tuples containing full taxonomic paths and their
        associated counts, e.g. ``[("Bacteria / Proteobacteria / Gammaproteobacteria", 10),
        ("Bacteria / Proteobacteria / Alphaproteobacteria", 5)]``.
    :type conflicting: list[tuple[str, int]]
    :param level_idx: The current taxonomic level index being evaluated (0-based).
    :type level_idx: int
    :param threshold: The minimum fraction (between 0 and 1) required for a strict
        majority decision at a given taxonomic level.
    :type threshold: float
    :param max_levels: The maximum number of taxonomic levels to evaluate before
        stopping recursion.
    :type max_levels: int
    :param prefix: A list of already-resolved taxon names up to the previous level.
        Defaults to ``None``, in which case an empty list is used.
    :type prefix: list[str] | None

    :return: A tuple containing:
        - **short_output** (str): A simplified taxonomic path with "Uncertain taxa"
          appended if no consensus was reached at the current level.
        - **full_output** (str): The full taxonomic path, filled with additional
          "Uncertain taxa" entries up to the minimal conflicting level depth.
    :rtype: tuple[str, str]

    :example:
        conflicts = [
        ...     ("Bacteria / Proteobacteria / Gammaproteobacteria", 10),
        ...     ("Bacteria / Proteobacteria / Alphaproteobacteria", 5)
        ... ]
        _resolve_conflicts_recursive(conflicts, level_idx=0, threshold=0.6, max_levels=5)
        ('Bacteria / Proteobacteria / Uncertain taxa',
         'Bacteria / Proteobacteria / Uncertain taxa / Uncertain taxa')
    """
    if prefix is None:
        prefix = []

    if not conflicting:
        return "", ""

    # split paths (keeps sync)
    conflicting_levels = [t.split(" / ") for t, _ in conflicting]
    min_conflicting_level = min(len(p) for p in conflicting_levels)

    # If only one full path remains -> return it (both short & full)
    if len(conflicting) == 1:
        return conflicting[0][0], conflicting[0][0]

    # If we've reached deepest comparable level or max_levels -> pick most supported full path
    if level_idx >= min_conflicting_level or level_idx >= max_levels:
        full_counts = defaultdict(int)
        for t, c in conflicting:
            full_counts[t] += c
        chosen_full = max(full_counts.items(), key=lambda x: x[1])[0]
        return chosen_full, chosen_full

    # Count names at this level
    level_counts = defaultdict(int)
    for taxon, count in conflicting:
        parts = taxon.split(" / ")
        if level_idx < len(parts):
            level_counts[parts[level_idx]] += count

    # If everyone agrees at this level, append that name and go deeper
    if len(level_counts) == 1:
        name = next(iter(level_counts))
        return _resolve_conflicts_recursive(conflicting, level_idx + 1, threshold, max_levels, prefix + [name])

    # Check for a strict majority at this level
    chosen_level_name = _choose_taxon(level_counts, threshold)
    if chosen_level_name is not None:
        # Keep only taxa that match the chosen name at this level and go deeper
        filtered = [(t, c) for (t, c) in conflicting
                    if level_idx < len(t.split(" / ")) and t.split(" / ")[level_idx] == chosen_level_name]
        return _resolve_conflicts_recursive(filtered, level_idx + 1, threshold, max_levels, prefix + [chosen_level_name])

    # No majority at this level -> construct uncertain outputs.
    # short: prefix + 'Uncertain taxa' (stop here)
    short_path = prefix + ['Uncertain taxa']
    short_output = " / ".join(short_path)

    # full: prefix + 'Uncertain taxa' + fill until min_conflicting_level
    full_path = short_path + ['Uncertain taxa'] * (min_conflicting_level - (level_idx + 1))
    full_output = " / ".join(full_path)

    return short_output, full_output


def resolve_taxon_majority(taxon_counts: dict, threshold: float = 0.90, max_levels: int = 7):
    """
    Resolve the majority consensus taxonomic classification from a set of annotated paths.

    :param taxon_counts: A dictionary mapping full taxonomic paths to their respective
        occurrence counts. For example:
        ``{"Bacteria / Proteobacteria / Gammaproteobacteria": 10,
           "Bacteria / Proteobacteria / Alphaproteobacteria": 5}``.
    :type taxon_counts: dict[str, int]
    :param threshold: Minimum fraction (between 0 and 1) required for a strict majority
        at each taxonomic level. Defaults to ``0.90``.
    :type threshold: float, optional
    :param max_levels: Maximum number of taxonomic levels to evaluate before stopping
        recursion. Defaults to ``7``.
    :type max_levels: int, optional

    :return: A tuple containing:
        - **short_output** (str): The consensus taxonomic path, possibly ending with
          "Uncertain taxa" if ambiguity remains.
        - **full_output** (str): The expanded taxonomic path including placeholder
          "Uncertain taxa" labels up to the minimal conflicting depth.
    :rtype: tuple[str, str]
    """
    conflicting = list(taxon_counts.items())  # list of (path, count) keeps taxa<->count synced
    return _resolve_conflicts_recursive(conflicting, level_idx=0, threshold=threshold, max_levels=max_levels)


def process_taxa_output(taxa_dicts, output_path, uncertain_threshold):
    """
    Generate a tabular taxa report from a list of read taxonomies.

    This function processes taxonomic assignments for each read in a dataset
    and produces a summary report similar in style to Kraken2 output. The
    report includes:

    - The number of reads that were marked as "Uncertain taxa" at each
      taxonomic rank.
    - For all other taxa, the number of reads assigned and their percentage
      relative to the sample total.
    - The taxonomic rank for each assignment, indicated in the report with
      indentation proportional to the rank depth.

    :param taxa_dicts: List of taxon count dictionaries, one per read, where
                       keys are taxon names and values are counts of how often
                       they were assigned to the read.
    :type taxa_dicts: list[dict[str, int]]
    :param output_path: Path to save the tabular taxa report.
    :type output_path: str
    :param uncertain_threshold: Threshold for resolving conflicting taxonomic
                                assignments (values below the threshold are
                                marked as "Uncertain taxa").
    :type uncertain_threshold: float
    :return: None. The function writes the taxa report to ``output_path``.
    :rtype: None
    """
    uncertain_dict = {level: 0 for level in TAXONOMIC_LEVELS}
    aggregated_counts = defaultdict(int)

    for read_taxa in taxa_dicts:
        resolved_taxon, _ = resolve_taxon_majority(read_taxa, uncertain_threshold)

        # Add counts for resolved taxon
        levels = resolved_taxon.split(" / ")
        for i in range(1, len(levels) + 1):
            aggregated_counts[" / ".join(levels[:i])] += 1

    # Calculate total count
    total_count = sum(value for key, value in aggregated_counts.items()
                      if len(key.split(" / ")) == 1)

    report_lines = []

    for taxonomy, count in sorted(aggregated_counts.items(), key=lambda x: x[0]):
        levels = taxonomy.split(" / ")
        indent = " " * (len(levels) - 1) * 2
        taxon_name = levels[-1]
        taxon_level_code = TAXONOMIC_LEVELS[len(levels) - 1] if len(levels) <= len(TAXONOMIC_LEVELS) else "U"
        percentage = (count / total_count) * 100 if total_count > 0 else 0

        report_lines.append(
            f"{percentage:.2f}\t{count}\t{total_count}\t{taxon_level_code}\t{indent}{taxon_name}"
        )

        if taxon_name == 'Uncertain taxa':
            uncertain_dict[taxon_level_code] += count

    # Write output
    with open(output_path, 'w') as f:
        f.write("Uncertain count per taxonomie level" + str(uncertain_dict) + '\n')
        f.write('percentage_rooted\tnumber_rooted\ttotal_num\ttaxon_level\tindentificatie\n')
        f.write("\n".join(report_lines))


def process_header_annotations(taxa_dicts, headers, e_values, output_path, uncertain_threshold,
                               identity_list, coverage_list, source_list, bitscore_list):
    """
    Generate an Excel report with per-read and aggregated taxonomic annotations.

    This function processes taxonomic assignments per read, along with associated
    metrics (e-value, identity, coverage, bitscore, source), and generates an Excel
    file with two sheets:

    1. **Individual_Reads** – One row per sequence, including all metrics and
       the resolved taxonomic path.
    2. **Merged_by_Taxa** – Aggregated results per unique taxonomic path, including
       merged metrics and counts.

    Internally, the function creates a temporary TSV file, processes it into
    a pandas DataFrame, expands the taxonomic path into separate columns
    (e.g., kingdom, phylum, class, …), and then writes the results into a
    multi-sheet Excel file. Column widths are automatically adjusted for readability.

    :param taxa_dicts: List of taxon count dictionaries, one per read, where keys
                       are taxa and values are their counts in the read.
    :type taxa_dicts: list[dict[str, int]]
    :param headers: Sequence headers corresponding to each read. Used to extract
                    read identifiers and counts.
    :type headers: list[str]
    :param e_values: List of e-values per read. Each element is a list of values
                     associated with that read.
    :type e_values: list[list[float]]
    :param output_path: Path to the Excel file to be created.
    :type output_path: str
    :param uncertain_threshold: Threshold for resolving conflicting taxonomic
                                assignments. Values below the threshold are
                                replaced with "Uncertain taxa".
    :type uncertain_threshold: float
    :param identity_list: Percent identity values per read.
    :type identity_list: list[list[str]]
    :param coverage_list: Coverage percentage values per read.
    :type coverage_list: list[list[str]]
    :param source_list: Source identifiers per read (e.g., database names).
    :type source_list: list[list[str]]
    :param bitscore_list: Bitscore values per read.
    :type bitscore_list: list[list[str]]
    :return: None. The function writes results to an Excel file at ``output_path``.
    :rtype: None
    :raises PermissionError: If the temporary TSV or output Excel file cannot be written
                             (e.g., file is open in another program).
    :raises IndexError: If the input lists (headers, e_values, etc.) are not aligned
                        with ``taxa_dicts``.
    :raises ValueError: If sequence headers are not formatted as expected for extracting counts.
    """
    report_lines = []
    for i, read_taxa in enumerate(taxa_dicts):
        _, resolved_taxon_long = resolve_taxon_majority(read_taxa, uncertain_threshold)
        try:
            e_val = e_values[i][0] if e_values[i] else "N/A"
            identity = identity_list[i][0] if identity_list[i] else "N/A"
            coverage = coverage_list[i][0] if coverage_list[i] else "N/A"
            source = source_list[i][0] if source_list[i] else "N/A"
            bitscore = bitscore_list[i][0] if bitscore_list[i] else "N/A"
        except (IndexError, TypeError) as e:
            print(f"Mismatch while extracting values: {e}. Check that your annotated and unannotated input files correspond correctly.")
            continue

        header = headers[i] if i < len(headers) else f"Header missing"

        # Extract count
        try:
            header_base, count_str = header.rsplit("(", 1)
            count = int(count_str.rstrip(")"))
        except ValueError as e:
            print(f'Failed extracting count: {e}')
            header_base = header
            count = 1

        report_lines.append(
            f'{header_base}\t{e_val}\t{identity}\t{coverage}\t{bitscore}\t{count}\t{source}\t{resolved_taxon_long}')
    # Create temporary TSV for processing
    temp_tsv_path = 'temp.tsv'
    try:
        with open(temp_tsv_path, 'w') as f:
            f.write('header\te_value\tidentity percentage\tcoverage\tbitscore\tcount\tsource\ttaxa\n')
            f.write("\n".join(report_lines))
    except PermissionError as e:
        print(f"Unable to write to file, error: {e} file might be opened")

    # Process the data
    df = pd.read_csv(temp_tsv_path, sep='\t', encoding="latin1")
    if not df.empty:
        # Split taxa into separate columns
        taxa_split = df["taxa"].str.split(" / ", expand=True)
        max_levels = taxa_split.shape[1] if taxa_split.shape[1] is not None else 7
        level_names = ["kingdom", "phylum", "class", "order", "family", "genus", "species"][:max_levels]
        taxa_split.columns = level_names

        # Add taxa columns to main dataframe
        df_individual = pd.concat([df, taxa_split], axis=1)
        df_individual = df_individual.sort_values(['species', 'genus', 'family'], ascending=[True, True, True])
        # Create merged dataframe - first add taxa columns, then aggregate
        df_for_merge = df_individual.copy()

        group_cols = ["taxa"]
        agg_dict = {
            "e_value": lambda x: f"{x.min()}–{x.max()}",
            "identity percentage": lambda x: list_to_string(list(x.unique())),
            "coverage": lambda x: list_to_string(list(x.unique())),
            "bitscore": lambda x: list_to_string(list(x.unique())),
            "count": "sum",
            "source": "first"
        }

        # Add aggregation for taxa levels that actually exist
        for level in level_names:
            if level in df_for_merge.columns:
                agg_dict[level] = "first"

        df_merged = df_for_merge.groupby(group_cols, as_index=False).agg(agg_dict)

        sort_columns = []
        sort_ascending = []

        # Add family, genus, species if they exist
        for level in ['species', 'genus', 'family']:
            if level in df_merged.columns:
                sort_columns.append(level)
                sort_ascending.append(True)

        df_merged = df_merged.sort_values(sort_columns, ascending=sort_ascending)

        # Write both datasets to single Excel file with multiple sheets
        temp_path = output_path + ".xlsx"
        os.makedirs(os.path.dirname(temp_path), exist_ok=True)
        with pd.ExcelWriter(temp_path, engine='openpyxl', mode='w') as writer:
            df_individual.to_excel(writer, sheet_name='Individual_Reads', index=False)
            df_merged.to_excel(writer, sheet_name='Merged_by_Taxa', index=False)

            # Auto-adjust column widths for both sheets
            for sheet_name in writer.sheets:
                worksheet = writer.sheets[sheet_name]
                for column in worksheet.columns:
                    max_length = 0
                    column_letter = column[0].column_letter
                    for cell in column:
                        try:
                            if len(str(cell.value)) > max_length:
                                max_length = len(str(cell.value))
                        except:
                            pass
                    adjusted_width = min(max_length + 2, 50)  # Cap at 50 characters
                    worksheet.column_dimensions[column_letter].width = adjusted_width
        os.replace(temp_path, output_path)
        # Clean up temporary file
        os.remove(temp_tsv_path)
    else:
        print("Dataframe empty, no annotation results")


def create_circle_diagram(taxa_dicts, output_path, use_counts, uncertain_threshold):
    """
    Generate hierarchical JSON data for a circular taxonomy diagram.

    This function aggregates taxonomic classification results from multiple reads
    and prepares hierarchical count data suitable for circular visualization.

    :param taxa_dicts: A list of dictionaries, where each dictionary represents a
        mapping from full taxonomic paths to their corresponding counts for a single
        read. Each key should be a taxonomic path in the format
    :type taxa_dicts: list[dict[str, int]]
    :param output_path: File path to write the generated JSON output containing
        circle diagram data.
    :type output_path: str
    :param use_counts: Whether to aggregate counts by total read frequency
        (``True``) or to count only unique taxa once (``False``).
    :type use_counts: boolean
    :param uncertain_threshold: Fraction (between 0 and 1) representing the minimum
        majority threshold used for resolving conflicting taxonomic assignments.
    :type uncertain_threshold: float

    :return: None. The resulting JSON file will be written to ``output_path``. The
        file contains a list of dictionaries, one per taxonomic level, with:
            - **labels** (*list[str]*): Names of taxa at that level.
            - **sizes** (*list[int]*): Corresponding aggregated counts.
    :rtype: None

    :note:
        - Ambiguous classifications are labeled as "Uncertain taxa".
        - The number of hierarchical levels is determined by the length of
          ``TAXONOMIC_LEVELS``.
    """
    aggregated_counts = defaultdict(int)
    seen_taxa = set()

    for read_taxa in taxa_dicts:
        _, resolved_taxon_long = resolve_taxon_majority(read_taxa, uncertain_threshold)

        levels = resolved_taxon_long.split(" / ")

        if use_counts:
            for i in range(1, len(levels) + 1):
                aggregated_counts[" / ".join(levels[:i])] += 1
        else:
            if resolved_taxon_long not in seen_taxa:
                for i in range(1, len(levels) + 1):
                    aggregated_counts[" / ".join(levels[:i])] += 1
                seen_taxa.add(resolved_taxon_long)

    # Prepare circle data
    circle_data = []
    for level in range(len(TAXONOMIC_LEVELS)):
        circle_data.append({"labels": [], "sizes": []})

    for taxonomy, count in sorted(aggregated_counts.items(), key=lambda x: x[0]):
        levels = taxonomy.split(" / ")
        if len(levels) <= len(TAXONOMIC_LEVELS):
            circle_data[len(levels) - 1]["labels"].append(levels[-1].strip())
            circle_data[len(levels) - 1]["sizes"].append(count)
    with open(output_path, "w") as f:
        json.dump(circle_data, f, indent=2)


def process_single_file(anno_file_path, unanno_file_path, args):
    """
    Process a single annotated BLAST file and generate all requested analytical outputs.

    This function integrates multiple steps for analyzing annotated BLAST output
    files and corresponding unannotated FASTA sequences. It organizes BLAST hits
    per query, extracts the best-scoring matches, and produces various types of
    output files (plots, tables, diagrams, and statistics) according to user
    arguments.

    :param anno_file_path: Path to the annotated BLAST results file.
    :type anno_file_path: str
    :param unanno_file_path: Path to a FASTA file containing unannotated sequences.
        Used to determine sequence order and total unique sequence counts.
    :type unanno_file_path: str
    :param args: Parsed command-line arguments containing the following attributes:
        - **eval_threshold** (*float*): Maximum allowed E-value for valid hits.
        - **eval_plot** (*str | None*): Output path for E-value distribution plot.
        - **taxa_output** (*str | None*): Output path for taxonomic summary table.
        - **header_anno** (*str | None*): Output path for annotated header table.
        - **circle_data** (*str | None*): Output path for circular diagram JSON data.
        - **anno_stats** (*str | None*): Output path for annotation statistics file.
        - **use_counts** (*bool*): Whether to use read counts or unique taxa only.
        - **uncertain_threshold** (*float*): Majority threshold for resolving taxonomic ambiguity.
    :type args: argparse.Namespace

    :return: None. This function writes multiple output files to the paths
        specified in ``args``
    :rtype: None

    :raises FileNotFoundError: If the annotated BLAST or unannotated FASTA file cannot be found.
    :raises ValueError: If file content is malformed or missing required columns.

    :notes:
        - The function first groups BLAST hits by query ID and filters out invalid or
          low-confidence taxa.
        - For each query, the best BLAST hit is determined based on bitscore and E-value.
        - The total number of unique annotated sequences is inferred from sequence
          headers containing a ``count=`` field.
        - Progress and warnings are printed to standard output for traceability.
    """
    # Core data structures - one per read
    read_data = []  # List of dicts, one per read
    headers = []
    seen_reads = set()
    annotated_unique_count = 0
    total_unique_count = 0

    # Current read processing
    current_q_id = ''
    current_read_info = None

    # Check if file exists and has content
    if not os.path.exists(anno_file_path):
        print(f"Error: Input file {anno_file_path} not found")
        return

    unanno_headers_ordered = []
    if os.path.exists(unanno_file_path):
        with open(unanno_file_path, 'r') as f:
            for line in f:
                if line.startswith('>'):
                    header = line.split()[0].strip('>')
                    unanno_headers_ordered.append(header)
        print(f"Found {len(unanno_headers_ordered)} headers in unannotated file")
    else:
        print(f"Warning: Unannotated file {unanno_file_path} not found")

    # Counter to track which read we're currently expecting
    unanno_set = set(unanno_headers_ordered)
    read_counter = 0

    # Count total unique sequences from unannotated file
    if os.path.exists(unanno_file_path):
        with open(unanno_file_path, 'r') as f:
            for line in f:
                if "count=" in line:
                    count = int(line.split("count=")[1].split(";")[0])
                    total_unique_count += count

    # Process annotated file
        from collections import OrderedDict
        blast_groups = OrderedDict()  # q_id -> list of hit dicts

        # Group BLAST hits per q_id
        try:
            with open(anno_file_path, 'r') as f:
                for line in f:
                    if line.startswith("#"):
                        continue
                    parts = line.strip().split('\t')
                    if len(parts) < 7:
                        continue

                    q_id = parts[0]
                    try:
                        e_val = float(parts[6])
                    except Exception:
                        continue
                    if e_val > args.eval_threshold:
                        continue
                    valid_taxa = True
                    taxon = parts[-1].strip()
                    if is_invalid_taxon(taxon):
                        print(f"Skipping {q_id}: invalid taxon: {taxon}")
                        valid_taxa = False

                    identity = parts[4] if len(parts) > 4 else ''
                    cov = parts[5] if len(parts) > 5 else ''
                    bitscore = float(parts[7] if len(parts) > 7 else '')
                    source = parts[8] if len(parts) > 8 else ''
                    if valid_taxa:
                        hit = {
                            'e_val': e_val,
                            'identity': identity,
                            'cov': cov,
                            'bitscore': bitscore,
                            'source': source,
                            'taxon': taxon
                        }
                        blast_groups.setdefault(q_id, []).append(hit)
        except Exception as e:
            print(f"Error reading BLAST file {anno_file_path}: {e}")
            return

        # BLAST-reads not in fasta
        extra_blast_qids = [q for q in blast_groups.keys() if q not in unanno_set]
        if extra_blast_qids:
            print(f"Note: {len(extra_blast_qids)} BLAST q_ids not in FASTA: {extra_blast_qids[:10]}...")

        # Process fasta order
        for header in unanno_headers_ordered:
            if header not in blast_groups:
                print(f"Skipping {header}: no BLAST hits")
                continue
            current_read_info = {
                'taxa_dict': {},
                'header': header,
                'e_values': set(),
                'identities': set(),
                'coverages': set(),
                'sources': set(),
                'bitscores': set(),
                'best_eval': float(),
                'best_identity': '',
                'best_coverage': '',
                'best_source': '',
                'best_bitscore': -1,
                'best_taxa_dict' : {}
            }

            # process all hits for this header
            for hit in blast_groups[header]:
                e_val = hit['e_val']
                identity = hit['identity']
                cov = hit['cov']
                bitscore = hit['bitscore']
                source = hit['source']
                taxon = hit['taxon']

                current_read_info['e_values'].add(str(e_val))
                current_read_info['identities'].add(identity)
                current_read_info['coverages'].add(cov)
                current_read_info['sources'].add(source)
                current_read_info['bitscores'].add(bitscore)

            # Keep track of best hit
                if (bitscore > current_read_info['best_bitscore'] or
                        (bitscore == current_read_info['best_bitscore'] and e_val < current_read_info['best_eval'])):
                    # New best set → reset
                    current_read_info['best_bitscore'] = bitscore
                    current_read_info['best_eval'] = e_val
                    current_read_info['best_identity'] = identity
                    current_read_info['best_coverage'] = cov
                    current_read_info['best_source'] = source
                    current_read_info['best_taxa_dict'] = {taxon: 1}

                # If hit is of same quality: add to existing dictionary
                elif bitscore == current_read_info['best_bitscore'] and e_val == current_read_info['best_eval']:
                    td = current_read_info['best_taxa_dict']
                    td[taxon] = td.get(taxon, 0) + 1
                # Might be unnecesary, best_taxa_dict seems to be the same as taxa_dict
                # in output. Too scared to remove.
                if len(current_read_info['e_values']) == 1:
                    taxa_dict = current_read_info['taxa_dict']
                    taxa_dict[taxon] = taxa_dict.get(taxon, 0) + 1
            read_data.append(current_read_info)

            if header not in seen_reads:
                seen_reads.add(header)
                headers.append(header)
                try:
                    count = int(header.split('(')[1].split(')')[0])
                except IndexError or ValueError or AttributeError:
                    count = 0
                annotated_unique_count += count

    # Extract data for different functions from unified structure
    taxa_dicts = [read['taxa_dict'] for read in read_data]
    td = [read['best_taxa_dict'] for read in read_data]
    # Generate outputs based on arguments
    if args.eval_plot:
        e_val_sets = [read['e_values'] for read in read_data]
        make_eval_plot(e_val_sets, args.eval_plot)

    if args.taxa_output:
        process_taxa_output(td, args.taxa_output, args.uncertain_threshold)

    if args.header_anno:
        # Extract best values efficiently - single list comprehension per metric
        e_values_for_headers = [[read['best_eval']] for read in read_data]
        identity_for_headers = [[read['best_identity']] for read in read_data]
        coverage_for_headers = [[read['best_coverage']] for read in read_data]
        source_for_headers = [[read['best_source']] for read in read_data]
        bitscore_for_headers= [[read['best_bitscore']] for read in read_data]
        process_header_annotations(td, headers, e_values_for_headers,
                                   args.header_anno, args.uncertain_threshold,
                                   identity_for_headers, coverage_for_headers, source_for_headers, bitscore_for_headers)

    if args.circle_data:
        create_circle_diagram(td, args.circle_data, args.use_counts,
                              args.uncertain_threshold)

    if args.anno_stats:
        stats = calculate_annotation_stats(len(taxa_dicts), unanno_file_path,
                                           annotated_unique_count, total_unique_count)
        with open(args.anno_stats, 'w') as f:
            f.write('metric\tvalue\n')
            for key, value in stats.items():
                f.write(f'{key}\t{value}\n')

def is_invalid_taxon(taxon: str) -> bool:
    """
    Determine whether a given taxonomic path should be considered invalid and excluded from analysis.

    This function identifies and filters out taxonomic strings that contain unreliable,
    incomplete, or placeholder taxonomic information. It is used to clean BLAST or
    classification outputs before aggregation and visualization.

    :param taxon: Full taxonomic path as a string
    :type taxon: str
    :return: ``True`` if the taxon is considered invalid according to the criteria below;
        otherwise ``False``.
    :rtype: bool

    :criteria:
        A taxon is considered invalid if:
        - It contains the substring ``"invalid taxid"`` or ``"GenBank invalid taxon"``.
        - It includes the term ``"environmental sample"`` (non-specific environmental sequence).
        - It has any hierarchical level labeled as ``"unknown"`` followed by a more
          specific level (e.g., *unknown genus / Homo sapiens*).

    :note:
        This function performs a **case-insensitive** check and assumes the
        input taxonomic path uses ``" / "`` as a level delimiter.
    """
    taxon_lower = taxon.lower()

    if "invalid taxid" in taxon_lower or "genbank invalid taxon" in taxon_lower:
        return True

    if "environmental sample" in taxon_lower:
        return True

    # check for "unknown" followed by something deeper
    parts = taxon.split(" / ")
    for i, part in enumerate(parts):
        if "unknown" in part.lower() and i < len(parts) - 1:
            # there is something more specific after an unknown level
            return True
    return False


def main(arg_list=None):
    """
    Entry point for Galaxy-compatible BLAST annotation processing.

    :param arg_list: Optional list of command-line arguments to override
            ``sys.argv``. Primarily used for testing or programmatic execution.
            If ``None``, arguments are read directly from the command line.
    :type arg_list: list[str] | None
    :return: None
    :rtype: None

    Notes:
        Calls `process_single_file` with parsed arguments.
    """
    args = parse_arguments(arg_list)
    for output_file in [args.eval_plot, args.header_anno, args.taxa_output, args.circle_data, args.anno_stats]:
        if output_file:
            os.makedirs(os.path.dirname(output_file), exist_ok=True)
    process_single_file(args.input_anno, args.input_unanno, args)
    print("Processing completed successfully")


if __name__ == "__main__":
    main()