view tests/test_cdhit_analysis.py @ 4:e64af72e1b8f draft default tip

planemo upload for repository https://github.com/Onnodg/Naturalis_NLOOR/tree/main/NLOOR_scripts/process_clusters_tool commit 4017d38cf327c48a6252e488ba792527dae97a70-dirty
author onnodg
date Mon, 15 Dec 2025 16:44:40 +0000
parents ff68835adb2b
children
line wrap: on
line source

"""
Test suite for CD-HIT cluster analysis processor.
"""
import pytest
from pathlib import Path
import pandas as pd
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Stage_1_translated.NLOOR_scripts.process_clusters_tool.cdhit_analysis import (
    parse_cluster_file,
    process_cluster_data,
    calculate_cluster_taxa,
    write_similarity_output,
    write_count_output,
    write_taxa_excel,
)


class TestCDHitAnalysis:

    @pytest.fixture(scope="class")
    def test_data_dir(self):
        base = Path("Stage_1_translated/NLOOR_scripts/process_clusters_tool/test-data")
        assert base.exists()
        return base

    @pytest.fixture(scope="class")
    def sample_cluster_file(self, test_data_dir):
        f = test_data_dir / "prev_anno.txt"
        assert f.exists()
        return str(f)

    @pytest.fixture(scope="class")
    def sample_annotation_file(self, test_data_dir):
        f = test_data_dir / "prev4.xlsx"
        assert f.exists()
        return str(f)

    @pytest.fixture(scope="class")
    def parsed_clusters(self, sample_cluster_file, sample_annotation_file):
        return parse_cluster_file(sample_cluster_file, sample_annotation_file)


    def test_cluster_parsing_structure(self, parsed_clusters):
        assert len(parsed_clusters) == 514
        cluster_0 = parsed_clusters[0]
        assert len(cluster_0) == 430

        read = cluster_0["M01687:460:000000000-LGY9G:1:1101:8356:6156_CONS"]
        assert read["count"] == 19
        assert isinstance(read["similarity"], float)

    def test_annotation_integration_basic(self, parsed_clusters):
        cluster_0 = parsed_clusters[0]

        annotated_found = any(
            data["taxa"] != "Unannotated read" for data in cluster_0.values()
        )
        assert annotated_found, "At least one annotated read expected"


    def test_process_cluster_data_counts_and_taxa_map(self, parsed_clusters):
        sim, taxa_map, annotated, unannotated = process_cluster_data(parsed_clusters[0])

        assert isinstance(sim, list)
        assert annotated + unannotated == sum(d["count"] for d in parsed_clusters[0].values())
        assert isinstance(taxa_map, dict)
        assert annotated == 47004 and unannotated == 9


    def test_weighted_lca_splitting_on_uncertain_taxa(self):
        taxa_dict = {
            "K / P / C / O / F / G1 / S1": 60,
            "K / P / C / O / F / Uncertain taxa / Uncertain taxa": 60,
        }

        class ArgsLow:
            uncertain_taxa_use_ratio = 0.5
            min_to_split = 0.45
            min_count_to_split = 10

        class ArgsHigh:
            uncertain_taxa_use_ratio = 1.0
            min_to_split = 0.45
            min_count_to_split = 10

        # LOW weight → uncertain counts half → G1 wins → no split
        res_low = calculate_cluster_taxa(taxa_dict, ArgsLow())
        assert len(res_low) == 1
        assert sum(res_low[0].values()) == 60  # total preserved

        # HIGH weight → uncertain = full weight → equal → split
        res_high = calculate_cluster_taxa(taxa_dict, ArgsHigh())
        assert len(res_high) == 2
        total = sum(sum(g.values()) for g in res_high)
        assert total == 120


    def test_calculate_cluster_taxa_preserves_counts_real_cluster(self, parsed_clusters):
        sim, taxa_map, annotated, unannotated = process_cluster_data(parsed_clusters[3])


        raw_total = annotated + unannotated
        taxa_map_total = sum(info["count"] for info in taxa_map.values())
        assert raw_total == taxa_map_total

        class Args:
            uncertain_taxa_use_ratio = 0.5
            min_to_split = 0.3
            min_count_to_split = 5


        results = calculate_cluster_taxa({t: i["count"] for t, i in taxa_map.items()}, Args())


        resolved_total = sum(sum(group.values()) for group in results)
        assert resolved_total <= raw_total
        assert resolved_total > 0


    def test_write_similarity_and_count_outputs(self, tmp_path, parsed_clusters):
        out_simi = tmp_path / "simi.txt"
        out_count = tmp_path / "count.txt"

        cluster_data_list = []
        all_simi = []

        for c in parsed_clusters:
            sim, taxa_map, annotated, unannotated = process_cluster_data(c)
            cluster_data_list.append(
                {
                    "similarities": sim,
                    "taxa_map": taxa_map,
                    "annotated": annotated,
                    "unannotated": unannotated,
                }
            )
            all_simi.extend(sim)

        write_similarity_output(cluster_data_list, str(out_simi))
        assert out_simi.exists()

        write_count_output(cluster_data_list, str(out_count))
        assert out_count.exists()


    def test_write_taxa_excel_raw_and_processed(self, tmp_path, parsed_clusters):

        class Args:
            uncertain_taxa_use_ratio = 0.5
            min_to_split = 0.45
            min_count_to_split = 10
            min_cluster_support = 1
            make_taxa_in_cluster_split = False

        cluster_data_list = []
        for c in parsed_clusters:
            sim, taxa_map, annotated, unannotated = process_cluster_data(c)
            cluster_data_list.append(
                {
                    "similarities": sim,
                    "taxa_map": taxa_map,
                    "annotated": annotated,
                    "unannotated": unannotated,
                }
            )

        out = tmp_path / "taxa.xlsx"
        write_taxa_excel(
            cluster_data_list, Args(), str(out), write_raw=True, write_processed=True
        )

        xl = pd.ExcelFile(out)
        assert "Raw_Taxa_Clusters" in xl.sheet_names
        assert "Processed_Taxa_Clusters" in xl.sheet_names
        assert "Settings" in xl.sheet_names

    def test_write_taxa_excel_only_raw_or_only_processed(self, tmp_path, parsed_clusters):

        class Args:
            uncertain_taxa_use_ratio = 0.5
            min_to_split = 0.45
            min_count_to_split = 10
            min_cluster_support = 1
            make_taxa_in_cluster_split = False

        cluster_data_list = []
        for c in parsed_clusters:
            sim, taxa_map, annotated, unannotated = process_cluster_data(c)
            cluster_data_list.append(
                {
                    "similarities": sim,
                    "taxa_map": taxa_map,
                    "annotated": annotated,
                    "unannotated": unannotated,
                }
            )


        out_raw = tmp_path / "raw.xlsx"
        write_taxa_excel(cluster_data_list, Args(), str(out_raw), write_raw=True, write_processed=False)
        xl_raw = pd.ExcelFile(out_raw)
        assert "Raw_Taxa_Clusters" in xl_raw.sheet_names
        assert "Processed_Taxa_Clusters" not in xl_raw.sheet_names


        out_proc = tmp_path / "proc.xlsx"
        write_taxa_excel(cluster_data_list, Args(), str(out_proc), write_raw=False, write_processed=True)
        xl_proc = pd.ExcelFile(out_proc)
        assert "Processed_Taxa_Clusters" in xl_proc.sheet_names


    def test_parse_arguments_all_flags(self, tmp_path):
        from Stage_1_translated.NLOOR_scripts.process_clusters_tool import cdhit_analysis as ca
        args = ca.parse_arguments([
            "--input_cluster", str(tmp_path / "dummy.clstr"),
            "--simi_plot_y_min", "90",
            "--simi_plot_y_max", "99",
            "--uncertain_taxa_use_ratio", "0.3",
            "--min_to_split", "0.2",
            "--min_count_to_split", "5",
            "--output_excel", str(tmp_path / "report.xlsx"),
        ])
        assert args.simi_plot_y_min == 90
        assert args.simi_plot_y_max == 99

    def test_main_runs_and_creates_outputs(self, tmp_path):
        from Stage_1_translated.NLOOR_scripts.process_clusters_tool import cdhit_analysis as ca

        clstr = tmp_path / "simple.clstr"
        clstr.write_text(">Cluster 0\n0\t88nt, >read1_CONS(3)... *\n")

        anno = tmp_path / "anno.xlsx"
        df = pd.DataFrame([
            {
                "header": "read1_CONS",
                "seq_id": "SEQ001",
                "source": "Genbank",
                "taxa": "K / P / C / O / F / G / S",
            }
        ])
        with pd.ExcelWriter(anno) as w:
            df.to_excel(w, sheet_name="Individual_Reads", index=False)

        sim_file = tmp_path / "sim.txt"
        excel_file = tmp_path / "taxa.xlsx"
        args = [
            "--input_cluster", str(clstr),
            "--input_annotation", str(anno),
            "--output_similarity_txt", str(sim_file),
            "--output_excel", str(excel_file),
            '--output_taxa_clusters',
            '--output_taxa_processed',
            '--log_file', 'test-data/new_logs.txt',
            '--simi_plot_y_min', '95',
            '--simi_plot_y_max', '100',
            '--uncertain_taxa_use_ratio', '0.5',
            '--min_to_split', '0.45',
            '--min_count_to_split', '10',
            '--min_cluster_support', '1'
        ]

        ca.main(args)
        assert sim_file.exists()
        assert excel_file.exists()

    def test_parse_cluster_file_empty_and_no_annotation(self, tmp_path):
        from Stage_1_translated.NLOOR_scripts.process_clusters_tool import cdhit_analysis2 as ca

        empty = tmp_path / "empty.clstr"
        empty.write_text("")

        clusters = ca.parse_cluster_file(str(empty), annotation_file=None, log_messages=[])
        assert clusters == []

    def test_create_similarity_plot_creates_file(self, tmp_path, parsed_clusters):
        from Stage_1_translated.NLOOR_scripts.process_clusters_tool import cdhit_analysis2 as ca


        cluster_data_list = []
        all_simi = []
        lengths = []

        for c in parsed_clusters[:5]:
            sim, taxa_map, annotated, unannotated = process_cluster_data(c)
            cluster_data_list.append(
                {"similarities": sim, "taxa_map": taxa_map,
                 "annotated": annotated, "unannotated": unannotated}
            )
            if sim:
                all_simi.extend(sim)
                lengths.append(len(sim))

        class Args:
            simi_plot_y_min = 95.0
            simi_plot_y_max = 100.0

        out_png = tmp_path / "sim.png"
        ca.create_similarity_plot(all_simi, lengths, Args(), str(out_png))
        if all_simi:
            assert out_png.exists()