WhatIf returns the n_counterfactual most similar observations to x_interest from observations in predictor$data$X whose prediction is in the desired_outcome interval.

Details

Only observations whose features values lie between the corresponding values in lower and upper are considered counterfactual candidates.

References

Gower, J. C. (1971), "A general coefficient of similarity and some of its properties". Biometrics, 27, 623–637.

Wexler, J., Pushkarna, M., Bolukbasi, T., Wattenberg, M., Viégas, F., & Wilson, J. (2019). The what-if tool: Interactive probing of machine learning models. IEEE transactions on visualization and computer graphics, 26(1), 56–65.

Methods

Inherited methods


Method new()

Create a new WhatIfRegr object.

Usage

WhatIfRegr$new(
  predictor,
  n_counterfactuals = 1L,
  lower = NULL,
  upper = NULL,
  distance_function = "gower"
)

Arguments

predictor

(Predictor)
The object (created with iml::Predictor$new()) holding the machine learning model and the data.

n_counterfactuals

(integerish(1))
The number of counterfactuals to return Default is 1L.

lower

(numeric() | NULL)
Vector of minimum values for numeric features. If NULL (default), the element for each numeric feature in lower is taken as its minimum value in predictor$data$X. If not NULL, it should be named with the corresponding feature names.

upper

(numeric() | NULL)
Vector of maximum values for numeric features. If NULL (default), the element for each numeric feature in upper is taken as its maximum value in predictor$data$X. If not NULL, it should be named with the corresponding feature names.

distance_function

(function() | 'gower' | 'gower_c')
The distance function used to compute the distances between x_interest and the training data points for finding x_nn. Either the name of an already implemented distance function ('gower' or 'gower_c') or a function. If set to 'gower' (default), then Gower's distance (Gower 1971) is used; if set to 'gower_c', a C-based more efficient version of Gower's distance is used. A function must have three arguments x, y, and data and should return a double matrix with nrow(x) rows and maximum nrow(y) columns.


Method clone()

The objects of this class are cloneable with this method.

Usage

WhatIfRegr$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

if (require("randomForest")) {
  set.seed(123456)
  # Train a model
  rf = randomForest(mpg ~ ., data = mtcars)
  # Create a predictor object
  predictor = iml::Predictor$new(rf)
  # Find counterfactuals for x_interest
  wi_regr = WhatIfRegr$new(predictor, n_counterfactuals = 5L)
  cfactuals = wi_regr$find_counterfactuals(
    x_interest = mtcars[1L, ], desired_outcome = c(22, 26)
  )
  # Print the results
  cfactuals
}
#> 5 Counterfactual(s) 
#>  
#> Desired outcome range: [22, 26] 
#>  
#> Head: 
#>      cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <num> <num> <num> <num> <num> <num> <num> <num> <num> <num>
#> 1:     4 121.0   109  4.11  2.78 18.60     1     1     4     2
#> 2:     4 108.0    93  3.85  2.32 18.61     1     1     4     1
#> 3:     4 146.7    62  3.69  3.19 20.00     1     0     4     2