WhatIf returns the n_counterfactual most similar observations to x_interest from observations in predictor$data$X whose prediction for the desired_class is in the desired_prob interval.

Details

By default, the dissimilarities are computed using Gower's dissimilarity measure (Gower 1971).
Only observations whose features values lie between the corresponding values in lower and upper are considered counterfactual candidates.

References

Gower, J. C. (1971), "A general coefficient of similarity and some of its properties". Biometrics, 27, 623–637.

Wexler, J., Pushkarna, M., Bolukbasi, T., Wattenberg, M., Viégas, F., & Wilson, J. (2019). The what-if tool: Interactive probing of machine learning models. IEEE transactions on visualization and computer graphics, 26(1), 56–65.

Methods

Inherited methods


Method new()

Create a new WhatIfClassif object.

Usage

WhatIfClassif$new(
  predictor,
  n_counterfactuals = 1L,
  lower = NULL,
  upper = NULL,
  distance_function = "gower"
)

Arguments

predictor

(Predictor)
The object (created with iml::Predictor$new()) holding the machine learning model and the data.

n_counterfactuals

(integerish(1))
The number of counterfactuals to return. Default is 1L.

lower

(numeric() | NULL)
Vector of minimum values for numeric features. If NULL (default), the element for each numeric feature in lower is taken as its minimum value in predictor$data$X. If not NULL, it should be named with the corresponding feature names.

upper

(numeric() | NULL)
Vector of maximum values for numeric features. If NULL (default), the element for each numeric feature in upper is taken as its maximum value in predictor$data$X. If not NULL, it should be named with the corresponding feature names.

distance_function

(function() | 'gower' | 'gower_c')
The distance function used to compute the distances between x_interest and the training data points for finding x_nn. Either the name of an already implemented distance function ('gower' or 'gower_c') or a function. If set to 'gower' (default), then Gower's distance (Gower 1971) is used; if set to 'gower_c', a C-based more efficient version of Gower's distance is used. A function must have three arguments x, y, and data and should return a double matrix with nrow(x) rows and maximum nrow(y) columns.


Method clone()

The objects of this class are cloneable with this method.

Usage

WhatIfClassif$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

if (require("randomForest")) {
  # Train a model
  rf = randomForest(Species ~ ., data = iris)
  # Create a predictor object
  predictor = iml::Predictor$new(rf, type = "prob")
  # Find counterfactuals for x_interest
  wi_classif = WhatIfClassif$new(predictor, n_counterfactuals = 5L)
  cfactuals = wi_classif$find_counterfactuals(
    x_interest = iris[150L, ], desired_class = "versicolor", desired_prob = c(0.5, 1)
  )
  # Print the results
  cfactuals$data
}
#>    Sepal.Length Sepal.Width Petal.Length Petal.Width
#>           <num>       <num>        <num>       <num>
#> 1:          5.9         3.2          4.8         1.8
#> 2:          6.0         2.7          5.1         1.6
#> 3:          5.9         3.0          4.2         1.5
#> 4:          6.7         3.0          5.0         1.7
#> 5:          6.0         2.9          4.5         1.5