Skip to contents

generics::augment() method for nested models. augment.nested_model_fit() will add column(s) for predictions to the given data.

Usage

# S3 method for nested_model_fit
augment(x, new_data, ...)

Arguments

x

A nested_model_fit object produced by fit.nested_model().

new_data

A data frame - can be nested or non-nested.

...

Passed onto parsnip::augment.model_fit().

Value

A data frame with one or more added columns for predictions.

Examples


library(dplyr)
#> 
#> Attaching package: ‘dplyr’
#> The following objects are masked from ‘package:stats’:
#> 
#>     filter, lag
#> The following objects are masked from ‘package:base’:
#> 
#>     intersect, setdiff, setequal, union
library(tidyr)
library(parsnip)

data <- filter(example_nested_data, id %in% 1:5)

nested_data <- nest(data, data = -c(id, id2))

model <- linear_reg() %>%
  set_engine("lm") %>%
  nested()

fitted <- fit(model, z ~ x + y + a + b, nested_data)

augment(fitted, example_nested_data)
#> Warning: Some predictions failed.
#> # A tibble: 1,000 × 7
#>    .pred .resid     x     y     z     a     b
#>    <dbl>  <dbl> <int> <dbl> <dbl> <dbl> <dbl>
#>  1  24.5  4.55     49  48.5  29.1  44.7 50.0 
#>  2  23.4  6.22     50  64.2  29.7  40.2 64.9 
#>  3  25.0  1.64     51 -19.4  26.6  43.2 38.0 
#>  4  25.3  3.56     52  41.0  28.8  66.4 61.7 
#>  5  26.4 -2.45     53 -94.2  23.9  18.2 -1.66
#>  6  29.0  1.02     54  72.6  30.0  83.8 38.8 
#>  7  27.3 -3.29     55 -91.5  24.0  91.7 40.7 
#>  8  26.2 -0.651    56 -50.5  25.5  79.8 55.4 
#>  9  28.4  2.25     57  90.3  30.6  50.3 33.8 
#> 10  27.6  0.954    58  32.4  28.6  25.4 20.5 
#> # ℹ 990 more rows