{caret} custom function implementation

Problem: Optimal Probability Threshold

It’s been a while! I’m writing this post for a couple of reasons.

  1. I don’t want to only be writing my dissertation…
  2. Someone found me on github and asked me to help them.

I am happy to be writing this post for the two reasons listed above so let’s jump straight into the problem.

Recently, a random stranger e-mailed me about a problem they were having at work. It was regarding identification of the optimal probability threshold in classification problems. Essentially, when you fit a logistic regression model to your data, you’re likely interested in obtaining the predicted probabilities (via predict()). The predicted probabilities then allows you to classify the observation if they exceed a certain probability threshold. Most of the time, people don’t really think about this probability threshold. I for one believe the threshold is domain/context specific. So there isn’t going to be a universal threshold that you can hammer onto all problems. So then, how do we find the most optimal probability threshold for the problem you’re working on?

Solution: Specify it in your cross validation procedure and select the optimal threshold

This has been handled in the {caret} documentation already, so I won’t go into all the technical details here. I’ll walk through a scenario and hopefully it will help people go through it with ease with their own data.

Let’s start with the data(Sonar). People may be familiar with the data set already but I’m just going to show a tidbit. If people are interested in learning more… help(Sonar)

library("pacman")
p_load(caret, mlbench, knitr)
data("Sonar")

Let’s split the data into training and test set.

inTrain <- createDataPartition(Sonar$Class, p=0.75, list = FALSE)
training <- Sonar[inTrain,]
testing <- Sonar[-inTrain,]

OK. So from here we can use the default train() function to undergo a k-fold CV and extract the performance of a model being trained. Let’s use a simple glm with elastic net and identify the optimal alpha, lambda.

tuneGrid <- expand.grid(alpha = 0:1, lambda = seq(0.0001, 1, length = 10))

myControl <- trainControl(
  method = "cv", number = 10
)

glmtrain <- train(
  Class ~.,
  training,
  method="glmnet",
  tuneGrid = tuneGrid, trControl = myControl
)

glmtrain
## glmnet 
## 
## 157 samples
##  60 predictor
##   2 classes: 'M', 'R' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 141, 141, 142, 142, 141, 142, ... 
## Resampling results across tuning parameters:
## 
##   alpha  lambda  Accuracy   Kappa    
##   0      0.0001  0.7764706  0.5446293
##   0      0.1112  0.8026716  0.6013982
##   0      0.2223  0.7893382  0.5755607
##   0      0.3334  0.7826716  0.5627521
##   0      0.4445  0.7823039  0.5622975
##   0      0.5556  0.7756373  0.5485573
##   0      0.6667  0.7560539  0.5077482
##   0      0.7778  0.7560539  0.5077482
##   0      0.8889  0.7498039  0.4960175
##   0      1.0000  0.7498039  0.4939777
##   1      0.0001  0.7123039  0.4123037
##   1      0.1112  0.7639706  0.5197876
##   1      0.2223  0.5350245  0.0000000
##   1      0.3334  0.5350245  0.0000000
##   1      0.4445  0.5350245  0.0000000
##   1      0.5556  0.5350245  0.0000000
##   1      0.6667  0.5350245  0.0000000
##   1      0.7778  0.5350245  0.0000000
##   1      0.8889  0.5350245  0.0000000
##   1      1.0000  0.5350245  0.0000000
## 
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 0 and lambda = 0.1112.

So… we see the alpha and lambda values are chosen. We can also extract the model to

pred <- predict(glmtrain, newdata = testing)
confusionMatrix(testing$Class, pred)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  M  R
##          M 17 10
##          R  3 21
##                                           
##                Accuracy : 0.7451          
##                  95% CI : (0.6037, 0.8567)
##     No Information Rate : 0.6078          
##     P-Value [Acc > NIR] : 0.02870         
##                                           
##                   Kappa : 0.4966          
##                                           
##  Mcnemar's Test P-Value : 0.09609         
##                                           
##             Sensitivity : 0.8500          
##             Specificity : 0.6774          
##          Pos Pred Value : 0.6296          
##          Neg Pred Value : 0.8750          
##              Prevalence : 0.3922          
##          Detection Rate : 0.3333          
##    Detection Prevalence : 0.5294          
##       Balanced Accuracy : 0.7637          
##                                           
##        'Positive' Class : M               
## 

