Previous changeset 84:4eea5c2313d2 (2016-10-28) Next changeset 86:2212133e6a36 (2016-10-28) |
Commit message:
Uploaded |
added:
randomForest.R |
b |
diff -r 4eea5c2313d2 -r 94aa63659613 randomForest.R --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/randomForest.R Fri Oct 28 08:48:22 2016 -0400 |
[ |
@@ -0,0 +1,113 @@ +######################################################## +# +# creation date : 07/01/16 +# last modification : 25/10/16 +# author : Dr Nicolas Beaume +# +######################################################## + +suppressWarnings(suppressMessages(library(randomForest))) +############################ helper functions ####################### +# optimize +optimize <- function(genotype, phenotype, ntree=1000, + rangeMtry=seq(ceiling(ncol(genotype)/5), + ceiling(ncol(genotype)/3), ceiling(ncol(genotype)/100)), + repet=3) { + # accuracy over all mtry values + acc <- NULL + for(mtry in rangeMtry) { + # to compute the mean accuracy over repetiotion for the current mtry value + tempAcc <- NULL + for(i in 1:repet) { + # 1/3 of the dataset is used as test set + n <- ceiling(nrow(genotype)/3) + indexTest <- sample(1:nrow(genotype), size=n) + # create training and test set + train <- genotype[-indexTest,] + test <- genotype[indexTest,] + phenoTrain <- phenotype[-indexTest] + phenoTest <- phenotype[indexTest] + # compute model + model <- randomForest(x=train, y=phenoTrain, ntree = ntree, mtry =mtry) + # predict on test set and compute accuracy + pred <- predict(model, test) + tempAcc <- c(tempAcc, r2(phenoTest, pred)) + } + # find mean accuracy for the current mtry value + acc <- c(acc, mean(tempAcc)) + } + # return mtry for the best accuracy + names(acc) <- rangeMtry + bestParam <- which.max(acc) + return(rangeMtry[bestParam]) +} + +# compute r2 by computing the classic formula +# compare the sum of square difference from target to prediciton +# to the sum of square difference from target to the mean of the target +r2 <- function(target, prediction) { + sst <- sum((target-mean(target))^2) + ssr <- sum((target-prediction)^2) + return(1-ssr/sst) +} +################################## main function ########################### +rfSelection <- function(genotype, phenotype, folds, outFile, evaluation=T, mtry=NULL, ntree=1000) { + + # go for optimization + if(is.null(mtry)) { + # find best mtry + mtry <- seq(ceiling(ncol(genotype)/10), ceiling(ncol(genotype)/3), ceiling(ncol(genotype)/100)) + mtry <- optimize(genotype=genotype, phenotype=phenotype, + ntree = ntree, rangeMtry = mtry) + } + # evaluation + if(evaluation) { + prediction <- NULL + for(i in 1:length(folds)) { + # create training and testing set for the current fold + train <- genotype[-folds[[i]],] + test <- genotype[folds[[i]],] + phenoTrain <- phenotype[-folds[[i]]] + # compute model + rf <- randomForest(x=train, y=phenoTrain, mtry = mtry, ntree = ntree) + # predict and save prediction for the current fold + prediction <- c(prediction, list(predict(rf, test))) + } + # save preductions for all folds to be used for evaluation + saveRDS(prediction, file = paste(outFile, ".rds", sep = "")) + } else { + # just compute the model and save it + model <- randomForest(x=genotype, y=phenotype, mtry = mtry, ntree=ntree) + saveRDS(model, file = paste(outFile, ".rds", sep = "")) + } +} + + +############################ main ############################# +# load parameters +cmd <- commandArgs(T) +source(cmd[1]) +# load classifier parameters +mtry <- as.numeric(mtry) +ntree <- as.numeric(ntree) +if(mtry == -1) {mtry <- NULL} +# check if evaluation is required +evaluation <- F +if(as.integer(doEvaluation) == 1) { + evaluation <- T + con = file(folds) + folds <- readLines(con = con, n = 1, ok=T) + close(con) + folds <- readRDS(folds) +} +# load genotype and phenotype +con = file(genotype) +genotype <- readLines(con = con, n = 1, ok=T) +close(con) +genotype <- read.table(genotype, sep="\t", h=T) +phenotype <- read.table(phenotype, sep="\t", h=T)[,1] +# run ! +rfSelection(genotype = data.matrix(genotype), phenotype=phenotype, + evaluation = evaluation, out = out, folds = folds, mtry = mtry, ntree=ntree) +# send the path containing results to galaxy +cat(paste(paste(out, ".rds", sep = ""), "\n", sep="")) |