WhatIf returns the n_counterfactual
most similar observations to x_interest
from observations in predictor$data$X
whose prediction is in the desired_outcome
interval.
Only observations whose features values lie between the corresponding values in lower
and upper
are considered
counterfactual candidates.
Gower, J. C. (1971), "A general coefficient of similarity and some of its properties". Biometrics, 27, 623–637.
Wexler, J., Pushkarna, M., Bolukbasi, T., Wattenberg, M., Viégas, F., & Wilson, J. (2019). The what-if tool: Interactive probing of machine learning models. IEEE transactions on visualization and computer graphics, 26(1), 56–65.
counterfactuals::CounterfactualMethod
-> counterfactuals::CounterfactualMethodRegr
-> WhatIfRegr
new()
Create a new WhatIfRegr object.
WhatIfRegr$new(
predictor,
n_counterfactuals = 1L,
lower = NULL,
upper = NULL,
distance_function = "gower"
)
predictor
(Predictor)
The object (created with iml::Predictor$new()
) holding the machine learning model and the data.
n_counterfactuals
(integerish(1)
)
The number of counterfactuals to return Default is 1L
.
lower
(numeric()
| NULL
)
Vector of minimum values for numeric features.
If NULL
(default), the element for each numeric feature in lower
is taken as its minimum value in predictor$data$X
.
If not NULL
, it should be named with the corresponding feature names.
upper
(numeric()
| NULL
)
Vector of maximum values for numeric features.
If NULL
(default), the element for each numeric feature in upper
is taken as its maximum value in predictor$data$X
.
If not NULL
, it should be named with the corresponding feature names.
distance_function
(function()
| 'gower'
| 'gower_c'
)
The distance function used to compute the distances between x_interest
and the training data points for finding x_nn
.
Either the name of an already implemented distance function
('gower' or 'gower_c') or a function.
If set to 'gower' (default), then Gower's distance (Gower 1971) is used;
if set to 'gower_c', a C-based more efficient version of Gower's distance is used.
A function must have three arguments x
, y
, and data
and should
return a double
matrix with nrow(x)
rows and maximum nrow(y)
columns.
if (require("randomForest")) {
set.seed(123456)
# Train a model
rf = randomForest(mpg ~ ., data = mtcars)
# Create a predictor object
predictor = iml::Predictor$new(rf)
# Find counterfactuals for x_interest
wi_regr = WhatIfRegr$new(predictor, n_counterfactuals = 5L)
cfactuals = wi_regr$find_counterfactuals(
x_interest = mtcars[1L, ], desired_outcome = c(22, 26)
)
# Print the results
cfactuals
}
#> 5 Counterfactual(s)
#>
#> Desired outcome range: [22, 26]
#>
#> Head:
#> cyl disp hp drat wt qsec vs am gear carb
#> <num> <num> <num> <num> <num> <num> <num> <num> <num> <num>
#> 1: 4 121.0 109 4.11 2.78 18.60 1 1 4 2
#> 2: 4 108.0 93 3.85 2.32 18.61 1 1 4 1
#> 3: 4 146.7 62 3.69 3.19 20.00 1 0 4 2