view infer_rnaformer.xml @ 0:02b0ecc34d9a draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/rna_tools/rnaformer commit ee837e8d27a53baa3d4881412d4fbc566ae06499
author rnateam
date Thu, 11 Jul 2024 20:56:23 +0000
parents
children
line wrap: on
line source

<tool id="infer_rnaformer" name="@EXECUTABLE@" version="@TOOL_VERSION@" profile="22.05">
    <description>Predict the secondary structure of an RNA with RNAformer</description>
    <macros>
        <import>macros.xml</import>
    </macros>
    <expand macro="requirements">
        <requirement type="package" version="1.83">biopython</requirement>
    </expand>
    <command detect_errors="exit_code"><![CDATA[
    mkdir -p './model' &&
    wget -O './model/RNAformer_32M_state_dict_intra_family_finetuned.pth' 'https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_state_dict_intra_family_finetuned.pth'
    &&
    wget -O './model/RNAformer_32M_config_intra_family_finetuned.yml' 'https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_config_intra_family_finetuned.yml'
    &&
    python '$script_file' > '$output'
]]></command>
<configfiles>
        <configfile name="script_file"><![CDATA[import RNAformer
import os
import argparse
import torch
import urllib.request
import logging
from collections import defaultdict
import torch.cuda
import loralib as lora
from RNAformer.model.RNAformer import RiboFormer
from RNAformer.utils.configuration import Config

from Bio import SeqIO

import logging
import sys

def is_valid_rna_sequence(sequence):
    """Check if the sequence contains only RNA bases."""
    valid_bases = {'A', 'C', 'G', 'U', 'N'}  # Include 'N' if unknown bases are allowed
    return all(base in valid_bases for base in sequence.upper())

config_file_path = 'model/RNAformer_32M_config_intra_family_finetuned.yml'
model_file_path = 'model/RNAformer_32M_state_dict_intra_family_finetuned.pth'

config = Config(config_file=config_file_path)
config.RNAformer.cycling = 6
model = RiboFormer(config.RNAformer)
state_dict_file = model_file_path

#if str($input_type.input_type) == 'True'
fasta_path = '$input_type.fasta_input'
sequences = [str(record.seq) for record in SeqIO.parse(fasta_path, 'fasta')]
#else
sequence_string = "$input_type.rna_input_string"
sequences = [seq.strip() for seq in sequence_string.split(',')]
#end if

for seq in sequences:
    if not is_valid_rna_sequence(seq):
        print(f"Invalid RNA sequence detected: {seq}. Please ensure only RNA sequences are used as input.", file=sys.stderr)
        sys.exit(1)

lora_config = {
    "r": config.r,
    "lora_alpha": config.lora_alpha,
    "lora_dropout": config.lora_dropout,
}

with torch.no_grad():
    for name, module in model.named_modules():
        if any(replace_key in name for replace_key in config.replace_layer):
            parent = model.get_submodule(".".join(name.split(".")[:-1]))
            target_name = name.split(".")[-1]
            target = model.get_submodule(name)
            if isinstance(target, torch.nn.Linear) and "qkv" in name:
                new_module = lora.MergedLinear(target.in_features, target.out_features,
                                            bias=target.bias is not None,
                                            enable_lora=[True, True, True], **lora_config)
                new_module.weight.copy_(target.weight)
                if target.bias is not None:
                    new_module.bias.copy_(target.bias)
            elif isinstance(target, torch.nn.Linear):
                new_module = lora.Linear(target.in_features, target.out_features,
                                        bias=target.bias is not None, **lora_config)
                new_module.weight.copy_(target.weight)
                if target.bias is not None:
                    new_module.bias.copy_(target.bias)

            elif isinstance(target, torch.nn.Conv2d):
                kernel_size = target.kernel_size[0]
                new_module = lora.Conv2d(target.in_channels, target.out_channels, kernel_size,
                                        padding=(kernel_size - 1) // 2, bias=target.bias is not None,
                                        **lora_config)

                new_module.conv.weight.copy_(target.weight)
                if target.bias is not None:
                    new_module.conv.bias.copy_(target.bias)
            setattr(parent, target_name, new_module)

state_dict = torch.load(state_dict_file, map_location=torch.device('cpu'))

model.load_state_dict(state_dict, strict=True)
model_name = state_dict_file.split(".pth")[0]

if torch.cuda.is_available():
        model = model.cuda()

        # check GPU can do bf16
        if torch.cuda.is_bf16_supported():
            model = model.bfloat16()
        else:
            model = model.half()

model.eval()
predicted_structures = []

for sequence in sequences:
    with torch.no_grad():
        device = "cpu"

        seq_vocab = ['A', 'C', 'G', 'U', 'N']
        seq_stoi = dict(zip(seq_vocab, range(len(seq_vocab))))

        pdb_sample = 1

        length = len(sequence)
        src_seq = torch.LongTensor(list(map(seq_stoi.get, sequence)))

        sample = {}
        sample['src_seq'] = src_seq.clone()
        sample['length'] = torch.LongTensor([length])[0]
        sample['pdb_sample'] = torch.LongTensor([pdb_sample])[0]

        sequence = sample['src_seq'].unsqueeze(0).to(device)
        src_len = torch.LongTensor([sequence.shape[-1]]).to(device)
        pdb_sample = torch.FloatTensor([[1]]).to(device)

        logits, pair_mask = model(sequence, src_len, pdb_sample)

        pred_mat = torch.sigmoid(logits[0, :, :, -1]) > 0.5

        pos_id = torch.where(pred_mat == True)
        pos1_id = pos_id[0].cpu().tolist()
        pos2_id = pos_id[1].cpu().tolist()
        predicted_structure = f"Pairing index 1: {pos1_id} \nPairing index 2: {pos2_id}"
        print(predicted_structure)

        seqlen = len(sample['src_seq'])
        dot_bracket =['.'] * seqlen
        for i in range(len(pos1_id)):
            open_index = pos1_id[i]
            close_index = pos2_id[i]

            if 0 <= open_index < len(dot_bracket) and 0 <= close_index < len(dot_bracket):
                if dot_bracket[open_index] == '.' and dot_bracket[close_index] == '.':
                    dot_bracket[open_index] = '('
                    dot_bracket[close_index] = ')'
        dot_bracket_str_pred = ''.join(dot_bracket)



]]></configfile>
    </configfiles>
    <inputs>
        <conditional name="input_type">
            <param name="input_type" type="select" label="Input from FASTA file">
                <option value="False">Provide a single RNA sequence string as text</option>
                <option value="True">Provide a FASTA file</option>
            </param>
            <when value="False">
                <param name="rna_input_string" label="Sequence(s) to fold" type="text" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA" help="Enter RNA sequences. Separate multiple RNA sequences by commas.">
                    <sanitizer>
                        <valid>
                            <add value="ACGUacgu,"/>
                        </valid>
                    </sanitizer>
                </param>
            </when>
            <when value="True">
                <param format="fasta" name="fasta_input" type="data" label="Sequence to fold (FASTA file)"/>
            </when>
        </conditional>
    </inputs>
    <outputs>
        <data name="output" format="txt" label="output"/>
    </outputs>
    <tests>
        <test>
            <param name="input_type" value="False"/>
            <param name="rna_input_string" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA"/>
            <output name="output" file="rna_2d_pred_text.txt"/>
        </test>
        <test>
            <param name="input_type" value="True"/>
            <param name="fasta_input" value="fasta_input1.fa"/>
            <output name="output" file="rna_2d_pred_FASTA.txt"/>
        </test>
    </tests>
    <help><![CDATA[
    **RNAformer**
        This tool reads RNA sequences and predicts their secondary structure using RNAformer.

    **Input format**

    RNAformer requires one or more RNA sequences either as a single FASTA file or as plain text.

    **Outputs**

    - Predicted secondary structure as a text file in the following formats:
        - base pair positions
        - dot-bracket notation
    ]]></help>
    <expand macro="citations" />
</tool>