generics::augment()
method for nested models. augment.nested_model_fit()
will add column(s) for predictions to the given data.
Usage
# S3 method for nested_model_fit
augment(x, new_data, ...)
Arguments
- x
A
nested_model_fit
object produced byfit.nested_model()
.- new_data
A data frame - can be nested or non-nested.
- ...
Passed onto
parsnip::augment.model_fit()
.
Examples
library(dplyr)
#>
#> Attaching package: ‘dplyr’
#> The following objects are masked from ‘package:stats’:
#>
#> filter, lag
#> The following objects are masked from ‘package:base’:
#>
#> intersect, setdiff, setequal, union
library(tidyr)
library(parsnip)
data <- filter(example_nested_data, id %in% 1:5)
nested_data <- nest(data, data = -c(id, id2))
model <- linear_reg() %>%
set_engine("lm") %>%
nested()
fitted <- fit(model, z ~ x + y + a + b, nested_data)
augment(fitted, example_nested_data)
#> Warning: Some predictions failed.
#> # A tibble: 1,000 × 7
#> .pred .resid x y z a b
#> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <dbl>
#> 1 24.5 4.55 49 48.5 29.1 44.7 50.0
#> 2 23.4 6.22 50 64.2 29.7 40.2 64.9
#> 3 25.0 1.64 51 -19.4 26.6 43.2 38.0
#> 4 25.3 3.56 52 41.0 28.8 66.4 61.7
#> 5 26.4 -2.45 53 -94.2 23.9 18.2 -1.66
#> 6 29.0 1.02 54 72.6 30.0 83.8 38.8
#> 7 27.3 -3.29 55 -91.5 24.0 91.7 40.7
#> 8 26.2 -0.651 56 -50.5 25.5 79.8 55.4
#> 9 28.4 2.25 57 90.3 30.6 50.3 33.8
#> 10 27.6 0.954 58 32.4 28.6 25.4 20.5
#> # ℹ 990 more rows