Visualizing trees with Sklearn
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Tree-based models are probably the second easiest ML technique for explaining the model to a non-data scientist. I am a big fan of tree-based models because of their simplicity and interpretability. But, when I try to visualize them is, when it gets my nerves. There are so many packages out there to visualize them. Sklearn has finally provided us with a new API to visualize trees through matplotlib. In this tutorial, I will show you how to visualize trees using sklearn for both classification and regression.
Importing libraries
The following are the libraries that are required to load datasets, split data, train models and visualize them.
from sklearn.datasets import load_wine, fetch_california_housing from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt from sklearn.tree import plot_tree, DecisionTreeClassifier, DecisionTreeRegressor
Classification
In this section, our objective is to
- Load wine dataset
- Split the data into train and test
- Train a decision tree classifier
- Visualize the decision tree
# load wine data set data = load_wine() x = data.data y = data.target # split into train and test data x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.30, random_state=42) # create a decision tree classifier clf = DecisionTreeClassifier(max_depth=2, random_state=0) clf.fit(x_train, y_train) # plot classifier tree plt.figure(figsize=(10,8)) plot_tree(clf, feature_names=data.feature_names, class_names=data.target_names, filled=True)
Once you execute the above code, you should have the following or similar decision tree for the wine dataset model.
Regression
Similar to classification, in this section, we will train and visualize a model for regression
- Load california housing dataset
- Split the data into train and test
- Train a decision tree regressor
- Visualize the decision tree
# load data set data = fetch_california_housing() x = data.data y = data.target # split into train and test data x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.30, random_state=42) # create a decision tree regressor clf = DecisionTreeRegressor(max_depth=2, random_state=0) clf.fit(x_train, y_train) # plot tree regressor plt.figure(figsize=(10,8)) plot_tree(clf, feature_names=data.feature_names, filled=True)
Once you execute the following code, you should end with a graph similar to the one below.
As you can see, visualizing a decision tree has become a lot simpler with sklearn models. In the past, it would take me about 10 to 15 minutes to write a code with two different packages that can be done with two lines of code. I am definitely looking forward to future updates that support random forest and ensemble models.
Thank you for going through this article. Kindly post below if you have any questions or comments below.
You can also find code for this on my Github page.
The post Visualizing trees with Sklearn appeared first on Hi! I am Nagdev.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.