So this would be a short prediction procedure to obtain the predictions for the test data. One thing that we didn’t discuss above was how the predict() function extracted the classes? Basically, the generic predict() function that works on objects of class train will output the raw classess if type= argument isn’t specified in the function.

probs <- predict(glmtrain, newdata = testing, type = "prob")

Once we specify the type="prob", we get the probabilities for both classes. From there, we could manually specify the predictions as such:

preds2 <- factor(ifelse(probs$M >= 0.5, "M","R"), levels=c("M","R"))
confusionMatrix(testing$Class, preds2)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  M  R
##          M 17 10
##          R  3 21
##                                           
##                Accuracy : 0.7451          
##                  95% CI : (0.6037, 0.8567)
##     No Information Rate : 0.6078          
##     P-Value [Acc > NIR] : 0.02870         
##                                           
##                   Kappa : 0.4966          
##                                           
##  Mcnemar's Test P-Value : 0.09609         
##                                           
##             Sensitivity : 0.8500          
##             Specificity : 0.6774          
##          Pos Pred Value : 0.6296          
##          Neg Pred Value : 0.8750          
##              Prevalence : 0.3922          
##          Detection Rate : 0.3333          
##    Detection Prevalence : 0.5294          
##       Balanced Accuracy : 0.7637          
##                                           
##        'Positive' Class : M               
## 

So that’s how things are done interally with the generic predict() function. Now let’s get to how we would optimize the probability thresholds… we essentially want another column that displays the range of probability thresholds that we would wnat to display the performance values (in this case we’ve just used Accuracy but we can modify that).

We need to get a bit more closer into the internals of {caret} to specify the probability threshold values. We first need to get information regarding the model being used in train. The function to use is getModelInfo():

thresh_code <- getModelInfo("glmnet", regex=FALSE)[[1]]
str(thresh_code)
## List of 15
##  $ label     : chr "glmnet"
##  $ library   : chr [1:2] "glmnet" "Matrix"
##  $ type      : chr [1:2] "Regression" "Classification"
##  $ parameters:'data.frame':  2 obs. of  3 variables:
##   ..$ parameter: chr [1:2] "alpha" "lambda"
##   ..$ class    : chr [1:2] "numeric" "numeric"
##   ..$ label    : chr [1:2] "Mixing Percentage" "Regularization Parameter"
##  $ grid      :function (x, y, len = NULL, search = "grid")  
##   ..- attr(*, "srcref")= 'srcref' int [1:8] 7 26 27 19 26 19 7 27
##   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x55bece9300d0> 
##  $ loop      :function (grid)  
##   ..- attr(*, "srcref")= 'srcref' int [1:8] 28 26 39 19 26 19 28 39
##   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x55bece9300d0> 
##  $ fit       :function (x, y, wts, param, lev, last, classProbs, ...)  
##   ..- attr(*, "srcref")= 'srcref' int [1:8] 40 25 66 19 25 19 40 66
##   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x55bece9300d0> 
##  $ predict   :function (modelFit, newdata, submodels = NULL)  
##   ..- attr(*, "srcref")= 'srcref' int [1:8] 67 29 88 19 29 19 67 88
##   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x55bece9300d0> 
##  $ prob      :function (modelFit, newdata, submodels = NULL)  
##   ..- attr(*, "srcref")= 'srcref' int [1:8] 89 26 122 19 26 19 89 122
##   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x55bece9300d0> 
##  $ predictors:function (x, lambda = NULL, ...)  
##   ..- attr(*, "srcref")= 'srcref' int [1:8] 123 32 139 19 32 19 123 139
##   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x55bece9300d0> 
##  $ varImp    :function (object, lambda = NULL, ...)  
##   ..- attr(*, "srcref")= 'srcref' int [1:8] 140 28 154 19 28 19 140 154
##   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x55bece9300d0> 
##  $ levels    :function (x)  
##   ..- attr(*, "srcref")= 'srcref' int [1:8] 155 28 155 93 28 93 155 155
##   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x55bece9300d0> 
##  $ tags      : chr [1:6] "Generalized Linear Model" "Implicit Feature Selection" "L1 Regularization" "L2 Regularization" ...
##  $ sort      :function (x)  
##   ..- attr(*, "srcref")= 'srcref' int [1:8] 159 26 159 66 26 66 159 159
##   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x55bece9300d0> 
##  $ trim      :function (x)  
##   ..- attr(*, "srcref")= 'srcref' int [1:8] 160 26 165 19 26 19 160 165
##   .. ..- attr(*, "srcfile")=Classes 'srcfilecopy', 'srcfile' <environment: 0x55bece9300d0>

