[This article was first published on Jakub Glinka's Blog, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Multilayer perceptron (MLP) is the simplest feed-forward neural network. It mitigates the constraints of original perceptron that was able to learn only linearly separable patterns from the data. It achieves this by introducing at least one hidden layer in order to learn representation of the data that would enable linear separation.
In the first layer MLP apply linear transformations to the data point < svg xmlns:xlink="http://www.w3.org/1999/xlink" width="1.33ex" height="1.676ex" style="vertical-align: -0.338ex;" viewBox="0 -576.1 572.5 721.6" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-78">:
< svg xmlns:xlink="http://www.w3.org/1999/xlink" width="33.627ex" height="3.509ex" style="vertical-align: -1.338ex;" viewBox="0 -934.9 14478.1 1510.9" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-66">< use transform="scale(0.707)" x="693" y="-213" xlink:href="#MJMATHI-6A">< use x="882" y="0" xlink:href="#MJMAIN-28">< use x="1271" y="0" xlink:href="#MJMATHI-78">< use x="1844" y="0" xlink:href="#MJMAIN-29">< use x="2511" y="0" xlink:href="#MJMAIN-3D">< g transform="translate(3567,0)">< use x="0" y="0" xlink:href="#MJMATHI-77">< use transform="scale(0.707)" x="1013" y="488" xlink:href="#MJMAIN-22A4">< use transform="scale(0.707)" x="1013" y="-430" xlink:href="#MJMATHI-6A">< use x="4934" y="0" xlink:href="#MJMATHI-78">< use x="5729" y="0" xlink:href="#MJMAIN-2B">< g transform="translate(6730,0)">< use x="0" y="0" xlink:href="#MJMATHI-62">< use transform="scale(0.707)" x="607" y="-213" xlink:href="#MJMATHI-6A">< g transform="translate(8051,0)">< use x="0" y="0" xlink:href="#MJMAIN-66">< use x="372" y="0" xlink:href="#MJMAIN-6F">< use x="873" y="0" xlink:href="#MJMAIN-72">< use x="9816" y="0" xlink:href="#MJMATHI-6A">< use x="10507" y="0" xlink:href="#MJMAIN-3D">< use x="11563" y="0" xlink:href="#MJMAIN-31">< use x="12063" y="0" xlink:href="#MJMAIN-2C">< use x="12509" y="0" xlink:href="#MJMAIN-2E">< use x="12954" y="0" xlink:href="#MJMAIN-2E">< use x="13399" y="0" xlink:href="#MJMAIN-2C">< use x="13844" y="0" xlink:href="#MJMATHI-4A">
the number of the transformations is the number of hidden nodes in the first hidden layer.
Next it applies non-linear transformation of outputs using so called activation function. Using linear function as a activation function would defeat the purpose of MLP as composition of linear transformations is still linear transformation.
The most often used activation function is so called rectifier:
< svg xmlns:xlink="http://www.w3.org/1999/xlink" width="17.284ex" height="2.843ex" style="vertical-align: -0.838ex;" viewBox="0 -863.1 7441.7 1223.9" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-3D5">< use x="596" y="0" xlink:href="#MJMAIN-28">< use x="986" y="0" xlink:href="#MJMATHI-78">< use x="1558" y="0" xlink:href="#MJMAIN-29">< use x="2225" y="0" xlink:href="#MJMAIN-3D">< g transform="translate(3282,0)">< use xlink:href="#MJMAIN-6D">< use x="833" y="0" xlink:href="#MJMAIN-61">< use x="1334" y="0" xlink:href="#MJMAIN-78">< use x="5144" y="0" xlink:href="#MJMAIN-28">< use x="5534" y="0" xlink:href="#MJMAIN-30">< use x="6034" y="0" xlink:href="#MJMAIN-2C">< use x="6479" y="0" xlink:href="#MJMATHI-78">< use x="7052" y="0" xlink:href="#MJMAIN-29">
Finally the outputs of activation function are again combined using linear transformation:
< svg xmlns:xlink="http://www.w3.org/1999/xlink" width="30.137ex" height="5.843ex" style="vertical-align: -3.338ex;" viewBox="0 -1078.4 12975.6 2515.6" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-68">< use transform="scale(0.707)" x="815" y="-213" xlink:href="#MJMATHI-6B">< use x="1045" y="0" xlink:href="#MJMAIN-28">< use x="1434" y="0" xlink:href="#MJMATHI-78">< use x="2007" y="0" xlink:href="#MJMAIN-29">< use x="2674" y="0" xlink:href="#MJMAIN-3D">< g transform="translate(3730,0)">< use x="0" y="0" xlink:href="#MJSZ2-2211">< use transform="scale(0.707)" x="815" y="-1536" xlink:href="#MJMATHI-6A">< use x="5341" y="0" xlink:href="#MJMATHI-3D5">< use x="5938" y="0" xlink:href="#MJMAIN-28">< g transform="translate(6327,0)">< use x="0" y="0" xlink:href="#MJMATHI-66">< use transform="scale(0.707)" x="693" y="-213" xlink:href="#MJMATHI-6A">< use x="7210" y="0" xlink:href="#MJMAIN-28">< use x="7599" y="0" xlink:href="#MJMATHI-78">< use x="8172" y="0" xlink:href="#MJMAIN-29">< use x="8561" y="0" xlink:href="#MJMAIN-29">< use x="9173" y="0" xlink:href="#MJMAIN-2217">< g transform="translate(9896,0)">< use x="0" y="0" xlink:href="#MJMATHI-76">< use transform="scale(0.707)" x="686" y="499" xlink:href="#MJMATHI-6B">< use transform="scale(0.707)" x="686" y="-430" xlink:href="#MJMATHI-6A">< use x="11072" y="0" xlink:href="#MJMAIN-2B">< g transform="translate(12073,0)">< use x="0" y="0" xlink:href="#MJMATHI-63">< use transform="scale(0.707)" x="613" y="-213" xlink:href="#MJMATHI-6B">
At this point one can either repeat activation step and extend network with next activation layer or apply final transformation of the outputs to fit the algorithm objective. In case of classification problems most often used transformation is softmax function:
< svg xmlns:xlink="http://www.w3.org/1999/xlink" width="30.361ex" height="6.509ex" style="vertical-align: -2.671ex;" viewBox="0 -1652.5 13072.2 2802.6" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-3D5">< use transform="scale(0.707)" x="843" y="-213" xlink:href="#MJMATHI-6B">< use x="1065" y="0" xlink:href="#MJMAIN-28">< g transform="translate(1454,0)">< use x="0" y="0" xlink:href="#MJMATHI-68">< use transform="scale(0.707)" x="815" y="-213" xlink:href="#MJMAIN-31">< use x="2485" y="0" xlink:href="#MJMAIN-2C">< use x="2930" y="0" xlink:href="#MJMAIN-2E">< use x="3375" y="0" xlink:href="#MJMAIN-2E">< use x="3820" y="0" xlink:href="#MJMAIN-2E">< use x="4265" y="0" xlink:href="#MJMAIN-2C">< g transform="translate(4710,0)">< use x="0" y="0" xlink:href="#MJMATHI-68">< use transform="scale(0.707)" x="815" y="-213" xlink:href="#MJMATHI-4B">< use x="6016" y="0" xlink:href="#MJMAIN-29">< use x="6683" y="0" xlink:href="#MJMAIN-3D">< g transform="translate(7462,0)">< g transform="translate(397,0)">< rect stroke="none" width="450" x="0" y="220">< g transform="translate(869,770)">< use xlink:href="#MJMAIN-65">< use x="444" y="0" xlink:href="#MJMAIN-78">< use x="973" y="0" xlink:href="#MJMAIN-70">< use x="1529" y="0" xlink:href="#MJMAIN-28">< g transform="translate(1919,0)">< use x="0" y="0" xlink:href="#MJMATHI-68">< use transform="scale(0.707)" x="815" y="-213" xlink:href="#MJMATHI-6B">< use x="2964" y="0" xlink:href="#MJMAIN-29">< g transform="translate(60,-771)">< use x="0" y="0" xlink:href="#MJSZ1-2211">< use transform="scale(0.707)" x="1494" y="-405" xlink:href="#MJMATHI-73">< g transform="translate(1655,0)">< use xlink:href="#MJMAIN-65">< use x="444" y="0" xlink:href="#MJMAIN-78">< use x="973" y="0" xlink:href="#MJMAIN-70">< use x="3184" y="0" xlink:href="#MJMAIN-28">< g transform="translate(3574,0)">< use x="0" y="0" xlink:href="#MJMATHI-68">< use transform="scale(0.707)" x="815" y="-213" xlink:href="#MJMATHI-73">< use x="4582" y="0" xlink:href="#MJMAIN-29">
which maps real valued vector to a vector of probabilities.
In case of classification problems the most often used loss function is cross-entropy between class label < svg xmlns:xlink="http://www.w3.org/1999/xlink" width="14.719ex" height="2.843ex" style="vertical-align: -0.838ex;" viewBox="0 -863.1 6337.4 1223.9" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-79">< use x="775" y="0" xlink:href="#MJMAIN-2208">< use x="1720" y="0" xlink:href="#MJMAIN-7B">< use x="2221" y="0" xlink:href="#MJMAIN-31">< use x="2721" y="0" xlink:href="#MJMAIN-2C">< use x="3166" y="0" xlink:href="#MJMAIN-2E">< use x="3611" y="0" xlink:href="#MJMAIN-2E">< use x="4057" y="0" xlink:href="#MJMAIN-2E">< use x="4502" y="0" xlink:href="#MJMAIN-2C">< use x="4947" y="0" xlink:href="#MJMATHI-4B">< use x="5836" y="0" xlink:href="#MJMAIN-7D"> and probability returned by softmax function
< svg xmlns:xlink="http://www.w3.org/1999/xlink" width="30.725ex" height="7.343ex" style="vertical-align: -3.005ex;" viewBox="0 -1867.7 13228.6 3161.4" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-4C">< use x="681" y="0" xlink:href="#MJMAIN-28">< g transform="translate(1071,0)">< use x="0" y="0" xlink:href="#MJMATHI-3D5">< use transform="scale(0.707)" x="843" y="-213" xlink:href="#MJMATHI-6B">< use x="2136" y="0" xlink:href="#MJMAIN-2C">< use x="2581" y="0" xlink:href="#MJMATHI-79">< use x="3078" y="0" xlink:href="#MJMAIN-29">< use x="3746" y="0" xlink:href="#MJMAIN-3D">< use x="4802" y="0" xlink:href="#MJMAIN-2212">< g transform="translate(5747,0)">< use x="0" y="0" xlink:href="#MJSZ2-2211">< g transform="translate(85,-1110)">< use transform="scale(0.707)" x="0" y="0" xlink:href="#MJMATHI-6B">< use transform="scale(0.707)" x="521" y="0" xlink:href="#MJMAIN-3D">< use transform="scale(0.707)" x="1300" y="0" xlink:href="#MJMAIN-31">< use transform="scale(0.707)" x="576" y="1627" xlink:href="#MJMATHI-4B">< g transform="translate(7358,0)">< use x="0" y="0" xlink:href="#MJMAIN-31">< g transform="translate(500,-187)">< use transform="scale(0.707)" x="0" y="0" xlink:href="#MJMAIN-7B">< use transform="scale(0.707)" x="500" y="0" xlink:href="#MJMATHI-79">< use transform="scale(0.707)" x="998" y="0" xlink:href="#MJMAIN-3D">< use transform="scale(0.707)" x="1776" y="0" xlink:href="#MJMATHI-6B">< use transform="scale(0.707)" x="2298" y="0" xlink:href="#MJMAIN-7D">< g transform="translate(10104,0)">< use xlink:href="#MJMAIN-6C">< use x="278" y="0" xlink:href="#MJMAIN-6F">< use x="779" y="0" xlink:href="#MJMAIN-67">< use x="11384" y="0" xlink:href="#MJMAIN-28">< g transform="translate(11773,0)">< use x="0" y="0" xlink:href="#MJMATHI-3D5">< use transform="scale(0.707)" x="843" y="-213" xlink:href="#MJMATHI-6B">< use x="12839" y="0" xlink:href="#MJMAIN-29">
which is averaged over all training observations.
Universal Approximation Theorem
According to the theorem first proved by George Cybenko for sigmoid activation function: “feed-forward network with a single hidden layer containing a finite number of neurons (i.e., a multilayer perceptron), can approximate continuous functions on compact subsets of < svg xmlns:xlink="http://www.w3.org/1999/xlink" width="2.897ex" height="2.343ex" style="vertical-align: -0.338ex;" viewBox="0 -863.1 1247.1 1008.6" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJAMS-52">< use transform="scale(0.707)" x="1021" y="581" xlink:href="#MJMATHI-6E">, under mild assumptions on the activation function.”
Lets put mlp to the test then. For this purpose I will use sprials dataset from mlbench package.
MXNet
MXNet is an open-source deep learning framework that allows you to define, train, and deploy deep neural networks on a wide array of devices, from cloud infrastructure to mobile devices and it allows to mix symbolic and imperative programming flavors. For example custom loss functions and accuracy measures.
MXNet package expose so called symbolic API for R users. Its purpose is to create user friendly way of building neural networks abstracting out computational details to the MXNet specialized engine.
Feed-forward networks are trained using iterative gradient descent type of algorithm. Additionally during single forward pass only subset of the data is used called batch. Process is repeated until all training examples are used. This is called an epoch.
After every epoch MXNet returns training accuracy:
############# basic training #############
mx.set.seed(2014)
model <- mx.model.FeedForward.create(
symbol = mlp,
X = dta.train[, c("x", "y")],
y = dta.train[, c("label")],
num.round = 5,
array.layout = "rowmajor",
learning.rate = 1,
eval.metric = mx.metric.accuracy)
In order to stop process of training when the progress in accuracy is below certain level of tolerance we need to add custom callback to the feed forward procedure. It is called after every epoch to check if algorithm progresses. If not it will terminate optimization procedure and return results.
######## custom stopping criterion #######
mx.callback.train.stop <- function(tol = 1e-3,
mean.n = 1e2,
period = 100,
min.iter = 100
) {
function(iteration, nbatch, env, verbose = TRUE) {
if (nbatch == 0 & !is.null(env$metric)) {
continue <- TRUE
acc.train <- env$metric$get(env$train.metric)$value
if (is.null(env$acc.log)) {
env$acc.log <- acc.train
} else {
if ((abs(acc.train - mean(tail(env$acc.log, mean.n))) < tol &
abs(acc.train - max(env$acc.log)) < tol &
iteration > min.iter) |
acc.train == 1) {
cat("Training finished with final accuracy: ",
round(acc.train * 100, 2), " %\n", sep = "")
continue <- FALSE
}
env$acc.log <- c(env$acc.log, acc.train)
}
}
if (iteration %% period == 0) {
cat("[", iteration,"]"," training accuracy: ",
round(acc.train * 100, 2), " %\n", sep = "")
}
return(continue)
}
}
###### training with custom stopping #####
mx.set.seed(2014)
model <- mx.model.FeedForward.create(
symbol = mlp,
X = dta.train[, c("x", "y")],
y = dta.train[, c("label")],
num.round = 2000,
array.layout = "rowmajor",
learning.rate = 1,
epoch.end.callback = mx.callback.train.stop(),
eval.metric = mx.metric.accuracy,
verbose = FALSE
)
## [100] training accuracy: 90.07 %
## [200] training accuracy: 98.88 %
## [300] training accuracy: 99.33 %
## Training finished with final accuracy: 99.44 %