WhatIf returns the n_counterfactual
most similar observations to x_interest
from observations in predictor$data$X
whose prediction for the desired_class
is in the desired_prob
interval.
By default, the dissimilarities are computed using Gower's dissimilarity measure (Gower 1971).
Only observations whose features values lie between the corresponding values in lower
and upper
are considered
counterfactual candidates.
Gower, J. C. (1971), "A general coefficient of similarity and some of its properties". Biometrics, 27, 623–637.
Wexler, J., Pushkarna, M., Bolukbasi, T., Wattenberg, M., Viégas, F., & Wilson, J. (2019). The what-if tool: Interactive probing of machine learning models. IEEE transactions on visualization and computer graphics, 26(1), 56–65.
counterfactuals::CounterfactualMethod
-> counterfactuals::CounterfactualMethodClassif
-> WhatIfClassif
new()
Create a new WhatIfClassif object.
WhatIfClassif$new(
predictor,
n_counterfactuals = 1L,
lower = NULL,
upper = NULL,
distance_function = "gower"
)
predictor
(Predictor)
The object (created with iml::Predictor$new()
) holding the machine learning model and the data.
n_counterfactuals
(integerish(1)
)
The number of counterfactuals to return. Default is 1L
.
lower
(numeric()
| NULL
)
Vector of minimum values for numeric features.
If NULL
(default), the element for each numeric feature in lower
is taken as its minimum value in predictor$data$X
.
If not NULL
, it should be named with the corresponding feature names.
upper
(numeric()
| NULL
)
Vector of maximum values for numeric features.
If NULL
(default), the element for each numeric feature in upper
is taken as its maximum value in predictor$data$X
.
If not NULL
, it should be named with the corresponding feature names.
distance_function
(function()
| 'gower'
| 'gower_c'
)
The distance function used to compute the distances between x_interest
and the training data points for finding x_nn
.
Either the name of an already implemented distance function
('gower' or 'gower_c') or a function.
If set to 'gower' (default), then Gower's distance (Gower 1971) is used;
if set to 'gower_c', a C-based more efficient version of Gower's distance is used.
A function must have three arguments x
, y
, and data
and should
return a double
matrix with nrow(x)
rows and maximum nrow(y)
columns.
if (require("randomForest")) {
# Train a model
rf = randomForest(Species ~ ., data = iris)
# Create a predictor object
predictor = iml::Predictor$new(rf, type = "prob")
# Find counterfactuals for x_interest
wi_classif = WhatIfClassif$new(predictor, n_counterfactuals = 5L)
cfactuals = wi_classif$find_counterfactuals(
x_interest = iris[150L, ], desired_class = "versicolor", desired_prob = c(0.5, 1)
)
# Print the results
cfactuals$data
}
#> Sepal.Length Sepal.Width Petal.Length Petal.Width
#> <num> <num> <num> <num>
#> 1: 5.9 3.2 4.8 1.8
#> 2: 6.0 2.7 5.1 1.6
#> 3: 5.9 3.0 4.2 1.5
#> 4: 6.7 3.0 5.0 1.7
#> 5: 6.0 2.9 4.5 1.5