diff validate_fasta.py @ 6:04e95886cf24 draft

"planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 724a7a389c878dded1c0332f3b6e507e0c4cd52a-dirty"
author galaxy-australia
date Mon, 04 Apr 2022 01:46:22 +0000
parents 6c92e000d684
children eb085b3dbaf8
line wrap: on
line diff
--- a/validate_fasta.py	Thu Mar 10 21:53:42 2022 +0000
+++ b/validate_fasta.py	Mon Apr 04 01:46:22 2022 +0000
@@ -1,6 +1,7 @@
 """Validate input FASTA sequence."""
 
 import re
+import sys
 import argparse
 from typing import List, TextIO
 
@@ -16,10 +17,6 @@
         """Initialize from FASTA file."""
         self.fastas = []
         self.load(fasta_path)
-        print("Loaded FASTA sequences:")
-        for f in self.fastas:
-            print(f.header)
-            print(f.aa_seq)
 
     def load(self, fasta_path: str):
         """Load bare or FASTA formatted sequence."""
@@ -29,36 +26,32 @@
         if "__cn__" in self.content:
             # Pasted content with escaped characters
             self.newline = '__cn__'
-            self.caret = '__gt__'
+            self.read_caret = '__gt__'
         else:
             # Uploaded file with normal content
             self.newline = '\n'
-            self.caret = '>'
+            self.read_caret = '>'
 
         self.lines = self.content.split(self.newline)
-        header, sequence = self.interpret_first_line()
+
+        if not self.lines[0].startswith(self.read_caret):
+            # Fasta is headless, load as single sequence
+            self.update_fastas(
+                '', ''.join(self.lines)
+            )
 
-        i = 0
-        while i < len(self.lines):
-            line = self.lines[i]
-            if line.startswith(self.caret):
-                self.update_fastas(header, sequence)
-                header = '>' + self.strip_header(line)
-                sequence = ''
-            else:
-                sequence += line.strip('\n ')
-            i += 1
-
-        # after reading whole file, header & sequence buffers might be full
-        self.update_fastas(header, sequence)
-
-    def interpret_first_line(self):
-        line = self.lines[0]
-        if line.startswith(self.caret):
-            header = '>' + self.strip_header(line)
-            return header, ''
         else:
-            return '', line
+            header = None
+            sequence = None
+            for line in self.lines:
+                if line.startswith(self.read_caret):
+                    if header:
+                        self.update_fastas(header, sequence)
+                    header = '>' + self.strip_header(line)
+                    sequence = ''
+                else:
+                    sequence += line.strip('\n ')
+            self.update_fastas(header, sequence)
 
     def strip_header(self, line):
         """Strip characters escaped with underscores from pasted text."""
@@ -77,10 +70,14 @@
 
 
 class FastaValidator:
-    def __init__(self, fasta_list: List[Fasta]):
+    def __init__(
+            self,
+            fasta_list: List[Fasta],
+            min_length=None,
+            max_length=None):
+        self.min_length = min_length
+        self.max_length = max_length
         self.fasta_list = fasta_list
