Mercurial > repos > rnateam > infer_rnaformer
view infer_rnaformer.xml @ 0:02b0ecc34d9a draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/rna_tools/rnaformer commit ee837e8d27a53baa3d4881412d4fbc566ae06499
author | rnateam |
---|---|
date | Thu, 11 Jul 2024 20:56:23 +0000 |
parents | |
children |
line wrap: on
line source
<tool id="infer_rnaformer" name="@EXECUTABLE@" version="@TOOL_VERSION@" profile="22.05"> <description>Predict the secondary structure of an RNA with RNAformer</description> <macros> <import>macros.xml</import> </macros> <expand macro="requirements"> <requirement type="package" version="1.83">biopython</requirement> </expand> <command detect_errors="exit_code"><![CDATA[ mkdir -p './model' && wget -O './model/RNAformer_32M_state_dict_intra_family_finetuned.pth' 'https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_state_dict_intra_family_finetuned.pth' && wget -O './model/RNAformer_32M_config_intra_family_finetuned.yml' 'https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_config_intra_family_finetuned.yml' && python '$script_file' > '$output' ]]></command> <configfiles> <configfile name="script_file"><![CDATA[import RNAformer import os import argparse import torch import urllib.request import logging from collections import defaultdict import torch.cuda import loralib as lora from RNAformer.model.RNAformer import RiboFormer from RNAformer.utils.configuration import Config from Bio import SeqIO import logging import sys def is_valid_rna_sequence(sequence): """Check if the sequence contains only RNA bases.""" valid_bases = {'A', 'C', 'G', 'U', 'N'} # Include 'N' if unknown bases are allowed return all(base in valid_bases for base in sequence.upper()) config_file_path = 'model/RNAformer_32M_config_intra_family_finetuned.yml' model_file_path = 'model/RNAformer_32M_state_dict_intra_family_finetuned.pth' config = Config(config_file=config_file_path) config.RNAformer.cycling = 6 model = RiboFormer(config.RNAformer) state_dict_file = model_file_path #if str($input_type.input_type) == 'True' fasta_path = '$input_type.fasta_input' sequences = [str(record.seq) for record in SeqIO.parse(fasta_path, 'fasta')] #else sequence_string = "$input_type.rna_input_string" sequences = [seq.strip() for seq in sequence_string.split(',')] #end if for seq in sequences: if not is_valid_rna_sequence(seq): print(f"Invalid RNA sequence detected: {seq}. Please ensure only RNA sequences are used as input.", file=sys.stderr) sys.exit(1) lora_config = { "r": config.r, "lora_alpha": config.lora_alpha, "lora_dropout": config.lora_dropout, } with torch.no_grad(): for name, module in model.named_modules(): if any(replace_key in name for replace_key in config.replace_layer): parent = model.get_submodule(".".join(name.split(".")[:-1])) target_name = name.split(".")[-1] target = model.get_submodule(name) if isinstance(target, torch.nn.Linear) and "qkv" in name: new_module = lora.MergedLinear(target.in_features, target.out_features, bias=target.bias is not None, enable_lora=[True, True, True], **lora_config) new_module.weight.copy_(target.weight) if target.bias is not None: new_module.bias.copy_(target.bias) elif isinstance(target, torch.nn.Linear): new_module = lora.Linear(target.in_features, target.out_features, bias=target.bias is not None, **lora_config) new_module.weight.copy_(target.weight) if target.bias is not None: new_module.bias.copy_(target.bias) elif isinstance(target, torch.nn.Conv2d): kernel_size = target.kernel_size[0] new_module = lora.Conv2d(target.in_channels, target.out_channels, kernel_size, padding=(kernel_size - 1) // 2, bias=target.bias is not None, **lora_config) new_module.conv.weight.copy_(target.weight) if target.bias is not None: new_module.conv.bias.copy_(target.bias) setattr(parent, target_name, new_module) state_dict = torch.load(state_dict_file, map_location=torch.device('cpu')) model.load_state_dict(state_dict, strict=True) model_name = state_dict_file.split(".pth")[0] if torch.cuda.is_available(): model = model.cuda() # check GPU can do bf16 if torch.cuda.is_bf16_supported(): model = model.bfloat16() else: model = model.half() model.eval() predicted_structures = [] for sequence in sequences: with torch.no_grad(): device = "cpu" seq_vocab = ['A', 'C', 'G', 'U', 'N'] seq_stoi = dict(zip(seq_vocab, range(len(seq_vocab)))) pdb_sample = 1 length = len(sequence) src_seq = torch.LongTensor(list(map(seq_stoi.get, sequence))) sample = {} sample['src_seq'] = src_seq.clone() sample['length'] = torch.LongTensor([length])[0] sample['pdb_sample'] = torch.LongTensor([pdb_sample])[0] sequence = sample['src_seq'].unsqueeze(0).to(device) src_len = torch.LongTensor([sequence.shape[-1]]).to(device) pdb_sample = torch.FloatTensor([[1]]).to(device) logits, pair_mask = model(sequence, src_len, pdb_sample) pred_mat = torch.sigmoid(logits[0, :, :, -1]) > 0.5 pos_id = torch.where(pred_mat == True) pos1_id = pos_id[0].cpu().tolist() pos2_id = pos_id[1].cpu().tolist() predicted_structure = f"Pairing index 1: {pos1_id} \nPairing index 2: {pos2_id}" print(predicted_structure) seqlen = len(sample['src_seq']) dot_bracket =['.'] * seqlen for i in range(len(pos1_id)): open_index = pos1_id[i] close_index = pos2_id[i] if 0 <= open_index < len(dot_bracket) and 0 <= close_index < len(dot_bracket): if dot_bracket[open_index] == '.' and dot_bracket[close_index] == '.': dot_bracket[open_index] = '(' dot_bracket[close_index] = ')' dot_bracket_str_pred = ''.join(dot_bracket) ]]></configfile> </configfiles> <inputs> <conditional name="input_type"> <param name="input_type" type="select" label="Input from FASTA file"> <option value="False">Provide a single RNA sequence string as text</option> <option value="True">Provide a FASTA file</option> </param> <when value="False"> <param name="rna_input_string" label="Sequence(s) to fold" type="text" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA" help="Enter RNA sequences. Separate multiple RNA sequences by commas."> <sanitizer> <valid> <add value="ACGUacgu,"/> </valid> </sanitizer> </param> </when> <when value="True"> <param format="fasta" name="fasta_input" type="data" label="Sequence to fold (FASTA file)"/> </when> </conditional> </inputs> <outputs> <data name="output" format="txt" label="output"/> </outputs> <tests> <test> <param name="input_type" value="False"/> <param name="rna_input_string" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA"/> <output name="output" file="rna_2d_pred_text.txt"/> </test> <test> <param name="input_type" value="True"/> <param name="fasta_input" value="fasta_input1.fa"/> <output name="output" file="rna_2d_pred_FASTA.txt"/> </test> </tests> <help><![CDATA[ **RNAformer** This tool reads RNA sequences and predicts their secondary structure using RNAformer. **Input format** RNAformer requires one or more RNA sequences either as a single FASTA file or as plain text. **Outputs** - Predicted secondary structure as a text file in the following formats: - base pair positions - dot-bracket notation ]]></help> <expand macro="citations" /> </tool>