classify
Description
Examples
Train BERT Document Classifier
Read the training data from the factoryReports
CSV file. The file contains factory reports, including a text description and categorical label for each report.
filename = "factoryReports.csv"; data = readtable(filename,TextType="string"); head(data)
Description Category Urgency Resolution Cost _____________________________________________________________________ ____________________ ________ ____________________ _____ "Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45 "Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35 "There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200 "Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352 "Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55 "Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371 "A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441 "Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38
Convert the labels in the Category
column of the table to categorical values.
data.Category = categorical(data.Category);
Partition the data into a training set and a test set. Specify the holdout percentage as 10%.
cvp = cvpartition(data.Category,Holdout=0.1); dataTrain = data(cvp.training,:); dataTest = data(cvp.test,:);
Extract the text data and labels from the tables.
textDataTrain = dataTrain.Description; textDataTest = dataTest.Description; TTrain = dataTrain.Category; TTest = dataTest.Category;
Load a pretrained BERT-Base document classifier using the bertDocumentClassifier
function.
classNames = categories(data.Category); mdl = bertDocumentClassifier(ClassNames=classNames)
mdl = bertDocumentClassifier with properties: Network: [1×1 dlnetwork] Tokenizer: [1×1 bertTokenizer] ClassNames: ["Electronic Failure" "Leak" "Mechanical Failure" "Software Failure"]
Specify the training options. Choosing among training options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager (Deep Learning Toolbox) app.
Train using the Adam optimizer.
Train for eight epochs.
For fine-tuning, lower the learning rate. Train using a learning rate of 0.0001.
Shuffle the data every epoch.
Monitor the training progress in a plot and monitor the accuracy metric.
Disable the verbose output.
options = trainingOptions("adam", ... MaxEpochs=8, ... InitialLearnRate=1e-4, ... Shuffle="every-epoch", ... Plots="training-progress", ... Metrics="accuracy", ... Verbose=false);
Train the neural network using the trainBERTDocumentClassifier
function. By default, the trainBERTDocumentClassifier
function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainBERTDocumentClassifier
function uses the CPU. To specify the execution environment, use the ExecutionEnvironment
training option.
mdl = trainBERTDocumentClassifier(textDataTrain,TTrain,mdl,options);
Make predictions using the test data.
YTest = classify(mdl,textDataTest);
Calculate the classification accuracy of the test predictions.
accuracy = mean(TTest == YTest)
accuracy = 0.9375
Input Arguments
mdl
— BERT document classifier model
bertDocumentClassifier
object
BERT document classifier model, specified as a bertDocumentClassifier
object.
documents
— Input documents
string array | cell array of character vectors | tokenizedDocument
array
Input documents, specified as a string array, a cell array of character vectors, or
a tokenizedDocument
array.
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN
, where Name
is
the argument name and Value
is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Example: classify(mdl,document,MiniBatchSize=64)
classifies the
specified documents using mini-batches of size 64.
MiniBatchSize
— Mini-batch size
32
(default) | positive integer
Mini-batch size to use for prediction, specified as a positive integer. Larger mini-batch sizes require more memory, but can lead to faster predictions.
Data Types: single
| double
| int8
| int16
| int32
| int64
| uint8
| uint16
| uint32
| uint64
Acceleration
— Performance optimization
"auto"
(default) | "mex"
| "none"
Performance optimization, specified as one of these values:
"auto"
— Automatically apply a number of optimizations that are suitable for the input network and hardware resources."mex"
— Compile and execute a MEX function. This option is available only when you use a GPU. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). If Parallel Computing Toolbox or a suitable GPU is not available, then the software returns an error."none"
— Disable all acceleration.
When you use the "auto"
or "mex"
option, the software
can offer performance benefits at the expense of an increased initial run time. Subsequent
calls to the function are typically faster. Use performance optimization when you call the
function multiple times using different input data.
When Acceleration
is "mex"
, the software generates and
executes a MEX function based on the model and parameters you specify in the function call.
A single model can have several associated MEX functions at one time. Clearing the model
variable also clears any MEX functions associated with that model.
When Acceleration
is
"auto"
, the software does not generate a MEX function.
The "mex"
option is available only when you use a GPU. You must have a
C/C++ compiler installed and the GPU Coder™ Interface for Deep Learning support package. Install the support package using the Add-On Explorer in
MATLAB®. For setup instructions, see MEX Setup (GPU Coder). GPU Coder is not required.
MATLAB
Compiler™ software does not support compiling models when you use the
"mex"
option.
ExecutionEnvironment
— Hardware resource
"auto"
(default) | "gpu"
| "cpu"
Hardware resource, specified as one of these values:
"auto"
— Use a GPU if one is available. Otherwise, use the CPU."gpu"
— Use the GPU. Using a GPU requires a Parallel Computing Toolbox license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). If Parallel Computing Toolbox or a suitable GPU is not available, then the software returns an error."cpu"
— Use the CPU.
Output Arguments
Y
— Predicted classes
categorical array
Predicted classes, returned as a categorical array.
Version History
Introduced in R2023b
See Also
trainBERTDocumentClassifier
| bertDocumentClassifier
| bert
| dlnetwork
(Deep Learning Toolbox) | bertTokenizer
Topics
- Train BERT Document Classifier
- Classify Text Data Using Deep Learning
- Create Simple Text Model for Classification
- Analyze Text Data Using Topic Models
- Analyze Text Data Using Multiword Phrases
- Sequence Classification Using Deep Learning (Deep Learning Toolbox)
- Deep Learning in MATLAB (Deep Learning Toolbox)
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)