Mercurial > repos > nicolas > oghma
comparison prediction.R @ 83:eab9bce19e04 draft
Uploaded
author | nicolas |
---|---|
date | Fri, 28 Oct 2016 08:47:48 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
82:44386547d0f8 | 83:eab9bce19e04 |
---|---|
1 ######################################################## | |
2 # | |
3 # creation date : 26/01/16 | |
4 # last modification : 02/06/16 | |
5 # author : Dr Nicolas Beaume | |
6 # owner : IRRI | |
7 # | |
8 ######################################################## | |
9 | |
10 library(rrBLUP) | |
11 suppressWarnings(suppressMessages(library(randomForest))) | |
12 library(e1071) | |
13 suppressWarnings(suppressMessages(library(glmnet))) | |
14 library(methods) | |
15 | |
16 | |
17 ############################ main ############################# | |
18 classifierNames <- c("list", "randomForest", "svm", "glmnet") | |
19 # load argument | |
20 cmd <- commandArgs(trailingOnly = T) | |
21 source(cmd[1]) | |
22 # load data | |
23 con = file(genotype) | |
24 genotype <- readLines(con = con, n = 1, ok=T) | |
25 close(con) | |
26 genotype <- read.table(genotype, sep="\t", h=T) | |
27 con = file(model) | |
28 model <- readLines(con = con, n = 1, ok=T) | |
29 close(con) | |
30 model <- readRDS(model) | |
31 # check if the classifier name is valid | |
32 if(all(is.na(match(class(model), classifierNames)))) { | |
33 stop(paste(class(model), "is not recognized as a valid model. Valid models are : ", classifierNames)) | |
34 } | |
35 # run prediction according to the classifier | |
36 # rrBLUP prediction | |
37 if(any(class(model) == "list")) { | |
38 predictions <- as.matrix(genotype) %*% as.matrix(model$u) | |
39 predictions <- predictions[,1]+model$beta | |
40 predictions <- data.frame(lines=rownames(genotype), predictions=predictions) | |
41 # LASSO prediction | |
42 } else if(any(class(model) == "glmnet")) { | |
43 predictions <- predict(model, as.matrix(genotype), type = "response") | |
44 predictions <- data.frame(lines=rownames(genotype), predictions=predictions) | |
45 # SVM or RandomForest prediction (predict is a wrapper for many machine learning function) | |
46 } else { | |
47 predictions <- predict(model, genotype) | |
48 predictions <- data.frame(lines=names(predictions), predictions=predictions) | |
49 } | |
50 # save results | |
51 write.table(predictions, file=out, sep="\t", row.names = F) | |
52 |