Out-of-Distribution Detection for BERT Document Classifier
This example shows how to detect out-of-distribution (OOD) data in a BERT document classifier.
OOD data detection is the process of identifying inputs to a deep neural network that might yield unreliable predictions. OOD data refers to data that is different from the data used to train the model, for example, data collected in a different way, under different conditions, or for a different task than the data on which the model was originally trained.
You can classify data as in-distribution (ID) or OOD by assigning confidence scores to the predictions of a network. You can then choose how you treat OOD data. For example, you can choose to reject the prediction of a neural network if the network detects OOD data.
In this example, you fine-tune a pretrained BERT classification model to predict the type of maintenance work done on traffic signals using text descriptions. You then construct a discriminator to classify the text descriptions as ID or OOD.
In this example, you fine-tune and use a pretrained BERT document classifier in five steps:
Import and preprocess the data.
Separate the ID and OOD data.
Fine-tune a pretrained BERT model using the ID data.
Create a BERT mini-batch queue.
Construct and calibrate a distribution discriminator and compare the distribution scores of the ID and OOD data.
Import and Preprocess Data
This example uses a large data set that contains records of work completed by traffic signal technicians in the city of Austin, TX, United States [1]. This data set is a table containing approximately 36,000 reports with various attributes, including a plain text description in the JobDescription
variable and a categorical label in the WorkNeeded
variable.
Load the example data.
zipFile = matlab.internal.examples.downloadSupportFile("textanalytics","data/Traffic_Signal_Work_Orders.zip"); filepath = fileparts(zipFile); dataFolder = fullfile(filepath,"Traffic_Signal_Work_Orders"); unzip(zipFile,dataFolder); filename = "Traffic_Signal_Work_Orders.csv"; data = readtable(fullfile(dataFolder,filename),TextType="string", VariableNamingRule="preserve"); data.Properties.VariableNames = matlab.lang.makeValidName(data.Properties.VariableNames); head(data)
WorkOrderID Status AssetType AssetID LocationID CreatedDate ModifiedDate SubmittedDate ClosedDate FiscalYear WorkType WorkNeeded WorkTypeOther WorkRequestedBy JobDescription ProblemFound ActionTaken Follow_UpNeeded ChildWorkOrder ParentWorkOrder IsFollow_Up TMCIssueID ServiceRequest_ DamageReport LocationName Latitude Longitude Location ______________ ________ ________________ _______ ______________ ______________________________ ______________________________ ______________________________ ______________________________ __________ ________________ _______________________________________________________________ _____________________________________________________________________ _____________________________ _____________________________________________________________________________ ________________________________________________________________________________________________________________________________ __________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________ _______________ ______________ _______________ ___________ ______________ _______________ ____________ ______________________________________________________________ ________ _________ ______________________________ "WRK17-001685" "Closed" "School Flasher" NaN <missing> "08/19/2017 08:55:00 PM +0000" "09/14/2017 06:27:00 PM +0000" "08/19/2017 09:00:00 PM +0000" "09/14/2017 06:27:00 PM +0000" 2017 "Scheduled Work" "Call-Back (Test Monitors and Cabinets)" <missing> "Austin Transportation Staff" "HAVE AUSTIN ENERGY TIE IN NEW SOURCE DROP OVERHEAD @ CIMA SERENA WB FLASHER" "N/A." "AUSTIN ENERGY TECHNICIANS DISPATCHED TO LOCATION. AE TECHS COULD NOT DO WORK BECAUSE OF LACK OF METER ON POLE/SOURCE. AE TECHS SAID TO CONTACT "WORK MANAGMENT NORTH" 5125057179 FOR FURTHER ACTION. INFORMATION WILL BE RELAYED TO SUPERVISOR. " "False" <missing> <missing> <missing> <missing> <missing> <missing> <missing> NaN NaN <missing> "WRK17-001865" "Closed" "Signal" 317 "LOC16-001550" "08/24/2017 03:28:00 PM +0000" "09/14/2017 06:42:00 PM +0000" "08/24/2017 03:56:00 PM +0000" "09/14/2017 06:42:00 PM +0000" 2017 "Scheduled Work" "Installation - Other" <missing> "Austin Transportation Staff" <missing> "bad cable for nb in the conduits" "pulled in 20 conductor cable for nb signals and peds . installed a new 332 cabinet , respliced all signals and peds for 2 way project ." "False" <missing> <missing> <missing> <missing> <missing> <missing> "5TH ST / TRINITY ST" NaN NaN "POINT (-97.739677 30.266132)" "WRK17-001875" "Closed" "Signal" 319 "LOC16-001560" "08/24/2017 03:45:00 PM +0000" "09/14/2017 06:54:00 PM +0000" "08/24/2017 04:03:00 PM +0000" "09/14/2017 06:54:00 PM +0000" 2017 "Scheduled Work" "Installation - Other" <missing> "MMC" "install wb standard and splice in signals and peds" <missing> "install wb mast arm, remove street light pole, splice signal cables and peds" "False" <missing> <missing> <missing> <missing> <missing> <missing> "5TH ST / RED RIVER ST" NaN NaN "POINT (-97.737488 30.265535)" "WRK17-001890" "Closed" "School Flasher" NaN <missing> "08/24/2017 08:23:00 PM +0000" "08/24/2017 08:31:00 PM +0000" "08/24/2017 08:31:00 PM +0000" "08/28/2017 03:08:00 PM +0000" 2017 "Trouble Call" "OtherDay-Call (Deliver Timing sheets to intersections and PM)" "SOMMERS ELEMENTARY - NOT FLASHING↵↵SR #17-00242843↵#17-00244051↵" "Austin Transportation Staff" "SOMMERS ELEMENTARY - NOT FLASHING↵SR #17-00242843, #17-00244051" "NO PROBLEMS FOUND AT SCHOOL FLASHERS. BOTH PEDESTRIAN FLASHERS NEED SCHEDULE." "BOTH SCHOOL CLOCKS CHECKED FOR TIME, DATE, SCHEDULE, FLASHERS OPERATION AND COMMUNICATION.↵BOTH PEDESTRIAN FLASHER CLOCKS CHECKED FOR TIME, DATE, SCHEDULE, OPERATION, AND COMM.↵TIME, DATE AND SCHEDULE UPDATED IN PEDESTRIAN FLASHER CLOCKS." "False" <missing> <missing> <missing> <missing> <missing> <missing> <missing> NaN NaN <missing> "WRK17-003185" "Closed" "Signal" 25 "LOC16-000120" "10/09/2017 07:46:00 PM +0000" "01/23/2023 04:47:00 PM +0000" "10/09/2017 07:49:00 PM +0000" "10/10/2017 04:45:00 PM +0000" 2018 "Scheduled Work" "Installation - Camera" <missing> "MMC" "replace the avidia cctv with a pelco repaired unit" <missing> "replaced the avidia cctv with a repaired pelco task # 2423015000" "False" <missing> <missing> <missing> <missing> <missing> <missing> "MARTIN LUTHER KING JR BLVD / CONGRESS AVE (MLK/Capitol Mall)" NaN NaN "POINT (-97.738106 30.280687)" "WRK17-003430" "Closed" "Signal" 185 "LOC16-000915" "10/18/2017 08:43:00 PM +0000" "10/26/2017 07:30:00 PM +0000" "10/18/2017 08:49:00 PM +0000" "10/26/2017 07:30:00 PM +0000" 2018 "Trouble Call" "Visibility Issue" <missing> "MMC" "Tree limbs blocking WB signal direction." "Tree limbs blocking WB signal direction." "Cut limbs blocking WB signal direction to make visible for ongoing traffic." "True" <missing> <missing> <missing> "TMC17-006530" "17-00311041" <missing> "LAMAR BLVD / PANTHER TRL" NaN NaN "POINT (-97.789284 30.23867)" "WRK17-001895" "Closed" "Signal" NaN <missing> "08/24/2017 08:32:00 PM +0000" "08/24/2017 08:40:00 PM +0000" "08/24/2017 08:40:00 PM +0000" "08/28/2017 03:06:00 PM +0000" 2017 "Trouble Call" "OtherDay-Call (Deliver Timing sheets to intersections and PM)" "DOSS/MURCHISON COMBO WB NOT FLASHING" "Austin Transportation Staff" "DOSS/MURCHISON COMBO WB NOT FLASHING" "WB FLASHER ON GREYSTONE DOES NOT HAVE COMMUNICATION. CLOCK HAD NO SCHEDULE. EB FLASHER ON N HILLS DR. HAS LIMBS OBSTRUCTION." "DATE, TIME, SCHEDULE, AND FLASHER OPERATION CHECKED FOR ALL CLCOKS. WB CLOCK ON GREYSTONE PROGRAMMED WITH 2017/2018 SCHEDULE. LIMBS REMOVED FROM EB FLASHER ON N HILLS DR." "False" <missing> <missing> <missing> <missing> <missing> <missing> <missing> NaN NaN <missing> "WRK17-002010" "Closed" "Signal" 779 "LOC16-003835" "08/29/2017 07:58:00 PM +0000" "09/14/2017 07:04:00 PM +0000" "08/30/2017 11:02:00 AM +0000" "09/14/2017 07:04:00 PM +0000" 2017 "Trouble Call" "Detection Failure" <missing> "MMC" "fisheye camera turned" "gridsmart camera out of alignment" "with assistance from the TMC - realigned camera and tightened" "False" <missing> <missing> <missing> <missing> <missing> <missing> "MC KINNEY FALLS PKWY / WILLIAM CANNON DR" NaN NaN "POINT (-97.72583 30.163218)"
The goal of this example is to classify maintenance visits by the label in the WorkNeeded
column. To divide the data into classes, convert these labels to categorical.
data.WorkNeeded = categorical(data.WorkNeeded);
Remove data classified into categories that are rare using the removeRareCategories
function, defined at the end of this example.
data = removeRareCategories(data);
Remove data with empty job description.
data(ismissing(data.JobDescription),:) = [];
Separate ID and OOD Data
The data set includes two types of work, scheduled work and trouble calls.
data.WorkType = categorical(data.WorkType); figure histogram(data.WorkType)
In this example, you train a document classifier on the JobDescription
fields of the reports from work resulting from trouble calls. This data comprises the ID data. The reports resulting from scheduled work comprise the OOD data.
dataID = data(data.WorkType=="Trouble Call",:); dataOOD = data(data.WorkType=="Scheduled Work",:);
Remove any now unused categories from both ID and OOD data.
dataID.WorkNeeded = removecats(dataID.WorkNeeded); dataOOD.WorkNeeded = removecats(dataOOD.WorkNeeded);
Compare the JobDescription
fields of both ID and OOD data using word clouds.
figure tiledlayout("horizontal") nexttile wordcloud(dataID.JobDescription); title("In-distribution") nexttile wordcloud(dataOOD.JobDescription); title("Out-of-distribution")
Prepare Data for Training
Next, partition the ID data into sets for training, validation, and testing. Partition the data into a training set containing 80% of the ID data, a validation set containing 10% of the ID data, and a test set containing the remaining 10% of the ID data. To partition the data, use the trainingPartitions
function, attached to this example as a supporting file. To access this file, open the example as a live script.
numReports = size(dataID,1);
[idxTrain,idxValidation,idxTest] = trainingPartitions(numReports,[0.8 0.1 0.1]); % attached to this example as a supporting file
dataTrain = dataID(idxTrain,:);
dataValidation = dataID(idxValidation,:);
dataTest = dataID(idxTest,:);
classNames = categories(dataID.WorkNeeded);
To avoid having two copies of the ID data in memory, remove dataID
.
clear("dataID");
Extract the text data and labels from the partitioned tables and the OOD data.
documentsTrain = dataTrain.JobDescription; documentsValidation = dataValidation.JobDescription; documentsTest = dataTest.JobDescription; documentsOOD = dataOOD.JobDescription; YTrain = dataTrain.WorkNeeded; YValidation = dataValidation.WorkNeeded; YTest = dataTest.WorkNeeded; YOOD = dataOOD.WorkNeeded;
Load Pretrained BERT Document Classifier
Load a pretrained BERT-Tiny document classifier using the bertDocumentClassifier
function. If the Text Analytics Toolbox™ Model for BERT-Tiny Network support package is not installed, then the function provides a link to the required support package in the Add-On Explorer. To install the support package, click the link, and then click Install.
mdl = bertDocumentClassifier(Model="tiny",ClassNames=classNames)
mdl = bertDocumentClassifier with properties: Network: [1×1 dlnetwork] Tokenizer: [1×1 bertTokenizer] ClassNames: ["Communication Failure" "Detection Failure" "Knockdown" "LED Out" "Push Button Not Working" "Signal Out or on Flash" "Timing Issue" "Visibility Issue"]
Specify Training Options
Specify the training options. Choosing among training options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.
Train using the Adam optimizer.
Train for eight epochs.
For fine-tuning, lower the learning rate. Train using a learning rate of 0.0001.
Shuffle the data every epoch.
Monitor the training progress in a plot and monitor the accuracy metric.
Disable the verbose output.
options = trainingOptions("adam", ... MaxEpochs=8, ... InitialLearnRate=1e-4, ... ValidationData={documentsValidation,YValidation}, ... Shuffle="every-epoch", ... Plots="training-progress", ... Metrics="accuracy", ... Verbose=false);
Train Neural Network
Train the neural network using the trainBERTDocumentClassifier
function. By default, the trainBERTDocumentClassifier
function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainBERTDocumentClassifier
function uses the CPU. To specify the execution environment, use the ExecutionEnvironment
training option.
mdl = trainBERTDocumentClassifier(documentsTrain,YTrain,mdl,options);
Test Neural Network
Make predictions using the test data.
YPred = classify(mdl,documentsTest);
Compare the true and predicted labels.
figure confusionchart(YTest,YPred)
Detect OOD Data
You can assign confidence scores to network predictions by computing a distribution confidence score for each observation. ID data usually has a higher confidence score than OOD data. You can then apply a threshold to the scores to determine whether an input is ID or OOD.
Create a discriminator that separates ID and OOD data by using the networkDistributionDiscriminator
function. The function returns a discriminator containing a threshold for separating data into ID and OOD using their distribution scores.
The networkDistributionDiscriminator
function requires at least three input arguments:
A
dlnetwork
object.Input data in the form of either a
dlarray
object or aminibatchqueue
object.The algorithm used by the function, specified as
BaselineDistributionDiscriminator
,ODINDistributionDiscriminator
,EnergyDistributionDiscriminator
, orHBOSDistributionDiscriminator
.
Create BERT Mini-Batch Queue
First, create a mini-batch queue for BERT using the bertMiniBatchQueue
function defined in this example.
miniBatchSize = 128; mbqTrain = bertMiniBatchQueue(mdl,documentsTrain,miniBatchSize); mbqValidation = bertMiniBatchQueue(mdl,documentsValidation,miniBatchSize); mbqTest = bertMiniBatchQueue(mdl,documentsTest,miniBatchSize); mbqOOD = bertMiniBatchQueue(mdl,documentsOOD,miniBatchSize);
To create a mini-batch queue, first create a datastore that holds the input data for BERT. Then create a minibatchqueue
object using the preprocessPredictors
function to preprocess the data. The preprocessPredictors
function is attached to this example as a supporting file. It truncates and pads sequences to be the same length, equal to the context size of the BERT tokenizer. It also ensures that the sequences always end with an end-of-sentence token.
function mbq = bertMiniBatchQueue(mdl,documents,miniBatchSize) tokenizer = mdl.Tokenizer; [inputID, segmentID] = encode(tokenizer, documents); inputIDDS = arrayDatastore(inputID, OutputType="same"); segmentIDDS = arrayDatastore(segmentID, OutputType="same"); combinedDS = combine(inputIDDS, segmentIDDS); mbq = minibatchqueue(combinedDS,3,... % 3 outputs: inputID, mask, segmentID MiniBatchSize=miniBatchSize, ... MiniBatchFcn=@(inputID,segmentID) preprocessPredictors(inputID,segmentID,tokenizer), ... MiniBatchFormat=["CTB" "CTB" "CTB"], ... OutputEnvironment="auto"); end
Compute Confidence Scores to Detect OOD Data
Extract the dlnetwork
object from the BERT model mdl
to pass to the networkDistributionDiscriminator
function.
net = mdl.Network;
Calibrate Discriminator
Create a distribution discriminator using the energy OOD discrimination algorithm. The energy method computes distribution confidence scores based on softmax scores. For more information, see Distribution Confidence Scores. Set the Temperature
name-value argument to 1
.
discriminator = networkDistributionDiscriminator(net,mbqTrain,mbqOOD,"energy",Temperature=1);
To ensure the discriminator is well calibrated, calculate the distribution confidence scores of training and validation data by passing the discriminator object to the distributionScores
function. Plot a histogram of the distribution scores using the plotDistributionScores
function, defined at the end of this example.
If the distribution discriminator is well calibrated, then the histograms of the two data sets are similar. If the histograms do not look similar, increase or decrease the value of the Temperature
hyperparameter.
scoresTrain = distributionScores(discriminator,mbqTrain); scoresValidation = distributionScores(discriminator,mbqValidation); figure plotDistributionScores(discriminator,scoresTrain,scoresValidation,"Training Data","Validation Data")
Detect OOD Data
Once you are satisfied with your distribution discriminator, pass the discriminator object to the isInNetworkDistribution
function along with the test data, which is ID data, and OOD data. To assess the performance of the discriminator on this set of OOD data, calculate the true positive rate (TPR) and false positive rate (FPR).
tfOOD = isInNetworkDistribution(discriminator,mbqOOD); tfID = isInNetworkDistribution(discriminator,mbqTest); tpr = nnz(tfID)/numel(tfID)
tpr = 0.7906
fpr = nnz(tfOOD)/numel(tfOOD)
fpr = 0.0355
To calculate the distribution scores and distribution threshold of ID and OOD data according to the discriminator, pass the discriminator object to the distributionScores
function. Plot a histogram of the distribution scores using the plotDistributionScores
function, defined at the end of this example.
scoresID = distributionScores(discriminator,mbqTest); scoresOOD = distributionScores(discriminator,mbqOOD); figure plotDistributionScores(discriminator,scoresID,scoresOOD,"In-distribution scores","Out-of-distribution scores")
Helper Functions
The plotDistributionScores
function takes as input a distribution discriminator object and distribution confidence scores for ID and OOD data. The function plots a histogram of the two confidence scores and overlays the distribution threshold.
function plotDistributionScores(discriminator,scoresID,scoresOOD,labelID,labelOOD) hID = histogram(scoresID,Normalization="percentage"); hold on hOOD = histogram(scoresOOD,Normalization="percentage"); xl = xlim; hID.BinWidth = (xl(2)-xl(1))/25; hOOD.BinWidth = (xl(2)-xl(1))/25; xline(discriminator.Threshold) l = legend([labelID labelOOD "Threshold"],Location="best"); title(l,discriminator.Method+" distribution discriminator") xlabel("Distribution Confidence Scores") ylabel("Frequency") hold off end
The removeRareCategories
function removes data from rarely used data.workNeeded
categories, as well as miscellaneous categories that do not share many features.
function commonData = removeRareCategories(data) workNeededCategories = categories(data.WorkNeeded); categoryFrequencies = countcats(data.WorkNeeded); commonCategories = workNeededCategories(categoryFrequencies>500); commonData = data(ismember(data.WorkNeeded,commonCategories),:); otherCategories = commonCategories(contains(commonCategories,"Other")); commonData = commonData(~ismember(commonData.WorkNeeded,otherCategories),:); commonData.WorkNeeded = removecats(commonData.WorkNeeded); end
References
[1] Traffic Signal Work Orders. City of Austin Open Data. Retrieved April 30, 2023, from https://data.austintexas.gov/Transportation-and-Mobility/Traffic-Signal-Work-Orders/hst3-hxcz.
See Also
bertDocumentClassifier
(Text Analytics Toolbox) | trainBERTDocumentClassifier
(Text Analytics Toolbox) | classify
(Text Analytics Toolbox) | networkDistributionDiscriminator
| distributionScores
| isInNetworkDistribution
| minibatchqueue