What’s Artificial Intelligence all about?
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Setup
library(tidyverse) # As always library(MASS) # Sampling from a multivariate distributions library(plotly) # For 3D plots
OK, but what is AI actually about? Over the past two summers, I taught a statistics and research methods course to psychology students. Generally speaking, these students tend to be a little intimidated by this field, and as always, I tried over the summer to both empower them and spark an interest in them in the beauty and ‘coolness’ of statistics.
One idea they liked was when I told them: If you understand linear regression, you understand Chat GPT. While this is obviously a simplification, I managed to convince them it is not far from the truth. If simple linear modeling is about finding the best two parameters (a slope and an intercept) to approximate a function, AI is about finding the best millions to billions of parameters to approximate a function. In this blog post I will explain one of the basic and most important concepts of modern AI – Gradient Descent using simple (as possible) terms and some R code.
What is a function?
nothing too interesting
set.seed(14) nice_colors <- c("#8DD3C7", "#FFFFB3", "#BEBADA", "#FB8072", "#80B1D3", "#FDB462", "#B3DE69", "#FCCDE5", "#D9D9D9", "#BC80BD", "#CCEBC5", "#FFED6F") d <- data.frame(mvrnorm(21, mu = c(0, 1.4), Sigma = matrix(c(1, 0.4, 0.4, 1), 2))) ggplot(d, aes(X1, X2)) + geom_point() + labs(x = "Time Spent Studying", y = "Final Grade") + theme_classic()
What is the relationship between the variables Time Spent Studying and Final Grade?
The mission of every statistical model is to provide the best function to estimate this relationship. The function can be linear:
It can also be quadratic:
And even a fifth degree polynomial:
Which is the best? In order to answer this question, we need to define what we want our model to maximize (or minimize). Usually, two main things are considered when evaluating our fitted lines: 1. Distance from truth, or how close the line (predictions) is to the points. 2. Complexity, or how complex is the function.
Together, these two components create the Loss that will always look conceptually like that:
For the sake of simplicity (no pun intended), and because both simple linear regression and deep neural networks share this feature, we will ignore the complexity component of the Loss function. So what is the distance between the line and the points for the linear function?
Each gray dotted line represents the difference between the actual outcome (Final Grade) and the model’s prediction based on Time Spent Studying. In a more formal way we say that each line is:
Where:
is the real outcome of point
is the predicted outcome of point
And the total loss can be the mean of these lines:
BUT, this definition is problematic as I will show in the next chapter.
Optimal = Minimum loss (intro to loss functions)
Now we can calculate the loss for every model we fit, therefore the loss is actually a function of the model and is called the Loss Function.
Models can differ in two main ways: in their structure (for example: a linear function and a 5-th degree polynomial), or in the values of their parameters (for example: these two linear functions: ,
). With the model structure usually chosen a-priori (or through a separate experiment), we will focus on the loss as a function of the parameters’ values.
So our task is to find the set of parameters that minimize the loss function. Let’s formalize it for the linear model case: Given a linear model:
, our goal is to find the minimum of:
But, there is the catch - this function does not have a minimum point! This is because the lines we saw before can be as negative as we want, just make and
negative enough. We can also notice that this function is linear with respect to
and
, therefore it has no minimum or maximum points.
In order to solve this, we simply define the loss function to be the mean Square distance from the line - also known as the Variance.
Finding a function’s minimum point
How does this function looks like?
parameter_values <- expand_grid(beta0 = seq(-1.5, 2.5, 0.05), beta1 = seq(-1.5, 2.5, 0.05)) |> mutate(id = factor(c(1:n()))) df <- parameter_values |> expand_grid(d) |> select(id, beta0, beta1, x = X1, y = X2) |> mutate(y_pred = beta0 + beta1*x) |> mutate(loss = (y - y_pred)^2) |> group_by(id) |> reframe(beta0 = beta0, beta1 = beta1, mse = mean(loss)) |> distinct()
plot_ly(df, x = ~beta0, y = ~beta1, z = ~mse, type = "scatter3d", mode = "markers")
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.