This family of functions revolves around selecting a column of data
to use
for case weights. This column must be one of the allowed case weight types,
such as hardhat::frequency_weights()
or hardhat::importance_weights()
.
Specifically, it must return TRUE
from hardhat::is_case_weights()
. The
underlying model will decide whether or not the type of case weights you have
supplied are applicable or not.
add_case_weights()
specifies the column that will be interpreted as case weights in the model. This column must be present in thedata
supplied to fit().remove_case_weights()
removes the case weights. Additionally, if the model has already been fit, then the fit is removed.update_case_weights()
first removes the case weights, then replaces them with the new ones.
Arguments
- x
A workflow
- col
A single unquoted column name specifying the case weights for the model. This must be a classed case weights column, as determined by
hardhat::is_case_weights()
.
Details
For formula and variable preprocessors, the case weights col
is removed
from the data before the preprocessor is evaluated. This allows you to use
formulas like y ~ .
or tidyselection like everything()
without fear of
accidentally selecting the case weights column.
For recipe preprocessors, the case weights col
is not removed and is
passed along to the recipe. Typically, your recipe will include steps that
can utilize case weights.
Examples
library(parsnip)
library(magrittr)
library(hardhat)
mtcars2 <- mtcars
mtcars2$gear <- frequency_weights(mtcars2$gear)
spec <- linear_reg() %>%
set_engine("lm")
wf <- workflow() %>%
add_case_weights(gear) %>%
add_formula(mpg ~ .) %>%
add_model(spec)
wf <- fit(wf, mtcars2)
# Notice that the case weights (gear) aren't included in the predictors
extract_mold(wf)$predictors
#> # A tibble: 32 × 10
#> `(Intercept)` cyl disp hp drat wt qsec vs am carb
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 1 6 160 110 3.9 2.62 16.5 0 1 4
#> 2 1 6 160 110 3.9 2.88 17.0 0 1 4
#> 3 1 4 108 93 3.85 2.32 18.6 1 1 1
#> 4 1 6 258 110 3.08 3.22 19.4 1 0 1
#> 5 1 8 360 175 3.15 3.44 17.0 0 0 2
#> 6 1 6 225 105 2.76 3.46 20.2 1 0 1
#> 7 1 8 360 245 3.21 3.57 15.8 0 0 4
#> 8 1 4 147. 62 3.69 3.19 20 1 0 2
#> 9 1 4 141. 95 3.92 3.15 22.9 1 0 2
#> 10 1 6 168. 123 3.92 3.44 18.3 1 0 4
#> # ℹ 22 more rows
# Strip them out of the workflow, which also resets the model
remove_case_weights(wf)
#> ══ Workflow ══════════════════════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: linear_reg()
#>
#> ── Preprocessor ──────────────────────────────────────────────────────────
#> mpg ~ .
#>
#> ── Model ─────────────────────────────────────────────────────────────────
#> Linear Regression Model Specification (regression)
#>
#> Computational engine: lm
#>