Site icon R-bloggers

Linear Discriminant Analysis (LDA) in R

[This article was first published on RStudioDataLab, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Are you looking for a simple, robust, and efficient method to help you with classification and dimensionality reduction? Do you want to learn how to implement and evaluate LDA in R programming, one of the most popular and powerful techniques for these tasks? If yes, then you have come to the right place. Linear discriminant analysis (LDA) is a supervised machine-learning technique that can be used for two main purposes: Classification and Dimensionality reduction. 

# Fit the LDA model
library(MASS)
model <- lda(Species ~ ., data = train)
# Print the model
print(model)

< details class="sp toc" open="" style="text-align: left;">< summary data-hide="Hide all" data-show="Show all" style="text-align: justify;">Table of Contents

Key takeaways

  • Linear discriminant analysis is a supervised machine-learning technique that can be used for classification and dimensionality reduction.
  • It is based on finding the linear combinations of features that best separate the classes in the data set.
  • It can be implemented in R programming using the lda function from MASS packages.
  • It can be evaluated and compared with other methods using various metrics, such as accuracy, confusion matrix, ROC curve, or cross-validation.
  • It has advantages and disadvantages that should be considered before applying it to a real-world problem.

Hello, my name is Zubair Goraya, a data analyst, a PhD scholar, and a freelancer with five years of experience. I am passionate about data science and machine learning, and I love sharing my knowledge and skills with others. In this blog post, I will explain everything you need to know about linear discriminant analysis, from the theory and assumptions to the implementation and evaluation of the method. Whether you are a student, a researcher, or a practitioner, you will find this guide useful and informative. So, let’s get started!

Function Description
createDataPartition() Divides the data into training and testing sets by preserving the proportion of classes. It is useful for creating balanced datasets for training machine learning models.
lda() Fits a Linear Discriminant Analysis (LDA) model to the data. LDA is a dimensionality reduction technique commonly used for classification tasks. It finds linear combinations of variables that best separate classes in the data.
predict() Predicts class labels or probabilities for new data based on a fitted model. It can be used with various types of models, including LDA, to make predictions on unseen data.
as.character() Converts the data to character type. It is often used to convert factor variables to character type for specific operations.
factor() Converts the data to a factor variable, specifying levels if needed. Factors are useful for representing categorical data in R.
table() Generates contingency tables, showing the frequency distribution of variables. It is commonly used to summarize and compare categorical variables.
sum() Calculates the sum of elements in a vector or array. It is a basic arithmetic function used for various purposes in R.
diag() Extracts the diagonal elements of a matrix. It is frequently used to extract true positive counts from confusion matrices for calculating accuracy.
confusionMatrix() Computes a confusion matrix to evaluate the performance of a classification model. It summarizes the actual and predicted classes’ agreement.
roc() Computes the Receiver Operating Characteristic (ROC) curve and related metrics for evaluating binary classification models. It plots the trade-off between true positive rate and false positive rate across different thresholds.
kappa2() Calculates the Cohen’s kappa coefficient, which measures inter-rater agreement for categorical items. It is commonly used to assess the agreement between predicted and true class labels.
kappam.fleiss() Computes Fleiss’ kappa coefficient, a statistical measure of inter-rater reliability for multiple raters. It extends Cohen’s kappa to more than two raters.

What is Linear Discriminant Analysis in R?

Linear discriminant analysis (LDA) is a supervised machine-learning technique that can be used for two main purposes:

  • Classification
  • Dimensionality reduction. 
Classification is the task of assigning a class label to a new observation based on its features. At the same time, dimensionality reduction reduces the number of features while preserving the essential information. LDA is based on finding the linear combinations of features that best separate the classes in the data set. These linear combinations are called linear discriminants and can be used as new features for classification or dimensionality reduction.

Real-World Example of LDA

Suppose you have a data set of flowers with four features: sepal length, sepal width, petal length, and petal width. The data set also has three classes: setosa, versicolor, and virginica. You want to build a classifier that can predict the class of a new flower based on its features. 

You can use LDA to find the linear discriminants that best separate the three classes and then use them as new features for the classifier. Alternatively, you can use LDA to reduce the dimensionality of the data set from four features to two or one and then visualize the data set or apply another method for classification.

How Does Linear Discriminant Analysis Work?

Linear discriminant analysis works by assuming that the features of each class are normally distributed and that the covariance matrices of each class are equal.

The steps of linear discriminant analysis are as follows:

  • Compute the mean vector and the covariance matrix for each class.
  • Compute the pooled within-class covariance matrix, the weighted average of the individual class covariance matrices.
  • Compute the between-class covariance matrix, which is the product of the number of observations in each class and the squared difference between the class mean vector and the overall mean vector.
  • Compute the eigenvalues and eigenvectors of the inverse of the pooled within-class covariance matrix multiplied by the between-class covariance matrix.
  • Select the k largest eigenvalues and their corresponding eigenvectors, where k is the number of linear discriminants to be extracted.
  • Transform the original features into the new linear discriminants by multiplying the eigenvector matrix by the feature matrix.

Assumption of LDA in R

One of the key assumptions of LDA is that the predictor variables have the same variance-covariance matrix for each class. The variables have the same variability and correlation structure across the different groups.

If the assumption is violated, we can transform the predictor variables to make them more homogeneous or use a different method that does not require this assumption, such as quadratic discriminant analysis (QDA). QDA is similar to LDA, but it allows each class to have its variance-covariance matrix, making it more flexible and accurate but more complex and prone to overfitting. 

< details class="sp toc" open="">< summary data-hide="Hide all" data-show="Show all">Related Posts

How to Implement LDA in R?

To implement linear discriminant analysis in R programming, we need to follow some steps, such as

  • Load the data,
  • Splitting the data into training and testing sets,
  • Fitting the LDA model,
  • Predicting the class labels or posterior probabilities,
  • Evaluating the model performance.

Required R Packages

We also need to use some packages, such as MASS, caret, pROC, and irr, which provide various functions and tools for linear discriminant analysis.

Before We start, Make sure you Have the following:
# Required Packages
library(MASS)
library(caret)
library(pROC)
library(irr)

Load the Dataset

For this tutorial, we will use the Iris data set, a famous data set containing 150 observations of three species of iris flowers: setosa, versicolor, and virginica. The data set has four features:

  • Sepal length,
  • Sepal width,
  • Petal length,
  • Petal width, measured in centimeters.

The data set also has a class variable that indicates the species of each flower. The Iris data set is available in the datasets package, a base package in R. We can load the data set using the data function and then check the structure and summary using the str and summary functions.

# Load the data
data(iris)
# Check the structure of the data
str(iris)
# Check the summary of the data
summary(iris)

Split the data into testing and training data sets

To split the data into training and testing sets, we will use the createDataPartition function from the caret package, which creates stratified random samples based on the class variable. We will use 80% of the data for training and 20% for testing.

# Set the seed for reproducibility
set.seed(123)
# Split the data into training and testing sets
train_index <- createDataPartition(iris$Species, p = 0.8, list = FALSE)
train <- iris[train_index, ]
test <- iris[-train_index, ]

Fitting the LDA Model using the lda Function from the MASS Package

We will use the lda function from the MASS package to fit the LDA model, which performs linear discriminant analysis. The lda function takes a formula argument, which specifies the class variable and the features, and a data argument, which specifies the data frame. This prior argument specifies the probabilities of the classes, and a tol argument specifies the tolerance for the eigenvalues. 

We will use the default values for the prior and tol arguments and the training set as the data argument. We will assign the output of the lda function to a variable called model and then print and plot the model using the print and plot functions.

# Fit the LDA model
model <- lda(Species ~ ., data = train)
# Print the model
print(model)

The output shows that the LDA model has two linear discriminants, LD1 and LD2, which explain 99.02% and 0.98% of the total variance, respectively. We can also see the coefficients of the linear discriminants, which indicate how much each feature contributes to the linear discriminates. 

For example, the coefficient of sepal length for LD1 is 0.9902, which means that sepal length has a positive and strong influence on LD1. We can also see the group means, which indicate the average values of each feature for each class. For example, the group mean of petal length for setosa is 1.465, which means that setosa flowers have the shortest petal length on average.

LDA Model Visualization

# Define colors for each species
species_colors <- c("setosa" = "red", "versicolor" = "blue", "virginica" = "green")
# Plot the model with custom colors for species
plot(model, col = species_colors, main="Visualization of Linear Discriminant Analysis")

From the plot, we can see the distribution of the observations along the linear discriminants and the separation of the classes. We can see that LD1 separates setosa from the other two classes, while LD2 separates versicolor from virginica. We can also see some overlap between versicolor and virginica, which means that the linear discriminants do not perfectly separate them.

Predicting the Class Labels or Posterior Probabilities using the predict Function 

We will use the predict function from the stats package to predict the class labels or posterior probabilities for the testing set, which predicts the outcome of a fitted model. The predict function takes an object argument, which specifies the fitted model, a newdata argument, which specifies the new data, and a type argument, which specifies the type of prediction. 

We will use the model variable as the object argument, the testing set as the newdata argument, and “class” or “response” as the type argument, depending on whether we want the class labels or the posterior probabilities. We will assign the output of the predict function to a pred variable, then check the prediction results using the table and head functions.

# Predict the class labels for the testing set
pred_class <- predict(model, newdata = test, type = "class")
# Convert test$Species to character
test_species <- as.character(test$Species)
# Convert it back to a factor with correct levels
test_species <- factor(test_species, levels = levels(pred_class$class))
# Check the prediction results
table(pred_class$class, test_species)
# Predict the posterior probabilities for the testing set
pred_prob <- predict(model, newdata = test, type = "response")
# Check the prediction results
head(pred_prob,5)

The output shows that the prediction results are stored in a vector for the class labels and a matrix for the posterior probabilities. We can also see that the prediction accuracy is 96.6%, as 29 of 30 observations are correctly classified. 

We can also see the predicted class labels and posterior probabilities for the first six observations in the testing set. For example, the first observation is predicted to be setosa, with a posterior probability of 1 for setosa and 0 for versicolor and virginica.

Evaluating the Model Performance Using Various Metrics

We will use various metrics, such as accuracy, confusion matrix, ROC curve,  kappa coefficient to evaluate the model performance. 

Accuracy 

Accuracy is the proportion of correct predictions, which can be calculated by dividing the number of true positives and true negatives by the total number of observations. 

# Calculate the accuracy
accuracy <- sum(diag(table(pred_class$class, test_species))) / nrow(test)
accuracy


Confusion matrix

A confusion matrix is a table that shows the distribution of predicted and actual classes, which can be computed using the confusionMatrix function from the caret package. 

# Compute the confusion matrix
confusionMatrix(pred_class$class, test$Species)

ROC curve 

ROC curve is a plot that shows the trade-off between the true positive rate and the false positive rate, which can be computed and plotted using the roc function from the pROC package. 

# Access the predicted probabilities
pred_prob1 <- pred_prob$posterior
# Calculate ROC curve for the "virginica" class
roc <- roc(test$Species, pred_prob1[, "virginica"], levels = c("virginica", "setosa"))
roc
plot(roc)

Kappa coefficient

Kappa coefficient is a measure of inter-rater reliability, which can be calculated using the kappa function from the irr package. We will use the pred and class variables of the testing set as the inputs for these metrics. Then, we will print and plot the results using the print and plot functions.

# Calculate Fleiss' kappa
kappam.fleiss(data.frame(pred_class, test$Species))

Conclusion

Linear Discriminant Analysis (LDA) in R offers a robust approach for classification and dimensionality reduction tasks. Key takeaways include understanding LDA’s theoretical foundations, implementing it using the `lda` function from the `MASS` package, and evaluating model performance. Future directions for learning include exploring advanced topics such as Quadratic Discriminant Analysis (QDA) and further validation techniques for model assumptions. Mastering LDA empowers data analysts to tackle diverse real-world problems effectively, advancing their machine learning and data science expertise.

Frequently Asked Questions (FAQS)

What is Linear Discriminant Analysis (LDA) in R programming?

Linear Discriminant Analysis is a method used for dimensionality reduction and classification in machine learning. It projects the data into a lower-dimensional space by finding the linear combinations of predictor variables that best separate different groups.

How can I fit the LDA model in R?

You can fit the LDA model using the `lda()` function from the MASS package in R. This function takes predictor variables and group labels as input and returns the coefficients of linear discriminants and prior probabilities of group membership.

What is the role of prior probabilities in LDA?

Prior probabilities of group membership in LDA represent the proportions of each group in the dataset. These probabilities are used to calculate the linear discriminant function to determine the group assignment for new observations.

How is LDA different from Logistic Regression?

LDA and Logistic Regression are used for classification. Still, LDA assumes that the predictor variables are normally distributed in each group and the variance-covariance matrices are equal across groups. In contrast, Logistic Regression makes no assumptions about the distribution of predictor variables. It estimates the probability of group membership directly.

What is the role of the first linear discriminant in LDA?

The first linear discriminant in LDA is the linear combination of predictor variables that best separates the groups. It is the most powerful dimension in distinguishing between groups and is used for classification.

How do you evaluate the performance of an LDA model?

The performance of an LDA model can be evaluated using techniques such as confusion matrix, cross-validation, and misclassification rate. These methods help assess the accuracy and reliability of the model in predicting group membership.

What is Quadratic Discriminant Analysis (QDA)?

Quadratic Discriminant Analysis is a variation of LDA that relaxes the assumption of equal variance-covariance matrices across groups. It allows for different covariance structures for each group, making it more flexible but requiring more parameters to estimate.

How does LDA relate to Principal Component Analysis (PCA)?

Both LDA and PCA are used for dimensionality reduction. Still, LDA considers the group information in the dataset and aims to maximize the separation between groups. At the same time, PCA focuses on capturing the overall variance in the data without considering group differences.

What is leave-one-out cross-validation in LDA?

Leave-one-out cross-validation in LDA is a technique where each observation is treated as a validation set, and the model is trained on the remaining data points. This process is repeated for each observation, and the model’s overall performance is assessed based on the aggregated results.

What are the key assumptions of LDA in R programming?

The key assumptions of LDA in R include:

  • Normality of predictor variable distributions within each group.
  • Equality of covariance matrices across groups.
  • Linearity of the discriminant functions.
It is important to check these assumptions before applying LDA to a dataset.

Need a Customized solution for your data analysis projects? Are you interested in learning through Zoom? Hire me as your data analyst. I have five years of experience and a PhD. I can help you with data analysis projects and problems using R and other tools. To hire me, you can visit this link and fill out the order form. You can also contact me at info@data03.online for any questions or inquiries. I will be happy to work with you and provide you with high-quality data analysis services.


Join Our Community < svg class="line" style="margin-right: 12px; stroke: rgb(255, 255, 255);" viewbox="0 0 24 24">< g transform="translate(2.000000, 2.500000)">< path d="M0.7501,0.7499 L2.8301,1.1099 L3.7931,12.5829 C3.8701,13.5199 4.6531,14.2389 5.5931,14.2359094 L16.5021,14.2359094 C17.3991,14.2379 18.1601,13.5779 18.2871,12.6899 L19.2361,6.1319 C19.3421,5.3989 18.8331,4.7189 18.1011,4.6129 C18.0371,4.6039 3.1641,4.5989 3.1641,4.5989">
  • < path d="M5.1544,17.7025 C5.4554,17.7025 5.6984,17.9465 5.6984,18.2465 C5.6984,18.5475 5.4554,18.7915 5.1544,18.7915 C4.8534,18.7915 4.6104,18.5475 4.6104,18.2465 C4.6104,17.9465 4.8534,17.7025 5.1544,17.7025 Z">< path d="M16.4347,17.7025 C16.7357,17.7025 16.9797,17.9465 16.9797,18.2465 C16.9797,18.5475 16.7357,18.7915 16.4347,18.7915 C16.1337,18.7915 15.8907,18.5475 15.8907,18.2465 C15.8907,17.9465 16.1337,17.7025 16.4347,17.7025 Z"> Allow us to Assist You
  • .
    To leave a comment for the author, please follow the link and comment on their blog: RStudioDataLab.

    R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
    Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
    Exit mobile version