R TensorFlow Deep Neural Network
[This article was first published on Data Science, Machine Learning and Predictive Analytics, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
In the previous post I fitted a neural network to the cars_19 dataset using the neuralnet package. In this post I am going to use TensorFlow to fit a deep neural network using the same data.Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
The main difference between the neuralnet package and TensorFlow is TensorFlow uses the adagrad optimizer by default whereas neuralnet uses rprop+ Adagrad is a modified stochastic gradient descent optimizer with a per-parameter learning rate.
The data which is all 2019 vehicles which are non pure electric (1253 vehicles) are summarized in previous posts below.
str(cars_19) 'data.frame': 1253 obs. of 12 variables: $ fuel_economy_combined: int 21 28 21 26 28 11 15 18 17 15 ... $ eng_disp : num 3.5 1.8 4 2 2 8 6.2 6.2 6.2 6.2 ... $ num_cyl : int 6 4 8 4 4 16 8 8 8 8 ... $ transmission : Factor w/ 7 levels "A","AM","AMS",..: 3 2 6 3 6 3 6 6 6 5 ... $ num_gears : int 9 6 8 7 8 7 8 8 8 7 ... $ air_aspired_method : Factor w/ 5 levels "Naturally Aspirated",..: 4 4 4 4 4 4 3 1 3 3 ... $ regen_brake : Factor w/ 3 levels "","Electrical Regen Brake",..: 2 1 1 1 1 1 1 1 1 1 ... $ batt_capacity_ah : num 4.25 0 0 0 0 0 0 0 0 0 ... $ drive : Factor w/ 5 levels "2-Wheel Drive, Front",..: 4 2 2 4 2 4 2 2 2 2 ... $ fuel_type : Factor w/ 5 levels "Diesel, ultra low sulfur (15 ppm, maximum)",..: 4 3 3 5 3 4 4 4 4 4 ... $ cyl_deactivate : Factor w/ 2 levels "N","Y": 1 1 1 1 1 2 1 2 2 1 ... $ variable_valve : Factor w/ 2 levels "N","Y": 2 2 2 2 2 2 2 2 2 2 ...
To prepare the data to fit the neural network, TensorFlow requires categorical variables to be converted into a dense representation by using the column_embedding() function.
cols <- feature_columns( column_numeric(colnames(cars_19[c(2, 3, 5, 8)])), column_embedding(column_categorical_with_identity("transmission", num_buckets = 7),dimension = 1), column_embedding(column_categorical_with_identity("air_aspired_method", num_buckets = 5),dimension=1), column_embedding(column_categorical_with_identity("regen_brake", num_buckets = 3),dimension=1), column_embedding(column_categorical_with_identity("drive", num_buckets = 5),dimension=1), column_embedding(column_categorical_with_identity("fuel_type", num_buckets = 5),dimension=1), column_embedding(column_categorical_with_identity("cyl_deactivate", num_buckets = 2),dimension=1), column_embedding(column_categorical_with_identity("variable_valve", num_buckets = 2),dimension=1) )
Similar to the neural network I fitted using neuralnet(), I am going to use two hidden layers with seven and three neurons respectively.
Train, evaluate, and predict:
#Create a deep neural network (DNN) estimator. model <- dnn_regressor(hidden_units=c(7,3),feature_columns = cols) set.seed(123) indices <- sample(1:nrow(cars_19), size = 0.75 * nrow(cars_19)) train <- cars_19[indices, ] test <- cars_19[-indices, ] #train model model %>% train(cars_19_input_fn(train, num_epochs = 1000)) #evaluate model model %>% evaluate(cars_19_input_fn(test)) #predict yhat <- model %>% predict(cars_19_input_fn(test)) yhat <- unlist(yhat) y <- test$fuel_economy_combined
postResample(yhat, y) RMSE Rsquared MAE 1.9640173 0.8700275 1.4838347
The results are similar to the other models and neuralnet().
I am going to look at the error rate in TensorBoard which is a visualization tool. TensorBoard is great for visualizing TensorFlow graphs and for plotting quantitative metrics about the execution of the graph. Below is the mean squared error at each iteration. It stabilizes fairly quickly. Next post I will get into TensorBoard in a lot more depth.
To leave a comment for the author, please follow the link and comment on their blog: Data Science, Machine Learning and Predictive Analytics.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.