Mercurial > repos > rnateam > infer_rnaformer
changeset 0:02b0ecc34d9a draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/rna_tools/rnaformer commit ee837e8d27a53baa3d4881412d4fbc566ae06499
author | rnateam |
---|---|
date | Thu, 11 Jul 2024 20:56:23 +0000 |
parents | |
children | |
files | infer_rnaformer.xml macros.xml test-data/fasta_input1.fa test-data/rna_2d_pred_FASTA.txt test-data/rna_2d_pred_text.txt |
diffstat | 5 files changed, 238 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/infer_rnaformer.xml Thu Jul 11 20:56:23 2024 +0000 @@ -0,0 +1,212 @@ +<tool id="infer_rnaformer" name="@EXECUTABLE@" version="@TOOL_VERSION@" profile="22.05"> + <description>Predict the secondary structure of an RNA with RNAformer</description> + <macros> + <import>macros.xml</import> + </macros> + <expand macro="requirements"> + <requirement type="package" version="1.83">biopython</requirement> + </expand> + <command detect_errors="exit_code"><![CDATA[ + mkdir -p './model' && + wget -O './model/RNAformer_32M_state_dict_intra_family_finetuned.pth' 'https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_state_dict_intra_family_finetuned.pth' + && + wget -O './model/RNAformer_32M_config_intra_family_finetuned.yml' 'https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_config_intra_family_finetuned.yml' + && + python '$script_file' > '$output' +]]></command> +<configfiles> + <configfile name="script_file"><![CDATA[import RNAformer +import os +import argparse +import torch +import urllib.request +import logging +from collections import defaultdict +import torch.cuda +import loralib as lora +from RNAformer.model.RNAformer import RiboFormer +from RNAformer.utils.configuration import Config + +from Bio import SeqIO + +import logging +import sys + +def is_valid_rna_sequence(sequence): + """Check if the sequence contains only RNA bases.""" + valid_bases = {'A', 'C', 'G', 'U', 'N'} # Include 'N' if unknown bases are allowed + return all(base in valid_bases for base in sequence.upper()) + +config_file_path = 'model/RNAformer_32M_config_intra_family_finetuned.yml' +model_file_path = 'model/RNAformer_32M_state_dict_intra_family_finetuned.pth' + +config = Config(config_file=config_file_path) +config.RNAformer.cycling = 6 +model = RiboFormer(config.RNAformer) +state_dict_file = model_file_path + +#if str($input_type.input_type) == 'True' +fasta_path = '$input_type.fasta_input' +sequences = [str(record.seq) for record in SeqIO.parse(fasta_path, 'fasta')] +#else +sequence_string = "$input_type.rna_input_string" +sequences = [seq.strip() for seq in sequence_string.split(',')] +#end if + +for seq in sequences: + if not is_valid_rna_sequence(seq): + print(f"Invalid RNA sequence detected: {seq}. Please ensure only RNA sequences are used as input.", file=sys.stderr) + sys.exit(1) + +lora_config = { + "r": config.r, + "lora_alpha": config.lora_alpha, + "lora_dropout": config.lora_dropout, +} + +with torch.no_grad(): + for name, module in model.named_modules(): + if any(replace_key in name for replace_key in config.replace_layer): + parent = model.get_submodule(".".join(name.split(".")[:-1])) + target_name = name.split(".")[-1] + target = model.get_submodule(name) + if isinstance(target, torch.nn.Linear) and "qkv" in name: + new_module = lora.MergedLinear(target.in_features, target.out_features, + bias=target.bias is not None, + enable_lora=[True, True, True], **lora_config) + new_module.weight.copy_(target.weight) + if target.bias is not None: + new_module.bias.copy_(target.bias) + elif isinstance(target, torch.nn.Linear): + new_module = lora.Linear(target.in_features, target.out_features, + bias=target.bias is not None, **lora_config) + new_module.weight.copy_(target.weight) + if target.bias is not None: + new_module.bias.copy_(target.bias) + + elif isinstance(target, torch.nn.Conv2d): + kernel_size = target.kernel_size[0] + new_module = lora.Conv2d(target.in_channels, target.out_channels, kernel_size, + padding=(kernel_size - 1) // 2, bias=target.bias is not None, + **lora_config) + + new_module.conv.weight.copy_(target.weight) + if target.bias is not None: + new_module.conv.bias.copy_(target.bias) + setattr(parent, target_name, new_module) + +state_dict = torch.load(state_dict_file, map_location=torch.device('cpu')) + +model.load_state_dict(state_dict, strict=True) +model_name = state_dict_file.split(".pth")[0] + +if torch.cuda.is_available(): + model = model.cuda() + + # check GPU can do bf16 + if torch.cuda.is_bf16_supported(): + model = model.bfloat16() + else: + model = model.half() + +model.eval() +predicted_structures = [] + +for sequence in sequences: + with torch.no_grad(): + device = "cpu" + + seq_vocab = ['A', 'C', 'G', 'U', 'N'] + seq_stoi = dict(zip(seq_vocab, range(len(seq_vocab)))) + + pdb_sample = 1 + + length = len(sequence) + src_seq = torch.LongTensor(list(map(seq_stoi.get, sequence))) + + sample = {} + sample['src_seq'] = src_seq.clone() + sample['length'] = torch.LongTensor([length])[0] + sample['pdb_sample'] = torch.LongTensor([pdb_sample])[0] + + sequence = sample['src_seq'].unsqueeze(0).to(device) + src_len = torch.LongTensor([sequence.shape[-1]]).to(device) + pdb_sample = torch.FloatTensor([[1]]).to(device) + + logits, pair_mask = model(sequence, src_len, pdb_sample) + + pred_mat = torch.sigmoid(logits[0, :, :, -1]) > 0.5 + + pos_id = torch.where(pred_mat == True) + pos1_id = pos_id[0].cpu().tolist() + pos2_id = pos_id[1].cpu().tolist() + predicted_structure = f"Pairing index 1: {pos1_id} \nPairing index 2: {pos2_id}" + print(predicted_structure) + + seqlen = len(sample['src_seq']) + dot_bracket =['.'] * seqlen + for i in range(len(pos1_id)): + open_index = pos1_id[i] + close_index = pos2_id[i] + + if 0 <= open_index < len(dot_bracket) and 0 <= close_index < len(dot_bracket): + if dot_bracket[open_index] == '.' and dot_bracket[close_index] == '.': + dot_bracket[open_index] = '(' + dot_bracket[close_index] = ')' + dot_bracket_str_pred = ''.join(dot_bracket) + + + +]]></configfile> + </configfiles> + <inputs> + <conditional name="input_type"> + <param name="input_type" type="select" label="Input from FASTA file"> + <option value="False">Provide a single RNA sequence string as text</option> + <option value="True">Provide a FASTA file</option> + </param> + <when value="False"> + <param name="rna_input_string" label="Sequence(s) to fold" type="text" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA" help="Enter RNA sequences. Separate multiple RNA sequences by commas."> + <sanitizer> + <valid> + <add value="ACGUacgu,"/> + </valid> + </sanitizer> + </param> + </when> + <when value="True"> + <param format="fasta" name="fasta_input" type="data" label="Sequence to fold (FASTA file)"/> + </when> + </conditional> + </inputs> + <outputs> + <data name="output" format="txt" label="output"/> + </outputs> + <tests> + <test> + <param name="input_type" value="False"/> + <param name="rna_input_string" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA"/> + <output name="output" file="rna_2d_pred_text.txt"/> + </test> + <test> + <param name="input_type" value="True"/> + <param name="fasta_input" value="fasta_input1.fa"/> + <output name="output" file="rna_2d_pred_FASTA.txt"/> + </test> + </tests> + <help><![CDATA[ + **RNAformer** + This tool reads RNA sequences and predicts their secondary structure using RNAformer. + + **Input format** + + RNAformer requires one or more RNA sequences either as a single FASTA file or as plain text. + + **Outputs** + + - Predicted secondary structure as a text file in the following formats: + - base pair positions + - dot-bracket notation + ]]></help> + <expand macro="citations" /> +</tool> \ No newline at end of file
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/macros.xml Thu Jul 11 20:56:23 2024 +0000 @@ -0,0 +1,16 @@ +<macros> + <token name="@EXECUTABLE@">RNAformer</token> + <token name="@TOOL_VERSION@">1.0.0</token> + <token name="@profile@">22.05</token> + <xml name="requirements"> + <requirements> + <requirement type="package" version="0.0.1">rnaformer</requirement> + <yield/> + </requirements> + </xml> + <xml name="citations"> + <citations> + <citation type="doi">10.1101/2024.02.12.579881</citation> + </citations> + </xml> +</macros> \ No newline at end of file
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test-data/fasta_input1.fa Thu Jul 11 20:56:23 2024 +0000 @@ -0,0 +1,4 @@ +>Anolis_caro_chrUn_GL343590.trna2_AlaAGC (218800-218872) Ala (AGC) 73 bp Sc: 49.55 +UGGGAAUUAGCUCAAAUGGUAGAGCGCUCGCUUAGCAUGUGAGAGGUAGUGGGAUCGAUGCCCACAUUCUCCA +>Anolis_caro_chrUn_GL343207.trna3_AlaAGC (1513626-1513698) Ala (AGC) 73 bp Sc: 56.15 +GGGGAAUUAGCUCAAAUGGUAGAGCGCUCGCUUAGCAUGCGAGAGGUAGCGGGAUUGAUGCCCGCAUUCUCCA
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test-data/rna_2d_pred_FASTA.txt Thu Jul 11 20:56:23 2024 +0000 @@ -0,0 +1,4 @@ +Pairing index 1: [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 8, 9, 9, 10, 11, 12, 13, 13, 17, 18, 20, 21, 21, 22, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 48, 49, 50, 51, 52, 53, 54, 55, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71] +Pairing index 2: [71, 70, 69, 68, 67, 66, 65, 13, 14, 20, 47, 22, 24, 44, 23, 22, 21, 7, 20, 54, 55, 7, 12, 45, 8, 11, 10, 9, 43, 42, 41, 40, 39, 38, 37, 36, 32, 31, 30, 29, 28, 27, 26, 25, 9, 21, 64, 63, 62, 61, 60, 57, 17, 18, 53, 52, 51, 50, 49, 48, 6, 5, 4, 3, 2, 1, 0] +Pairing index 1: [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 8, 9, 9, 10, 11, 12, 13, 13, 14, 20, 20, 21, 21, 22, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 48, 49, 50, 51, 52, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71] +Pairing index 2: [71, 70, 69, 68, 67, 66, 65, 13, 14, 20, 22, 24, 44, 23, 22, 21, 7, 20, 7, 7, 13, 12, 45, 8, 11, 10, 9, 43, 42, 41, 40, 39, 38, 37, 32, 31, 30, 29, 28, 27, 26, 25, 9, 21, 64, 63, 62, 61, 60, 52, 51, 50, 49, 48, 6, 5, 4, 3, 2, 1, 0]
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test-data/rna_2d_pred_text.txt Thu Jul 11 20:56:23 2024 +0000 @@ -0,0 +1,2 @@ +Pairing index 1: [0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 10, 11, 12, 13, 17, 18, 21, 21, 22, 23, 23, 24, 25, 39, 40, 42, 43, 44, 48, 49, 50, 53, 54, 55, 56, 57, 58, 59, 60, 62, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76] +Pairing index 2: [76, 75, 74, 73, 72, 71, 70, 13, 21, 23, 25, 24, 23, 22, 7, 59, 60, 7, 13, 12, 8, 11, 10, 9, 30, 29, 50, 49, 48, 44, 43, 42, 69, 68, 67, 66, 65, 62, 17, 18, 58, 57, 56, 55, 54, 53, 6, 5, 4, 3, 2, 1, 0]