Implementing soft margin SVM algorithm using linear kernel for classification of hand written digits (3 and 8) with a detailed Matlab code.
Contents
- Parameters
- Datasets
- Loading the data
- Zscore normalization
- initialization of vectors
- calculate Alphas
- Compute the near boundary coefficients
- Calculate Support Vectors
- Compute bias parameter b
- Compute the decision function on training set
- Compute the decision function on test set
- Threshold the decisions
- Calculate the misclassification errors
- plot the graphs of misclassification errors
- plot the support vector count percentage
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % SVM Linear Kernel % % Step by step detailed implementation for soft margin SVM % % Santosh Tirunagari % % % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Parameters
\;and\;C\;=\;\{0.0001, 0.001, 0.01, 0.1, 1\}$$”>
Datasets
The datasets for hand written digits used are training set and test set with 500 samples and 256 features. Class labels for training and test set are denoted as for
and
for
. There are 250 samples each for handwritten 3s and 8s for both the training and test sets. Download the dataset
Loading the data
Load and zscore normalize both the train and test data using zscore matlab command.
load usps_3_vs_8_train_X.txt; load usps_3_vs_8_train_y.txt; load usps_3_vs_8_test_X.txt; load usps_3_vs_8_test_y.txt;
Zscore normalization
Y = usps_3_vs_8_train_y; X = zscore(usps_3_vs_8_train_X); X_t = zscore(usps_3_vs_8_test_X); Y_t = usps_3_vs_8_test_y;
initialization of vectors
result_train =[]; result_test = []; support = []; warning off; % for k = 1:1:5
C = [0.0001,0.001,0.01,0.1,1]; K = X*X'; % Calculate linear kernel function H = (Y*Y').*K; A = []; Aeq = Y'; l = zeros(500,1); c = -1*ones(500,1); b = []; beq = 0; u = C(k)*ones(500, 1); options = optimset('Algorithm','interior-point-convex');
calculate Alphas
%Solve SVM dual optimization problem in standard QP form using quadprog function.
alpha = quadprog(H, c, A, b, Aeq, beq, l, u,[],options);
Minimum found that satisfies the constraints.
Compute the near boundary coefficients
alpha(alpha < C(k) * 0.001) = 0;
alpha(alpha > C(k)*0.999999999999) = C(k);% move near boundary alpha values to boundaries
Calculate Support Vectors
sv = find(alpha >0 & alpha<C(k)); sv_one = zeros(500,1); sv_one(sv,1) = 1;
Compute bias parameter b
b = sv_one'*(Y-((alpha.*Y)'*K')')/sum(sv_one); s = length(sv); s=(s/500)*100; support = [support;s];
Compute the decision function on training set
Ki = X(sv,:)*X'; temp = bsxfun(@plus,Ki'*(alpha(sv,:).*Y(sv,:)),b); res = temp; res(res>=0) = 1; res(res<0) = -1; r = sum(res~=Y); r=(r/500)*100; result_train = [result_train;r];
Compute the decision function on test set
Ki = X(sv,:)*X_t'; temp = bsxfun(@plus,Ki'*(alpha(sv,:).*Y(sv,:)),b); res = temp;
Threshold the decisions
res(res>=0) = 1; res(res<0) = -1;
Calculate the misclassification errors
r = sum(res~=Y_t); r=(r/500)*100; result_test = [result_test;r];
end
plot the graphs of misclassification errors
set(0,'DefaultAxesFontWeight','bold') set(0,'DefaultAxesFontSize',[13]) set(0,'DefaultTextFontSize',[18]) h = figure; hut = log10(C); plot(hut,result_train,'k-*','LineWidth',5,'MarkerSize',10); hold on plot(hut,result_test,'r-o','LineWidth',5,'MarkerSize',10); set(gca,'YTick',[0 5 10]) set(gca,'XTick',log10(C)) xlabel('log10(C)'); ylabel('Mis-Classification Error'); hleg1 = legend('Train','Test'); title('Linear Kernel'); set(hleg1,'Location','NorthEast') set(hleg1,'Interpreter','none') saveas(h,'lk_mce','png')
plot the support vector count percentage
h = figure; hut = log10(C); plot(hut,support,'k-*','LineWidth',5,'MarkerSize',10); set(gca,'YTick',[0 50 100]) set(gca,'XTick',log10(C)) xlabel('log10(C)'); ylabel('Support Vector'); title('Linear Kernal'); saveas(h,'lk_sv','png')
C | Mis-Classification Error on Training set | Mis-Classification Error on Test set | Support Vectors % |
---|---|---|---|
0.0001 | 5.2 | 8 | 100 |
0.001 | 0.8 | 2 | 57 |
0.01 | 0.2 | 3.6 | 22.6 |
0.1 | 0 | 3.6 | 14.8 |
1 | 0 | 3.6 | 12.6 |
One comment :