Skip to contents

parsnip::multi_predict() method for nested models. Allows predictions to be made on sub-models in a model object.

Usage

# S3 method for nested_model_fit
multi_predict(object, new_data, ...)

Arguments

object

A nested_model_fit object produced by fit.nested_model().

new_data

A data frame - can be nested or non-nested.

...

Passed onto parsnip::multi_predict()

Value

A tibble with the same number of rows as new_data, after it has been unnested.

Examples


library(dplyr)
library(tidyr)
library(parsnip)
library(glmnet)

data <- filter(example_nested_data, id %in% 16:20)

nested_data <- nest(data, data = -id2)

model <- linear_reg(penalty = 1) %>%
  set_engine("glmnet") %>%
  nested()

fitted <- fit(model, z ~ x + y + a + b, nested_data)

multi_predict(fitted, example_nested_data,
  penalty = c(0.1, 0.2, 0.3)
)
#> Warning: Some predictions failed.
#> # A tibble: 1,000 × 1
#>    .pred           
#>    <list>          
#>  1 <tibble [3 × 2]>
#>  2 <tibble [3 × 2]>
#>  3 <tibble [3 × 2]>
#>  4 <tibble [3 × 2]>
#>  5 <tibble [3 × 2]>
#>  6 <tibble [3 × 2]>
#>  7 <tibble [3 × 2]>
#>  8 <tibble [3 × 2]>
#>  9 <tibble [3 × 2]>
#> 10 <tibble [3 × 2]>
#> # ℹ 990 more rows