Weighted Linear Support Vector Machine
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Consider the spam vs ham data from this site. Let us do some basic analysis on the data with R version 3.3.3, 64 bits on qn windows machine
setwd("your directory") sms_data<-read.csv("sms_spam.csv",stringsAsFactors = FALSE)
Next, check the proportion of spam and ham in your data set
prop.table(table(sms_data$type)) ham spam 0.8659849 0.1340151
As you can see, the proportion of ham in the data set is almost 6.6 times higher than spam. Let us select some amount of data from this entire data set to train a model for our classifier.
simply_text<-Corpus(VectorSource(sms_data$text)) cleaned_corpus<-tm_map(simply_text, tolower) cleaned_corpus<-tm_map(cleaned_corpus,removeNumbers) cleaned_corpus<-tm_map(cleaned_corpus,removeWords,stopwords()) sms_dtm<-DocumentTermMatrix(cleaned_corpus) sms_train<-cleaned_corpus[1:3000] sms_test<-cleaned_corpus[3001:5574] freq_term=(findFreqTerms(sms_dtm,lowfreq=10, highfreq = Inf)) sms_freq_train<-DocumentTermMatrix(sms_train, list(dictionary=freq_term)) sms_freq_test<-DocumentTermMatrix(sms_test,list(dictionary=freq_term)) print(NCOL(sms_freq_train)) print(NCOL(sms_freq_test)) y_train<-factor(sms_data$type[1:3000]) y_test<-factor(sms_data$type[3001:5574]) prop.table(table(y_train)) print(NCOL(sms_freq_train)) print(NCOL(sms_freq_test)) [1] 856 [1] 856 ham spam 0.8636667 0.1363333
Notice that the proportion of spam and ham in the training data set is similar to that of the entire data. One of the widely used classifiers is Linear Support Vector Machine. From my last writing on Linear Support Vector Machine, you can find that in case of Linear SVM we solve the following optimization problem.
$$Minimize_{w,b}\frac{\vert\vert w\vert\vert^{2}}{2}$$
subject to the constraint
$$y_{i}(w^{T}x_{i}+b)\geq \frac{1}{\vert \vert w\vert\vert}~\forall ~x_{i}~in ~Training~Data~set$$.
Here \(w\) is a weight vector while \(b\) is a scalar also known as bias. For linearly inseparable case, we introduce some penalty \(0\leq \xi_{i} \leq 1\) in the objective function and our constraint.
$$Minimize_{w,b}(\frac{\vert\vert w\vert\vert^{2}}{2}+ C\sum_{i=1}^{n}\xi_{i})$$
subject to the constraints
$$(1)~y_{i}(w^{T}x_{i}+b)\geq 1-\xi_{i}~\forall ~x_{i}~in ~Training~Data~set~of~size~n$$ and
$$(2)~\xi_{i}\geq 0~\forall ~n~points~in~the~Training~Data~set$$
\(C\) is a user defined parameter which is known as regularization parameter. The regularization parameter is a special parameter. It tells the optimizer what it should minimize more between \(\frac{\vert\vert w\vert\vert^{2}}{2}\) and \(\sum_{i=1}^{n}\xi_{i}\). For example if \(C=0\), then only \(\frac{\vert\vert w\vert\vert^{2}}{2}\) will be minimized whereas if \(C\) is a large value, then \(\sum_{i=1}^{n}\xi_{i}\) will be minimized.
In this example case where the class proportions are so skewed, the choice of the regularization parameter will be a concern. Notice that \(C\) is the scalar-weight of the penalty of mis-classification. Intuitively you can think that, since the proportion of ham is 6.6 times higher than spam samples in the training data, the linear SVM may get biased towards classifying hams more accurately as mis-classifying lots of ham will lead to more aggregated penalty.
Thus, instead of using Linear SVM directly on such data set, it is better to use weighted Linear SVM where instead of using one regularization parameter, we use two separate regularization parameters, \(C_{1}, C_{2}\) where \(C_{1}\) (respectively \(C_{2}\)) is the weight on penalty of mis-classifying a ham sample (respectively a spam sample). So this time our objective function becomes,
$$Minimize_{w,b}(\frac{\vert\vert w\vert\vert^{2}}{2}+ C_{1}\sum_{i=1}^{n_{1}}\xi_{i} +C_{2}\sum_{j=1}^{n_{2}}\xi_{j})$$
where \(n_{1}\) (respective \(n_{2}\)) are the number of hams (respectively spams) in our training data.
How to Fix the Values for \(C_{1},C_{2}\)
One simple way is to set
$$C_{1}=\frac{1}{n_{1}}$$ and
$$C_{2}=\frac{1}{n_{2}}$$
The R code for doing this is as follows:
y<-((sms_data$type)) wts<-1/table(y) print(wts) ham spam 0.000207168 0.001338688
Now train your weighted linear SVM. An easy way of doing so is as follows:
library(e1071) sms_freq_matrx<-as.matrix(sms_freq_train) sms_freq_dtm<-as.data.frame(sms_freq_matrx) sms_freq_matrx_test<-as.matrix(sms_freq_test) sms_freq_dtm_test<-as.data.frame(sms_freq_matrx_test) trained_model<-svm(sms_freq_dtm, y_train, type="C-classification", kernel="linear", class.weights = wts)
Let us use the model for some predictions
y_predict<-predict(trained_model, sms_freq_dtm_test) library(gmodels) CrossTable(y_predict,y_test,prop.chisq = FALSE) | y_test y_predict | ham | spam | Row Total | -------------|-----------|-----------|-----------| ham | 2002 | 233 | 2235 | | 0.896 | 0.104 | 0.868 | | 0.895 | 0.689 | | | 0.778 | 0.091 | | -------------|-----------|-----------|-----------| spam | 234 | 105 | 339 | | 0.690 | 0.310 | 0.132 | | 0.105 | 0.311 | | | 0.091 | 0.041 | | -------------|-----------|-----------|-----------| Column Total | 2236 | 338 | 2574 | | 0.869 | 0.131 | | -------------|-----------|-----------|-----------|
Of the 2236 hams, 2002 got correctly classified as ham while 234 has got mis-classified as spam. However, the case is not so shiny with spams. Of the 338 spams, 105 got correctly classified as spam while 233 has got mis-classified as ham.
Now let us see what happens without using the weights.
trained_model2<-svm(sms_freq_dtm, y_train, type="C-classification", kernel="linear") y_predict<-predict(trained_model2, sms_freq_dtm_test) library(gmodels) CrossTable(y_predict,y_test,prop.chisq = FALSE) | y_test y_predict | ham | spam | Row Total | -------------|-----------|-----------|-----------| ham | 1922 | 224 | 2146 | | 0.896 | 0.104 | 0.834 | | 0.860 | 0.663 | | | 0.747 | 0.087 | | -------------|-----------|-----------|-----------| spam | 314 | 114 | 428 | | 0.734 | 0.266 | 0.166 | | 0.140 | 0.337 | | | 0.122 | 0.044 | | -------------|-----------|-----------|-----------| Column Total | 2236 | 338 | 2574 | | 0.869 | 0.131 | | -------------|-----------|-----------|-----------|
Well, the number of spams correctly classified increased but the number of ham correctly classified decreased.
Let us repeat our experiment by increasing the values of \(C_{1}\) and \(C_{2}\).
wts<-100/table(y) print(wts) ham spam 0.0207168 0.1338688
The results presented below show that the mis-classification for spam has reduced and the accuracy for spam classification has increased. However, the accuracy for ham has decreased.
| y_test y_predict | ham | spam | Row Total | -------------|-----------|-----------|-----------| ham | 1908 | 219 | 2127 | | 0.897 | 0.103 | 0.826 | | 0.853 | 0.648 | | | 0.741 | 0.085 | | -------------|-----------|-----------|-----------| spam | 328 | 119 | 447 | | 0.734 | 0.266 | 0.174 | | 0.147 | 0.352 | | | 0.127 | 0.046 | | -------------|-----------|-----------|-----------| Column Total | 2236 | 338 | 2574 | | 0.869 | 0.131 | | -------------|-----------|-----------|-----------|
We continued our experiments further by setting
wts<-10/table(y) print(wts) ham spam 0.00207168 0.01338688
The results are
| y_test y_predict | ham | spam | Row Total | -------------|-----------|-----------|-----------| ham | 1931 | 230 | 2161 | | 0.894 | 0.106 | 0.840 | | 0.864 | 0.680 | | | 0.750 | 0.089 | | -------------|-----------|-----------|-----------| spam | 305 | 108 | 413 | | 0.738 | 0.262 | 0.160 | | 0.136 | 0.320 | | | 0.118 | 0.042 | | -------------|-----------|-----------|-----------| Column Total | 2236 | 338 | 2574 | | 0.869 | 0.131 | | -------------|-----------|-----------|-----------|
One can repeat the experiment with different values of wts and see the results.
Related Post
- Logistic Regression Regularized with Optimization
- Analytical and Numerical Solutions to Linear Regression Problems
- How to create a loop to run multiple regression models
- Regression model with auto correlated errors – Part 3, some astrology
- Regression model with auto correlated errors – Part 2, the models
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.