Marginal effects is very useful in model interpretation . If one wants to know the effect of predictor variable x on the dependent variable y, marginal effects are an easy way to get the answer. In this short blog post, I will demonstate how to exract and plot the marginal effects of gernerlized linera model and multinomial model.
library(ggplot2)
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(tibble)
## Warning: package 'tibble' was built under R version 3.4.4
library(broom)
library(margins)
## Warning: package 'margins' was built under R version 3.4.4
First, fit a general linear regression moddel with iris dataset and set dependent variable as binary variable with only two values: “virginica” and “versicolor”.
iris.small <- filter(iris, Species %in% c("virginica", "versicolor"))
set.seed(140)
indexes <- sample(1:nrow(iris.small),size=0.8*nrow(iris.small))
iris.train <- iris.small[indexes,]
iris.test <- iris.small[-indexes,]
logit.iris <- glm(Species ~ Sepal.Width + Sepal.Length + Petal.Width + Petal.Length,
data = iris.train,
family = binomial)
tidy(logit.iris)
## term estimate std.error statistic p.value
## 1 (Intercept) -44.515607 25.261885 -1.7621649 0.07804144
## 2 Sepal.Width -5.459623 4.783834 -1.1412650 0.25375967
## 3 Sepal.Length -1.454655 2.351741 -0.6185438 0.53621694
## 4 Petal.Width 17.315117 9.748273 1.7762242 0.07569598
## 5 Petal.Length 8.224098 4.483069 1.8344796 0.06658284
(effects_logit <- margins(logit.iris))
## Warning in warn_for_weights(model): 'weights' used in model estimation are
## currently ignored!
## Average marginal effects
## glm(formula = Species ~ Sepal.Width + Sepal.Length + Petal.Width + Petal.Length, family = binomial, data = iris.train)
## Sepal.Width Sepal.Length Petal.Width Petal.Length
## -0.1121 -0.02986 0.3554 0.1688
cat('\n\n')
summary(effects_logit)
## factor AME SE z p lower upper
## Petal.Length 0.1688 0.0650 2.5977 0.0094 0.0414 0.2962
## Petal.Width 0.3554 0.1353 2.6278 0.0086 0.0903 0.6206
## Sepal.Length -0.0299 0.0456 -0.6551 0.5124 -0.1192 0.0595
## Sepal.Width -0.1121 0.0890 -1.2588 0.2081 -0.2866 0.0624
plot(effects_logit)
#or
sum_logit.iris<- summary(effects_logit)
ggplot(data = sum_logit.iris) +
geom_point(aes(factor, AME)) +
geom_errorbar(aes(x = factor, ymin = lower, ymax = upper)) +
geom_hline(yintercept = 0) +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45))
dydx(iris.test, logit.iris, "Sepal.Width") #Differentiate the Model with Respect to Variable "Sepal.Width"
## dydx_Sepal.Width
## 1 -2.958938e-02
## 2 -3.456669e-04
## 3 -1.029195e-03
## 4 -1.494851e-04
## 5 -6.989586e-04
## 6 -9.719082e-03
## 7 -3.176616e-07
## 8 -1.362470e+00
## 9 -1.695241e-02
## 10 -7.689105e-06
## 11 -5.187446e-06
## 12 -5.563316e-05
## 13 -2.447134e-03
## 14 -2.986186e-05
## 15 -1.442733e-11
## 16 -6.464212e-01
## 17 -2.303155e-06
## 18 -5.018934e-04
## 19 -1.854360e-04
## 20 -3.844258e-07
Using dydx(), we can get the marginal effect of variable Sepal.Width for test set.
library(mlogit)
## Warning: package 'mlogit' was built under R version 3.4.4
## Loading required package: Formula
## Loading required package: maxLik
## Warning: package 'maxLik' was built under R version 3.4.4
## Loading required package: miscTools
## Warning: package 'miscTools' was built under R version 3.4.4
##
## Please cite the 'maxLik' package as:
## Henningsen, Arne and Toomet, Ott (2011). maxLik: A package for maximum likelihood estimation in R. Computational Statistics 26(3), 443-458. DOI 10.1007/s00180-010-0217-1.
##
## If you have questions, suggestions, or comments regarding the 'maxLik' package, please use a forum or 'tracker' at maxLik's R-Forge site:
## https://r-forge.r-project.org/projects/maxlik/
data("Fishing", package = "mlogit")
Fish <- mlogit.data(Fishing, varying = c(2:9), shape = "wide", choice = "mode")
m <- mlogit(mode ~ price | income | catch, data = Fish)
# compute a data.frame containing the mean value of the covariates in
# the sample
z <- with(Fish, data.frame(price = tapply(price, index(m)$alt, mean),
catch = tapply(catch, index(m)$alt, mean),
income = mean(income)))
# compute the marginal effects (the second one is an elasticity
effects_mlogit_income <- effects(m, covariate = "income", data = z)
effects_mlogit_price <-effects(m, covariate = "price", type = "rr", data = z)
effects_mlogit_catch <-effects(m, covariate = "catch", type = "ar", data = z)
coe_df <- data.frame((summary(m)$CoefTable))
coe_df$factor <- row.names(summary(m)$CoefTable)
coe_df$lower <- coe_df$Estimate - coe_df$Std..Error
coe_df$upper <- coe_df$Estimate + coe_df$Std..Error
ggplot(data = coe_df) +
geom_point(aes(factor, Estimate)) +
geom_errorbar(aes(x = factor, ymin = lower, ymax = upper)) +
geom_hline(yintercept = 0) +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45))