comparison prediction.R @ 83:eab9bce19e04 draft

Uploaded
author nicolas
date Fri, 28 Oct 2016 08:47:48 -0400
parents
children
comparison
equal deleted inserted replaced
82:44386547d0f8 83:eab9bce19e04
1 ########################################################
2 #
3 # creation date : 26/01/16
4 # last modification : 02/06/16
5 # author : Dr Nicolas Beaume
6 # owner : IRRI
7 #
8 ########################################################
9
10 library(rrBLUP)
11 suppressWarnings(suppressMessages(library(randomForest)))
12 library(e1071)
13 suppressWarnings(suppressMessages(library(glmnet)))
14 library(methods)
15
16
17 ############################ main #############################
18 classifierNames <- c("list", "randomForest", "svm", "glmnet")
19 # load argument
20 cmd <- commandArgs(trailingOnly = T)
21 source(cmd[1])
22 # load data
23 con = file(genotype)
24 genotype <- readLines(con = con, n = 1, ok=T)
25 close(con)
26 genotype <- read.table(genotype, sep="\t", h=T)
27 con = file(model)
28 model <- readLines(con = con, n = 1, ok=T)
29 close(con)
30 model <- readRDS(model)
31 # check if the classifier name is valid
32 if(all(is.na(match(class(model), classifierNames)))) {
33 stop(paste(class(model), "is not recognized as a valid model. Valid models are : ", classifierNames))
34 }
35 # run prediction according to the classifier
36 # rrBLUP prediction
37 if(any(class(model) == "list")) {
38 predictions <- as.matrix(genotype) %*% as.matrix(model$u)
39 predictions <- predictions[,1]+model$beta
40 predictions <- data.frame(lines=rownames(genotype), predictions=predictions)
41 # LASSO prediction
42 } else if(any(class(model) == "glmnet")) {
43 predictions <- predict(model, as.matrix(genotype), type = "response")
44 predictions <- data.frame(lines=rownames(genotype), predictions=predictions)
45 # SVM or RandomForest prediction (predict is a wrapper for many machine learning function)
46 } else {
47 predictions <- predict(model, genotype)
48 predictions <- data.frame(lines=names(predictions), predictions=predictions)
49 }
50 # save results
51 write.table(predictions, file=out, sep="\t", row.names = F)
52