R/NICERegr.R
NICERegr.Rd
NICE (Brughmans and Martens 2021) searches for counterfactuals by iteratively replacing feature values
of x_interest
with the corresponding value of its most similar (optionally correctly predicted) instance x_nn
.
While the original method is only applicable to classification tasks (see NICEClassif), this implementation extend it to regression tasks.
NICE starts the counterfactual search for x_interest
by finding its most similar (optionally) correctly predicted
neighbor x_nn
with(in) the desired prediction (range). Correctly predicted means that the prediction of x_nn
is less
than a user-specified margin_correct
away from the true outcome of x_nn
.
This is designed to mimic the search for x_nn
for regression tasks.
If no x_nn
satisfies this constraint, a warning is returned that no counterfactual could be found.
In the first iteration, NICE creates new instances by replacing a different feature value of x_interest
with the corresponding
value of x_nn
in each new instance. Thus, if x_nn
differs from x_interest
in d
features, d
new instances are created.
Then, the reward values for the created instances are computed with the chosen reward function.
Available reward functions are sparsity
, proximity
, and plausibility
.
In the second iteration, NICE creates d-1
new instances by replacing a different feature value of the highest
reward instance of the previous iteration with the corresponding value of x_interest
, and so on.
If finish_early = TRUE
, the algorithm terminates when the predicted outcome for
the highest reward instance is in the interval desired_outcome
; if finish_early = FALSE
, the
algorithm continues until x_nn
is recreated.
Once the algorithm terminated, it depends on return_multiple
which instances
are returned as counterfactuals: if return_multiple = FALSE
, then only the highest reward instance in the
last iteration is returned as counterfactual; if return_multiple = TRUE
, then all instances (of all iterations)
whose predicted outcome is in the interval desired_outcome
are returned as counterfactuals.
If finish_early = FALSE
and return_multiple = FALSE
, then x_nn
is returned as single counterfactual.
The function computes the dissimilarities using Gower's dissimilarity measure (Gower 1971).
Brughmans, D., & Martens, D. (2021). NICE: An Algorithm for Nearest Instance Counterfactual Explanations. arXiv 2104.07411 v2.
Gower, J. C. (1971), "A general coefficient of similarity and some of its properties". Biometrics, 27, 623–637.
counterfactuals::CounterfactualMethod
-> counterfactuals::CounterfactualMethodRegr
-> NICERegr
x_nn
(logical(1)
)
The most similar (optionally) correctly classified instance of x_interest
.
archive
(list()
)
A list that stores the history of the algorithm run. For each algorithm iteration, it has one element containing
a data.table
, which stores all created instances of this iteration together with their
reward values and their predictions.
new()
Create a new NICERegr object.
NICERegr$new(
predictor,
optimization = "sparsity",
x_nn_correct = TRUE,
margin_correct = NULL,
return_multiple = FALSE,
finish_early = TRUE,
distance_function = "gower"
)
predictor
(Predictor)
The object (created with iml::Predictor$new()
) holding the machine learning model and the data.
optimization
(character(1)
)
The reward function to optimize. Can be sparsity
(default), proximity
or plausibility
.
x_nn_correct
(logical(1)
)
Should only correctly classified data points in predictor$data$X
be considered for the most similar instance search?
Default is TRUE
.
margin_correct
(numeric(1)
| NULL
)
The accepted margin for considering a prediction as "correct".
Ignored if x_nn_correct = FALSE
.
If NULL, the accepted margin is set to half the median absolute distance between the true and predicted outcomes in the data (predictor$data
).
return_multiple
(logical(1)
)
Should multiple counterfactuals be returned? If TRUE, the algorithm returns all created instances whose
prediction is in the interval desired_outcome
. For more information, see the Details
section.
finish_early
(logical(1)
)
Should the algorithm terminate after an iteration in which the prediction for the highest reward instance
is in the interval desired_outcome
. If FALSE
, the algorithm continues until x_nn
is recreated.
distance_function
(function()
| 'gower'
| 'gower_c'
)
The distance function used to compute the distances between x_interest
and the training data points for finding x_nn
. If optimization
is set
to proximity
, the distance function is also used for calculating the
distance between candidates and x_interest
.
Either the name of an already implemented distance function
('gower' or 'gower_c') or a function is allowed as input.
If set to 'gower' (default), then Gower's distance (Gower 1971) is used;
if set to 'gower_c', a C-based more efficient version of Gower's distance is used.
A function must have three arguments x
, y
, and data
and should
return a double
matrix with nrow(x)
rows and maximum nrow(y)
columns.
if (require("randomForest")) {
set.seed(123456)
# Train a model
rf = randomForest(mpg ~ ., data = mtcars)
# Create a predictor object
predictor = iml::Predictor$new(rf)
# Find counterfactuals
nice_regr = NICERegr$new(predictor)
cfactuals = nice_regr$find_counterfactuals(
x_interest = mtcars[1L, ], desired_outcome = c(22, 26)
)
# Print the results
cfactuals$data
# Print archive
nice_regr$archive
}
#> [[1]]
#> cyl disp hp drat wt qsec vs am gear carb reward
#> <num> <num> <num> <num> <num> <num> <num> <num> <num> <num> <num>
#> 1: 4 160.0 110 3.90 2.62 16.46 0 1 4 4 0.99778000
#> 2: 6 140.8 110 3.90 2.62 16.46 0 1 4 4 0.29772333
#> 3: 6 160.0 95 3.90 2.62 16.46 0 1 4 4 0.32473333
#> 4: 6 160.0 110 3.92 2.62 16.46 0 1 4 4 -0.02246333
#> 5: 6 160.0 110 3.90 3.15 16.46 0 1 4 4 -0.08577333
#> 6: 6 160.0 110 3.90 2.62 22.90 0 1 4 4 0.09450190
#> 7: 6 160.0 110 3.90 2.62 16.46 1 1 4 4 0.13947333
#> 8: 6 160.0 110 3.90 2.62 16.46 0 0 4 4 -0.12199833
#> 9: 6 160.0 110 3.90 2.62 16.46 0 1 4 2 0.34425000
#> pred
#> <num>
#> 1: 21.61431
#> 2: 20.91425
#> 3: 20.94126
#> 4: 20.59406
#> 5: 20.53075
#> 6: 20.71103
#> 7: 20.75600
#> 8: 20.49453
#> 9: 20.96078
#>
#> [[2]]
#> cyl disp hp drat wt qsec vs am gear carb reward
#> <num> <num> <num> <num> <num> <num> <num> <num> <num> <num> <num>
#> 1: 4 140.8 110 3.90 2.62 16.46 0 1 4 4 0.276123333
#> 2: 4 160.0 95 3.90 2.62 16.46 0 1 4 4 0.361413333
#> 3: 4 160.0 110 3.92 2.62 16.46 0 1 4 4 -0.006596667
#> 4: 4 160.0 110 3.90 3.15 16.46 0 1 4 4 -0.044490000
#> 5: 4 160.0 110 3.90 2.62 22.90 0 1 4 4 0.078001905
#> 6: 4 160.0 110 3.90 2.62 16.46 1 1 4 4 0.167953333
#> 7: 4 160.0 110 3.90 2.62 16.46 0 0 4 4 -0.097881667
#> 8: 4 160.0 110 3.90 2.62 16.46 0 1 4 2 0.328353333
#> pred
#> <num>
#> 1: 21.89043
#> 2: 21.97572
#> 3: 21.60771
#> 4: 21.56982
#> 5: 21.69231
#> 6: 21.78226
#> 7: 21.51642
#> 8: 21.94266
#>
#> [[3]]
#> cyl disp hp drat wt qsec vs am gear carb reward
#> <num> <num> <num> <num> <num> <num> <num> <num> <num> <num> <num>
#> 1: 4 140.8 95 3.90 2.62 16.46 0 1 4 4 0.024280000
#> 2: 4 160.0 95 3.92 2.62 16.46 0 1 4 4 -0.003096667
#> 3: 4 160.0 95 3.90 3.15 16.46 0 1 4 4 -0.048013333
#> 4: 4 160.0 95 3.90 2.62 22.90 0 1 4 4 0.024280000
#> 5: 4 160.0 95 3.90 2.62 16.46 1 1 4 4 0.024280000
#> 6: 4 160.0 95 3.90 2.62 16.46 0 0 4 4 -0.090591667
#> 7: 4 160.0 95 3.90 2.62 16.46 0 1 4 2 0.024280000
#> pred
#> <num>
#> 1: 22.25882
#> 2: 21.97262
#> 3: 21.92771
#> 4: 22.06255
#> 5: 22.15771
#> 6: 21.88513
#> 7: 22.31211
#>