Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Are you looking for a simple, robust, and efficient method to help you with classification and dimensionality reduction? Do you want to learn how to implement and evaluate LDA in R programming, one of the most popular and powerful techniques for these tasks? If yes, then you have come to the right place. Linear discriminant analysis (LDA) is a supervised machine-learning technique that can be used for two main purposes: Classification and Dimensionality reduction.
# Fit the LDA model library(MASS) model <- lda(Species ~ ., data = train) # Print the model print(model)
Key takeaways
- Linear discriminant analysis is a supervised machine-learning technique that can be used for classification and dimensionality reduction.
- It is based on finding the linear combinations of features that best separate the classes in the data set.
- It can be implemented in R programming using the lda function from MASS packages.
- It can be evaluated and compared with other methods using various metrics, such as accuracy, confusion matrix, ROC curve, or cross-validation.
- It has advantages and disadvantages that should be considered before applying it to a real-world problem.
Hello, my name is Zubair Goraya, a data analyst, a PhD scholar, and a freelancer with five years of experience. I am passionate about data science and machine learning, and I love sharing my knowledge and skills with others. In this blog post, I will explain everything you need to know about linear discriminant analysis, from the theory and assumptions to the implementation and evaluation of the method. Whether you are a student, a researcher, or a practitioner, you will find this guide useful and informative. So, let’s get started!
Function | Description |
---|---|
createDataPartition() |
Divides the data into training and testing sets by preserving the proportion of classes. It is useful for creating balanced datasets for training machine learning models. |
lda() |
Fits a Linear Discriminant Analysis (LDA) model to the data. LDA is a dimensionality reduction technique commonly used for classification tasks. It finds linear combinations of variables that best separate classes in the data. |
predict() |
Predicts class labels or probabilities for new data based on a fitted model. It can be used with various types of models, including LDA, to make predictions on unseen data. |
as.character() |
Converts the data to character type. It is often used to convert factor variables to character type for specific operations. |
factor() |
Converts the data to a factor variable, specifying levels if needed. Factors are useful for representing categorical data in R. |
table() |
Generates contingency tables, showing the frequency distribution of variables. It is commonly used to summarize and compare categorical variables. |
sum() |
Calculates the sum of elements in a vector or array. It is a basic arithmetic function used for various purposes in R. |
diag() |
Extracts the diagonal elements of a matrix. It is frequently used to extract true positive counts from confusion matrices for calculating accuracy. |
confusionMatrix() |
Computes a confusion matrix to evaluate the performance of a classification model. It summarizes the actual and predicted classes’ agreement. |
roc() |
Computes the Receiver Operating Characteristic (ROC) curve and related metrics for evaluating binary classification models. It plots the trade-off between true positive rate and false positive rate across different thresholds. |
kappa2() |
Calculates the Cohen’s kappa coefficient, which measures inter-rater agreement for categorical items. It is commonly used to assess the agreement between predicted and true class labels. |
kappam.fleiss() |
Computes Fleiss’ kappa coefficient, a statistical measure of inter-rater reliability for multiple raters. It extends Cohen’s kappa to more than two raters. |
What is Linear Discriminant Analysis in R?
Linear discriminant analysis (LDA) is a supervised machine-learning technique that can be used for two main purposes:
- Classification
- Dimensionality reduction.
Real-World Example of LDA
Suppose you have a data set of flowers with four features: sepal length, sepal width, petal length, and petal width. The data set also has three classes: setosa, versicolor, and virginica. You want to build a classifier that can predict the class of a new flower based on its features.
You can use LDA to find the linear discriminants that best separate the three classes and then use them as new features for the classifier. Alternatively, you can use LDA to reduce the dimensionality of the data set from four features to two or one and then visualize the data set or apply another method for classification.
How Does Linear Discriminant Analysis Work?
Linear discriminant analysis works by assuming that the features of each class are normally distributed and that the covariance matrices of each class are equal.
The steps of linear discriminant analysis are as follows:
- Compute the mean vector and the covariance matrix for each class.
- Compute the pooled within-class covariance matrix, the weighted average of the individual class covariance matrices.
- Compute the between-class covariance matrix, which is the product of the number of observations in each class and the squared difference between the class mean vector and the overall mean vector.
- Compute the eigenvalues and eigenvectors of the inverse of the pooled within-class covariance matrix multiplied by the between-class covariance matrix.
- Select the k largest eigenvalues and their corresponding eigenvectors, where k is the number of linear discriminants to be extracted.
- Transform the original features into the new linear discriminants by multiplying the eigenvector matrix by the feature matrix.
Assumption of LDA in R
One of the key assumptions of LDA is that the predictor variables have the same variance-covariance matrix for each class. The variables have the same variability and correlation structure across the different groups.
If the assumption is violated, we can transform the predictor variables to make them more homogeneous or use a different method that does not require this assumption, such as quadratic discriminant analysis (QDA). QDA is similar to LDA, but it allows each class to have its variance-covariance matrix, making it more flexible and accurate but more complex and prone to overfitting.
< details class="sp toc" open="">< summary data-hide="Hide all" data-show="Show all">Related PostsHow to Implement LDA in R?
To implement linear discriminant analysis in R programming, we need to follow some steps, such as
- Load the data,
- Splitting the data into training and testing sets,
- Fitting the LDA model,
- Predicting the class labels or posterior probabilities,
- Evaluating the model performance.
Required R Packages
We also need to use some packages, such as MASS, caret, pROC, and irr, which provide various functions and tools for linear discriminant analysis.
# Required Packages library(MASS) library(caret) library(pROC) library(irr)
Load the Dataset
For this tutorial, we will use the Iris data set, a famous data set containing 150 observations of three species of iris flowers: setosa, versicolor, and virginica. The data set has four features:
- Sepal length,
- Sepal width,
- Petal length,
- Petal width, measured in centimeters.
The data set also has a class variable that indicates the species of each flower. The Iris data set is available in the datasets package, a base package in R. We can load the data set using the data function and then check the structure and summary using the str and summary functions.
# Load the data data(iris) # Check the structure of the data str(iris) # Check the summary of the data summary(iris)
Split the data into testing and training data sets
To split the data into training and testing sets, we will use the createDataPartition function from the caret package, which creates stratified random samples based on the class variable. We will use 80% of the data for training and 20% for testing.
# Set the seed for reproducibility set.seed(123) # Split the data into training and testing sets train_index <- createDataPartition(iris$Species, p = 0.8, list = FALSE) train <- iris[train_index, ] test <- iris[-train_index, ]
Fitting the LDA Model using the lda Function from the MASS Package
We will use the lda function from the MASS package to fit the LDA model, which performs linear discriminant analysis. The lda function takes a formula argument, which specifies the class variable and the features, and a data argument, which specifies the data frame. This prior argument specifies the probabilities of the classes, and a tol argument specifies the tolerance for the eigenvalues.
We will use the default values for the prior and tol arguments and the training set as the data argument. We will assign the output of the lda function to a variable called model and then print and plot the model using the print and plot functions.
# Fit the LDA model model <- lda(Species ~ ., data = train) # Print the model print(model)
The output shows that the LDA model has two linear discriminants, LD1 and LD2, which explain 99.02% and 0.98% of the total variance, respectively. We can also see the coefficients of the linear discriminants, which indicate how much each feature contributes to the linear discriminates.
For example, the coefficient of sepal length for LD1 is 0.9902, which means that sepal length has a positive and strong influence on LD1. We can also see the group means, which indicate the average values of each feature for each class. For example, the group mean of petal length for setosa is 1.465, which means that setosa flowers have the shortest petal length on average.
LDA Model Visualization
# Define colors for each species species_colors <- c("setosa" = "red", "versicolor" = "blue", "virginica" = "green") # Plot the model with custom colors for species plot(model, col = species_colors, main="Visualization of Linear Discriminant Analysis")
From the plot, we can see the distribution of the observations along the linear discriminants and the separation of the classes. We can see that LD1 separates setosa from the other two classes, while LD2 separates versicolor from virginica. We can also see some overlap between versicolor and virginica, which means that the linear discriminants do not perfectly separate them.
Predicting the Class Labels or Posterior Probabilities using the predict Function
We will use the predict function from the stats package to predict the class labels or posterior probabilities for the testing set, which predicts the outcome of a fitted model. The predict function takes an object argument, which specifies the fitted model, a newdata argument, which specifies the new data, and a type argument, which specifies the type of prediction.
We will use the model variable as the object argument, the testing set as the newdata argument, and “class” or “response” as the type argument, depending on whether we want the class labels or the posterior probabilities. We will assign the output of the predict function to a pred variable, then check the prediction results using the table and head functions.
# Predict the class labels for the testing set pred_class <- predict(model, newdata = test, type = "class") # Convert test$Species to character test_species <- as.character(test$Species) # Convert it back to a factor with correct levels test_species <- factor(test_species, levels = levels(pred_class$class)) # Check the prediction results table(pred_class$class, test_species) # Predict the posterior probabilities for the testing set pred_prob <- predict(model, newdata = test, type = "response") # Check the prediction results head(pred_prob,5)
The output shows that the prediction results are stored in a vector for the class labels and a matrix for the posterior probabilities. We can also see that the prediction accuracy is 96.6%, as 29 of 30 observations are correctly classified.
We can also see the predicted class labels and posterior probabilities for the first six observations in the testing set. For example, the first observation is predicted to be setosa, with a posterior probability of 1 for setosa and 0 for versicolor and virginica.
Evaluating the Model Performance Using Various Metrics
We will use various metrics, such as accuracy, confusion matrix, ROC curve, kappa coefficient to evaluate the model performance.
Accuracy
Accuracy is the proportion of correct predictions, which can be calculated by dividing the number of true positives and true negatives by the total number of observations.
# Calculate the accuracy accuracy <- sum(diag(table(pred_class$class, test_species))) / nrow(test) accuracy
Confusion matrix
A confusion matrix is a table that shows the distribution of predicted and actual classes, which can be computed using the confusionMatrix function from the caret package.
# Compute the confusion matrix confusionMatrix(pred_class$class, test$Species)
ROC curve
ROC curve is a plot that shows the trade-off between the true positive rate and the false positive rate, which can be computed and plotted using the roc function from the pROC package.
# Access the predicted probabilities pred_prob1 <- pred_prob$posterior # Calculate ROC curve for the "virginica" class roc <- roc(test$Species, pred_prob1[, "virginica"], levels = c("virginica", "setosa")) roc plot(roc)
Kappa coefficient