Dirichlet Distribution Output Layers for Uncertainty in Classification
by Peter de Blanc + ChatGPT Deep Research 3 days agoMotivation and Concept of Dirichlet Outputs
In a standard classifier, the softmax output gives a single categorical distribution for each input, but it cannot express uncertainty about that distribution itself. If an input is ambiguous or has inherently stochastic outcomes (multiple possible labels), a softmax often yields a high-entropy (flat) prediction – yet this alone does not reveal whether the model is uncertain due to lack of knowledge (epistemic uncertainty) or because the input truly has a mixed label distribution (aleatoric uncertainty). Using a Dirichlet distribution as the output layer addresses this limitation by predicting a distribution over categorical probability distributions. In practice, the network outputs parameters of a Dirichlet over classes, instead of a single pointwise softmax probability. The shape of the Dirichlet encodes confidence: for a “confident” input, the Dirichlet is concentrated at one corner of the probability simplex (one others), indicating one label with near-100% probability; for an input with intrinsic class ambiguity (high aleatoric uncertainty), the Dirichlet may be sharply peaked around the center of the simplex (indicating the model is sure the label distribution is broad, e.g. 50/50 between two classes); and for a completely novel or out-of-distribution input, the model can output a nearly flat Dirichlet (all low and equal), expressing maximal ignorance (Predictive Uncertainty Estimation via Prior Networks) (Predictive Uncertainty Estimation via Prior Networks). In essence, the Dirichlet’s concentration () serves as a measure of confidence or “knowledge”: a large means the prediction is based on strong evidence (low epistemic uncertainty), whereas a small (e.g. all ) indicates the model has little confidence and is essentially predicting a uniform distribution over labels (high epistemic uncertainty). This two-tier output (a “meta-distribution” over label probabilities) naturally allows one to distinguish epistemic vs. aleatoric uncertainty – something a single softmax layer cannot do ([1802.10501] Predictive Uncertainty Estimation via Prior Networks).
Evidential Deep Learning (Dirichlet Evidential Classifier)
One prominent approach is Evidential Deep Learning (EDL) proposed by Sensoy et al. (2018). EDL treats the network’s softmax logits as producing “evidence” for each class and places a Dirichlet distribution on the class probabilities as the output. In this framework (inspired by subjective logic), the network’s prediction is interpreted as a subjective Dirichlet opinion over the labels, rather than a single categorical distribution. The network is trained with a novel loss that penalizes incorrect confident predictions and encourages uncertainty when evidence is insufficient. Intuitively, if an input has conflicting or limited evidence for the labels, the model learns to output a Dirichlet with lower concentration (higher uncertainty). During training, each observed label is treated as a single draw from the predicted Dirichlet; the loss function (a form of expected cross-entropy plus regularization) updates the Dirichlet parameters such that they aggregate the evidence from all training instances. This means if a given input is seen with multiple different labels across the dataset (due to inherent randomness or annotator disagreement), the network can learn a higher-entropy Dirichlet for that input, reflecting the mixed outcomes. Sensoy et al. report that this evidential approach yields strong uncertainty quantification performance: the model can detect out-of-distribution inputs and flag low-confidence predictions far better than a standard softmax network. In their experiments, an EDL classifier achieved “unprecedented success” at identifying novel inputs and was more robust to adversarial perturbations than baseline methods. Notably, EDL does this without sampling or ensemble methods – it produces uncertainty estimates in one forward pass by outputting the Dirichlet parameters directly. Subsequent research has applied EDL in various domains; for example, in high-energy physics classification (jet identification), EDL has been used to provide a confidence measure (epistemic uncertainty) for each prediction as an alternative to Bayesian ensembles (Evidential Deep Learning for Uncertainty Quantification and Out-of-Distribution Detection in Jet Identification using Deep Neural Networks). EDL’s interpretation of model outputs as “evidence” has made it appealing for safety-critical tasks where the system should “know when it doesn’t know.”
Dirichlet Prior Networks (Malinin & Gales)
Another influential work is the Dirichlet Prior Network (DPN) framework by Malinin & Gales (2018) (Predictive Uncertainty Estimation via Prior Networks). A DPN explicitly learns a Dirichlet over the class probabilities as a prior, with the goal of separating data uncertainty (aleatoric) from distributional uncertainty (epistemic due to out-of-distribution inputs). During training, the network is encouraged to behave in three distinct ways (see Fig. 2 of Malinin & Gales (Predictive Uncertainty Estimation via Prior Networks) (Predictive Uncertainty Estimation via Prior Networks)): (a) for normal in-distribution inputs with a clear class, output a Dirichlet sharply peaked at the true class (high for that class, indicating confident prediction); (b) for inputs that are inherently noisy or ambiguous within the training distribution (e.g. overlapping class clusters or inherently stochastic outcomes), output a broad but high-concentration Dirichlet centered towards a simplex interior point (representing a known mixture over classes – high aleatoric uncertainty but low epistemic, since the model knows it should predict a distribution); and (c) for out-of-distribution inputs, output a flat Dirichlet (all classes equally likely, low concentration), representing maximum epistemic uncertainty or “unknown unknown.” To train such behaviors, DPNs often use an auxiliary loss or data augmentation: besides the usual training on labeled in-distribution samples (for case a and b), the network sees examples of unknown inputs (or synthetic out-of-distribution noise) and is trained to output a flat Dirichlet for those (Predictive Uncertainty Estimation via Prior Networks). This effectively gives the model a sense of what “nothing familiar” looks like. A key insight from Malinin & Gales is that by analyzing the Dirichlet produced by the DPN, one can quantify different types of uncertainty. For example, the mutual information between the Dirichlet and the predicted categorical (i.e. difference between the entropy of the mean categorical distribution and the expected entropy) provides a measure of epistemic uncertainty, whereas the entropy of the expected categorical distribution reflects total uncertainty (Predictive Uncertainty Estimation via Prior Networks) (Predictive Uncertainty Estimation via Prior Networks). In their evaluation, Dirichlet Prior Networks outperformed previous approaches on distinguishing in-domain vs. out-of-domain inputs and on misclassification detection tasks ([1802.10501] Predictive Uncertainty Estimation via Prior Networks). For instance, on MNIST and CIFAR-10, a DPN was able to correctly flag virtually all out-of-distribution examples (such as noise or unrelated images) by producing a low- Dirichlet, something a conventional softmax or even Bayesian methods struggled with ([1802.10501] Predictive Uncertainty Estimation via Prior Networks). Perhaps most importantly, DPNs demonstrated the ability to separate aleatoric and epistemic uncertainty: e.g. on a synthetic task where each input had two possible class outcomes with equal probability, the DPN learned to output parameters suggesting a predictive class mix (capturing the aleatoric part) with high (indicating the model is confident about this 50/50 split, hence low epistemic uncertainty). By contrast, for a novel input, the DPN would give a class mix with very low , indicating “I have no idea which class – it could be either.” This behavior illustrates how Dirichlet outputs facilitate a nuanced understanding of model predictions.
Other Dirichlet-Based Modeling Approaches
The idea of using a Dirichlet output layer has been explored in various other forms. Gast & Roth (2018) introduced a “Lightweight Probabilistic Deep Network” that replaces the softmax with a Dirichlet output by predicting both a mean and a variance for the class probabilities (Lightweight Probabilistic Deep Networks) (Lightweight Probabilistic Deep Networks). They pool per-class uncertainty estimates to produce the Dirichlet concentration (scale) along with the class probability mean, and train via maximum likelihood of the observed label under the Dirichlet (Lightweight Probabilistic Deep Networks). One practical trick they use is Laplace smoothing of one-hot labels (e.g. treating a training label as a 0.999/0.001 split rather than a strict 1/0 one-hot) so that the Dirichlet likelihood is well-defined (Lightweight Probabilistic Deep Networks). This approach yielded a classifier that provides its own uncertainty estimates and was tested on CIFAR-10 and MNIST. An important finding was that the Dirichlet-based network’s predicted entropy had a much stronger correlation with true error rates than a standard softmax network’s entropy did (Lightweight Probabilistic Deep Networks). In other words, when the Dirichlet model says it’s uncertain (high predictive entropy), it is far more likely to be wrong on that input, whereas a softmax often produces poorly calibrated confidence scores (Lightweight Probabilistic Deep Networks). The Dirichlet network in Gast & Roth’s work achieved similar (even slightly better) accuracy than a conventional network, but with significantly improved calibration (Lightweight Probabilistic Deep Networks). They also observed improved adversarial robustness: for example, under FGSM adversarial attacks, the Dirichlet-output model’s accuracy degraded much more gracefully compared to a standard softmax model (Lightweight Probabilistic Deep Networks). This robustness is attributed to the model not becoming over-confident on perturbed inputs – it naturally outputs lower (higher uncertainty) for subtly shifted inputs that it isn’t sure about.
Beyond these, researchers have proposed variations like “Posterior Networks” (Charpentier et al., 2020) which also predict Dirichlet parameters but avoid the need for out-of-distribution training data by using density estimates to generate pseudo-counts for the Dirichlet (effectively approximating how many virtual samples of each class the model’s seen for a given input) ([PDF] Posterior Network: Uncertainty Estimation without OOD Samples via ...). Another line of work is Ensemble Distribution Distillation, where an ensemble of models (or Monte Carlo dropout samples) is distilled into a single Dirichlet-generating network (Ensemble Distribution Distillation | Request PDF - ResearchGate). Malinin et al. (2019) showed that a Dirichlet prior network can be trained to match the predictive distribution of an ensemble, preserving the ensemble’s uncertainty information in a single model (Ensemble Distribution Distillation | Request PDF - ResearchGate). This is particularly useful in discrete domains where running a large ensemble at inference is infeasible: the distilled DPN mimics the ensemble’s behavior (e.g. broad Dirichlet in regions where ensemble members disagree, sharp Dirichlet where they concur). There have also been applications in areas like federated learning (e.g. modeling personalized uncertainty for each client with Dirichlet outputs ([PDF] Dirichlet-based Uncertainty Quantification for Personalized ... - IJCAI)) and multi-annotator learning (where each data point may have a distribution of labels from different annotators). In the latter case, a Dirichlet output can naturally model the annotator label distribution: the network can be trained on the empirical label frequencies for each input (using a Dirichlet-multinomial likelihood), rather than on a single “ground truth” label. This allows the model to learn a higher uncertainty (more diffuse Dirichlet) for items with annotator disagreement or inherent ambiguity. Such an approach has been mentioned in the context of partial label learning, where a framework was proposed to use a Dirichlet to represent an instance’s label probabilities rather than committing to one label ([PDF] Label Enhancement via Joint Implicit Representation Clustering) (Papers by Min-Ling Zhang - AIModels.fyi). Overall, the flexibility of Dirichlet outputs has spurred diverse explorations wherever uncertainty quantification is crucial.
Performance and Implications for Uncertainty Quantification
Dirichlet-based output layers have demonstrated improved uncertainty quantification in multiple metrics and settings. In classification tasks with a moderate number of classes (discrete, low-dimensional output spaces), these models can achieve calibration and detection performance on par with or exceeding Bayesian neural networks or ensembles, but at a fraction of the computational cost. For example, on CIFAR-10, a Dirichlet-output network trained with Gast & Roth’s method had similar accuracy to a standard network, but was much better calibrated (its predicted probability entropy correlated strongly with actual misclassification rates) (Lightweight Probabilistic Deep Networks). Sensoy et al. showed that their evidential net could detect out-of-distribution inputs with near-perfect accuracy (by thresholding the Dirichlet concentration); this was a notable improvement over approaches like plain softmax confidence or MC dropout. Malinin & Gales reported that on a rotated-MNIST task (where inputs are from a new distribution), the Prior Network reliably identified high distributional uncertainty (flat Dirichlet) and on an ambiguous synthetic task it correctly expressed high data uncertainty for in-domain ambiguity (Predictive Uncertainty Estimation via Prior Networks) (Predictive Uncertainty Estimation via Prior Networks). In terms of aleatoric vs. epistemic separation, these methods provide quantitative measures: one can compute the expected predictive entropy (reflecting total uncertainty) and the Dirichlet’s mutual information (reflecting epistemic uncertainty) to see how much uncertainty comes from lack of knowledge. Empirically, this has enabled more nuanced decisions – e.g. a medical diagnosis model using a Dirichlet output could say “this scan is likely Class A with 60% and Class B with 40%, and I’m confident this reflects genuine ambiguity (aleatoric), not model ignorance,” or conversely flag a novel case with “uncertain because we’ve not seen something like this (epistemic).”
It is worth noting that Dirichlet-output networks often need careful training to avoid degenerate solutions. Regularization is typically used to prevent trivial Dirichlet predictions (e.g. always outputting a flat Dirichlet); the evidential loss functions usually include terms that encourage the epistemic uncertainty to diminish as more data evidence is gathered for a given class. Recent analyses have identified some challenges: for instance, Bengs et al. (2021) pointed out that certain EDL training objectives can lead to residual epistemic uncertainty even in the infinite data limit, meaning the model might not fully “trust itself” even when an input has been seen many times (Improved Evidential Deep Learning via a Mixture of Dirichlet Distributions). In response, Ryu et al. (2024) proposed an improved approach using a mixture of Dirichlet distributions to ensure the learned epistemic uncertainty vanishes appropriately with sufficient data (Improved Evidential Deep Learning via a Mixture of Dirichlet Distributions). Such refinements indicate that the field is actively addressing the theoretical consistency of Dirichlet-based uncertainty models. Another practical consideration is scalability: for very large label spaces (e.g. language models with tens of thousands of tokens), outputting a Dirichlet with thousands of parameters is possible but computationally heavier. Thus, most success with Dirichlet output layers has been reported in tasks with a manageable number of classes (vision, small-scale NLP, etc.), or by focusing on subsets of the output space in structured prediction.
In summary, using a Dirichlet distribution as the output of a neural classifier provides a principled way to aggregate multiple label instances per input and to decompose predictive uncertainty into its components. Across numerous studies and domains, this approach has shown strong performance in uncertainty quantification – models are better calibrated (Lightweight Probabilistic Deep Networks), more robust to out-of-distribution and adversarial inputs (Lightweight Probabilistic Deep Networks), and capable of expressing when an input/output relationship is intrinsically stochastic. These benefits come with little to no loss in accuracy (and sometimes even gains, due to the regularization effects of the Dirichlet-based loss (Lightweight Probabilistic Deep Networks) (Lightweight Probabilistic Deep Networks)). As research continues, Dirichlet output layers are finding broader applications, providing an important tool for developing trustworthy AI systems that know what they know – and what they don’t.
Sources: Recent papers exploring Dirichlet outputs include Sensoy et al. (2018), Malinin & Gales (2018) (Predictive Uncertainty Estimation via Prior Networks) (Predictive Uncertainty Estimation via Prior Networks), Gast & Roth (2018) (Lightweight Probabilistic Deep Networks) (Lightweight Probabilistic Deep Networks), Charpentier et al. (2020) ([PDF] Posterior Network: Uncertainty Estimation without OOD Samples via ...), and Ryu et al. (2024) (Improved Evidential Deep Learning via a Mixture of Dirichlet Distributions), among others, with applications ranging from image recognition to physics and beyond.
1 Comment
Sign in or sign up to post a comment.
Peter de Blanc
3 days agoFor Monte Carlo Tree Search, I think this could be useful for estimating how deeply to search a position. Higher meta-uncertainty -> more search.
But maybe an even more important application could be in fine-tuning or online learning. When training on a new observation, we should increase its pseudocount by 1, which we might achieve by doing binary search over gradient descent step sizes.