Contents
Description of demo_multiclass_KNN.m
Demonstrates multiclass KNN and multiclass logistic regression
clear all close all generateData_5grid
usage of k-nearest neighbours classification (5grid data)
options_knn = []; options_knn.k = 5; model_knn = ml_multiclass_KNN(Xtrain, ytrain, options_knn); yhat_knn = model_knn.predict(model_knn, Xtest); testError_st = mean(yhat_knn ~= ytest); fprintf('Averaged misclassification test error with %s is: %.3f\n', ... model_knn.name, testError_st);
Averaged misclassification test error with k-Nearest Neighbours Classification is: 0.102
usage of multi-class logistic classification (5grid data)
options_lg = []; options_lg.addBias = 1; model_lg = ml_multiclass_logistic(Xtrain, ytrain, options_lg); yhat_lg = model_lg.predict(model_lg, Xtest); testError_lg = mean(yhat_lg ~= ytest); fprintf('Averaged misclassification test error with %s is: %.3f\n', ... model_lg.name, testError_lg);
Averaged misclassification test error with Multiclass Logistic Classification is: 0.089
figure; plotClassifier(Xtrain, ytrain, model_knn); figure; plotClassifier(Xtrain, ytrain, model_lg); generateData_gridMulti


usage of k-nearest neighbours classification (gridMulti data)
options_knn = []; options_knn.k = 10; model_knn = ml_multiclass_KNN(Xtrain, ytrain, options_knn); yhat_knn = model_knn.predict(model_knn, Xtest); testError_st = mean(yhat_knn ~= ytest); fprintf('Averaged misclassification test error with %s is: %.3f\n', ... model_knn.name, testError_st);
Averaged misclassification test error with k-Nearest Neighbours Classification is: 0.324
usage of multi-class logistic classification (gridMulti data)
options_lg = []; options_lg.addBias = 1; model_lg = ml_multiclass_logistic(Xtrain, ytrain, options_lg); yhat_lg = model_lg.predict(model_lg, Xtest); testError_lg = mean(yhat_lg ~= ytest); fprintf('Averaged misclassification test error with %s is: %.3f\n', ... model_lg.name, testError_lg);
Averaged misclassification test error with Multiclass Logistic Classification is: 0.302
figure; plotClassifier(Xtrain, ytrain, model_knn); figure; plotClassifier(Xtrain, ytrain, model_lg);

