changeset 0:02b0ecc34d9a draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/rna_tools/rnaformer commit ee837e8d27a53baa3d4881412d4fbc566ae06499
author rnateam
date Thu, 11 Jul 2024 20:56:23 +0000
parents
children
files infer_rnaformer.xml macros.xml test-data/fasta_input1.fa test-data/rna_2d_pred_FASTA.txt test-data/rna_2d_pred_text.txt
diffstat 5 files changed, 238 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/infer_rnaformer.xml	Thu Jul 11 20:56:23 2024 +0000
@@ -0,0 +1,212 @@
+<tool id="infer_rnaformer" name="@EXECUTABLE@" version="@TOOL_VERSION@" profile="22.05">
+    <description>Predict the secondary structure of an RNA with RNAformer</description>
+    <macros>
+        <import>macros.xml</import>
+    </macros>
+    <expand macro="requirements">
+        <requirement type="package" version="1.83">biopython</requirement>
+    </expand>
+    <command detect_errors="exit_code"><![CDATA[
+    mkdir -p './model' &&
+    wget -O './model/RNAformer_32M_state_dict_intra_family_finetuned.pth' 'https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_state_dict_intra_family_finetuned.pth'
+    &&
+    wget -O './model/RNAformer_32M_config_intra_family_finetuned.yml' 'https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_config_intra_family_finetuned.yml'
+    &&
+    python '$script_file' > '$output'
+]]></command>
+<configfiles>
+        <configfile name="script_file"><![CDATA[import RNAformer
+import os
+import argparse
+import torch
+import urllib.request
+import logging
+from collections import defaultdict
+import torch.cuda
+import loralib as lora
+from RNAformer.model.RNAformer import RiboFormer
+from RNAformer.utils.configuration import Config
+
+from Bio import SeqIO
+
+import logging
+import sys
+
+def is_valid_rna_sequence(sequence):
+    """Check if the sequence contains only RNA bases."""
+    valid_bases = {'A', 'C', 'G', 'U', 'N'}  # Include 'N' if unknown bases are allowed
+    return all(base in valid_bases for base in sequence.upper())
+
+config_file_path = 'model/RNAformer_32M_config_intra_family_finetuned.yml'
+model_file_path = 'model/RNAformer_32M_state_dict_intra_family_finetuned.pth'
+
+config = Config(config_file=config_file_path)
+config.RNAformer.cycling = 6
+model = RiboFormer(config.RNAformer)
+state_dict_file = model_file_path
+
+#if str($input_type.input_type) == 'True'
+fasta_path = '$input_type.fasta_input'
+sequences = [str(record.seq) for record in SeqIO.parse(fasta_path, 'fasta')]
+#else
+sequence_string = "$input_type.rna_input_string"
+sequences = [seq.strip() for seq in sequence_string.split(',')]
+#end if
+
+for seq in sequences:
+    if not is_valid_rna_sequence(seq):
+        print(f"Invalid RNA sequence detected: {seq}. Please ensure only RNA sequences are used as input.", file=sys.stderr)
+        sys.exit(1)
+
+lora_config = {
+    "r": config.r,
+    "lora_alpha": config.lora_alpha,
+    "lora_dropout": config.lora_dropout,
+}
+
+with torch.no_grad():
+    for name, module in model.named_modules():
+        if any(replace_key in name for replace_key in config.replace_layer):
+            parent = model.get_submodule(".".join(name.split(".")[:-1]))
+            target_name = name.split(".")[-1]
+            target = model.get_submodule(name)
+            if isinstance(target, torch.nn.Linear) and "qkv" in name:
+                new_module = lora.MergedLinear(target.in_features, target.out_features,
+                                            bias=target.bias is not None,
+                                            enable_lora=[True, True, True], **lora_config)
+                new_module.weight.copy_(target.weight)
+                if target.bias is not None:
+                    new_module.bias.copy_(target.bias)
+            elif isinstance(target, torch.nn.Linear):
+                new_module = lora.Linear(target.in_features, target.out_features,
+                                        bias=target.bias is not None, **lora_config)
+                new_module.weight.copy_(target.weight)
+                if target.bias is not None:
+                    new_module.bias.copy_(target.bias)
+
+            elif isinstance(target, torch.nn.Conv2d):
+                kernel_size = target.kernel_size[0]
+                new_module = lora.Conv2d(target.in_channels, target.out_channels, kernel_size,
+                                        padding=(kernel_size - 1) // 2, bias=target.bias is not None,
+                                        **lora_config)
+
+                new_module.conv.weight.copy_(target.weight)
+                if target.bias is not None:
+                    new_module.conv.bias.copy_(target.bias)
+            setattr(parent, target_name, new_module)
+
+state_dict = torch.load(state_dict_file, map_location=torch.device('cpu'))
+
+model.load_state_dict(state_dict, strict=True)
+model_name = state_dict_file.split(".pth")[0]
+
+if torch.cuda.is_available():
+        model = model.cuda()
+
+        # check GPU can do bf16
+        if torch.cuda.is_bf16_supported():
+            model = model.bfloat16()
+        else:
+            model = model.half()
+
+model.eval()
+predicted_structures = []
+
+for sequence in sequences:
+    with torch.no_grad():
+        device = "cpu"
+
+        seq_vocab = ['A', 'C', 'G', 'U', 'N']
+        seq_stoi = dict(zip(seq_vocab, range(len(seq_vocab))))
+
+        pdb_sample = 1
+
+        length = len(sequence)
+        src_seq = torch.LongTensor(list(map(seq_stoi.get, sequence)))
+
+        sample = {}
+        sample['src_seq'] = src_seq.clone()
+        sample['length'] = torch.LongTensor([length])[0]
+        sample['pdb_sample'] = torch.LongTensor([pdb_sample])[0]
+
+        sequence = sample['src_seq'].unsqueeze(0).to(device)
+        src_len = torch.LongTensor([sequence.shape[-1]]).to(device)
+        pdb_sample = torch.FloatTensor([[1]]).to(device)
+
+        logits, pair_mask = model(sequence, src_len, pdb_sample)
+
+        pred_mat = torch.sigmoid(logits[0, :, :, -1]) > 0.5
+
+        pos_id = torch.where(pred_mat == True)
+        pos1_id = pos_id[0].cpu().tolist()
+        pos2_id = pos_id[1].cpu().tolist()
+        predicted_structure = f"Pairing index 1: {pos1_id} \nPairing index 2: {pos2_id}"
+        print(predicted_structure)
+
+        seqlen = len(sample['src_seq'])
+        dot_bracket =['.'] * seqlen
+        for i in range(len(pos1_id)):
+            open_index = pos1_id[i]
+            close_index = pos2_id[i]
+
+            if 0 <= open_index < len(dot_bracket) and 0 <= close_index < len(dot_bracket):
+                if dot_bracket[open_index] == '.' and dot_bracket[close_index] == '.':
+                    dot_bracket[open_index] = '('
+                    dot_bracket[close_index] = ')'
+        dot_bracket_str_pred = ''.join(dot_bracket)
+
+
+
+]]></configfile>
+    </configfiles>
+    <inputs>
+        <conditional name="input_type">
+            <param name="input_type" type="select" label="Input from FASTA file">
+                <option value="False">Provide a single RNA sequence string as text</option>
+                <option value="True">Provide a FASTA file</option>
+            </param>
+            <when value="False">
+                <param name="rna_input_string" label="Sequence(s) to fold" type="text" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA" help="Enter RNA sequences. Separate multiple RNA sequences by commas.">
+                    <sanitizer>
+                        <valid>
+                            <add value="ACGUacgu,"/>
+                        </valid>
+                    </sanitizer>
+                </param>
+            </when>
+            <when value="True">
+                <param format="fasta" name="fasta_input" type="data" label="Sequence to fold (FASTA file)"/>
+            </when>
+        </conditional>
+    </inputs>
+    <outputs>
+        <data name="output" format="txt" label="output"/>
+    </outputs>
+    <tests>
+        <test>
+            <param name="input_type" value="False"/>
+            <param name="rna_input_string" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA"/>
+            <output name="output" file="rna_2d_pred_text.txt"/>
+        </test>
+        <test>
+            <param name="input_type" value="True"/>
+            <param name="fasta_input" value="fasta_input1.fa"/>
+            <output name="output" file="rna_2d_pred_FASTA.txt"/>
+        </test>
+    </tests>
+    <help><![CDATA[
+    **RNAformer**
+        This tool reads RNA sequences and predicts their secondary structure using RNAformer.
+
+    **Input format**
+
+    RNAformer requires one or more RNA sequences either as a single FASTA file or as plain text.
+
+    **Outputs**
+
+    - Predicted secondary structure as a text file in the following formats:
+        - base pair positions
+        - dot-bracket notation
+    ]]></help>
+    <expand macro="citations" />
+</tool>
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/macros.xml	Thu Jul 11 20:56:23 2024 +0000
@@ -0,0 +1,16 @@
+<macros>
+    <token name="@EXECUTABLE@">RNAformer</token>
+    <token name="@TOOL_VERSION@">1.0.0</token>
+    <token name="@profile@">22.05</token>
+    <xml name="requirements">
+        <requirements>
+            <requirement type="package" version="0.0.1">rnaformer</requirement>
+            <yield/>
+        </requirements>
+    </xml>
+    <xml name="citations">
+        <citations>
+            <citation type="doi">10.1101/2024.02.12.579881</citation>
+        </citations>
+    </xml>
+</macros>
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/fasta_input1.fa	Thu Jul 11 20:56:23 2024 +0000
@@ -0,0 +1,4 @@
+>Anolis_caro_chrUn_GL343590.trna2_AlaAGC (218800-218872)  Ala (AGC) 73 bp  Sc: 49.55
+UGGGAAUUAGCUCAAAUGGUAGAGCGCUCGCUUAGCAUGUGAGAGGUAGUGGGAUCGAUGCCCACAUUCUCCA
+>Anolis_caro_chrUn_GL343207.trna3_AlaAGC (1513626-1513698)  Ala (AGC) 73 bp  Sc: 56.15
+GGGGAAUUAGCUCAAAUGGUAGAGCGCUCGCUUAGCAUGCGAGAGGUAGCGGGAUUGAUGCCCGCAUUCUCCA
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/rna_2d_pred_FASTA.txt	Thu Jul 11 20:56:23 2024 +0000
@@ -0,0 +1,4 @@
+Pairing index 1: [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 8, 9, 9, 10, 11, 12, 13, 13, 17, 18, 20, 21, 21, 22, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71] 
+Pairing index 2: [71, 70, 69, 68, 67, 66, 65, 13, 14, 20, 47, 22, 24, 44, 23, 22, 21, 7, 20, 54, 55, 7, 12, 45, 8, 11, 10, 9, 43, 42, 41, 40, 39, 38, 37, 36, 32, 31, 30, 29, 28, 27, 26, 25, 9, 21, 64, 63, 62, 61, 60, 57, 17, 18, 53, 52, 51, 50, 49, 48, 6, 5, 4, 3, 2, 1, 0]
+Pairing index 1: [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 8, 9, 9, 10, 11, 12, 13, 13, 14, 20, 20, 21, 21, 22, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 48, 49, 50, 51, 52, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71] 
+Pairing index 2: [71, 70, 69, 68, 67, 66, 65, 13, 14, 20, 22, 24, 44, 23, 22, 21, 7, 20, 7, 7, 13, 12, 45, 8, 11, 10, 9, 43, 42, 41, 40, 39, 38, 37, 32, 31, 30, 29, 28, 27, 26, 25, 9, 21, 64, 63, 62, 61, 60, 52, 51, 50, 49, 48, 6, 5, 4, 3, 2, 1, 0]
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/rna_2d_pred_text.txt	Thu Jul 11 20:56:23 2024 +0000
@@ -0,0 +1,2 @@
+Pairing index 1: [0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 10, 11, 12, 13, 17, 18, 21, 21, 22, 23, 23, 24, 25, 39, 40, 42, 43, 44, 48, 49, 50, 53, 54, 55, 56, 57, 58, 59, 60, 62, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76] 
+Pairing index 2: [76, 75, 74, 73, 72, 71, 70, 13, 21, 23, 25, 24, 23, 22, 7, 59, 60, 7, 13, 12, 8, 11, 10, 9, 30, 29, 50, 49, 48, 44, 43, 42, 69, 68, 67, 66, 65, 62, 17, 18, 58, 57, 56, 55, 54, 53, 6, 5, 4, 3, 2, 1, 0]