Multiclass projects in DataRobot are projects that allow for prediction of more than two classes (unlike binary prediction, which is for precisely two classes). Currently, DataRobot supports predicting up to 10 different classes.
To explore multiclass projects, let’s first connect to DataRobot. First, you must load the DataRobot R package library.
If you have set up a credentials file,
library(datarobot)
will initialize a connection to
DataRobot automatically. Otherwise, you can specify your
endpoint
and apiToken
as in this example to
connect to DataRobot directly. For more information on connecting to
DataRobot, see the “Introduction to DataRobot” vignette.
library(datarobot)
<- "https://<YOUR DATAROBOT URL GOES HERE>/api/v2"
endpoint <- "<YOUR API TOKEN GOES HERE>"
apiToken ConnectToDataRobot(endpoint = endpoint, token = apiToken)
Let’s predict for the iris dataset:
library(knitr)
data(iris) # Load `iris` from R data memory.
kable(iris)
If your target is categorical and has a cardinality of up to 10, we
will automatically select a Multiclass targetType
and that
argument is not needed when calling StartProject
. However,
if the target is numerical and you would like to force it to be seen as
a Multiclass project in DataRobot, you can specify the
targetType
as seen below:
<- StartProject(iris,
project projectName = "multiclassExample",
target = "Species",
targetType = TargetType$Multiclass,
maxWait = 600)
Now we can build a model:
<- ListBlueprints(project)[[1]]
blueprint RequestNewModel(project, blueprint)
And then we can get predictions:
<- ListModels(project)[[1]]
model <- Predict(model, iris)
predictions print(table(predictions))
## request issued, waiting for predictions
## Multiclass with labels setosa, versicolor, virginica
setosa versicolor virginica
50 47 53
You can also get a dataframe with the probabilities of each class
using type = "probability"
:
<- Predict(model, iris, type = "probability")
predictions kable(head(predictions))
## request issued, waiting for predictions
## Multiclass with labels setosa, versicolor, virginica
class_setosa | class_versicolor | class_virginica |
---|---|---|
0.9987500 | 0.0000000 | 0.0012500 |
0.9344544 | 0.0491984 | 0.0163472 |
0.9854799 | 0.0080586 | 0.0064615 |
0.9931519 | 0.0054731 | 0.0013750 |
0.9954167 | 0.0022222 | 0.0023611 |
0.9883673 | 0.0017766 | 0.0098561 |
The confusion chart is a chart that helps understand how the multiclass model performs:
<- GetConfusionChart(model, source = DataPartition$VALIDATION)
confusionChart kable(capture.output(confusionChart))
x |
---|
\(source | |[1] "validation" | | | |\)data |
\(data\)classes |
[1] “setosa” “versicolor” “virginica” |
\(data\)classMetrics |
\(data\)classMetrics\(wasActualPercentages | |\)data\(classMetrics\)wasActualPercentages[[1]] |
percentage otherClassName |
1 1 setosa |
2 0 versicolor |
3 0 virginica |
\(data\)classMetrics\(wasActualPercentages[[2]] | |percentage otherClassName | |1 0.0 setosa | |2 0.8 versicolor | |3 0.2 virginica | | | |\)data\(classMetrics\)wasActualPercentages[[3]] |
percentage otherClassName |
1 0 setosa |
2 0 versicolor |
3 1 virginica |
\(data\)classMetrics\(f1 | |[1] 1.0000000 0.8888889 0.9523810 | | | |\)data\(classMetrics\)confusionMatrixOneVsAll |
\(data\)classMetrics\(confusionMatrixOneVsAll[[1]] | |[,1] [,2] | |[1,] 15 0 | |[2,] 0 9 | | | |\)data\(classMetrics\)confusionMatrixOneVsAll[[2]] |
[,1] [,2] |
[1,] 19 0 |
[2,] 1 4 |
\(data\)classMetrics\(confusionMatrixOneVsAll[[3]] | |[,1] [,2] | |[1,] 13 1 | |[2,] 0 10 | | | | | |\)data\(classMetrics\)recall |
[1] 1.0 0.8 1.0 |
\(data\)classMetrics\(actualCount | |[1] 9 5 10 | | | |\)data\(classMetrics\)precision |
[1] 1.0000000 1.0000000 0.9090909 |
\(data\)classMetrics\(wasPredictedPercentages | |\)data\(classMetrics\)wasPredictedPercentages[[1]] |
percentage otherClassName |
1 1 setosa |
2 0 versicolor |
3 0 virginica |
\(data\)classMetrics\(wasPredictedPercentages[[2]] | |percentage otherClassName | |1 0 setosa | |2 1 versicolor | |3 0 virginica | | | |\)data\(classMetrics\)wasPredictedPercentages[[3]] |
percentage otherClassName |
1 0.00000000 setosa |
2 0.09090909 versicolor |
3 0.90909091 virginica |
\(data\)classMetrics\(className | |[1] "setosa" "versicolor" "virginica" | | | |\)data\(classMetrics\)predictedCount |
[1] 9 4 11 |
\(data\)confusionMatrix |
[,1] [,2] [,3] |
[1,] 9 0 0 |
[2,] 0 4 1 |
[3,] 0 0 10 |
Here, we can see the source comes from the "validation"
partition (options are in the DataPartition
object), and
class metrics show:
The confusion chart also shows a full confusion matrix with one row and one column for each class, showing how each class was predicted or mispredicted. The columns represent the predicted classes and the rows represent the actual classes.