Intoducing ClustImpute: A new approach for k-means clustering with build-in missing data imputation
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
We are happily introducing a new k-means clustering algorithm that includes a powerful multiple missing data imputation at the computational cost of a few extra random imputations (benchmarks following in a separate article). More precisely, the algorithm draws the missing values iteratively based on the current cluster assignment so that correlations are considered on this level (we assume a more granular dependence structure is not relevant if we are “only” interest in k partitions). Subsequently, penalizing weights are imposed on imputed values and successivelydecreased (to zero) as the missing data imputation gets better. The hope is that at some point the observed point is near a cluster that provides a suitable neighborhood to draw the missing variable from. The algorithm is computationally efficient since the imputation is only as accurate as the clustering, and will be much faster than any approach that derives the full conditional missing distribution, e.g., as implemented in the (awesome) MICE package, independently of the clustering.
ClustImpute can currently be installed via github only:
# devtools::install_github("o1iv3r/ClustImpute") library(ClustImpute)
We’ll provide an example based on simulated data to emphasis the benefits of ClustImpute.
Simulated data with missings
First we create a random dataset with some structure and a few uncorrelated variables
### Random Dataset set.seed(739) n <- 7500 # numer of points nr_other_vars <- 4 mat <- matrix(rnorm(nr_other_vars*n),n,nr_other_vars) me<-4 # mean x <- c(rnorm(n/3,me/2,1),rnorm(2*n/3,-me/2,1)) y <- c(rnorm(n/3,0,1),rnorm(n/3,me,1),rnorm(n/3,-me,1)) true_clust <- c(rep(1,n/3),rep(2,n/3),rep(3,n/3)) # true clusters dat <- cbind(mat,x,y) dat<- as.data.frame(scale(dat)) # scaling summary(dat) #> V1 V2 V3 #> Min. :-3.40352 Min. :-4.273673 Min. :-3.82710 #> 1st Qu.:-0.67607 1st Qu.:-0.670061 1st Qu.:-0.66962 #> Median : 0.01295 Median :-0.006559 Median :-0.01179 #> Mean : 0.00000 Mean : 0.000000 Mean : 0.00000 #> 3rd Qu.: 0.67798 3rd Qu.: 0.684672 3rd Qu.: 0.67221 #> Max. : 3.35535 Max. : 3.423416 Max. : 3.80557 #> V4 x y #> Min. :-3.652267 Min. :-2.1994 Min. :-2.151001 #> 1st Qu.:-0.684359 1st Qu.:-0.7738 1st Qu.:-0.975136 #> Median : 0.001737 Median :-0.2901 Median : 0.009932 #> Mean : 0.000000 Mean : 0.0000 Mean : 0.000000 #> 3rd Qu.: 0.687404 3rd Qu.: 0.9420 3rd Qu.: 0.975788 #> Max. : 3.621530 Max. : 2.8954 Max. : 2.265420
One can clearly see the three clusters
plot(dat$x,dat$y)
We create 20% of missings using the custom function miss_sim()
dat_with_miss <- miss_sim(dat,p=.2,seed_nr=120) summary(dat_with_miss) #> V1 V2 V3 V4 #> Min. :-3.4035 Min. :-4.2737 Min. :-3.8271 Min. :-3.5844 #> 1st Qu.:-0.6756 1st Qu.:-0.6757 1st Qu.:-0.6634 1st Qu.:-0.6742 #> Median : 0.0163 Median :-0.0104 Median :-0.0092 Median : 0.0194 #> Mean : 0.0024 Mean :-0.0063 Mean : 0.0027 Mean : 0.0117 #> 3rd Qu.: 0.6886 3rd Qu.: 0.6683 3rd Qu.: 0.6774 3rd Qu.: 0.7010 #> Max. : 3.2431 Max. : 3.4234 Max. : 3.8056 Max. : 3.6215 #> NA's :1513 NA's :1499 NA's :1470 NA's :1486 #> x y #> Min. :-2.1994 Min. :-2.1510 #> 1st Qu.:-0.7636 1st Qu.:-0.9745 #> Median :-0.2955 Median : 0.0065 #> Mean : 0.0022 Mean :-0.0019 #> 3rd Qu.: 0.9473 3rd Qu.: 0.9689 #> Max. : 2.8954 Max. : 2.2654 #> NA's :1580 NA's :1516
The correlation matrix of the missing indicator shows that the missings are correlated – thus we are not in a missing completely at random (MCAR) stetting:
mis_ind <- is.na(dat_with_miss) # missing indicator corrplot(cor(mis_ind),method="number")
Typical approach: median or random imputation
Clearly, an imputation with the median value does a pretty bad job here:
dat_median_imp <- dat_with_miss for (j in 1:dim(dat)[2]) { dat_median_imp[,j] <- Hmisc::impute(dat_median_imp[,j],fun=median) } imp <- factor(pmax(mis_ind[,5],mis_ind[,6]),labels=c("Original","Imputed")) # point is imputed if x or y is imputed ggplot(dat_median_imp) + geom_point(aes(x=x,y=y,color=imp)) #> Don't know how to automatically pick scale for object of type impute. Defaulting to continuous. #> Don't know how to automatically pick scale for object of type impute. Defaulting to continuous.
But also a random imputation is not much better: it creates plenty of points in areas with no data
dat_random_imp <- dat_with_miss for (j in 1:dim(dat)[2]) { dat_random_imp[,j] <- impute(dat_random_imp[,j],fun="random") } imp <- factor(pmax(mis_ind[,5],mis_ind[,6]),labels=c("Original","Imputed")) # point is imputed if x or y is imputed ggplot(dat_random_imp) + geom_point(aes(x=x,y=y,color=imp)) #> Don't know how to automatically pick scale for object of type impute. Defaulting to continuous. #> Don't know how to automatically pick scale for object of type impute. Defaulting to continuous.
A clustering based on random imputation will thus not provide good results (even if we “know” the number of clusters as in this example)
tic("Clustering based on random imputation") cl_compare <- KMeans_arma(data=dat_random_imp,clusters=3,n_iter=100,seed=751) toc() #> Clustering based on random imputation: 0.01 sec elapsed dat_random_imp$pred <- predict_KMeans(dat_random_imp,cl_compare) ggplot(dat_random_imp) + geom_point(aes(x=x,y=y,color=factor(pred))) #> Don't know how to automatically pick scale for object of type impute. Defaulting to continuous. #> Don't know how to automatically pick scale for object of type impute. Defaulting to continuous.
Better approach: ClustImpute
We’ll now use ClustImpute and also measure the run-time. In short, the algorithm follows these steps
- It replaces all NAs by random imputation, i.e., for each variable with missings, it draws from the marginal distribution of this variable not taking into account any correlations with other variables
- Weights <1 are used to adjust the scale of an observation that was generated in step 1. The weights are calculated by a (linear) weight function that starts near zero and converges to 1 at n_end.
- A k-means clustering is performed with a number of c_steps steps starting with a random initialization.
- The values from step 2 are replaced by new draws conditionally on the assigned cluster from step 3.
- Steps 2-4 are repeated nr_iter times in total. The k-means clustering in step 3 uses the previous cluster centroids for initialization.
- After the last draws a final k-means clustering is performed.
The intuition is that points should be clustered with other points mainly based on their observed values, while the resulting clusters provide donors for the missing value imputation, so that step by step all variables can be used for the clustering.
nr_iter <- 10 # iterations of procedure n_end <- 10 # step until convergence of weight function to 1 nr_cluster <- 3 # number of clusters c_steps <- 50 # numer of cluster steps per iteration tic("Run ClustImpute") res <- ClustImpute(dat_with_miss,nr_cluster=nr_cluster, nr_iter=nr_iter, c_steps=c_steps, n_end=n_end) toc() #> Run ClustImpute: 0.39 sec elapsed
ClustImpute provides several results:
str(res) #> List of 5 #> $ complete_data :'data.frame': 7500 obs. of 6 variables: #> ..$ V1: num [1:7500] 1.403 -0.309 -0.214 -1.286 -0.202 ... #> ..$ V2: num [1:7500] -1.4579 -0.7899 -0.9775 -0.1607 -0.0413 ... #> ..$ V3: num [1:7500] 0.7836 -0.3234 -2.149 -0.0461 0.3609 ... #> ..$ V4: num [1:7500] 0.604 -0.427 -0.122 -1.287 -0.155 ... #> ..$ x : num [1:7500] 2.57 1.2 1.48 1.15 1.65 ... #> ..$ y : num [1:7500] -0.5077 -0.2453 0.022 0.0522 -0.0454 ... #> $ clusters : int [1:7500] 2 2 2 2 2 2 2 2 2 2 ... #> $ centroids : num [1:3, 1:6] 0.10554 -0.02642 -0.06682 0.14785 -0.00818 ... #> $ imp_values_mean: num [1:11, 1:7] -0.0305 0.0214 -0.005 0.0558 -0.0297 ... #> ..- attr(*, "dimnames")=List of 2 #> .. ..$ : chr [1:11] "mean_imp" "" "" "" ... #> .. ..$ : chr [1:7] "V1" "V2" "V3" "V4" ... #> $ imp_values_sd : num [1:11, 1:7] 0.992 0.954 0.978 0.997 0.983 ... #> ..- attr(*, "dimnames")=List of 2 #> .. ..$ : chr [1:11] "sd_imp" "" "" "" ... #> .. ..$ : chr [1:7] "V1" "V2" "V3" "V4" ... #> - attr(*, "class")= chr "kmeans_ClustImpute" #> - attr(*, "nr_iter")= num 10 #> - attr(*, "c_steps")= num 50 #> - attr(*, "wf")=function (n, n_end = 10) #> - attr(*, "n_end")= num 10 #> - attr(*, "seed_nr")= num 150519
We’ll first look at the complete data and clustering results. Quite obviously, it gives better results then median / random imputation.
ggplot(res$complete_data,aes(x,y,color=factor(res$clusters))) + geom_point()
Packages like MICE compute a traceplot of mean and variance of the imputed variables for various chains. This diagnostics helps to show if the Markov chains converge to a stationary distribution.
Here we only have a single realization and thus re-run ClustImpute with various seeds to obtain different realizations.
res2 <- ClustImpute(dat_with_miss,nr_cluster=nr_cluster, nr_iter=nr_iter, c_steps=c_steps, n_end=n_end,seed_nr = 2) res3 <- ClustImpute(dat_with_miss,nr_cluster=nr_cluster, nr_iter=nr_iter, c_steps=c_steps, n_end=n_end,seed_nr = 3) mean_all <- rbind(res$imp_values_mean,res2$imp_values_mean,res3$imp_values_mean) sd_all <- rbind(res$imp_values_sd,res2$imp_values_sd,res3$imp_values_sd) mean_all <- cbind(mean_all,seed=rep(c(150519,2,3),each=11)) sd_all <- cbind(sd_all,seed=rep(c(150519,2,3),each=11))
The realizations mix nicely with each other, as the following plots shows. Thus it seems we obtain to a similar missing value distribution independently of the seed.
ggplot(as.data.frame(mean_all)) + geom_line(aes(x=iter,y=V1,color=factor(seed))) + ggtitle("Mean")
ggplot(as.data.frame(sd_all)) + geom_line(aes(x=iter,y=V1,color=factor(seed))) + ggtitle("Std. dev.")
Quality of imputation and cluster results
Marginal distributions
Now we compare the marginal distributions using a violin plot of x and y. In particular for y, the distribution by cluster is quite far away from the original distribution for the random imputation based clustering, but quite close for ClustImpute.
dat4plot <- dat dat4plot$true_clust <- true_clust Xfinal <- res$complete_data Xfinal$pred <- res$clusters par(mfrow=c(1,2)) violinBy(dat4plot,"x","true_clust",main="Original data") violinBy(dat4plot,"y","true_clust",main="Original data")
violinBy(Xfinal,"x","pred",main="imputed data") violinBy(Xfinal,"y","pred",main="imputed data")
violinBy(dat_random_imp,"x","pred",main="random imputation") violinBy(dat_random_imp,"y","pred",main="random imputation")
External validation: rand index
Below we compare the rand index between true and fitted cluster assignment. For ClustImpute we obtain:
external_validation(true_clust, res$clusters) #> [1] 0.6353923
This is a much higher value than for a clsutering based on random imputation:
class(dat_random_imp$pred) <- "numeric" external_validation(true_clust, dat_random_imp$pred) #> [1] 0.4284653
Not surprisingly, the RandIndex for ClustImpute is much higher if we consider complete cases only (and throw away a considerable amount of our data).
## complete cases idx <- which(complete.cases(dat_with_miss)==TRUE) sprintf("Number of complete cases is %s",length(idx)) #> [1] "Number of complete cases is 2181" sprintf("Rand index for this case %s", external_validation(true_clust[idx], res$clusters[idx])) #> [1] "Rand index for this case 0.975286775813841"
Aside from the RandIndex, this function also computes a variety of other stats
external_validation(true_clust, res$clusters,summary_stats = TRUE) #> #> ---------------------------------------- #> purity : 0.8647 #> entropy : 0.4397 #> normalized mutual information : 0.5595 #> variation of information : 1.3956 #> normalized var. of information : 0.6116 #> ---------------------------------------- #> specificity : 0.8778 #> sensitivity : 0.7579 #> precision : 0.7561 #> recall : 0.7579 #> F-measure : 0.757 #> ---------------------------------------- #> accuracy OR rand-index : 0.8379 #> adjusted-rand-index : 0.6354 #> jaccard-index : 0.6091 #> fowlkes-mallows-index : 0.757 #> mirkin-metric : 9118230 #> ---------------------------------------- #> [1] 0.6353923
Variance reduction
To assess quality of our cluster results, we compute the sum of squares within each cluster, sum up these values and compare it with the total sum of squares.
res_var <- var_reduction(res) res_var$Variance_reduction #> [1] 0.2795489 res_var$Variance_by_cluster #> # A tibble: 1 x 6 #> V1 V2 V3 V4 x y #> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> #> 1 0.987 0.983 1.01 0.976 0.218 0.117
We se a reduction of about 27% using only 3 clusters, most strikingly for x and y because that these variables define the subspace of the true clusters.
More clusters will capture the random distribution of the other variables
res <- ClustImpute(dat_with_miss,nr_cluster=10, nr_iter=nr_iter, c_steps=c_steps, n_end=n_end) res_var <- var_reduction(res) res_var$Variance_reduction #> [1] 0.5209119 res_var$Variance_by_cluster #> # A tibble: 1 x 6 #> V1 V2 V3 V4 x y #> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> #> 1 0.568 0.601 0.609 0.616 0.267 0.187
Can we use this function systematically to find out the optimal number of clusters? We’ll do the exercise above for a several values of nr_cluster and need a helper function for that since X is an argument of ClustImpute
ClustImpute2 <- function(dataFrame,nr_cluster, nr_iter=10, c_steps=1, wf=default_wf, n_end=10, seed_nr=150519) { return(ClustImpute(dataFrame,nr_cluster, nr_iter, c_steps, wf, n_end, seed_nr)) } res_list <- lapply(X=1:10,FUN=ClustImpute2,dataFrame=dat_with_miss, nr_iter=nr_iter, c_steps=c_steps, n_end=n_end)
Nex we put the variances by cluster in a table
tmp <- var_reduction(res_list[[1]]) var_by_clust <- tmp$Variance_by_cluster for (k in 2:10) { tmp <- var_reduction(res_list[[k]]) var_by_clust <- rbind(var_by_clust,tmp$Variance_by_cluster) } var_by_clust$nr_clusters <- 1:10
While there is a rather gradual improvement for the other variables, x and y have a minimum at 3 showing optimality for these variables. Such a plot clearly indicates that 3 clusters are a good choice for this data set (which, of course, we knew in advance)
data2plot <- tidyr::gather(var_by_clust,key = "variable", value = "variance", -dplyr::one_of("nr_clusters")) ggplot(data2plot,aes(x=nr_clusters,y=variance,color=variable)) + geom_line() + scale_x_continuous(breaks=1:10)
Now it’s time for you to try this algorithm on real problems! Looking forward to feedback via twitter or on the github page.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.