Hi there,
I think I see what's happening here.
This is actually expected behavior for stateful LSTMs, but it's subtle and trips up a lot of people.
TL;DR: When you pass all sequences as one big array, MATLAB processes them as independent observations (resets state between them). In your loop, you're carrying over state from sequence i to sequence i + 1, which creates a dependency chain that shouldn't exist if your sequences are independent.
What's happening:
- Array input (predict(dlnet, inp)): MATLAB treats each column as an independent observation. The network state is automatically reset between observations, so sequence 5 doesn't "see" what happened in sequence 4. This is the correct behavior for independent test sequences.
- Your loop: By doing dlnet.State = state after each prediction, you're explicitly carrying over the LSTM's hidden/cell state to the next sequence. This means sequence i+1 starts with the "memory" of sequence i. If your sequences are supposed to be independent (which they usually are in train/test splits), this creates artificial temporal dependencies that degrade performance.
The fix:
If your sequences are independent (which seems to be the case since you mentioned "test performance"), don't update the state in the loop:
onet_tst = zeros(1, 1e5);
onet_ = predict(dlnet, inp(:,i));
onet_tst(i) = extractdata(onet_);
Or better yet, just use the array method which is vectorized and faster:
onet_tst = predict(dlnet, inp);
To answer your specific questions:
- Does predict use other sequences to refine output? No, not in the way you think. The array method processes each sequence independently (correct). The loop method creates artificial dependencies by carrying state over (incorrect for independent data).
- Affected by training parameters? Mini-batch size during training affects how gradients are computed, but the key issue here is state management during inference, not training parameters.
- How to make them equivalent? Remove the state update from your loop, or use resetState before each prediction if you truly need loop processing:
dlnet = resetState(dlnet);
onet_ = predict(dlnet, inp(:,i));
Why the plot shows "loop" lagging behind: The yellow curve in your plot shows the classic signature of state "bleeding" between sequences - the predictions are smoothed/delayed because the LSTM is trying to maintain temporal continuity where none exists.
When would you want state updates? Only if your 100k sequences are actually one long continuous time series that you artificially chopped into columns. In that case, the loop with state update is correct, but then you'd also want state persistence during training (stateful training), which is a whole different setup.
Hope this clears it up! The array method is giving you the correct independent predictions.
Best, Seba