When I consulted researchers on which statistical analysis to use for their data, a common first step was to think about the distribution of the target variable:
Is it a count, like the number of emails received within an hour? Poisson distribution it is.
Is the target variable a binary category, like whether or not a customer bought a product? Use the Bernoulli distribution.
There’s a long list of such distributions and too many ways you can modify them. Like truncated versions, mixtures, and so on. It’s a bit like a game where you can customize your character: An easy way to get stuck before you even get started.
Anyways, based on the presumed distribution the next step was to fit a GLM or GAM, which ultimately meant maximizing the likelihood of the chosen distribution.
In contrast to classic statistical modeling, supervised machine learning seems pretty dumb. Sophisticated discussion of distribution assumptions? Not on my watch. Just tell me whether it’s regression or classification.
But, of course, there’s more to machine learning. First of all, there are more tasks than just classification and regression, such as survival analysis, multi-label, and cost-sensitive classification. And for a given task like regression, you have options regarding what you optimize for.
Because you have to pick a loss function that the model optimizes the predictions against.1 For regression, for example, you can optimize the squared loss (y-f(x))^2
or the absolute loss abs(y-f(x))
.
But there’s more to picking a loss function than just identifying the task and then going through the rummage table of loss functions. There’s a beautiful connection to statistical modeling, bridging the gap between these two modeling mindsets.
Bridge to Statistical Modeling
Certain loss functions have the same optimization goal as maximum likelihood estimation for certain distributions.
For example, using the L2 loss leads to the same optimization as the maximum likelihood of a Gaussian distribution of the target conditional on the features (Y|X).
There are many examples of distribution assumptions where maximum likelihood estimation corresponds to certain loss functions:
Bernoulli distribution => Binary cross-entropy loss
Multinomial distribution => Categorical cross-entropy loss
Exponential distribution => Exponential loss
Poisson distribution => Poisson loss
Gaussian distribution => L2 loss
Quantiles of arbitrary distribution => Pinball loss
Median of arbitrary distribution => L1 loss (special case of pinball loss)
Whenever you can identify a likelihood function, you can turn it into a loss function by using the negative log-likelihood as the loss function.
This link between distributions and loss functions is beautiful for multiple reasons:
It bridges supervised machine learning and classic statistical modeling
As a statistician, machine learning has become more attractive. At least for me.
The link provides a reason for picking a loss function.
Depending on the loss and connected distribution, you may interpret your prediction in a certain way:
pinball loss: the prediction is a quantile
L2 loss: the prediction is the mean of a conditional distribution
Cross-entropy loss: the prediction can be interpreted as a probability for the positive class (at least if your model is well-calibrated)
With machine learning, however, you are not tied to picking a loss function that corresponds to a distribution. You are free to mix, adapt, and customize. Using this bridge to distributions can be a great starting point, but the real power lies in the flexibility of designing your custom loss functions.
Depends on the model class and implementation because some have some fixed loss functions (e.g. sklearn.linear_model.LinearRegression leaves you with L1 and L2) or the loss can be implicit and algorithmically defined, like for decision trees with greedy splitting.
You can also mix, adapt and customize your distribution.
Thank you for this well-written and insightful article. Could you please point me to a reference that derives the connection between data distribution and loss function.