A Grid Search for The Optimal Setting in Feed-Forward Neural Networks
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
The feed-forward neural network is a very powerful classification model in the machine learning content. Since the goodness-of-fit of a neural network is majorly dominated by the model complexity, it is very tempting for a modeler to over-parameterize the neural network by using too many hidden layers or/and hidden units.
As pointed out by Brian Ripley in his famous book “Modern Applied Statistics with S”, the complexity of a neural network can be regulated by a hyper-parameter called “weight decay” to penalize the weights of hidden units. Per Ripley, the use of weight decay can both help the optimization process and avoid the over-fitting.
Up till now, it becomes clear that the balance between the network complexity and the size of weight decay should form the optimal setting for a neural network. The only question remained is how to identify such a combination. In the real world, practitioners usually would use v-folder or cross-sample validation. However, given the expensive computing cost of a neural network, the cross-sample validation seems more efficient then the v-folder. In addition, due to the presence of local minimum, the validation result from a set of averaged models instead of a single model is deemed more reliable.
The example below shows a grip search strategy for the optimal setting in a neural network by cross-sample validation. As suggested by Ripley, the weight decay is in the approximate range between 0.01 and 0.1 for the entropy fit. For the simplicity, just a few numbers of hidden units are tried. However, with the availability of computing power, a finer grip search for a good combination between weight decay and the number of hidden units would be highly recommended.
> # DATA PREPARATIONS > df1 <- read.csv('credit_count.csv') > df2 <- df1[df1$CARDHLDR == 1, 2:12] > X <- I(as.matrix(df2[-1])) > st.X <- scale(X) > Y <- I(as.matrix(df2[1])) > df3 <- data.frame(X = st.X, Y); > > # DIVIDE DATA INTO TESTING AND TRAINING SETS > set.seed(2013) > rows <- sample(1:nrow(df3), nrow(df3) - 1000) > set1 <- df3[rows, ] > set2 <- df3[-rows, ] > > result <- c(NULL, NULL, NULL, NULL, NULL) > n_nets <- 10 > # SEARCH FOR OPTIMAL WEIGHT DECAY > for (w in c(0.01, 0.05, 0.1)) + { + # SEARCH FOR OPTIMAL NUMBER OF HIDDEN UNITS + for (n in c(1, 5, 10, 20)) + { + # CREATE A VECTOR OF RANDOM SEEDS + rv <- round(runif(n_nets) * 100) + # FOR EACH SETTING, RUN NEURAL NET MULTIPLE TIMES + for (i in 1:n_nets) + { + # INITIATE THE RANDOM STATE FOR EACH NET + set.seed(rv[i]); + # TRAIN NEURAL NETS + net <- nnet::nnet(Y ~ X, size = n, data = set1, entropy = TRUE, maxit = 1000, decay = w, skip = TRUE, trace = FALSE) + # COLLECT PREDICTIONS TO DO MODEL AVERAGING + if (i == 1) prob <- predict(net, set2) else prob <- prob + predict(net, set2) + } + # CALCULATE AREA UNDER CURVE OF THE MODEL AVERAGING PREDICTION + roc <- verification::roc.area(set2$Y, prob / n_nets)[1] + # COLLECT RESULTS + result <- rbind(result, c(w, n, roc, round(mean(prob / n_nets), 4), round(mean(set2$Y), 4))) + } + } > result2 <- data.frame(wt_decay = unlist(result[, 1]), n_units = unlist(result[, 2]),auc = unlist(result[, 3]), + pred_rate = unlist(result[, 4]), obsv_rate = unlist(result[, 5])) > result2[order(result2$auc, decreasing = T), ] wt_decay n_units auc pred_rate obsv_rate 1 0.01 1 0.6638209 0.0923 0.095 9 0.10 1 0.6625414 0.0923 0.095 5 0.05 1 0.6557022 0.0922 0.095 3 0.01 10 0.6530154 0.0938 0.095 8 0.05 20 0.6528293 0.0944 0.095 6 0.05 5 0.6516662 0.0917 0.095 2 0.01 5 0.6498284 0.0928 0.095 7 0.05 10 0.6456063 0.0934 0.095 4 0.01 20 0.6446176 0.0940 0.095 10 0.10 5 0.6434545 0.0927 0.095 12 0.10 20 0.6415935 0.0938 0.095 11 0.10 10 0.6348822 0.0928 0.095
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.