How can I apply the built-in matlab function mnrfit with the MNIST data?
9 views (last 30 days)
I have been given a task to apply my own implementation of softmax regression (multinomial logistic regression) on the MNIST dataset (Handwritten Digits).
I was successfully able to implement my own method which gave me around 92% accuracy on the test data in less than 60 seconds.
However, I was also asked to check my implementation by using a built-in function in matlab (I can't use anything else besides multinomial logistic regression or softmax regression)
I came across mnrfit and mnrval functions in matlab to achieve this task. Not sure if these functions are ok to use for this specific problem.
When I tried using mnrfit it on the MNIST data to get my coefficients, the run time became unreasonable (more than 15 minutes + --> stopped due to time).
The data is a 60,000X784 training dataset and 10,000X784 test dataset for images (X matrix) and 60,000X1 also 10000X1 labels (Y column vector).
Each row represents a 28X28 (=784) (flattened) image and their corresponding labels which are "handwritten" digits. There are a total of 10 classes (0-9).
If you want to see the dataset, you can download it from this website below (I could not attach them due to their size):
I have attached the functions required to read them and the code is below:
% Read Trainig Images
training_images = loadMNISTImages('train-images.idx3-ubyte')';
% Read Training Labels
training_labels = loadMNISTLabels('train-labels.idx1-ubyte');
% Read Test Images
test_images = loadMNISTImages('t10k-images.idx3-ubyte')';
% Read Test Labels
test_labels = loadMNISTLabels('t10k-labels.idx1-ubyte');
training_labels = categorical(training_labels); % Change the data to categorical
training_labels = double(training_labels); % This part may not be necessary but, included as a test
[B,dev,stats] = mnrfit(training_images,training_labels); % This is where the problem seems to be at (long run time)
pihat_train = mnrval(B, training_images); % This is to find the probability of each class
Any help is appreciated.