Classification Trees
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Decision trees are applied to situation where data is divided into groups rather than investigating a numerical response and its relationship to a set of descriptor variables. There are various implementations of classification trees in R and the some commonly used functions are rpart and tree.
Fast Tube by Casper
To illustrate the use of the tree function we will use a set of data from the UCI Machine Learning Repository where the objective of the study using this data was to predict the cellular localization sites of proteins.
The data provided on the website is shown here:
> ecoli.df = read.csv("ecoli.txt") > head(ecoli.df) Sequence mcv gvh lip chg aac alm1 alm2 class 1 AAT_ECOLI 0.49 0.29 0.48 0.5 0.56 0.24 0.35 cp 2 ACEA_ECOLI 0.07 0.40 0.48 0.5 0.54 0.35 0.44 cp 3 ACEK_ECOLI 0.56 0.40 0.48 0.5 0.49 0.37 0.46 cp 4 ACKA_ECOLI 0.59 0.49 0.48 0.5 0.52 0.45 0.36 cp 5 ADI_ECOLI 0.23 0.32 0.48 0.5 0.55 0.25 0.35 cp 6 ALKH_ECOLI 0.67 0.39 0.48 0.5 0.36 0.38 0.46 cp
We can use the xtabs function to summarise the number of cases in each class.
> xtabs( ~ class, data = ecoli.df) class cp im imL imS imU om omL pp 143 77 2 2 35 20 5 52
As noted in the comments the package that I used was the tree package:
> require(tree)
The complete classification tree using all variables is fitted to the data initially and then we will try to prune the tree to make it smaller.
> ecoli.tree1 = tree(class ~ mcv + gvh + lip + chg + aac + alm1 + alm2, data = ecoli.df) > summary(ecoli.tree1) Classification tree: tree(formula = class ~ mcv + gvh + lip + chg + aac + alm1 + alm2, data = ecoli.df) Variables actually used in tree construction: [1] "alm1" "mcv" "gvh" "aac" "alm2" Number of terminal nodes: 10 Residual mean deviance: 0.7547 = 246 / 326 Misclassification error rate: 0.122 = 41 / 336
The tree function is used in a similar way to other modelling functions in R. The misclassification rate is shown as part of the summary of the tree. This tree can be plotted and annotated with these commands:
> plot(ecoli.tree1) > text(ecoli.tree1, all = T)
To prune the tree we use cross-validation to identify the point to prune.
> cv.tree(ecoli.tree1) $size [1] 10 9 8 7 6 5 4 3 2 1 $dev [1] 463.6820 457.4463 447.9824 441.8617 455.8318 478.9234 533.5856 586.2820 713.2992 1040.3878 $k [1] -Inf 12.16500 15.60004 19.21572 34.29868 41.10627 50.57044 64.05494 180.78800 355.67747 $method [1] "deviance" attr(,"class") [1] "prune" "tree.sequence"
This suggests a tree size of 6 and we can re-fit the tree:
> ecoli.tree2 = prune.misclass(ecoli.tree1, best = 6) > summary(ecoli.tree2) Classification tree: snip.tree(tree = ecoli.tree1, nodes = c(4, 20, 7)) Variables actually used in tree construction: [1] "alm1" "mcv" "aac" "gvh" Number of terminal nodes: 6 Residual mean deviance: 0.9918 = 327.3 / 330 Misclassification error rate: 0.1548 = 52 / 336
The misclassification rate has increased but not substantially with the pruning of the tree.
Other useful resources are provided on the Supplementary Material page.
Data used in this post: Ecoli Data Set.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.