--- title: "An introduction to `balnet`" output: rmarkdown::html_vignette: math_method: katex vignette: > %\VignetteIndexEntry{An introduction to balnet} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.width = 7, fig.height = 6 ) set.seed(42) library(balnet) ``` This vignette provides a brief introduction to the `balnet` package. For additional details, see [this paper](https://arxiv.org/abs/2602.18577). A commonly used approach for estimating propensity scores in observational studies is logistic regression, often combined with regularization when the number of covariates is large or when overfitting is a concern. The `balnet` package also fits regularized logistic regression models, but replaces the traditional maximum likelihood loss with *covariate balancing loss functions* paired with a logistic link. A key feature of this approach is that it directly produces *balancing weights*, weights that approximately equalize covariate distributions between treatment arms. These weights can be used as plug-in components in inverse probability weighting (IPW) estimators, or combined with outcome models in doubly robust procedures such as AIPW and debiased machine learning. By targeting covariate balance directly, the fitted propensity models are explicitly tailored to the causal estimand of interest. The example below illustrates these ideas in a simple simulated setting. ## A toy example We begin by simulating a small example in which treatment assignment depends on a single pre-treatment covariate. In particular, units with certain values of $X_1$ are less likely to receive treatment. ```{r} n <- 100 p <- 25 X <- matrix(rnorm(n * p), n, p) W <- rbinom(n, 1, 1 / (1 + exp(1 - X[, 1]))) ``` Suppose we are interested in estimating an average treatment effect (ATE). We can fit a `balnet` object using the default options. ```{r} fit <- balnet(X, W) ``` By default, this fits a lasso-regularized path of logistic models, with tuning parameters and path construction chosen to mirror common `glmnet` usage. A few details are worth highlighting. When propensity scores are estimated using covariate balancing loss functions, the fitted models depend on the target estimand. For the ATE, `balnet` fits two propensity score models: one for the control arm and one for the treated arm. The control-arm model is used to estimate $E[Y_i(0)]$, while the treated-arm model is used to estimate $E[Y_i(1)]$. Printing the fitted object shows summary information for both arms. By default, the output is truncated to display only the beginning and end of the regularization path, the full path can be displayed by increasing the `max` argument in `print`. ```{r} print(fit) ``` The first column reports the number of nonzero coefficients and is analogous to the output of `glmnet`. As in `glmnet`, the regularization path starts at a value of $\lambda$ corresponding to an intercept-only model and proceeds in `nlambda` logarithmically spaced steps down to a minimum value determined by `lambda.min.ratio`. The next column reports the absolute standardized mean difference (SMD), averaged across covariates. Importantly, `balnet` always computes and reports balance metrics on the standardized scale. In this simulated example, it is not possible to find weights that *exactly* balance the treated and control covariate means to the overall sample means of $X$. As a result, for both treatment arms the regularization path is truncated before reaching the default path length of `nlambda = 100`. The treated arm, in particular, is more difficult to balance. ## The role of λ For lasso-regularized generalized linear models, $\lambda$ is often interpreted as a budget on the overall magnitude of the coefficients. In the covariate balancing framework, the interpretation is different. Covariate balancing loss functions arise as the primal formulation of an optimization problem that constrains imbalance. In the lasso case, $\lambda$ directly controls balance: it equals the _maximum allowable absolute standardized mean difference_ (SMD) across covariates (since `balnet` standardizes covariates by default). To illustrate this, consider $\lambda^{\max} \approx 0.62$ for the treated arm in the printed output. This value corresponds to the imbalance in the unweighted treatment arm data and can be verified directly: ```{r} smd.baseline <- (colMeans(X[W == 1, ]) - colMeans(X)) / (apply(X, 2, sd) * sqrt((n - 1) / n)) max(abs(smd.baseline)) ``` Since the smallest value of $\lambda$ attained for the treated arm is approximately $\lambda_{\min} \approx 0.21$, this indicates that the closest we can bring the standardized treated covariate means to the overall means is an absolute SMD of about 0.21. This interpretation of $\lambda$ provides a convenient way to target a desired level of imbalance, available through the option `max.imbalance`. For lasso penalization, `balnet` then adjusts the generated $\lambda$ sequence so that it terminates at this value. The algorithm then attempts to compute the full regularization path, stopping gracefully if further reductions in imbalance are not achievable. Alternatively, users may compute $\lambda^{\max}$ (e.g., the maximum absolute unweighted SMD) for their dataset and then choose `lambda.min.ratio` to reflect an acceptable fraction of this maximum imbalance. For example, if $\lambda^{\max} = 10$, the default setting `lambda.min.ratio = 0.01` corresponds to a target maximum absolute SMD of $10 \times 0.01 = 0.1$. > *Note*: Setting lambda = 0 to try to achieve exact balance is not recommended, just as `glmnet` advises against it. `balnet` works best by using warm starts and gradually decreasing regularization, a strategy similar to barrier methods in convex optimization. This approach helps the algorithm converge reliably and improves performance on real-world datasets where achieving covariate balance can be difficult. ## Plotting path diagnostics `balnet` provides default plotting methods for visualizing regularization path diagnostics. Calling `plot` without additional arguments produces a summary of metrics along the path, indexed by $\lambda$ on the log scale. ```{r} plot(fit) ``` Two quantities are shown, both normalized to percentages. The first is the percent bias reduction (PBR), which measures the reduction in absolute SMD after weighting relative to the unweighted data. The second is the effective sample size (ESS), defined as the squared sum of weights divided by the sum of squared weights, normalized to sum to 100. Recall that $\lambda^{\max}$ corresponds to the intercept-only (unweighted) fit. As $\lambda$ decreases, covariate imbalance is reduced, but at the cost of a smaller effective sample size, reflecting increased concentration of weights on a subset of units. Individual covariate SMDs can be visualized at specific values of $\lambda$ by supplying the `lambda` argument. If the requested value is not exactly on the fitted $\lambda$ sequence, the closest value is used; in particular, setting `lambda = 0` selects the smallest value along the path. `balnet` then predicts the corresponding propensity scores, constructs inverse probability weights for the chosen estimand, and computes the resulting SMDs. ```{r} plot(fit, lambda = 0) ``` The unweighted SMDs are shown at $\lambda^{\max}$, while colored points correspond to the weighted SMDs at the selected $\lambda$. Separate panels are displayed for the treated and control arms, reflecting the fact that distinct propensity score models are fit for each arm in the ATE case. In `balnet`, SMDs take the form $(\text{weighted covariate mean} - \text{target mean}) ~/~ \text{sd(target)}$. In this example, the plots suggest limited overlap for the treated arm, indicating that the ATE may not be an appropriate target estimand. Instead, we can target the average treatment effect on the treated (ATT) by setting `target = "ATT"`. In this case, `balnet` fits a model that aims to balance control covariate means toward those of the treated group. ```{r} fit.att <- balnet(X, W, target = "ATT") plot(fit.att, lambda = 0) ``` Here, the resulting weights achieve substantially improved balance. For additional functionality, users are encouraged to consult the documentation for the standard S3 methods provided by `balnet`, including `predict` for propensity score prediction and `balweights` for extracting balancing weights. On large datasets, we recommend calling `balnet` with `verbose = TRUE` to interactively print balance metrics during fitting.