class: ur-title, center, middle, title-slide .title[ # BST430 Lecture 16 ] .subtitle[ ## Multivariate linear models ] .author[ ### Tanzy Love, based on the course by Andrew McDavid ] .institute[ ### U of Rochester ] .date[ ### 2021-11-07 (updated: 2024-11-04 by TL) ] --- # Agenda 0. Multivariate plots 1. More modeling syntax 3. Diagnostics 4. Interpreting interaction models 4. Many models Here's the [R code in this lecture](l16/l16-multivariate-lm.R) --- class: code70 ### `mtcars` - For this lesson, we will use the (infamous) `mtcars` dataset that comes with R by default. ``` r library(tidyverse) library(tidymodels) ``` ``` ## Warning: package 'tidymodels' was built under R version 4.4.1 ``` ``` ## Warning: package 'dials' was built under R version 4.4.1 ``` ``` ## Warning: package 'infer' was built under R version 4.4.1 ``` ``` ## Warning: package 'modeldata' was built under R version 4.4.1 ``` ``` ## Warning: package 'parsnip' was built under R version 4.4.1 ``` ``` ## Warning: package 'recipes' was built under R version 4.4.1 ``` ``` ## Warning: package 'rsample' was built under R version 4.4.1 ``` ``` ## Warning: package 'tune' was built under R version 4.4.1 ``` ``` ## Warning: package 'workflows' was built under R version 4.4.1 ``` ``` ## Warning: package 'workflowsets' was built under R version 4.4.1 ``` ``` ## Warning: package 'yardstick' was built under R version 4.4.1 ``` ``` r library(broom) data("mtcars") glimpse(mtcars) ``` ``` ## Rows: 32 ## Columns: 11 ## $ mpg <dbl> 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22… ## $ cyl <dbl> 6, 6, 4, 6, 8, 6, 8, 4, 4, 6, 6, 8, 8, 8, 8, 8, 8,… ## $ disp <dbl> 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 1… ## $ hp <dbl> 110, 110, 93, 110, 175, 105, 245, 62, 95, 123, 123… ## $ drat <dbl> 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.… ## $ wt <dbl> 2.620, 2.875, 2.320, 3.215, 3.440, 3.460, 3.570, 3… ## $ qsec <dbl> 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 2… ## $ vs <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,… ## $ am <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ gear <dbl> 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3,… ## $ carb <dbl> 4, 4, 1, 1, 2, 1, 4, 2, 2, 4, 4, 3, 3, 3, 4, 4, 4,… ``` ??? 1974 Motor Trend US magazine, and comprises fuel consumption and 10 aspects of automobile design and performance for 32 automobiles. [, 1] mpg Miles/(US) gallon [, 2] cyl Number of cylinders [, 3] disp Displacement (cu.in.) [, 4] hp Gross horsepower [, 5] drat Rear axle ratio [, 6] wt Weight (1000 lbs) [, 7] qsec 1/4 mile time [, 8] vs Engine (0 = V-shaped, 1 = straight) [, 9] am Transmission (0 = automatic, 1 = manual) [,10] gear Number of forward gears [,11] carb Number of carburetors --- class: middle .hand[Visualizing multivariate relationships] --- ### Exploratory data analyses of multivariate data - Hard - Necessary - Often generates quite a few plots before you identify one that you can keep --- ### One simple trick Pairs plots! .panelset[ .panel[.panel-name[Code] ``` r library(GGally) GGally::ggpairs(mtcars) ``` ] .panel[.panel-name[Plot] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-4-1.png" width="70%" style="display: block; margin: auto;" /> ] ] --- * Evidently `vs`, `am`, `gear` and maybe `cyl` should be cast to factors .panelset[ .panel[.panel-name[Code] ``` r mtcars = mtcars %>% mutate(across(c(vs, am, gear, cyl), factor)) GGally::ggpairs(mtcars) ``` ] .panel[.panel-name[Plot] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-5-1.png" width="70%" style="display: block; margin: auto;" /> ] ] --- - Let's suppose we wanted to determine which variables affect fuel consumption (the `mpg` variable in the dataset). - `cyl`, `disp`, `hp`, `drat`, `wt` all appear to be correlated with `mpg` - `vs`, `am` and possibly `gear` could have distributional differences, as well --- class: code70 - To begin, we'll look at the association between the variables log-weight and mpg ``` r ggplot(mtcars, aes(x = wt, mpg)) + geom_point() + scale_x_continuous(transform = "log") + xlab("Weight") + ylab("Miles Per Gallon") ``` <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-6-1.png" width="60%" style="display: block; margin: auto;" /> - It seems that log-weight is negatively associated with mpg. - It seems that the data approximately fall on a line. --- ### OLS fits in R 1. Make sure you have the explanatory variables in the format you want: ``` r submt = mtcars %>% mutate(logwt = log(wt)) ``` 2. Use `lm()` ``` r lmout = lm(mpg ~ logwt, data = submt) lmtide = tidy(lmout) select(lmtide, term, estimate) ``` ``` ## # A tibble: 2 × 2 ## term estimate ## <chr> <dbl> ## 1 (Intercept) 39.3 ## 2 logwt -17.1 ``` --- ### `lm` syntax ```r lm(response ~ pred1 + pred2*pred3, data = data) ``` Finds the OLS estimates of the following model: > response = `\(\beta_0+ \beta_1\text{pred1} + \beta_2\text{pred2} + \beta_{3}\text{pred3}+ \beta_{23}\text{pred2} * \text{pred3}\)` + error The `data` argument tells `lm()` where to find the response and explanatory variables. --- ### Formula syntax * `x:z` form the interaction between `x` and `z` -- this is element-by-element multiplication if at least `x` or `z` is continuous, otherwise it is an outer-product. * `x*z` form interactions and include main effects. Equivalent to `x + z + x:z` * `+ 0` or `- 1` exclude an intercept term, or `- x` exclude `x` if included otherwise. * `(x + z + u)^2` form all two-way interactions and main effects with `x`, `z`, `u`. * `I(x - 10)` or `I(x^2)` -- perform arithmetic operations that use `+`, `-`, `*`, `^`, etc, on a variable "on-the-fly" in the formula. * better than just transform before you model * `y ~ .` every variable in the data, except the response `y`. --- class: code50 ### Example .alert[But don't do this without thinking!] ``` r tidy(lm(mpg ~ ., data = mtcars)) ``` ``` ## # A tibble: 13 × 5 ## term estimate std.error statistic p.value ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 (Intercept) 15.1 17.1 0.881 0.389 ## 2 cyl6 -1.20 2.39 -0.502 0.621 ## 3 cyl8 3.05 4.83 0.633 0.535 ## 4 disp 0.0126 0.0177 0.708 0.487 ## # ℹ 9 more rows ``` For a predictive goal, probably should use some method that does shrinkage and basis expansion. For inference, need to consider putative casual models. --- - Use `broom::glance()` function to get the estimated standard deviation, `\(R^2\)`, and information criteria. It's the value in the `sigma` column. ``` r glance(lmout) ``` ``` ## # A tibble: 1 × 12 ## r.squared adj.r.squared sigma statistic p.value df logLik ## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 0.810 0.804 2.67 128. 2.39e-12 1 -75.8 ## # ℹ 5 more variables: AIC <dbl>, BIC <dbl>, deviance <dbl>, ## # df.residual <int>, nobs <int> ``` --- # Prediction (Interpolation) - **Interpolation**: Making estimates/predictions within the range of the data. - **Extrapolation**: Making estimates/predictions outside the range of the data. - Interpolation is fine. Extrapolation is dangerous. .question[why?] --- - Interpolation <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-11-1.png" width="60%" style="display: block; margin: auto;" /> --- - Extrapolation <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-12-1.png" width="60%" style="display: block; margin: auto;" /> --- ## Why is extrapolation dangerous? 1. Not sure if the linear relationship is the same outside the range of the data (because we don't have data there to see the relationship). 2. Not sure if the variability is the same outside the range of the data (because we don't have data there to see the variability). --- ### Make a prediction: 1. You need a data frame with the exact same variable name as the explanatory variable. ``` r newdf = tribble(~logwt, 1, 1.5) ``` 2. Then you use the `predict()` function to obtain predictions. ``` r newdf = newdf %>% mutate(predictions = predict(object = lmout, newdata = newdf)) ``` --- ## Assumptions and Violations - In the linear model, you can trade assumptions for inference: - Assumptions in *decreasing* order of importance 1. **Independence** - The knowledge of the value of one observation does not give you any information on the value of another. 2. **Linearity** - The relationship looks like a straight line. 3. **Equal Variance** - The spread is the same for every value of `\(x\)` 4. **Normality** - The distribution of the errors isn't too skewed and there aren't any *too* extreme points. (Only an issue if you have outliers and a small number of observations because thanks be to the [central limit theorem](https://en.wikipedia.org/wiki/Central_limit_theorem)). 5. **Fixed and Known** predictors --- ### What do we lose when violated? 1. **Linearity** violated - Linear regression line does not pick up actual conditional expectation. As a linear approximation, the results will be sensitive to the particular `\(x\)` sampled. 2. **Independence** violated - Linear regression line is unbiased, but standard errors can be badly off. 3. **Equal Variance** violated - Linear regression line is unbiased, but standard errors are off. Your `\(p\)`-values may be too small, or too large. 4. **Normality** violated - Only an issue if your sample size is "small". Unstable results if outliers are present. Your `\(p\)`-values may be too small, or too large. 5. **Fixed and Known** violated if there is measurement error in the predictors .question[Why are no assumptions made about the distribution of the explanatory variable (the `\(x_i\)`'s)?] --- ## Evaluating Independence - Think about the problem. - Were different responses measured on the same observational/experimental unit? - Were data collected in groups? - Non-independence: The temperature today and the temperature tomorrow. If it is warm today, it is probably warm tomorrow. - Non-independence: You are measuring properties of 500 single cells isolated from 3 mice. Because the cells within a given mouse are probably similar, each cell is not independent. --- ### xkcd 2533 <img src="l16/img/slope_hypothesis_testing_2x.png" width="60%" style="display: block; margin: auto;" /> Via [xkcd](https://xkcd.com/2533/) --- class: code50 ### Evaluating other assumptions via residual diagnostics - Obtain the residuals by using `augment()` from broom. They will be the `.resid` variable. ``` r aout = augment(lmout) glimpse(aout) ``` ``` ## Rows: 32 ## Columns: 9 ## $ .rownames <chr> "Mazda RX4", "Mazda RX4 Wag", "Datsun 710", … ## $ mpg <dbl> 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24… ## $ logwt <dbl> 0.9631743, 1.0560527, 0.8415672, 1.1678274, … ## $ .fitted <dbl> 22.79984, 21.21293, 24.87761, 19.30316, 18.1… ## $ .resid <dbl> -1.79984305, -0.21293388, -2.07760886, 2.096… ## $ .hat <dbl> 0.03929578, 0.03263072, 0.05636910, 0.031929… ## $ .sigma <dbl> 2.693624, 2.714823, 2.685917, 2.686126, 2.71… ## $ .cooksd <dbl> 9.677220e-03, 1.109293e-04, 1.917252e-02, 1.… ## $ .std.resid <dbl> -0.68787926, -0.08110004, -0.80118930, 0.798… ``` - For inference, consider: - residuals `\(r_i\)` vs `\(\hat{y}_i\)`. - residuals `\(r_i\)` vs response `\(y_i\)` - residuals `\(r_i\)` vs explanatory variable `\(x_i\)` --- ### Potential remedies 1. Linearity Violated: Try a transformation. If the relationship looks curved and monotone (i.e. either always increasing or always decreasing) then try a log transformation. 2. Independence Violated: Try a two-step procedure (that collapses over repeated measures) or use longitudinal data analysis. 3. Equal Variance Violated: If the relationship is also curved and monotone, try a log transformation on the response variable. Use a bootstrap. Or stay tuned for en you learn about sandwich estimation. 4. Normality Violated: Don't trust prediction intervals (confidence intervals are fine). 5. Fixed and Known Violated: fit measurement error models --- ### Example 1: A perfect residual plot .panelset[ .panel[.panel-name[Fit] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-17-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residual] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-18-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Verify - ✅ Means are straight lines - ✅ Residuals seem to be centered at 0 for all `\(x\)` - ✅ Variance looks equal for all `\(x\)` - ✅ Everything looks perfect (too perfect...🤔) --- ### Example 2: Curved Monotone Relationship, Equal Variances - Simulate data: ``` r set.seed(1) x = rexp(100) x = x - min(x) + 0.5 y = log(x) * 20 + rnorm(100, sd = 4) (df_fake = tibble(x, y)) ``` ``` ## # A tibble: 100 × 2 ## x y ## <dbl> <dbl> ## 1 1.22 0.419 ## 2 1.64 12.3 ## 3 0.608 -9.46 ## 4 0.603 -11.3 ... ``` --- .panelset[ .panel[.panel-name[Fit] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-20-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residual] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-21-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Check - ❌ Curved (but always increasing) relationship between `\(x\)` and `\(y\)`. - ✅ Variance looks equal for all `\(x\)` - ❌ Residual plot has a parabolic shape. - **Solution**: These indicate a `\(\log\)` transformation of `\(x\)` could help. ``` r df_fake %>% mutate(logx = log(x)) -> df_fake lm_fake = lm(y ~ logx, data = df_fake) ``` --- ### Example 3: Curved Non-monotone Relationship, Equal Variances - Simulate data: ``` r set.seed(1) x = rnorm(100) y = -x^2 + rnorm(100) df_fake = tibble(x, y) ``` --- .panelset[ .panel[.panel-name[Fit] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-24-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residual] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-25-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Verify - ❌ Curved relationship between `\(x\)` and `\(y\)` - ❌ Sometimes the relationship is increasing, sometimes it is decreasing. - ✅ Variance looks equal for all `\(x\)` - ❌ Residual plot has a parabolic form. - **Solution**: Include a squared term in the model (or use a gam or spline) ``` r lmout = lm(y ~ I(x^2), data = df_fake) ``` --- ### Example 4: Curved Relationship, Variance Increases with `\(Y\)` - Simulate data: ``` r set.seed(1) x = rnorm(100) y = exp(x + rnorm(100, sd = 1/2)) df_fake = tibble(x, y) ``` --- .panelset[ .panel[.panel-name[Fit] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-28-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residual] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-29-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Verify - ❌ Curved relationship between `\(x\)` and `\(y\)` - ❌ Variance looks like it increases as `\(y\)` increases - ❌ Residual plot has a parabolic form. - ❌ Residual plot variance looks larger to the right and smaller to the left. - **Solution**: Take a log-transformation of `\(y\)`. ``` r df_fake %>% mutate(logy = log(y)) -> df_fake lm_fake = lm(logy ~ x, data = df_fake) ``` --- ### Example 5: Linear Relationship, Equal Variances, Skewed Distribution Simulate data: ``` r set.seed(1) x = runif(200) y = 15 * x + rexp(200, 0.2) ``` --- .panelset[ .panel[.panel-name[Fit] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-32-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residual] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-33-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Verify - ✅ Straight line relationship between `\(x\)` and `\(y\)`. - ✅ Variances about equal for all `\(x\)` - ❌ Skew for all `\(x\)` - ❌ Residual plots show skew. - **Solution**: Do nothing, but report skew, or use a bootstrap or robust standard errors --- ### Example 6: Linear Relationship, Unequal Variances - Simulate data: ``` r set.seed(1) x = runif(100) * 10 y = 0.85 * x + rnorm(100, sd = (x - 5) ^ 2) df_fake = tibble(x, y) ``` --- .panelset[ .panel[.panel-name[Fit] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-35-1.png" width="60%" style="display: block; margin: auto;" /> ] .panel[.panel-name[Residuals] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-36-1.png" width="60%" style="display: block; margin: auto;" /> ] ] --- ### Verify - ✅ Linear relationship between `\(x\)` and `\(y\)`. - ❌ Variance is different for different values of `\(x\)`. - **Solution**: The modern solution is bootstrap or use sandwich estimates of the standard errors ``` r rob_fit = estimatr::lm_robust(y ~ x, data = df_fake) tidy(rob_fit) ``` ``` ## term estimate std.error statistic p.value ## 1 (Intercept) -2.861944 2.8353903 -1.009365 0.315285178 ## 2 x 1.368002 0.5167311 2.647416 0.009452483 ## conf.low conf.high df outcome ## 1 -8.4886842 2.764795 98 y ## 2 0.3425659 2.393438 98 y ``` --- class: middle # Interpreting coefficients when you log --- ### Log `x` Generally, when you use logs, you interpret associations on a *multiplicative* scale instead of an *additive* scale. No log: - Model: `\(E[y_i] = \beta_0 + \beta_1 x_i\)` - Observations that differ by 1 unit in `\(x\)` tend to differ by `\(\beta_1\)` units in `\(y\)`. Log `\(x\)`: - Model: `\(E[y_i] = \beta_0 + \beta_1 \log_2(x_i)\)` - Observations that are twice as large in `\(x\)` tend to differ by `\(\beta_1\)` units in `\(y\)`. --- ### log `y` Log `\(y\)`: - Model: `\(E[\log_2(y_i)] = \beta_0 + \beta_1 x_i\)` - Observations that differ by 1 unit in `\(x\)` tend to be `\(2^{\beta_1}\)` times larger in `\(y\)`. Log both: - Model: `\(E[\log_2(y_i)] = \beta_0 + \beta_1 \log_2(x_i)\)` - Observations that are twice as large in `\(x\)` tend to be `\(2^{\beta_1}\)` times larger in `\(y\)`. .footnote[Note: we commit statistical abuse here, since `\(\exp \left[ \text{E}(\log(Y) | X) \right] \neq \text{E}(Y | X)\)`, ie, `exp` doesn't commute through the expectation. Though the delta method says this is the 1st order approximation. ] --- class: middle .hand[Interpreting interaction models] --- ## When it doubt, predict it out! With interaction models, it's easy to trick yourself. But you can also make predictions to check your understanding. ``` r only_wt = lm(mpg ~ log(wt), mtcars) only_disp = lm(mpg ~ log(disp), mtcars) only_cyl = lm(mpg ~ cyl, mtcars) complicated_model = lm(mpg ~ (log(wt) + log(disp))*cyl, mtcars) ``` --- ``` r GGally::ggcoef_compare(list(only_wt = only_wt, only_disp = only_disp, only_cyl = only_cyl, complicated=complicated_model)) ``` <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-39-1.png" width="80%" style="display: block; margin: auto;" /> .question[Why is `log(disp):cyl6` and `log(disp):cyl8` positive?] --- ### Predict it out! .panelset[ .panel[.panel-name[Code] ``` r (newdf = expand.grid(wt = mean(mtcars$wt), disp = mean(mtcars$disp), cyl = factor(c(4, 6, 8)))) ``` ``` ## wt disp cyl ## 1 3.21725 230.7219 4 ## 2 3.21725 230.7219 6 ## 3 3.21725 230.7219 8 ``` ``` r newdf = newdf %>% mutate(mpg = predict(complicated_model, across())) ggplot(mtcars, aes(x = cyl, y = mpg)) + geom_boxplot() + geom_line(data =newdf, aes(group = 1), color = 'red') ``` ] .panel[.panel-name[Plot] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-40-1.png" width="70%" style="display: block; margin: auto;" /> ] ] --- ### We are extrapolating! ``` r ggplot(mtcars, aes(x = log(disp), y = mpg)) + geom_point() + facet_wrap(~cyl) + geom_smooth(method = 'lm') + geom_point(data =newdf, aes(group = 1), color = 'red') ``` <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-41-1.png" width="60%" style="display: block; margin: auto;" /> --- # Summary of R commands - `augment()`: - Residuals `\(r_i = y_i - \hat{y}_i\)`: `$.resid` - Fitted Values `\(\hat{y}_i\)`: `$.fitted` - `tidy()`: - Name of variables: `$term` - Coefficient Estimates: `$estimate` - Standard Error (standard deviation of sampling distribution of coefficient estimates): `$std.error` - t-statistic: `$statistic` - p-value: `$p.value` - `glance()`: - R-squared value (proportion of variance explained by regression line, higher is better): `$r.squared` - AIC (lower is better): `$AIC` - BIC (lower is better): `$BIC` --- class: middle .hand[Many linear models with list columns] --- # Motivation - The gapminder data ``` r library(tidyverse) library(tidymodels) library(gapminder) data("gapminder") glimpse(gapminder) ``` ``` ## Rows: 1,704 ## Columns: 6 ## $ country <fct> "Afghanistan", "Afghanistan", "Afghanistan", … ## $ continent <fct> Asia, Asia, Asia, Asia, Asia, Asia, Asia, Asi… ## $ year <int> 1952, 1957, 1962, 1967, 1972, 1977, 1982, 198… ## $ lifeExp <dbl> 28.801, 30.332, 31.997, 34.020, 36.088, 38.43… ## $ pop <int> 8425333, 9240934, 10267083, 11537966, 1307946… ## $ gdpPercap <dbl> 779.4453, 820.8530, 853.1007, 836.1971, 739.9… ``` --- - Suppose we want to look at how life expectancy has changed over time in each country: ``` r gapminder %>% ggplot(aes(x = year, y = lifeExp, group = country)) + geom_line(alpha = 1/3) + xlab("Year") + ylab("Life Expectancy") ``` <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-43-1.png" width="60%" style="display: block; margin: auto;" /> --- class: code70 - General trend is up. But there are some countries where this doesn't happen. We can quantify this with one country using a linear model: ``` r usdf = gapminder %>% filter(country == "United States") ``` <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-45-1.png" width="50%" style="display: block; margin: auto;" /> ``` ## # A tibble: 2 × 5 ## term estimate std.error statistic p.value ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 (Intercept) -291. 13.8 -21.1 1.25e- 9 ## 2 year 0.184 0.00696 26.5 1.37e-10 ``` --- - So each year, the US has been increasing its life expectancy by 0.1841692 years. - How can we get these coefficient estimates for each country? -- - We could fit an interaction between year and country but: - Have to pick a reference country, or faff about with sum-to-zero contrasts, then recover the "missing level". - This would assumes identical error terms across all countries -- undesirable. --- ### Many linear models to the rescue! Also known as `group_by` with list columns. This is a common, and very useful pattern for me, but it might be because I am a degenerate bioinformatician and am always dealing with multiplicity. The concept is: 1. `group_by` lets us `summarize` data by subsets. 2. Model fits can go into a list. 3. data frames can contain lists: a data frame is a list that quacks like a matrix, and lists can contain more lists... 4. `tidyr::unnest()` function will let us extract the goods. --- ## General recipe ``` r many_fits = gapminder %>% group_by(country) %>% * summarize(fit = fit_model(across())) tidied_output = many_fits %>% * rowwise() %>% * mutate(df = list( * post_process(fit) )) unnested_output = tidied_output %>% * tidyr::unnest(cols = c(df)) ``` * `fit_model`: data.frame in, length-1 list out * `post_process`: unboxed list element in, a data frame wrapped as a list out * extract and combine with `unnest`. --- ### Specific example ``` r fit_model = function(data) { linear_reg() %>% set_engine("lm") %>% fit(lifeExp ~ I(year-1990), data = data) %>% list() # Base R version for lm # fit = lm(lifeExp ~ I(year-1990), data) # list(fit) } many_fits = gapminder %>% group_by(country, continent) %>% summarize(fit = fit_model(across(.col = 1:4))) %>% ungroup() many_fits ``` ``` ## # A tibble: 142 × 3 ## country continent fit ## <fct> <fct> <list> ## 1 Afghanistan Asia <fit[+]> ## 2 Albania Europe <fit[+]> ## 3 Algeria Africa <fit[+]> ## 4 Angola Africa <fit[+]> ... ``` --- ### Output Now we can see the model fit in Japan: ``` r dplyr::filter(many_fits, country == 'Japan')$fit[[1]] %>% tidy() ``` ``` ## # A tibble: 2 × 5 ## term estimate std.error statistic p.value ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 (Intercept) 78.5 0.463 170. 1.24e-18 ## 2 I(year - 1990) 0.353 0.0229 15.4 2.70e- 8 ``` or Senegal ``` r dplyr::filter(many_fits, country == 'Senegal')$fit[[1]] %>% tidy() ``` ``` ## # A tibble: 2 × 5 ## term estimate std.error statistic p.value ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 (Intercept) 55.9 0.315 178. 7.91e-19 ## 2 I(year - 1990) 0.505 0.0156 32.4 1.87e-11 ``` --- ### Notes 1. Use `across()` to pass the data frame, subset to the current group to our fit function. 2. We have to do an unlisting operation to extract the lm object itself. Without we get a length-1 list: ``` r dplyr::filter(many_fits, country == 'Senegal')$fit ``` ``` ## [[1]] ## parsnip model object ## ## ## Call: ## stats::lm(formula = lifeExp ~ I(year - 1990), data = data) ## ## Coefficients: ## (Intercept) I(year - 1990) ## 55.9253 0.5047 ``` --- 3.We have to wrap the `lm` output as a list: ``` r fit_not_list = function(data) { linear_reg() %>% set_engine("lm") %>% fit(lifeExp ~ I(year-1990), data = data) # Base R version for lm also doesn't work # fit = lm(lifeExp ~ I(year-1990), data) # fit } many_fits = gapminder %>% group_by(country, continent) %>% summarize(fit = fit_not_list(across())) ``` ``` ## Error in `summarize()`: ## ℹ In argument: `fit = fit_not_list(across())`. ## ℹ In group 1: `country = "Afghanistan"` and `continent = Asia`. ## Caused by error: ## ! `fit` must be a vector, not a <_lm/model_fit> object. ``` --- ### Add model summaries To really cook with this pattern, we want to extract various components from the model into a long data frame, suitable for further dplying or ggploting. The essential rule here is that we can mutate new list column with operations that are > `n-list` in and `n-list` out --- ### `n-list` in and `n-list` out! By `n-list` I mean a length-n list. `purrr::map` (and `lapply`) support this sort of implicit iteration, so we definitely can fall back on them. We can use also the `rowwise()` operator in dplyr, which lets us just write a function that takes an "unboxed" element of the list in and writes an `n-list` out. --- ### Manipulating list columns ``` r (df = tibble( x = 1:3, y = list('a', c('bb', 'cc'), 'ddd'))) ``` ``` ## # A tibble: 3 × 2 ## x y ## <int> <list> ## 1 1 <chr [1]> ## 2 2 <chr [2]> ## 3 3 <chr [1]> ``` ``` r str(df$y) ``` ``` ## List of 3 ## $ : chr "a" ## $ : chr [1:2] "bb" "cc" ## $ : chr "ddd" ``` Verify: `y` is a list. --- ### Manipulating list columns Suppose we want to count the number of characters in each element of `y`. `nchar` doesn't quite work: ``` r nchar(df$y) ``` ``` ## [1] 1 13 3 ``` Need `lapply` or `purrr::map` ``` r purrr::map(df$y, nchar) ``` ``` ## [[1]] ## [1] 1 ## ## [[2]] ## [1] 2 2 ## ## [[3]] ## [1] 3 ``` --- ### Finally, using mutate ``` r (df = df %>% mutate(nchar = purrr::map(y, nchar))) ``` ``` ## # A tibble: 3 × 3 ## x y nchar ## <int> <list> <list> ## 1 1 <chr [1]> <int [1]> ## 2 2 <chr [2]> <int [2]> ## 3 3 <chr [1]> <int [1]> ``` --- ### With purrr Ok, back to our set of `lm` fits: ``` r many_fits = many_fits %>% mutate(tidyout = map(fit, tidy)) many_fits ``` ``` ## # A tibble: 142 × 4 ## country continent fit tidyout ## <fct> <fct> <list> <list> ## 1 Afghanistan Asia <fit[+]> <tibble [2 × 5]> ## 2 Albania Europe <fit[+]> <tibble [2 × 5]> ## 3 Algeria Africa <fit[+]> <tibble [2 × 5]> ## 4 Angola Africa <fit[+]> <tibble [2 × 5]> ## # ℹ 138 more rows ``` --- ### Unnesting `tidy::unnest()` takes data frame list columns, extracts them, and rbinds them, and preserves the remaining columns, repeating these as necessary to get the dimension to match ``` r (df %>% unnest(cols = c(y, nchar))) ``` ``` ## # A tibble: 4 × 3 ## x y nchar ## <int> <chr> <int> ## 1 1 a 1 ## 2 2 bb 2 ## 3 2 cc 2 ## 4 3 ddd 3 ``` --- ### Unnest our fits ``` r many_fits_coef = many_fits %>% unnest(cols = c(tidyout)) many_fits_coef ``` ``` ## # A tibble: 284 × 8 ## country continent fit term estimate std.error statistic ## <fct> <fct> <list> <chr> <dbl> <dbl> <dbl> ## 1 Afghanis… Asia <fit[+]> (Int… 40.4 0.413 97.7 ## 2 Afghanis… Asia <fit[+]> I(ye… 0.275 0.0205 13.5 ## 3 Albania Europe <fit[+]> (Int… 71.9 0.670 107. ## 4 Albania Europe <fit[+]> I(ye… 0.335 0.0332 10.1 ## # ℹ 280 more rows ## # ℹ 1 more variable: p.value <dbl> ``` --- ### Finally What is the distribution of 1990-estimated life expectancy? .panelset[ .panel[.panel-name[Code] ``` r intercept = many_fits_coef %>% dplyr::filter(term == '(Intercept)') %>% arrange(desc(estimate)) %>% mutate(country = fct_inorder(country)) to_show = intercept %>% group_by(continent) %>% reframe(country = country[seq(from = 1, to = length(country), by = 5)]) intercept %>% ggplot(aes(y = country, x = estimate, xmin = estimate - std.error, xmax = estimate + std.error)) + geom_pointrange() + scale_y_discrete(breaks = to_show$country) + facet_grid(continent ~ ., scales = 'free_y', space = 'free') ``` ] .panel[.panel-name[Plot] <img src="l16-multivariate-lm_files/figure-html/unnamed-chunk-58-1.png" width="70%" style="display: block; margin: auto;" /> ]] --- # Acknowledgments Adapted from David Gerard's [Stat 512](https://data-science-master.github.io/lectures/05_linear_regression/05_simple_linear_regression.html) ## Resources - Chapter 25 from [RDS](https://r4ds.had.co.nz/).