-        self.min_length = 30
-        self.max_length = 2000
         self.iupac_characters = {
             'A', 'B', 'C', 'D', 'E', 'F', 'G',
             'H', 'I', 'K', 'L', 'M', 'N', 'P',
@@ -93,68 +90,89 @@
         self.validate_num_seqs()
         self.validate_length()
         self.validate_alphabet()
+
         # not checking for 'X' nucleotides at the moment.
         # alphafold can throw an error if it doesn't like it.
-        #self.validate_x()
+        # self.validate_x()
 
     def validate_num_seqs(self) -> None:
+        """Assert that only one sequence has been provided."""
         if len(self.fasta_list) > 1:
-            raise Exception(f'Error encountered validating fasta: More than 1 sequence detected ({len(self.fasta_list)}). Please use single fasta sequence as input')
+            raise Exception(
+                'Error encountered validating fasta:'
+                f' More than 1 sequence detected ({len(self.fasta_list)}).'
+                ' Please use single fasta sequence as input.')
         elif len(self.fasta_list) == 0:
-            raise Exception(f'Error encountered validating fasta: input file has no fasta sequences')
+            raise Exception(
+                'Error encountered validating fasta:'
+                ' input file has no fasta sequences')
 
     def validate_length(self):
-        """Confirms whether sequence length is valid. """
+        """Confirm whether sequence length is valid."""
         fasta = self.fasta_list[0]
-        if len(fasta.aa_seq) < self.min_length:
-            raise Exception(f'Error encountered validating fasta: Sequence too short ({len(fasta.aa_seq)}aa). Must be > 30aa')
-        if len(fasta.aa_seq) > self.max_length:
-            raise Exception(f'Error encountered validating fasta: Sequence too long ({len(fasta.aa_seq)}aa). Must be < 2000aa')
+        if self.min_length:
+            if len(fasta.aa_seq) < self.min_length:
+                raise Exception(
+                    'Error encountered validating fasta: Sequence too short'
+                    f' ({len(fasta.aa_seq)}AA).'
+                    f' Minimum length is {self.min_length}AA.')
+        if self.max_length:
+            if len(fasta.aa_seq) > self.max_length:
+                raise Exception(
+                    'Error encountered validating fasta:'
+                    f' Sequence too long ({len(fasta.aa_seq)}AA).'
+                    f' Maximum length is {self.max_length}AA.')
 
     def validate_alphabet(self):
         """
-        Confirms whether the sequence conforms to IUPAC codes.
-        If not, reports the offending character and its position.
+        Confirm whether the sequence conforms to IUPAC codes.
+        If not, report the offending character and its position.
         """
         fasta = self.fasta_list[0]
         for i, char in enumerate(fasta.aa_seq.upper()):
             if char not in self.iupac_characters:
-                raise Exception(f'Error encountered validating fasta: Invalid amino acid found at pos {i}: "{char}"')
+                raise Exception(
+                    'Error encountered validating fasta: Invalid amino acid'
+                    f' found at pos {i}: "{char}"')
 
     def validate_x(self):
-        """checks if any bases are X. TODO check whether alphafold accepts X bases. """
+        """Check for X bases."""
         fasta = self.fasta_list[0]
         for i, char in enumerate(fasta.aa_seq.upper()):
             if char == 'X':
-                raise Exception(f'Error encountered validating fasta: Unsupported aa code "X" found at pos {i}')
+                raise Exception(
+                    'Error encountered validating fasta: Unsupported AA code'
+                    f' "X" found at pos {i}')
 
 
 class FastaWriter:
     def __init__(self) -> None:
-        self.outfile = 'alphafold.fasta'
-        self.formatted_line_len = 60
+        self.line_wrap = 60
 
     def write(self, fasta: Fasta):
-        with open(self.outfile, 'w') as fp:
-            header = fasta.header
-            seq = self.format_sequence(fasta.aa_seq)
-            fp.write(header + '\n')
-            fp.write(seq + '\n')
+        header = fasta.header
+        seq = self.format_sequence(fasta.aa_seq)
+        sys.stdout.write(header + '\n')
+        sys.stdout.write(seq)
 
     def format_sequence(self, aa_seq: str):
         formatted_seq = ''
-        for i in range(0, len(aa_seq), self.formatted_line_len):
-            formatted_seq += aa_seq[i: i + self.formatted_line_len] + '\n'
+        for i in range(0, len(aa_seq), self.line_wrap):
+            formatted_seq += aa_seq[i: i + self.line_wrap] + '\n'
         return formatted_seq
 
 
 def main():
     # load fasta file
     args = parse_args()
-    fas = FastaLoader(args.input_fasta)
+    fas = FastaLoader(args.input)
 
     # validate
-    fv = FastaValidator(fas.fastas)
+    fv = FastaValidator(
+        fas.fastas,
+        min_length=args.min_length,
+        max_length=args.max_length,
+    )
     fv.validate()
 
     # write cleaned version
@@ -165,13 +183,26 @@
 def parse_args() -> argparse.Namespace:
     parser = argparse.ArgumentParser()
     parser.add_argument(
-        "input_fasta",
+        "input",
         help="input fasta file",
         type=str
     )
+    parser.add_argument(
+        "--min_length",
+        dest='min_length',
+        help="Minimum length of input protein sequence (AA)",
+        default=None,
+        type=int,
+    )
+    parser.add_argument(
+        "--max_length",
+        dest='max_length',
+        help="Maximum length of input protein sequence (AA)",
+        default=None,
+        type=int,
+    )
     return parser.parse_args()
 
 
-
 if __name__ == '__main__':
     main()