You see a bunch of things that will be useful. First we need to check that the type of task is Classification

thresh_code$type
## [1] "Regression"     "Classification"

Next we need to check the parameters

thresh_code$parameters
##   parameter   class                    label
## 1     alpha numeric        Mixing Percentage
## 2    lambda numeric Regularization Parameter

Since we want to optimize the probability threshold let’s add that part in here

thresh_code$parameters <- rbind(
  thresh_code$parameters, 
  data.frame(
    parameter="threshold",
    class="numeric",
    label="Probability Cutoff"
  )
)
thresh_code$parameters
##   parameter   class                    label
## 1     alpha numeric        Mixing Percentage
## 2    lambda numeric Regularization Parameter
## 3 threshold numeric       Probability Cutoff

Ok since that’s done, now we need to make sure the tuning Grid iterates through the probability threshold values according to how we specify it in train()

thresh_code$grid <- function(x, y, len=NULL, search="grid"){
  if(search =="grid") {
    numLev <- if(is.character(y) | is.factor(y)) length(levels(y)) else NA
    if(!is.na(numLev)){
      fam <- ifelse(numLev>2, "multinomial", "binomial")
    } else fam <- "gaussian"
    init <- glmnet::glmnet(Matrix::as.matrix(x),y,
                           family=fam,
                           nlambda=len+2,
                           alpha=0.5)
    lambda <- unique(init$lambda)
    lambda <- lambda[-c(1, length(lambda))]
    lambda <- lambda[1:min(length(lambda), len)]
    out <- expand.grid(
      alpha = seq(0.1, 1, length = len),
      lambda = lambda,
      threshold = seq(0.01,0.99, length=len)
    )
  } else {
    out <- data.frame(alpha = runif(len, min=0, 1),
               lambda = 2*runi(len, min = -10, 3),
               threshold = runif(1,0,size=len)
               )
  }
  out
}

Essentially the tuning grid is set up in a way such that if we don’t specify a tuning grid in the train() function then the function will generate a tuning grid specified as the bottom portion of the thresh_code$grid function.

Now that we’ve specified the tuning grid, we now need to set up a code that will have each performance calculations be conducted within the folds by each threshold value: A loop.

thresh_code$loop <- function(grid){
  loop <- plyr::ddply(
    grid, c("alpha","lambda"),
    function(x) c(threshold = max(x$threshold))
  )
  submodels <- vector(mode="list", length = nrow(loop))
  
  for(i in seq(along = loop$threshold)) {
    index <- which(
      grid$alpha == loop$alpha[i] &
        grid$lambda == loop$lambda[i] 
    )
    
    cuts <- grid[index, "threshold"]
    submodels[[i]] <- data.frame(threshold = cuts[cuts != loop$threshold[i]])
  }
  list(loop = loop, submodels = submodels)
}

Ok that was convoluted… You can try to debug the code and check what objects you get as you run through each iteration but what’s being done is, for each value of the threshold, we’re setting up a grid of alpha & lambda values that we’ll cross validate the performance on. Therefore submodels are being built for each threshold value. From those submodels we’re extracting the alpha & lambda performances.

Now let’s move on to the prediction function. Here we want to set it to getting predicted probabilities not class:

