Aggregate vs. Independent Lasso

1 view (last 30 days)
Philip Gigliotti
Philip Gigliotti on 15 Mar 2019
Hello,
I'm working on training a lasso regression on the famous School data from the Inner London Educational Authority. I'm following the procedure in this paper:
First they train and test a single lasso regression on the entire school data set which contains 139 schools and 15640 student observations, which they call "aggregate."
Then they train and test 139 independent lasso regressions on each school, which they call "Independent."
This is the code I followed to train the aggregate regression:
c = cvpartition(n,'HoldOut',0.3);
idxTrain = training(c,1);
idxTest = ~idxTrain;
XTrain = xlong(idxTrain, :);
yTrain = y(idxTrain);
XTest = xlong(idxTest,:);
yTest = y(idxTest);
[B,FitInfo] = lasso(XTrain,yTrain,'Alpha',0.75,'CV',10);
idxLambda1SE = FitInfo.Index1SE;
coef = B(:,idxLambda1SE);
coef0 = FitInfo.Intercept(idxLambda1SE);
yhat = XTest*coef + coef0;
mse = mean((yTest-yhat).^2)
This seems to work fine and gives an MSE that makes sense.
According to the authors, training the independent regressions should increase performance. But when I attempt to run the independent routine, i get a higher average mean square error. This is my code, where task_indexes lists the first index for each school.
task_indexes[140] = 15362
error = zeros(139,1)
for j = 1:139;
n = length(y(task_indexes(j):(task_indexes(j+1)-1)));
c = cvpartition(n,'HoldOut',0.3);
idxTrain = training(c,1);
idxTest = ~idxTrain;
Xschool = xlong(task_indexes(j):(task_indexes(j+1)-1), :);
yschool = y(task_indexes(j):(task_indexes(j+1)-1));
XTrain = Xschool(idxTrain, :);
yTrain = yschool(idxTrain);
XTest = Xschool(idxTest,:);
yTest = yschool(idxTest);
[B,FitInfo] = lasso(XTrain,yTrain,'Alpha',0.75,'CV',10);
idxLambda1SE = FitInfo.Index1SE;
coef = B(:,idxLambda1SE);
coef0 = FitInfo.Intercept(idxLambda1SE);
yhat = XTest*coef + coef0;
error(j) = mean((yTest-yhat).^2);
end
I'd like to seek advice on the proper way to run these 139 independent lasso regressions as detailed in the paper, and how to compute aggregate performance statistics for them.
Thanks.

Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!