thresh_code$predict <- function(modelFit, newdata, submodels = NULL) {
  if(!is.matrix(newdata)) newdata <- Matrix::as.matrix(newdata)
  if(length(modelFit$obsLevels) < 2) {
    class1Prob <- predict(modelFit, newdata, s = modelFit$lambdaOpt)
  } else {
    class1Prob <- predict(modelFit, newdata, s = modelFit$lambdaOpt, type = "response")
  }
  if(is.matrix(class1Prob)) class1Prob <- class1Prob[,1]
  
  if(modelFit$problemType == "Classification"){
    if(length(modelFit$obsLevels) == 2){
      out <- ifelse(class1Prob >= modelFit$tuneValue$threshold,
                    modelFit$obsLevels[1],
                    modelFit$obsLevels[2]
                    )
    } else {
      out <- matrix(out, ncol=length(modelFit$obsLevels), byrow=TRUE)
      out <- modelFit$obsLevels[apply(out, 1, which.max)]
    }
  }
  
  if(!is.null(submodels)) {
    if(length(modelFit$obsLevels) < 2) {
      tmp <- as.list(as.data.frame(predict(modelFit, newdata, s = submodels$lambda),
                                   stringsAsFactors = TRUE))
    } else {
      tmp2 <- out
      out <- vector(mode = "list", length=length(submodels$threshold))
      out[[1]] <- tmp2
      for(j in seq(along = submodels$threshold)) {
        if(modelFit$problemType == "Classification"){
          if(length(modelFit$obsLevels) == 2){
            out[[j+1]]<- ifelse(class1Prob >= submodels$threshold[[j]],
                               modelFit$obsLevels[1],
                               modelFit$obsLevels[2]
                               )
          } else {
            tmp_pred <- matrix(tmp_pred, ncol = length(modelFit$obsLevels), byrow=TRUE)
            tmp_pred <- matrix(tmp_pred, ncol = length(modelFit$obsLevels), byrow=TRUE)
          }
        }
      }
    }
  }
  out
}

That was a mess as well. But the gist is that we’re trying to get predictions for each threshold value for each of the test folds. The list out is essentially capturing the predicted classess according to the threshold values.

We’re almost there now! The probabilities are always the same but we have to create mulitple versions of the probs to evaluate the data across thresholds

thresh_code$prob <- function(modelFit, newdata, submodels = NULL) {
  out <- as.data.frame(predict(modelFit, as.matrix(newdata), s=modelFit$lambdaOpt, type="response"))
  out <- cbind(1-out, out)
  colnames(out) <- modelFit$obsLevels
  
  if(!is.null(submodels)) {
    tmp <- vector(mode = "list", length=length(submodels$threshold) + 1)
    tmp[[1]] <- out
    for(j in seq(along = submodels$threshold)) {
      tmp_pred <- predict(modelFit, as.matrix(newdata), s=modelFit$lambdaOpt, type="response")
      tmp_pred <- cbind(1-tmp_pred, tmp_pred)
      colnames(tmp_pred) <- modelFit$obsLevels
      tmp_pred <- as.data.frame(tmp_pred)
      tmp[[j+1]] <- tmp_pred
    }
  }
  tmp
}

Alright we’ve basically specified everything except the final performance calculation ( which we could customize as well… later). Let’s just keep our default performance calculation for now and see hwow we do.

tuneGridnew <- expand.grid(
  alpha = 0:1, 
  lambda = seq(0.0001, 1, length = 5),
  threshold = seq(0, 1, 0.2)
  )

myControl <- trainControl(
  method = "cv", number = 5,
  classProbs = TRUE
)

glmtrain2 <- train(
  Class ~.,
  training,
  method=thresh_code,
  tuneGrid = tuneGridnew, 
  trControl = myControl
)

kable(head(glmtrain2$results))
alpha lambda threshold Accuracy Kappa AccuracySD KappaSD
0 1e-04 0.0 0.5350806 0.0000000 0.0136257 0.0000000
0 1e-04 0.2 0.2739919 -0.4773996 0.0767136 0.1473699
0 1e-04 0.4 0.2350806 -0.5348954 0.0988042 0.1982347
0 1e-04 0.6 0.2096774 -0.5632851 0.0672463 0.1373809
0 1e-04 0.8 0.2090726 -0.5411881 0.0819749 0.1707124
0 1e-04 1.0 0.4649194 0.0000000 0.0136257 0.0000000
kable(glmtrain2$bestTune)
alpha lambda threshold
25 0 1 0

Now the CV is iterating throug h the probability threshold!

Specification of Custom Performance Metrics

I’m going to keep this section relatively short and make it as practical as possible. Basically this was dealt with in stack. If you want to specify all the nice performance metrics for a classification then use summaryFunction = MySummary otherwise use whatever suits your need. I’m also changing the roc() function here from the {pROC} version to the {ModelMetrics} version here. Note that in the thresh_code$prob portion above, we did a bit of manipulating to get the correct probabilities specified for the folds of the training data to be evaluated:

twoClassSummarya <- function (data, lev = NULL, model = NULL) 
{
  if (length(lev) > 2) {
    stop(paste("Your outcome has", length(lev), "levels. The twoClassSummary() function isn't appropriate."))
  }
  library("pROC")
  if (!all(levels(data[, "pred"]) == lev)) {
    stop("levels of observed and predicted data do not match")
  }
  
  # print(table(data$obs))
  
  rocAUC <- ModelMetrics::auc(
    as.numeric(as.character(ifelse(data$obs == lev[1], 1, 0))),
    data[,lev[1]]
  )
  out <- c(rocAUC, sensitivity(data[, "pred"], data[, "obs"], 
                               lev[1]), specificity(data[, "pred"], data[, "obs"], lev[2]))
  names(out) <- c("ROC", "Sens", "Spec")
  out
}

MySummary  <- function(data, lev = NULL, model = NULL){
  a1 <- defaultSummary(data, lev, model)
  b1 <- twoClassSummary(data, lev, model)
  c1 <- prSummary(data, lev, model)
  out <- c(a1, b1, c1)
  out}
myControl <- trainControl(method = "cv", 
                     number = 5,
                     savePredictions = TRUE,
                     summaryFunction = MySummary,
                     classProbs = TRUE,
                     allowParallel = FALSE
                     )

glmtrain3 <- train(
  Class ~.,
  training,
  method=thresh_code,
  tuneGrid = tuneGridnew, 
  trControl = myControl
)
## Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, :
## There were missing values in resampled performance measures.
kable(head(glmtrain3$results))
alpha lambda threshold Accuracy Kappa ROC Sens Spec AUC Precision Recall F AccuracySD KappaSD ROCSD SensSD SpecSD AUCSD PrecisionSD RecallSD FSD
0 1e-04 0.0 0.5350806 0.0000000 0.8672199 1.0000000 0.0000000 0.8184141 0.5350806 1.0000000 0.6970546 0.0136257 0.0000000 0.0426628 0.0000000 0.0000000 0.0417762 0.0136257 0.0000000 0.0115884
0 1e-04 0.2 0.2743952 -0.4826478 0.8672199 0.4161765 0.1085714 0.8184141 0.3444652 0.4161765 0.3755743 0.0559199 0.1052449 0.0426628 0.1225539 0.1001926 0.0417762 0.0654402 0.1225539 0.0883808
0 1e-04 0.4 0.2622984 -0.4841204 0.8672199 0.3095588 0.2066667 0.8184141 0.3095760 0.3095588 0.3091626 0.0832810 0.1668428 0.0426628 0.0870167 0.1087759 0.0417762 0.0839530 0.0870167 0.0845584
0 1e-04 0.6 0.1979839 -0.5906412 0.8672199 0.1676471 0.2333333 0.8184141 0.1994048 0.1676471 0.1817993 0.0548845 0.1089436 0.0426628 0.0546299 0.0778102 0.0417762 0.0549589 0.0546299 0.0548846
0 1e-04 0.8 0.2235887 -0.5068518 0.8672199 0.0352941 0.4409524 0.8184141 0.0687179 0.0352941 0.0758991 0.0735417 0.1624222 0.0426628 0.0322190 0.1652182 0.0417762 0.0708593 0.0322190 0.0102673
0 1e-04 1.0 0.4649194 0.0000000 0.8672199 0.0000000 1.0000000 0.8184141 NaN 0.0000000 NaN 0.0136257 0.0000000 0.0426628 0.0000000 0.0000000 0.0417762 NA 0.0000000 NA
kable(glmtrain2$bestTune)
alpha lambda threshold
25 0 1 0

There you have it! I hope this helped :) This is by no means super easy to implement on the first try so I would suggest taking enough time to debug areas of the code that you don’t understand via inserting print() statments to spit out the objects of interest along the way. Happy coding!

Avatar
Chong H. Kim
Health Economics & Outcomes Researcher

My research interests include health economics & outcomes research (HEOR), real-world evidence/observation research, predictive modeling, and spatial statistics.

Related