Simplex Clustering via with Applications to Online Adjustment of Black-Box Predictions
Abstract
We explore clustering the softmax predictions of deep neural networks and introduce a novel probabilistic clustering method, referred to as k-sBetas. In the general context of clustering discrete distributions, the existing methods focused on exploring distortion measures tailored to simplex data, such as the KL divergence, as alternatives to the standard Euclidean distance. We provide a general maximum a posteriori (MAP) perspective of clustering distributions, emphasizing that the statistical models underlying the existing distortion-based methods may not be descriptive enough. Instead, we optimize a mixed-variable objective measuring data conformity within each cluster to the introduced density function, whose parameters are constrained and estimated jointly with binary assignment variables. Our versatile formulation approximates various parametric densities for modeling simplex data and enables the control of the cluster-balance bias. This yields highly competitive performances for the unsupervised adjustment of black-box model predictions in various scenarios. Our code and comparisons with the existing simplex-clustering approaches and our introduced softmax-prediction benchmarks are publicly available: https://github.com/fchiaroni/Clustering_Softmax_Predictions.
Index Terms:
Probability simplex clustering, softmax predictions, deep black-box models, pre-trained, unsupervised adaptation.1 Introduction
Over the last decade, deep neural networks have continuously gained wide interest for the semantic analysis of real-world data. However, under real-world conditions, potential shifts in the feature or class distributions may affect the prediction performances of the pre-trained source models. To address this issue, several recent adaptation studies investigated fine-tuning all or a part of the model’s parameters using standard gradient-descent and back-propagation procedures [1, 2, 3]. However, in a breadth of practical applications, e.g., real-time predictions, it might be cumbersome to perform multiple forward and backward passes. This is particularly true when dealing with large (and continuously growing) pre-trained networks, such as the recently emerging foundational vision-language models [4]. In such scenarios, even fine-tuning only parts of the pre-trained model, such as the parameters of the normalization layers [1], might be computationally intensive. In addition, the existing adaptation techniques assume knowledge of the source model and its training procedures. Such assumptions may not always hold in practice due to data-privacy constraints [5] and the fact that many large-scale pre-trained models are only available through APIs [6], i.e., their pre-trained weights are not shared111This is the case, for instance, of large-scale foundation models in NLP such as the GPT family, Anthropic’s Claude or Google’s PaLM.. Therefore, model-agnostic solutions, which enable computationally efficient adaptation of black-box pre-trained models, would be of significant interest in practice.
In this work, we propose to cluster the softmax predictions of deep neural networks. Solely based on the model’s probabilistic outputs, this strategy yields computationally efficient adaptation while preserving the privacy of both the model and data. Fig. 1 depicts an application example for road segmentation under domain shifts.
![Refer to caption](extracted/5701331/illustrations/GTA5_Cityscapes_semseg/road_seg_only/7_argmax_overlayed.png)
(a) argmax
![Refer to caption](extracted/5701331/illustrations/GTA5_Cityscapes_semseg/road_seg_only/7_kmeans_overlayed.png)
(b) k-means
![Refer to caption](extracted/5701331/illustrations/GTA5_Cityscapes_semseg/road_seg_only/7_sBeta_overlayed.png)
(c) k-sBetas (ours)
![Refer to caption](extracted/5701331/illustrations/GTA5_Cityscapes_semseg/road_seg_only/7_GT_overlayed.png)
(d) Ground truth
Clustering the softmax output predictions of deep networks can be formulated as a clustering problem on the probability simplex domain, i.e., clustering discrete distributions. Probability simplex clustering is not a new problem and has been the subject of several early studies in other application domains such as text analysis [7, 8, 9, 10]. However, in the context of deep learning, clustering methods have been typically applied on deep feature maps, as in self-training frameworks [11]. This requires access to the internal feature maps of the source models, thereby violating the black-box assumption. In contrast, simplex-clustering-based adjustment of the probabilistic outputs of a deep pre-trained model complies with the black-box assumption but has remained largely under-explored, to the best of our knowledge.
![Refer to caption](x1.png)
(a)
![Refer to caption](x2.png)
(b)
![Refer to caption](x3.png)
(c)
![Refer to caption](x4.png)
(d)
The simplex clustering literature is often based on optimizing distortion-based objectives. The goal is to minimize, within each cluster, some distortion measure evaluating the discrepancy between a cluster representative and each sample within the cluster. Besides standard objectives like k-means, which correspond to an distortion, several simplex-clustering methods motivated and used distortion measures that are specific to simplex data. This includes, for instance, the Kullback-Leibler (KL) divergence [10, 9] and the Hilbert geometry distance [12]. In this work, we explore a general maximum a posteriori (MAP) perspective of clustering discrete distributions. Using the Gibbs model, we emphasize that the density functions underlying current distortion-based objectives (Table I) may not properly approximate the empirical marginal distributions of real-world softmax predictions.
Fig. 2 illustrates our empirical observations through histograms representing the marginal (per-coordinate) distributions of real softmax predictions. These predictions were obtained using a black-box source model applied to data from an unseen target domain. We juxtapose these empirical histograms to parametric density functions estimated on the same predictions via various clustering methods. For instance, curve represents the parametric Gibbs model corresponding to the Euclidean distance used in k-means. A close matching between the shape of the parametric density function and the one of the empirical histogram indicates a good fit of the simplex predictions. However, a significant disparity implies that the assumed parametric density function may not effectively model these predictions. For example, in Fig. 2 (c), the modes of parametric functions , , and do not correspond to the one of the empirical distribution depicted with the orange histogram. Furthermore, Fig. 2 (d) depicts a particular scenario with an approximately uniform empirical distribution. In such cases, it is crucial for clustering models to generate unimodal density functions in order to preserve the assumption that the distribution of each class is unimodal. This avoids erroneously fitting multiple unimodal distributions to the data within a single cluster. As depicted in Fig. 2 (d), the standard density may generate a bimodal distribution, thereby violating the unimodality assumption.
Contributions. Driven by the above observations, we introduce a novel probabilistic clustering objective integrating a generalization of the density, which we coin . We derive several properties of , which enable us to impose different constraints on our clustering model, referred to as k-sBetas, enforcing uni-modality of the densities within each cluster while avoiding degenerate solutions. We proceed with a block-coordinate descent approach, alternating optimization w.r.t assignment variables, and inner Newton iterations for solving the non-linear optimality conditions w.r.t the sBeta parameters. Furthermore, using the density moments, we derive a closed-form alternative to Newton iterations for parameter estimation. Our versatile formulation approximates various parametric densities for modeling simplex data, including highly peaked distributions at the vertices of the simplex, as observed empirically in the case of deep-network predictions. It also enables the control of the cluster balance. We report comprehensive experiments, comparisons, and ablation studies, which show highly competitive performances of our simplex clustering for unsupervised adjustment of black-box predictions in a variety of scenarios. Fig. 2 illustrates the ability of our method to consistently fit the empirical marginal distributions of real-world softmax predictions. For instance, Fig. 2 (c) shows that the density peak of (orange curve) matches the highest bar of the histogram, enabling a more accurate fitting than the other methods. In addition, unlike the standard density, our approach deliberately avoids fitting bimodal distributions (Fig. 2 (d)), thereby preventing degenerate solutions. This refinement not only ensures a more accurate representation of real data, but also maintains the assumption that the distribution of each class is unimodal. To reproduce our comparative experiments, we have made the codes publicly available, including the proposed k-sBetas model, the explored state-of-the-art approaches, as well as the introduced softmax prediction benchmarks 222https://github.com/fchiaroni/Clustering_Softmax_Predictions.
2 Problem formulation
We start by introducing the basic notations, which will be used throughout the article:
-
•
stands for the ()-probability simplex.
-
•
, with a given data set and denoting the random simplex vector in .
-
•
denotes cluster , where is a latent binary vector assigning point to cluster : if belongs to cluster and otherwise. Let denote the latent assignment matrix whose column vectors are given by . Note that .
Clustering consists of partitioning a given set into different subsets referred to as clusters. In general, this is done by optimizing some objective functions with respect to the latent assignment variables . Often, the basic assumption underlying objective functions for clustering is that data points belonging to the same cluster should be relatively close to each other, according to some pre-defined notion of closeness (e.g., via some distance or distortion measures).
In this study, we tackle the clustering of probability simplex data points (or distributions), with a particular focus on the output predictions from deep-learning models. Softmax 333Also referred to as the normalized exponential function, the softmax function is defined as , with denoting a vector of logits. is commonly used as the last activation function of neural-network classifiers to produce such output probability distributions. Each resulting softmax data point is a -dimensional probability vector of continuous random variables, which are in and sum to one. Thus, the target clustering challenge is defined on probability simplex domain . Fig. 3 depicts mixtures of three class distributions on , simulating the impact of a domain shift on the softmax predictions, as well as the benefits of simplex clustering over the standard argmax prediction in such a scenario. Specifically, Fig. 3 (a) shows the softmax predictions from a model trained on and applied to the same source domain. In this case, the standard argmax function, commonly used in deep networks, effectively separates the three classes. In contrast, Fig. 3 (b) reveals that, in the case of a domain shift, the argmax function may lead to erroneous predictions. Fig. 3 (c) illustrates how a simplex clustering successfully maintains class cohesion in this case. This synthetic example points to the potential of simplex clustering in dealing with domain-shift challenges while operating solely on the outputs (i.e., in a black-box setting).
![Refer to caption](x5.png)
(a) Argmax or PSC
NO SHIFT
![Refer to caption](x6.png)
(b) Argmax
SHIFT
![Refer to caption](x7.png)
(c) PSC
3 Related work
This section presents related work in the general context of clustering, focusing on the probability simplex domain.
3.1 Distortion-based clustering objectives
The most widely used form of clustering objectives is based on some distortion measures, for instance, a distance in general-purpose clustering or a divergence measure in the case of probability simplex data. This amounts to minimizing w.r.t both assignment variables and cluster representatives a mixed-variable function of the following general form:
(1) |
The goal is to minimize, within each cluster , some distortion measure evaluating the discrepancy between cluster representative and each point belonging to the cluster. When is the Euclidean distance, general form (1) becomes the standard and widely used k-means objective [13]. In general, optimizing distortion-based objective (1) is NP-hard444For example, a proof of the NP-hardness of the standard k-means objective could be found in [14].. One standard iterative solution to tackle the problem in (1) is to follow a block-coordinate descent approach, which alternates two steps, one optimizes the objective w.r.t to the cluster prototypes and the other w.r.t assignment variables:
-
•
U-update:
-
•
-update: Find
In the specific case of k-means (i.e., using the distance as distortion measure), optimization w.r.t the parameters in the -update step yields closed-form solutions, which correspond to the means of features within the clusters:
One way to generalize k-means is to replace the Euclidean distance with other distortion measures. In this case, the optimal value of the parameters may no longer correspond to cluster means. For instance, when using the distance as distortion measure in (1), the optimal corresponds to the cluster median, and the ensuing algorithm is the well-known k-median clustering [15]. For exponential distortions, the optimal parameters may correspond to cluster modes [16, 17]. Such exponential distortions enable the model to be more robust to outliers, but this comes at the price of additional computations (inner iterations), as cluster modes could not be obtained in closed-form.
3.2 Distortion measures for simplex data
In our case, data points are probability vectors within simplex domain , e.g. the softmax predictions of deep networks. These points are -dimensional vectors of continuous random variables bounded between and , and summing to one. To our knowledge, the simplex clustering literature is often based on distortion objectives of the general form in (1). Besides standard objectives like k-means, several simplex-clustering works motivated and used distortion measures that are specific to simplex data. This includes the Kullback-Leibler (KL) divergence [10, 9] and Hilbert geometry distance [12].
Information-theoretic k-means. The works in [10, 9] discussed KL k-means, a distortion-based clustering tailored to simplex data and whose objective fits the general form in (1). KL k-means uses the Kullback-Leibler (KL) divergence as a distortion measure instead of the Euclidean distance in the standard k-means:
(2) |
where and stand, respectively, for the -th component of simplex vectors and . is a Bregman divergence [8] measuring the dissimilarity between two distributions and . Despite being asymmetric [10], the machine learning community widely uses the KL divergence in a breadth of problems where one has to deal with probability simplex vectors [18, 19, 20, 21].
Hilbert Simplex Clustering (HSC). More recently, the study in [12] investigated the Hilbert Geometry (HG) distortion with minimum enclosing ball (MEB) centroids [22, 23] for clustering probability simplex vectors. For a given cluster set , the MEB center represents the midpoint of the two farthest points within the set. Given two points and within simplex domain , let , , denote the line passing through points and , with and the two intersection points of this line with the simplex domain boundary . Then, HG simplex distortion is given by:
(3) | ||||
Note that if one deals with centroids close or equal to the vertices of the simplex, then HG distortion would inconsistently output large or infinite distance values near such centroids.
3.3 From distortions to probabilistic clustering
Beyond distortions, a more versatile approach is to minimize the following well-known maximum a posteriori (MAP) generalization of K-means [24, 25, 26], coined by [26] as probabilistic K-means:
(4) |
where is a parametric density function modeling the likelihood probabilities of the samples within cluster , and the prior probability of cluster , with . Hence, instead of minimizing a distortion in (1), we maximize a more general quantity measuring the posterior probability of cluster given sample and parameters . It is easy to see that, for any choice of distortion measure, the objective in (1) corresponds, up to a constant, to a particular case of probabilistic K-means in (4), with , and parametric density chosen to be the Gibbs distribution:
(5) |
For instance, the K-means objective corresponds, up to additive and multiplicative constants, to choosing likelihood probability to be a Gaussian density with mean equal to , identity covariance matrix and prior probabilities verifying . This uncovers hidden assumptions in K-means: The samples are assumed to follow Gaussian distribution within each cluster, and the clusters are assumed to be balanced555In fact, it is well-known that K-means has a strong bias towards balanced partitions [24]. Therefore, the general probabilistic K-means formulation in (4) is more versatile as it enables to make more appropriate assumptions about the parametric statistical models of the samples within the clusters (e.g., via prior knowledge or validation data), and to manage the cluster-balance bias [24].
3.4 Do current simplex clustering methods model softmax predictions properly?
Metric-Based | |
---|---|
(Eq. (3)) |
Probabilistic | |
---|---|
Eq. (5) emphasized that there are implicit density functions underlying distortion-based clustering objectives via Gibbs models; see Table I. Fig. 2 illustrates how the Gibbs models corresponding to the existing simplex clustering algorithms may not properly approximate the empirical marginal density functions of real-world softmax predictions:
Standard k-means: Fast, but not descriptive enough and biased. The well-known k-means clustering falls under the probabilistic clustering framework, as defined in Sec. 3.3. In particular, it instantiates the data likelihoods as Gaussian densities with unit covariance matrices. An important advantage of k-means is that the associated -update step can be analytically solved by simply updating each as the mean of samples within cluster , according to the current latent assignments. Such closed-form solutions enable k-means to be among the most efficient clustering algorithms. However, the Gaussian assumption yielding such closed-form solutions might be limiting when dealing with asymmetric data densities, such as exponential densities, with the particularity that the mode and the mean are distinct. As depicted by Fig. 2 (a) and (b), the softmax outputs from actual deep-learning models commonly follow such asymmetric distributions. Note that k-means and other distortion-based objectives correspond to setting in the general probabilistic formulation in Eq. (4). Therefore, they encode the implicit strong bias towards balanced partitions [24], and might be sub-optimal when dealing with imbalanced clusters, as is often the case of realistic deep-network predictions.
Elliptic k-means. Elliptic k-means is a generalization of the standard K-means: It uses the Mahalanobis distance instead of the Euclidean distance, which assumes the data within each cluster follows a Gaussian distribution (i.e. Gibbs model for the Mahalanobis distance) whose covariance matrix is also a variable of the model (as opposed to unit covariance in k-means); see Table I. Fig. 2 (b) and (c) confirm that Gaussian density function is not appropriate for skewed distributions, as is the case of network predictions.
HSC yields poor approximations of highly peaked distributions located at the vertices of the simplex. Fig. 2 shows that HSC, which combines MEB centroids and HG metric, is not relevant to asymmetric distributions whose modes are close to the vertices of the simplex. Specifically, in Fig. 2 (a) and (b), the corresponding Gibbs probability is small near the vertices of the simplex and high in the middle of the simplex. This yields a poor approximation of the target empirical histograms (depicted in orange). In addition, the overall HSC algorithm is computationally demanding (see Table. III).
KL k-means. While KL could yield asymmetric density, it may poorly model highly peaked distributions at the vertices of the simplex (e.g. Fig. 2 a), as is the case of softmax predictions.
4 Proposed approach
4.1 Background
Multivariate Dirichlet density function: Commonly used for modeling categorical events, the Dirichlet density operates on -dimensional discrete distributions within the simplex, . Its expression, parameterized by vector , is given by:
(6) |
with , and the multivariate Beta function expressed using the Gamma function666The Gamma function for . when is a strictly positive integer.. While the Dirichlet density is designed for simplex vectors, its complex nature leads to computationally intensive optimization, often involving iterative and potentially sub-optimal parameter estimation [27]. To address this computational challenge, one could adopt the mean-field approximation principle [28]. This uses a relaxed product density, which is based on Dirichlet’s marginal density, i.e. with parameters and :
(7) | ||||
This approximation simplifies parameter estimation by treating each simplex coordinate independently. However, apart from the computational load, and its marginal may yield poor approximations of the empirical marginal distributions of real-world softmax predictions. For example, Fig. 2 (a), (b), and (c) show how has difficulty capturing empirical marginal distributions. In addition, allows multimodal distributions, as shown with its marginal in Fig. 2 (d), opposing the common assumption that each semantic class distribution is unimodal. Indeed, discriminative deep learning models are optimized to output probability simplex vertices (i.e. one-hot vectors). For each class, the model is optimized to output softmax predictions following an unimodal distribution. In this context, may erroneously represent two unimodal distributions as a single bimodal distribution.
In the next section, we propose , an alternative capable of effectively addressing these limitations.
4.2 Proposed density function: A generalization of Beta constrained to be unimodal
![Refer to caption](x8.png)
(a)
![Refer to caption](x9.png)
(b)
![Refer to caption](x10.png)
(c) concentration
Unlike , we seek a density function that can approximate highly peaked densities near the vertices of the simplex
while satisfying a uni-modality constraint, thereby avoiding degenerate solutions and accounting for the statistical properties of real-world softmax predictions.
4.2.1 Generalization of Beta
We propose the following generalization of , which we refer to as (scaled ) in the sequel:
(8) |
with . Clearly, when parameter is set equal to , the generalization in (8) reduces to density in (7). Figures 4 (a) and (b) depict and density functions for different values of parameters and , showing the difference between the two densities. For instance, for and , one may observe that is characterized by a higher density than in the proximity of the simplex vertex in this particular example. As illustrated in Figs. 2 (a), (b) and (c), approximations of the empirical distributions (orange histograms) of the softmax predictions are better than those obtained with the density. Overall, can be viewed as a scaled variant of . It is relatively more permissive, enabling it to fit a wider range of unimodal simplex distributions.
In the following, we provide expressions of the moments (first and second central) and mode of as functions of density parameters and . As we will see shortly, these properties will enable us to integrate different constraints with our -based probabilistic clustering objective, enforcing uni-modality of the distribution within each cluster while avoiding degenerate solutions. Furthermore, they will enable us to derive a computationally efficient, moment-based estimation of the density parameters.
4.2.2 Mean and Variance
Property 1.
The mean of could be expressed as a function of the density parameters as follows:
(9) |
Proof.
The mean of could be found by integrating . A detailed derivation is provided in Appendix A. ∎
Property 2.
The variance of could be expressed as a function of the density parameters as follows:
(10) |
Proof.
We use , with denoting the second moment of , which could be found by integrating . The details of the computations of and are provided in the Appendix. A. ∎
4.2.3 Mode and Concentration
The mode of a density function corresponds to the value of at which achieves its maximum.
Property 3.
The mode of could be expressed as a function of the density parameters as follows:
(11) |
Proof.
The mode of can be found by estimating at which value of the derivative of is equal to . A detailed derivation can be found in Appendix A. ∎
Note that, when , we have . Otherwise, . Thus, is asymmetric when , and the variance is consequently inadequate to measure the dispersion around mode.
Concentration parameter. Motivated by the previous observation, we present the concentration parameter to measure how much a sample set is condensed around the mode value. The concept of concentration parameter has been previously discussed in [29]. With respect to the mode equation 11, we express such that we have
(12) |
Using Eq. (12), we can express and as functions of the mode and concentration parameter:
(13) |
Fig. 4 (c) shows different concentrations around a given mode when changing .
4.3 Proposed clustering model: k-sBetas
Probabilistic clustering. We cast clustering distributions as minimizing the following probabilistic objective (k-sBetas), which measures the conformity of data within each cluster to a multi-variate density function, subject to constraints that enforce uni-modality and discourage degenerate solutions:
(14) |
(15) |
where is a multivariate extension of :
and denotes the concentration parameter of univariate . In our model in (14), , with and , and, following the general notation we introduced in Sec. 2, denotes binary point-to-cluster assignments.
4.3.1 Block-coordinate descent optimization
Our objective in (14) depends on three different sets of variables: ; ; and sBeta parameters . Therefore, we proceed with a block-coordinate descent approach, alternating three steps. Each step optimizes the objective w.r.t a set of variables while kee** the rest of the variables fixed.
-
•
updates, with and fixed: This section presents two different strategies for updating the parameters:
(a) Solving the necessary conditions for the minimum of w.r.t and :
Our objective in (14) is convex in each and . The global optima could be obtained by setting the gradient to zero. This could be viewed as a maximum likelihood estimation approach (MLE). Unfortunately, this yields a non-linear system of equations that, cannot be solved in closed-form. Therefore, we proceed to an inner iteration. In the appendix, we derive the following iterative and alternating updates to solve the non-linear system of equations:
(16) (17) where refers to iteration number and corresponds to the digamma function. Note that neither nor admit analytic expressions. Instead, we follow [27] to approximate those functions, at the cost of additional Newton iterations. The full derivation of the updates and all required details could be found in the Appendix.
(b) Method of moments (MoM):
As a computationally efficient alternative to solving a non-linear system within each outer iteration, we introduce an approximate estimation of sBeta parameters and , which we denote and , as the solutions to the first and second central moment equations, following Prop. 1 and Prop. 2:
(18) where and denote, respectively, the empirical means and variances of cluster :
(19) The system of two equations and two unknowns in (18) could be solved efficiently in closed-form, which yields the following estimates of parameters:
(20) with .
-
•
updates, with and fixed: With variables and fixed, the global optimum of our objective in (14) with respect to assignment variables , subject to constraints , corresponds to the following closed-form solution for each assignment variable :
(21) -
•
updates, with and fixed: Solving the Karush–Kuhn–Tucker (KKT) conditions for the minimum of (14) with respect to , s.t , yields the following closed-form solution for each :
(22)
4.3.2 Handling parameter constraints
Our model in (14) integrates two constraints on concentration parameters ; one discourages multi-modal solutions, and the other avoids degenerate, highly peaked (Dirac) densities.
Notice that the density is bimodal when and , which corresponds to (as ). Moreover, setting , which yields , corresponds to the uniform density. Thus, to constrain the densities to be strictly unimodal, we must restrict to positive values. In practice, we constrain to be greater or equal to a threshold , which we simply set to .
A high-density concentration around the mode corresponds to high values . Thus, to avoid degenerate, highly peaked (Dirac) densities, we constrain to stay smaller than or equal to a fixed, strictly positive threshold .
Overall, we enforce on parameters using Eq. (13), as detailed in Algorithm 1. This procedure allows for maintaining the same mode. Fig. 5 (a) shows density estimation, following the constraint , on a bimodal distribution sample. Fig. 5 (b) shows estimation, following the constraint , on a Dirac distribution sample.
![Refer to caption](x11.png)
(a) Avoid bimodal solutions
![Refer to caption](x12.png)
(b) Avoid Dirac solutions
4.3.3 Centroid initialization
Clustering algorithms are known to be sensitive to their parameter initialization. They often converge to a local optimum. To reduce this limitation, the seeding initialization strategy k-means++, proposed in [30] and then thoroughly studied in [31], is broadly used for centroid initialization. However, in the context of the clustering of softmax predictions, we can assume that softmax predictions generated by deep learning models were optimized upstream to be sets of one-hot vectors777A one-hot vector both refers to a semantic class and to a vertex on the probability simplex.. Thus, we propose initializing the centroids as vertices of the target probability simplex domain. More specifically, we initialize all and parameters such that they model exponential densities on at the start. In other words, each initial mode is set as a vertex among all possible vertices on the probability simplex domain. Beyond improving k-sBetas, this simple initialization strategy unanimously improves the scores of every tested clustering method888Table XIV in Appendix empirically supports this statement..
4.3.4 From cluster-label to class-label assignment
After clustering, we must align the obtained clusters with the target classes. One can align each cluster centroid with the closest one-hot vector by using the argmax function. However, in this case, a one-hot vector corresponding to a given class could be matched to several cluster centroids, thereby assigning them to the same class. To prevent this problem on closed-set challenges, i.e., with the prior knowledge , we use the optimal transport Hungarian algorithm [32]. Specifically, we compute the Euclidean distances between all the centroids and one-hot vectors. Then, we apply the Hungarian algorithm to this matrix of distances to find the lowest-cost way, assigning a separate class to each cluster.
5 Experiments
Throughout this section, we show different applications of the proposed method and compare it with the state-of-the-art clustering methods discussed in Sec. 3. We first validate our implementations on synthetic datasets and on real softmax predictions from unsupervised domain adaptation (UDA) source models (Sec. 5.1). Then, in Sec. 5.2, we demonstrate the usefulness of our simplex-clustering method in the context of transductive zero- and one-shot predictions using contrastive language-image pre-training (CLIP) [4] for the source model. Finally, we provide an application example for a dense prediction task: real-time UDA for road segmentation (Sec. 5.3). In each of these settings, our method is used as a plug-in on top of the output predictions from a black-box deep learning model. Along with the experiments and comparisons that follow, we also provide ablation studies, which highlight the effect of each component of the proposed framework.
Comparisons. We compare our method to several state-of-the-art clustering methods. Specifically, we include KL k-means [10] [9] and HSC [12], which are specifically designed to deal with probability simplex vectors. We additionally compare our method to standard clustering methods: k-means, GMM [33], k-medians [15], k-medoids [34] and k-modes [16]. Both k-means and KL k-means use the mean as the cluster’s parameter vector, with the former being based on the Euclidean distortion and the latter on the Kullback-Leibler one. Extensively studied in [15], k-medians uses the median as a cluster representative and the Manhattan distance. Clustering with k-medoids uses the Euclidean distance while estimating medoid representations of the clusters with the standard PAM algorithm [34]. Finally, k-modes [16] uses a kernel-induced distortion and computes cluster modes via the Meanshift algorithm [35].
k-Dirs. We also implemented a simplex-clustering strategy, which estimates a multivariate density per cluster using the Dirichlet density function. We used the iterative parameter estimation proposed in [27]. This algorithm, which we refer to as k-Dirs, represents a baseline for the artificial datasets built from the Dirichlet distributions.
Proposed k-sBetas. Our clustering algorithm uses as density function and the method of moments (MoM) detailed in Sec. 4.3.1 for parameter estimation. We empirically consolidate this choice in the ablation study in section 5.4. Note that k-Betas corresponds to a non-scaled variant of k-sBetas, i.e., when we set .
Hyper-parameters. Across all the experiments, we set the scaling hyper-parameter used in k-sBetas to the same value . We set and for parameter constraints. The maximum number of clustering iterations is set to for all the clustering algorithms. For each method, we use the parameter initialization detailed in Sec. 4.3.3. This technique is also applied to the methods based on the feature maps, for consistency and fairness: For each probability simplex vertex (i.e., one-hot vector), we select the feature-map point whose softmax prediction most closely aligns with this one-hot vector. We provide comparisons and justifications for our design choices in the ablation studies in Sec. 5.4.
Reproducibility. We used the implementation of the scikit-learn library for GMM 999The GMM code is available at: https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09b/sklearn/mixture/_gaussian_mixture.py., and the implementation provided by the authors for HSC101010The HSC code is available at: https://franknielsen.github.io/HSG/.. We have implemented all the other clustering algorithms 111111The codes for KL k-means, k-means, k-medians, k-medoids, k-modes, k-Dirs, k-Betas and k-sBetas are available at our repository: https://github.com/fchiaroni/Clustering_Softmax_Predictions.
Evaluation metrics. We use the standard Normalized Mutual Information (NMI) to evaluate the clustering task for each dataset. To evaluate the class-prediction scores, we use the classification Accuracy (Acc) for the balanced datasets or the Intersection over Union (IoU) measure for the imbalanced ones. The process of aligning the clusters with the class labels, which is necessary for computing the ACC and mIoU, is detailed in Sec. 4.3.4.
5.1 Unsupervised Domain Adaptation
5.1.1 Synthetic experiments
The purpose of the following synthetic experiments is to benchmark the different clustering methods on a simple, artificially generated task.
Simu dataset. We generate balanced mixtures composed of three Dirichlet distributions, defined on the 3-dimensional probability simplex, . The corresponding Dirichlet parameters are , , . Each component is made biased towards one different vertex of the simplex, thereby simulating the softmax predictions of a deep model for a certain class. Additionally, each component captures a different type of distribution121212We provide the 2D visualizations of these simulated distributions in Appendix, Fig. 7.: corresponds to an exponential distribution on , to an off-centered distribution with small variance, and to a more centered distribution, with a wider variance. All the components contribute equally to the final mixture. In total, examples are sampled. We perform 5 random runs, each using new examples.
Results. The third column of Table II reports the NMI and Acc scores of different approaches on the Simu dataset. As expected, k-Dirs and k-Betas yielded close scores as they both model the marginal densities. k-sBetas yields lower scores due to the scaling factor, . However, it behaves better than all the other clustering methods thanks to a proper statistical modeling of simplex data.
5.1.2 Clustering the softmax predictions of deep models
We now compare the different clustering approaches on real-world distributions, where data points correspond to the softmax predictions of
deep-learning models.
Setup. First, we employ the SVHNMNIST challenge, where a source model is trained on SVHN [36] and applied to the MNIST [37] test set, which is composed of images labelled with 10 different semantic classes. Additionally, we experiment with the more difficult VISDA-C challenge [38], which contains 55388 samples split into 12 different semantic classes. For both SVHNMNIST and VISDA-C, we use the common network architectures and training procedures detailed in [39].
For these closed-set UDA experiments, we use the optimal-transport Hungarian algorithm
to obtain cluster-to-class label assignment, as detailed in Sec. 4.3.4.
Results. Table II shows interesting comparative results for clustering real-world softmax predictions. In the SVHNMNIST setting, simplex-tailored methods KL k-means and k-sBetas clearly stand out in terms of both the NMI and Accuracy. In particular, the proposed k-sBetas achieves the best scores. On the VISDA-C challenge, however, KL k-means yielded a performance below the baseline, similarly to the other standard clustering methods. k-Dirs failed to converge. In contrast, k-sBetas outperforms the baseline by a margin, achieving the best scores.
Table III reports the running times of all the methods. When using MoM for parameter estimation, k-sBetas has a running time close to or even lower than the standard GMM approach. Furthermore, a GPU-based implementation of k-sBetas could be made even faster, which makes it appealing for large-scale datasets such as VISDA-C.
Method | Layer used | Simu | SVHNMNIST | VISDA-C | ||
---|---|---|---|---|---|---|
(NMI) | (NMI) | (Acc) | (NMI) | (Acc) | ||
k-means | feature map | - | 67.3 | 74.3 | 30.0 | 22.7 |
k-means | logit | - | 67.1 | 74.5 | 32.7 | 33.1 |
argmax | 60.1 | 58.6 | 69.8 | 36.5 | 53.1 | |
k-means | 76.6 | 58.4 | 68.9 | 37.6 | 47.9 | |
GMM | 75.8 | 61.9 | 69.2 | 36.3 | 49.4 | |
k-medians | 76.8 | 58.6 | 68.8 | 35.9 | 40.0 | |
k-medoids | simplex | 60.8 | 58.9 | 71.3 | 36.5 | 46.8 |
k-modes | (Black-Box) | 76.2 | 59.5 | 71.3 | 34.2 | 51.8 |
KL k-means | 76.2 | 63.3 | 75.5 | 39.8 | 51.2 | |
HSC | 9.3 | 59.2 | 68.9 | 28.7 | 18.1 | |
k-Dirs | 81.3 | 57.7 | 68.8 | fails | ||
k-Betas | 81.1 | 53.3 | 51.5 | 36.7 | 19.8 | |
k-sBetas | 79.2 | 65.0 | 76.6 | 40.3 | 56.0 |
Method | SVHNMNIST | VISDA-C |
---|---|---|
(, ) | (, ) | |
k-means - feature map | 2.36 | 18.01 |
k-means - logit/simplex | 0.06 | 1.27 |
GMM | 0.43 | 10.59 |
k-medians | 5.94 | 48.39 |
k-medoids | 8.83 | 195.00 |
k-modes | 0.08 | 4.71 |
KL k-means | 0.27 | 2.79 |
HSC | 8494.54 | One day |
k-sBetas (MLE-1000) | 64.32 | 107.43 |
k-sBetas (MLE-500) | 34.22 | 58.68 |
k-sBetas (MoM) | 0.48 | 3.61 |
k-sBetas (MoM, GPU-based) | 0.13 | 0.49 |
Feature maps vs. simplex predictions. Interestingly, Table II shows that k-means, when applied to the logit points, yields scores that are similar or even better than when applied to the feature maps. This suggests that the classifier head, which predicts the logit points, does not deteriorate the semantic information of the bottleneck layer. Table III shows that logit and simplex clustering are computationally less demanding than clustering the feature maps. This is due to the fact that the dimension of the simplex points is , which is considerably smaller than the dimension of feature maps. In addition, Table II shows that the proposed simplex-clustering method substantially outperforms the k-means feature-map clustering on the most difficult/realistic dataset, VISDA-C.
It is worth noting that the explored softmax-prediction datasets, SVHNMNIST and VISDA-C, are challenging in this particular black-box setting. This enables interesting comparisons with model-agnostic clustering methods. For reproducibility and future improvements, we made these black-box softmax predictions publicly available: https://github.com/fchiaroni/Clustering_Softmax_Predictions.
5.2 Transductive zero-shot and one-shot learning
We now demonstrate the general applicability and effectiveness of our method in the practical and realistic few-shot problem, which involves only an unknown fraction of the original set of classes in a small, unlabeled query (test) set [40]. We tackle transductive zero- and one-shot inference using the foundational CLIP model [4] by clustering its zero-shot softmax predictions. CLIP is based on two deep encoders, one for image representation and the other for text inputs. Coupled with projections, this dual structure yields image and text embeddings within the same low-dimensional space. Trained on a large-scale dataset of text-image pairs, CLIP maximizes the cosine similarity between the embedding of an image and its corresponding text description, which makes it well-suited for zero-shot prediction. At inference time, the classification of a given image into one of classes is given by the zero-shot softmax prediction of CLIP: , where logit evaluates the cosine similarity between the image embedding and the representation of a text prompt describing the class, typically ‘’a photo of a [name of class k]” [4].
In the one-shot setting, we cluster softmax predictions of the same form, except that, for each class, the text embedding is replaced by the visual embedding of the labeled shot of the class, i.e., the single labeled image of the class.
Experimental Setup: We assessed our method using four distinct architectures of the CLIP visual encoder: ResNet-50, ResNet-101 [41], ViT-B/32, and ViT-B/16 [42]. Our experiments were performed across eight datasets: CIFAR-100 [43], Stanford-Cars [44], FGVC Aircraft [45], Caltech101 [46], Food101 [47], Flowers102 [48], Sun397-100 [49] and ImageNetv2-100 [50]. For Sun397-100 and ImageNetv2-100, we restrict the evaluation to the first 100 classes. For each dataset, we generated 100 query sets, each containing 64 unlabeled examples sampled from a random selection of 2 to 10 classes.
We compare our method with the inductive baseline, i.e., CLIP zero-shot prediction, and with transductive state-of-the-art few-shot learning methods, which leverage the statistics of the unlabeled query set, and operate on the feature maps: BD-CSPN [51], Lap-Shot [52] and TIM [53]. Additional comparisons include k-means, KL k-means, and our k-sBetas method, which we apply directly to the softmax predictions of CLIP.
In the one-shot setting, the initialization prototype vector for each class is the CLIP-encoded image from the labeled support set. In the zero-shot setting, the initialization prototype vector for each class is the CLIP-encoded text prompt. Post-convergence, we directly employ the argmax function for cluster-to-class assignments. This function assigns each estimated cluster to the class whose one-hot vector is nearest to the cluster centroid.
Results. Tables IV and V show the results for these One-shot and Zero-shot settings. A first key observation is the enhanced effectiveness of the probability-simplex methods over those based on the feature maps. On another note, one may observe that feature-map methods TIM and k-means show similar results in these settings. This observation is consistent with a previous work [54], which showed that, under certain mild conditions, along with common posterior models and parameter regularization, such models become equivalent. Overall, the proposed k-sBetas method demonstrates a clear and consistent improvement over the inductive CLIP baseline and the other transductive few-shot methods across a variety of datasets and network encoders. This indicates its potential as a robust and versatile approach in enhancing the performance of foundation models, like CLIP, in practical applications with limited labeled data and unknown class distributions.
Method | Layer used | Network | cifar-100 | Stanford-cars | FGVC | Caltech101 | Food101 | Flowers102 | Sun397-100 | ImageNetv2-100 | ||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
(NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | |||
CLIP | feature map | RN50 | 54.1 | 17.2 | 61.6 | 26.9 | 55.4 | 15.8 | 77.7 | 64.5 | 58.8 | 29.4 | 74.0 | 58.5 | 65.3 | 43.8 | 63.2 | 31.3 |
BD-CSPN | feature map | RN50 | 54.1 | 13.8 | 60.3 | 19.7 | 56.1 | 13.7 | 70.0 | 51.1 | 57.4 | 22.7 | 70.5 | 46.8 | 60.4 | 32.6 | 61.8 | 25.2 |
Lap-Shot | feature map | RN50 | 54.1 | 13.9 | 60.3 | 19.7 | 56.0 | 13.7 | 70.2 | 51.3 | 57.4 | 22.8 | 70.5 | 46.8 | 60.5 | 32.8 | 61.8 | 25.2 |
TIM | feature map | RN50 | 56.3 | 14.3 | 59.0 | 14.5 | 57.2 | 11.6 | 63.9 | 35.4 | 57.0 | 17.8 | 68.7 | 34.3 | 55.9 | 22.9 | 60.8 | 18.0 |
k-means | feature map | RN50 | 56.3 | 14.3 | 59.0 | 14.5 | 57.2 | 11.6 | 63.9 | 35.4 | 57.0 | 17.8 | 68.7 | 34.3 | 55.9 | 22.9 | 60.8 | 18.0 |
k-means | simplex | RN50 | 52.1 | 17.5 | 65.6 | 29.3 | 55.5 | 16.0 | 79.6 | 63.9 | 60.2 | 31.4 | 80.1 | 65.3 | 69.9 | 45.7 | 65.6 | 31.5 |
KL k-means | simplex | RN50 | 52.1 | 17.6 | 65.5 | 29.3 | 55.4 | 16.1 | 79.4 | 64.1 | 60.1 | 31.2 | 80.0 | 65.0 | 69.7 | 45.8 | 65.7 | 31.5 |
k-sBetas | simplex | RN50 | 51.9 | 18.6 | 66.6 | 32.6 | 55.4 | 17.9 | 82.6 | 69.5 | 61.0 | 34.0 | 80.6 | 67.7 | 71.9 | 48.7 | 65.8 | 37.7 |
CLIP | feature map | RN101 | 56.8 | 21.3 | 66.2 | 31.5 | 55.9 | 19.8 | 78.2 | 67.2 | 64.7 | 37.2 | 76.7 | 61.4 | 68.7 | 51.4 | 66.3 | 37.2 |
BD-CSPN | feature map | RN101 | 56.6 | 17.4 | 63.5 | 23.6 | 55.8 | 16.3 | 69.3 | 53.1 | 63.4 | 30.2 | 71.9 | 48.6 | 62.8 | 39.1 | 65.1 | 31.1 |
Lap-Shot | feature map | RN101 | 56.6 | 17.4 | 63.5 | 23.6 | 55.8 | 16.3 | 69.4 | 53.2 | 63.4 | 30.3 | 71.9 | 48.7 | 62.8 | 39.3 | 65.1 | 31.2 |
TIM | feature map | RN101 | 59.6 | 19.9 | 61.8 | 16.7 | 56.7 | 14.3 | 64.9 | 40.4 | 63.9 | 27.3 | 73.2 | 42.0 | 59.5 | 28.6 | 64.2 | 23.3 |
k-means | feature map | RN101 | 59.6 | 19.9 | 61.8 | 16.7 | 56.7 | 14.3 | 64.9 | 40.4 | 63.9 | 27.3 | 73.2 | 42.0 | 59.5 | 28.6 | 64.2 | 23.3 |
k-means | simplex | RN101 | 57.1 | 21.8 | 72.6 | 35.0 | 58.5 | 20.1 | 79.9 | 66.5 | 67.3 | 40.6 | 84.9 | 67.6 | 72.3 | 52.7 | 71.0 | 41.5 |
KL k-means | simplex | RN101 | 57.1 | 21.9 | 72.4 | 34.8 | 58.4 | 20.0 | 80.0 | 66.6 | 67.3 | 40.6 | 84.9 | 67.6 | 72.4 | 52.5 | 70.8 | 41.7 |
k-sBetas | simplex | RN101 | 57.4 | 23.2 | 74.1 | 37.2 | 58.3 | 22.3 | 81.5 | 70.2 | 68.1 | 42.0 | 85.2 | 70.1 | 74.9 | 56.7 | 71.2 | 46.8 |
CLIP | feature map | ViT-B/32 | 59.0 | 26.2 | 63.0 | 28.0 | 55.8 | 18.8 | 82.0 | 69.9 | 64.8 | 38.4 | 77.2 | 61.8 | 69.1 | 44.4 | 66.4 | 36.6 |
BD-CSPN | feature map | ViT-B/32 | 58.5 | 21.7 | 60.8 | 21.5 | 55.9 | 15.2 | 73.1 | 57.7 | 62.4 | 30.6 | 71.7 | 49.9 | 63.8 | 36.4 | 65.0 | 30.6 |
Lap-Shot | feature map | ViT-B/32 | 58.6 | 21.8 | 60.8 | 21.5 | 55.9 | 15.3 | 73.4 | 57.8 | 62.5 | 30.7 | 71.7 | 49.9 | 64.0 | 36.6 | 65.0 | 30.6 |
TIM | feature map | ViT-B/32 | 60.9 | 21.1 | 58.0 | 15.4 | 56.7 | 12.8 | 65.5 | 41.2 | 61.5 | 25.4 | 69.0 | 36.7 | 58.7 | 26.1 | 63.3 | 22.7 |
k-means | feature map | ViT-B/32 | 60.9 | 21.1 | 58.0 | 15.4 | 56.7 | 12.8 | 65.5 | 41.2 | 61.5 | 25.4 | 69.0 | 36.7 | 58.7 | 26.1 | 63.3 | 22.7 |
k-means | simplex | ViT-B/32 | 60.2 | 26.4 | 68.0 | 29.9 | 58.1 | 19.3 | 84.4 | 69.7 | 66.1 | 39.8 | 85.3 | 67.2 | 72.7 | 46.0 | 70.9 | 38.2 |
KL k-means | simplex | ViT-B/32 | 60.0 | 26.2 | 67.9 | 29.9 | 58.1 | 19.3 | 84.3 | 69.6 | 66.0 | 39.8 | 85.3 | 67.3 | 72.6 | 45.9 | 70.8 | 38.6 |
k-sBetas | simplex | ViT-B/32 | 60.1 | 27.2 | 69.6 | 32.8 | 58.5 | 22.3 | 85.6 | 72.3 | 67.4 | 42.8 | 85.9 | 70.3 | 75.4 | 48.9 | 71.1 | 43.6 |
CLIP | feature map | ViT-B/16 | 60.6 | 31.2 | 66.8 | 33.8 | 56.7 | 23.2 | 81.6 | 75.5 | 68.2 | 48.2 | 80.8 | 68.9 | 66.8 | 46.3 | 69.2 | 43.0 |
BD-CSPN | feature map | ViT-B/16 | 59.9 | 26.2 | 64.7 | 26.6 | 56.0 | 18.9 | 71.3 | 60.2 | 65.2 | 38.5 | 76.7 | 57.1 | 62.7 | 35.5 | 66.6 | 35.1 |
Lap-Shot | feature map | ViT-B/16 | 60.0 | 26.3 | 64.8 | 26.7 | 56.0 | 18.9 | 71.5 | 60.2 | 65.2 | 38.7 | 76.8 | 57.1 | 62.7 | 35.4 | 66.8 | 35.2 |
TIM | feature map | ViT-B/16 | 63.3 | 26.0 | 61.4 | 18.3 | 56.9 | 15.0 | 63.8 | 44.0 | 62.7 | 28.1 | 75.2 | 45.7 | 58.0 | 26.4 | 63.8 | 25.2 |
k-means | feature map | ViT-B/16 | 63.3 | 26.0 | 61.4 | 18.3 | 56.9 | 15.0 | 63.8 | 44.0 | 62.7 | 28.1 | 75.2 | 45.7 | 58.0 | 26.4 | 63.8 | 25.2 |
k-means | simplex | ViT-B/16 | 61.3 | 30.2 | 71.3 | 34.3 | 61.6 | 23.7 | 82.7 | 73.3 | 70.3 | 48.9 | 87.9 | 74.5 | 70.4 | 47.4 | 76.8 | 46.7 |
KL k-means | simplex | ViT-B/16 | 61.1 | 30.2 | 71.1 | 34.4 | 61.5 | 23.8 | 82.5 | 73.2 | 70.2 | 48.6 | 87.8 | 74.5 | 70.3 | 47.3 | 76.8 | 47.1 |
k-sBetas | simplex | ViT-B/16 | 61.2 | 32.3 | 73.3 | 38.3 | 62.1 | 25.2 | 85.9 | 79.1 | 72.4 | 54.7 | 88.6 | 78.9 | 72.7 | 50.2 | 77.1 | 50.5 |
Method | Layer used | Network | cifar-100 | Stanford-cars | FGVC | Caltech101 | Food101 | Flowers102 | Sun397-100 | ImageNetv2-100 | ||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
(NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | (NMI) | (Acc) | |||
CLIP | feature map | RN50 | 61.0 | 38.9 | 74.3 | 55.3 | 57.3 | 17.6 | 87.3 | 81.3 | 77.8 | 71.0 | 81.8 | 63.2 | 76.4 | 71.9 | 74.1 | 60.8 |
BD-CSPN | feature map | RN50 | 58.7 | 23.4 | 68.6 | 25.8 | 56.1 | 6.5 | 77.2 | 47.4 | 69.8 | 37.8 | 75.0 | 29.5 | 66.6 | 41.2 | 68.1 | 31.8 |
Lap-Shot | feature map | RN50 | 56.3 | 19.6 | 69.6 | 25.3 | 55.1 | 5.5 | 77.4 | 46.3 | 66.3 | 29.8 | 61.5 | 12.7 | 66.9 | 41.5 | 66.2 | 27.1 |
TIM | feature map | RN50 | 60.5 | 28.3 | 61.9 | 19.6 | 57.9 | 11.7 | 66.4 | 42.6 | 66.2 | 37.9 | 72.0 | 37.1 | 57.1 | 25.7 | 64.0 | 29.4 |
k-means | feature map | RN50 | 60.5 | 28.3 | 61.9 | 19.6 | 57.9 | 11.7 | 66.4 | 42.6 | 66.2 | 37.9 | 72.0 | 37.1 | 57.1 | 25.7 | 64.0 | 29.4 |
k-means | simplex | RN50 | 63.4 | 40.4 | 80.2 | 59.0 | 60.1 | 18.8 | 90.7 | 84.6 | 81.8 | 79.2 | 90.6 | 69.0 | 81.3 | 74.8 | 78.2 | 69.2 |
KL k-means | simplex | RN50 | 63.4 | 40.5 | 80.2 | 59.0 | 60.0 | 18.9 | 90.8 | 84.6 | 81.8 | 79.2 | 90.5 | 68.8 | 81.4 | 75.1 | 78.2 | 69.3 |
k-sBetas | simplex | RN50 | 64.5 | 43.2 | 83.6 | 64.2 | 61.1 | 20.4 | 91.7 | 85.4 | 83.3 | 81.6 | 91.0 | 69.1 | 84.8 | 80.7 | 79.3 | 72.2 |
CLIP | feature map | RN101 | 59.5 | 42.7 | 77.9 | 62.8 | 57.4 | 18.1 | 86.1 | 83.3 | 80.0 | 74.5 | 80.1 | 61.0 | 78.4 | 71.5 | 77.3 | 67.7 |
BD-CSPN | feature map | RN101 | 56.9 | 24.4 | 70.7 | 34.0 | 57.5 | 7.7 | 77.3 | 45.9 | 70.6 | 38.5 | 78.1 | 28.9 | 69.8 | 44.4 | 71.0 | 34.1 |
Lap-Shot | feature map | RN101 | 53.7 | 22.6 | 72.7 | 35.3 | 56.3 | 7.1 | 78.7 | 45.0 | 70.4 | 37.4 | 78.1 | 28.0 | 70.6 | 45.3 | 72.0 | 35.1 |
TIM | feature map | RN101 | 63.3 | 36.8 | 63.6 | 25.4 | 59.4 | 14.1 | 69.4 | 51.9 | 70.2 | 49.4 | 74.5 | 38.8 | 64.1 | 36.5 | 67.7 | 38.6 |
k-means | feature map | RN101 | 63.3 | 36.8 | 63.6 | 25.4 | 59.4 | 14.1 | 69.4 | 51.9 | 70.2 | 49.4 | 74.5 | 38.8 | 64.1 | 36.5 | 67.7 | 38.6 |
k-means | simplex | RN101 | 65.2 | 46.3 | 83.2 | 66.0 | 62.5 | 20.0 | 89.6 | 85.2 | 84.1 | 82.1 | 90.0 | 66.3 | 83.0 | 74.5 | 83.3 | 77.0 |
KL k-means | simplex | RN101 | 65.2 | 46.4 | 83.2 | 66.1 | 62.5 | 20.0 | 89.6 | 85.1 | 84.0 | 82.2 | 89.9 | 66.3 | 82.8 | 74.3 | 83.3 | 77.0 |
k-sBetas | simplex | RN101 | 64.2 | 47.6 | 86.9 | 70.5 | 63.9 | 21.7 | 91.3 | 86.8 | 84.9 | 83.3 | 90.6 | 66.7 | 85.0 | 78.0 | 84.9 | 79.6 |
CLIP | feature map | ViT-B/32 | 70.4 | 58.9 | 75.8 | 59.1 | 57.8 | 17.6 | 89.8 | 85.4 | 78.9 | 76.2 | 81.7 | 62.0 | 79.0 | 72.9 | 77.1 | 63.6 |
BD-CSPN | feature map | ViT-B/32 | 65.7 | 36.8 | 68.5 | 30.2 | 58.0 | 10.5 | 78.9 | 51.0 | 69.1 | 46.5 | 77.9 | 28.3 | 68.9 | 46.0 | 70.7 | 36.3 |
Lap-Shot | feature map | ViT-B/32 | 64.9 | 35.2 | 69.1 | 29.5 | 57.7 | 9.9 | 79.6 | 50.2 | 69.0 | 43.4 | 75.2 | 21.6 | 69.5 | 46.3 | 69.7 | 32.3 |
TIM | feature map | ViT-B/32 | 65.0 | 38.9 | 62.2 | 22.0 | 58.1 | 12.7 | 66.8 | 44.5 | 65.5 | 42.1 | 73.6 | 38.7 | 59.1 | 30.4 | 66.2 | 33.7 |
k-means | feature map | ViT-B/32 | 65.0 | 38.9 | 62.2 | 22.0 | 58.1 | 12.7 | 66.8 | 44.5 | 65.5 | 42.1 | 73.6 | 38.7 | 59.1 | 30.4 | 66.2 | 33.7 |
k-means | simplex | ViT-B/32 | 74.7 | 65.3 | 81.1 | 61.2 | 60.9 | 17.8 | 92.0 | 88.0 | 84.4 | 83.6 | 91.1 | 66.4 | 83.3 | 75.9 | 83.0 | 73.0 |
KL k-means | simplex | ViT-B/32 | 74.8 | 65.4 | 81.1 | 61.3 | 60.8 | 17.4 | 91.9 | 88.0 | 84.4 | 83.7 | 91.1 | 66.3 | 83.4 | 76.2 | 83.0 | 73.0 |
k-sBetas | simplex | ViT-B/32 | 75.9 | 67.7 | 85.2 | 68.2 | 61.9 | 18.9 | 94.5 | 88.9 | 86.1 | 86.0 | 91.2 | 67.1 | 85.7 | 79.5 | 83.5 | 74.5 |
CLIP | feature map | ViT-B/16 | 71.0 | 63.8 | 79.1 | 63.7 | 61.1 | 26.0 | 91.0 | 86.2 | 85.1 | 82.9 | 83.1 | 66.4 | 79.7 | 75.6 | 79.6 | 70.5 |
BD-CSPN | feature map | ViT-B/16 | 66.3 | 43.6 | 71.5 | 34.5 | 61.6 | 13.9 | 80.7 | 62.3 | 74.8 | 56.2 | 80.7 | 34.6 | 69.5 | 50.8 | 71.9 | 38.2 |
Lap-Shot | feature map | ViT-B/16 | 65.7 | 40.7 | 72.9 | 33.6 | 61.6 | 12.8 | 81.4 | 61.1 | 74.9 | 54.2 | 81.5 | 31.3 | 70.0 | 51.8 | 72.7 | 36.1 |
TIM | feature map | ViT-B/16 | 65.8 | 44.4 | 64.2 | 25.0 | 60.6 | 15.8 | 70.4 | 49.5 | 70.1 | 48.9 | 75.1 | 43.8 | 60.0 | 31.3 | 66.3 | 38.8 |
k-means | feature map | ViT-B/16 | 65.8 | 44.4 | 64.2 | 25.0 | 60.6 | 15.8 | 70.4 | 49.5 | 70.1 | 48.9 | 75.1 | 43.8 | 60.0 | 31.3 | 66.3 | 38.8 |
k-means | simplex | ViT-B/16 | 75.6 | 68.0 | 84.1 | 65.7 | 66.3 | 26.4 | 93.0 | 87.8 | 89.9 | 88.9 | 91.3 | 70.1 | 83.7 | 77.4 | 85.6 | 77.2 |
KL k-means | simplex | ViT-B/16 | 75.6 | 68.1 | 84.1 | 65.8 | 66.3 | 26.6 | 92.9 | 87.8 | 89.9 | 88.9 | 91.3 | 70.1 | 83.7 | 77.4 | 85.8 | 77.2 |
k-sBetas | simplex | ViT-B/16 | 77.3 | 71.6 | 87.3 | 70.4 | 68.2 | 28.7 | 93.9 | 88.1 | 90.8 | 90.3 | 91.2 | 70.6 | 86.9 | 82.8 | 86.5 | 79.7 |
5.3 Real-Time UDA for road segmentation
We now consider a question that was, to the best of our knowledge, hitherto unaddressed in UDA segmentation: Can we adapt predictions from a black-box source model on a target set in real-time?
Setup. We address this question on the road GTA5road Cityscapes challenge.
A source model is trained on GTA5 [55] and applied on the Cityscapes validation set [56]. Following [57], we use Deeplab-V2 [58] as the semantic segmentation network. In our case, the model is exclusively trained for the binary task of road segmentation. The validation set contains 500 images with a 1024*2048 resolution. In order to simulate a real-time application, we treat each image independently, i.e. as a new clustering task, in which each pixel represents a point to cluster.
Towards Real-Time running time: Furthermore, to maximize the speed of execution, we downsample the model’s output probability map prior to fitting the clustering methods and then use the obtained densities to perform inference at the original resolution. Downsampling by a factor of 8 allows for an increase in the frame rate by almost two orders of magnitude, without incurring any loss in mIoU131313See Table XIII in the Appendix for more details.. Specifically, k-sBetas takes on average 0.0165 seconds for clustering on each 128*256 subset of pixels and 0.0056 seconds for predictions on original 1024*2048 images, which represents a processing frequency of 45 images per second. Hardware used: CPU 11th Gen Intel(R) Core(TM) i7-11700K 3.60GHz, and GPU NVIDIA GeForce RTX 2070 SUPER.
Results. Corresponding NMI and mean IoU scores are displayed on Table VI, and show k-sBetas outperforming low computational demanding competitors k-means and KL k-means by a large margin, with an improvement of 14 mIoU points over the pixel-wise inductive baseline. Fig. 1 provides qualitative results, in which k-sBetas yields a definite visual improvement over the baseline. In addition, and consistently with the previous results observed in Table II for the UDA classification challenge, Table VI also shows that clustering the probability simplex points may be more effective than clustering the logits.
Approach | GTA5Cityscapes | |
---|---|---|
(NMI) | (mIoU) | |
k-means - logits | 21.4 | 39.4 |
argmax | 19.7 | 49.2 |
k-means | 23.6 | 52.4 |
KL k-means | 24.9 | 52.2 |
k-sBetas | 35.8 | 65.7 |
5.4 Ablation studies
Joint v.s. Factorised density. Our proposed method draws inspiration from a factorized density, i.e. each component of the simplex vector is considered independent from each other. While such a simplifying assumption has significant analytical and computational benefits, it may also fail to capture the complexity of the target distribution. We explore this trade-off by comparing k-Betas with k-Dirs, in which the full joint density – corresponding to a Dirichlet density – is fitted at each iteration. Because there exists no closed-form solution to this problem, we resort to the approximate estimation procedure described in [27]. Results in Table II show that k-Dirs and k-Betas produce similar scores on mixtures of Dirichlet distributions. In the meantime, k-Dirs outperforms k-Betas on SVHNMNIST but fails to converge on the more challenging VISDA-C challenge, making it ill-suited to real-world applications.
Effect of class imbalance. To investigate the problem of class imbalance, we generate heavily imbalanced datasets. First, we create an imbalanced version of our synthetic Simu dataset, which we refer to as iSimu. Specifically, we weighted the 3 components of the Simu mixture with six different combinations using the imbalanced class proportions . Second, we consider a variant of VISDA-C, where class-proportions are sampled from a Dirichlet distribution with . We refer to this variant as iVISDA-Cs. We perform 10 random re-runs.
Table VII compares the clustering models on these two datasets. Displayed NMI scores for iSimus correspond to the average NMI scores obtained across the six different mixture proportions. NMI and IoU scores displayed for iVISDA-Cs correspond to the average scores obtained across ten different highly imbalanced subset variants of VISDA-C. k-sBetas (biased) refers to the proposed approach without the marginal probability term in eq. 14. These comparative results clearly highlight the benefit of the proposed k-sBetas formulation. Note that our unbiased formulation could theoretically apply to metric-based approaches but was systematically found to produce degenerated solutions in which all examples are assigned to a unique cluster.
iSimus | iVISDA-Cs | ||
---|---|---|---|
Approach | (NMI) | (NMI) | (mIoU) |
argmax | 55.5 | 31.6 | 22.7 |
k-means | 62.3 | 32.8 | 24.2 |
GMM | 64.5 | 34.4 | 21.1 |
k-medians | 60.4 | 33.1 | 22.4 |
k-medoids | 62.6 | 32.1 | 22.5 |
k-modes | 55.1 | 32.0 | 22.8 |
KL k-means | 59.9 | 35.2 | 24.9 |
HSC | 17.7 | 28.9 | 16.3 |
k-sBetas (biased) | 55.3 | 35.4 | 25.6 |
k-sBetas | 72.4 | 36.4 | 27.1 |
SVHNMNIST | VISDA-C | |||
---|---|---|---|---|
Unimodal constraint | ||||
k-Dirs | 65.3 | 68.8 | fails | fails |
k-Betas | 19.5 | 51.5 | 11.9 | 19.8 |
k-sBetas (MLE-500) | 76.2 | 76.2 | 55.0 | 55.0 |
k-sBetas | 76.6 | 76.6 | 56.0 | 56.0 |
Parameter estimation. Concerning the parameter estimation for k-sBetas, we can observe in Table III and Table VIII that using the method of moments (MoM) is more beneficial than the iterative MLE in practice, both in terms of computational cost and prediction performances.
Effect of the unimodal constraint. Table VIII empirically confirms that constraining a density-based clustering model to only consider mixtures of unimodal distributions is appropriate with softmax predictions from pre-trained source models. Yet, we can observe that disabling this constraint has no impact on k-sBetas results. The following paragraph explains such observations by the presence of .
Effect of . As a matter of fact, Table VIII results suggest that interest is finally twofold in our experiments: It enables a more flexible clustering. Meanwhile, it also prevents estimating bimodal density functions.
The latter point can be explained as follows. Bimodal distributions have a higher variance than uniform or unimodal distributions because they partition the major portion of the density at the vicinity of two opposite interval boundaries simultaneously. Thus, when is sufficiently large (e.g. with in these experiments), the scaled variance of the projection of coordinates into the interval becomes sufficiently small to inhibit the modeling of bimodal distributions. Complementary, Fig. 6 shows that setting provides consistent outperforming results on different datasets, and it allows fast convergence in terms of clustering iterations. Thus, we set along all the other presented experiments.
![Refer to caption](x13.png)
Pre-clustering: Centroid initialization. Parameter initialization plays a crucial role in unsupervised clustering. In Table IX, we compare the widely-used k-means++ initialization, designed for generic clustering, with our simplex-tailored vertex initialization. Regardless of the assignment method used, our initialization demonstrates up to a improvement in accuracy over k-means++141414Table XIV in the Appendix further extends this comparison to other clustering approaches, showing similar benefits across every tested approach..
Post-clustering: Cluster to class assignment. We also show on Table IX the interest of the Hungarian algorithm, referred to as Hung, for aligning cluster-labels with class-labels. Compared to the argmax assignment for centroids in the context of such closed-set challenges, we recall that Hung aims to ensure that each cluster is assigned to a separate class. We can observe that the proposed Hung strategy is particularly interesting when using the vertex init. This suggests that optimal transport assignment is naturally relevant in situations where each cluster is likely to represent a separate class.
In order to provide a fair comparison with the state-of-the-art, we jointly applied these pre- and post-clustering strategies on every compared approach.
SVHNMNIST | VISDA-C | |||
---|---|---|---|---|
Approach | Init | assignment | (Acc) | (Acc) |
k-sBetas | k-mean++ | argmax | 69.0 6.1 | 50.0 5.0 |
k-means++ | hung | 69.8 8.1 | 47.2 5.2 | |
vertex init | argmax | 73.5 0.0 | 53.8 0.0 | |
vertex init | Hung | 76.6 0.0 | 56.0 0.0 |
6 Discussion
This section discusses the potential limitations and extensions of the proposed approach.
Maximum performance may be upper-bounded. We have shown that clustering softmax predictions with the proposed approach can efficiently improve source model prediction performances at a reasonable computational footprint. However, it is worth noting that we could not reach the prediction performances of the recent state-of-the-art self-training [3], which exploits and updates the model parameters end-to-end through an iterative process. This is because, as for every clustering algorithm, the maximum performance of the proposed approach is upper-bounded by the quality of the output probability simplex domain representation, over which we have no control.
Spatial analysis could be complementary.
The clustering model does not consider global or local spatial information. That would be worthwhile in challenges such as semantic pixel-wise segmentation to complement it with spatial post-processing techniques [59].
Clustering softmax predictions for self-training.
Throughout this article, we have motivated the use of simplex clustering for efficient prediction adjustment of black-box source models. Nevertheless, simplex clustering could also play the role of pseudo-labeling along self-training strategies. This could be a simple and light generic alternative to feature map clustering [11], which would not imply cumbersome manipulation of high-dimensional and hidden feature maps.
If the number of classes is large. In some contexts, the number of classes may be potentially large. For example, K may correspond to millions of words in neural language modeling. It would be then interesting to envision strategies that reorganize the softmax layer for more efficient calculation [60]. In addition, distortion-based and probabilistic clustering methods usually require a sufficient number of points per class to produce consistent partitioning. Thus, if is too large with respect to the target set size, then the clustering could produce degenerate solutions. It would be interesting to explore hierarchical clustering strategies [61, 62] on the probability simplex domain, as these models can deal with a small amount of data points per class.
Low-powered devices. Systems such as mobile robots and embedded systems may not have the resources to perform costly neural network parameter updates on the application time. The presented framework can be viewed as an efficient plug-in solution for prediction adjustment in the wild when using such low-powered devices.
7 Conclusion
In this paper, we have tackled the simplex clustering of softmax predictions from black-box deep networks. We found that existing distortion-based objectives of the state-of-the-art do not adequately approximate real-world simplex distributions, which can, for example, be highly peaked near the simplex vertices. Thus, we have introduced a novel probabilistic clustering objective that integrates a generalization of the density, which we coin . This scaled variant of is relatively more permissive; meanwhile, we constrain it to fit only strictly unimodal distributions and to avoid degenerate solutions. In order to optimize our clustering objective, we proceed with a block-coordinate descent approach, which alternates optimization w.r.t assignment variables and parameter estimation of cluster distributions, s.t. the introduced parameter constraints. The resulting clustering model, which we refer to as k-sBetas, is both highly competitive and efficient on the proposed simulated and real-world softmax-prediction datasets, including class-imbalance scenarios, along which we performed our comparisons. Additionally, we emphasized the utility of probability simplex clustering through several practical applications: zero-shot and one-shot image classification, and real-time adjustment for road segmentation on full-size images.
Our future perspectives include extending the insights of probability simplex clustering to other challenges, such as pseudo-labeling for self-training procedures and other applications that require rapid correction of source models. Moreover, it may be interesting to complement the clustering with spatial and temporal information depending on the target data properties.
Overall, we hope that the presented simplex clustering framework and softmax-prediction benchmarks will encourage the community to give more attention to such facilitating, low-computational demanding, model-agnostic, plug-in solutions.
Acknowledgments
This research was supported by computational resources provided by Compute Canada.
References
- [1] D. Wang, E. Shelhamer, S. Liu, B. Olshausen, and T. Darrell, “Tent: Fully test-time adaptation by entropy minimization,” in International Conference on Learning Representations, 2021.
- [2] M. Wang and W. Deng, “Deep visual domain adaptation: A survey,” Neurocomputing, vol. 312, pp. 135–153, 2018.
- [3] J. Liang, D. Hu, Y. Wang, R. He, and J. Feng, “Source data-absent unsupervised domain adaptation through hypothesis transfer and labeling transfer,” IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021.
- [4] A. Radford, J. W. Kim, C. Hallacy, A. Ramesh, G. Goh, S. Agarwal, G. Sastry, A. Askell, P. Mishkin, J. Clark et al., “Learning transferable visual models from natural language supervision,” in International Conference on Machine Learning (ICML), 2021, pp. 8748–8763.
- [5] H. Xia, H. Zhao, and Z. Ding, “Adaptive adversarial network for source-free domain adaptation,” in Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), October 2021, pp. 9010–9019.
- [6] P. Colombo, V. Pellegrain, M. Boudiaf, V. Storchan, M. Tami, I. B. Ayed, C. Hudelot, and P. Piantanida, “Transductive learning for textual few-shot classification in api-based embedding models,” in Empirical Methods in Natural Language Processing (EMNLP), 2023.
- [7] F. Pereira, N. Tishby, and L. Lee, “Distributional clustering of english words,” arXiv preprint cmp-lg/9408011, 1994.
- [8] A. Banerjee, S. Merugu, I. S. Dhillon, J. Ghosh, and J. Lafferty, “Clustering with bregman divergences.” Journal of machine learning research, vol. 6, no. 10, 2005.
- [9] J. Wu, H. Xiong, and J. Chen, “Sail: Summation-based incremental learning for information-theoretic clustering,” in Proceedings of the 14th ACM SIGKDD international conference on knowledge discovery and data mining, 2008, pp. 740–748.
- [10] K. Chaudhuri and A. McGregor, “Finding metric structure in information theoretic clustering.” in COLT, vol. 8. Citeseer, 2008, p. 10.
- [11] M. Caron, P. Bojanowski, A. Joulin, and M. Douze, “Deep clustering for unsupervised learning of visual features,” in Proceedings of the European Conference on Computer Vision (ECCV), 2018, pp. 132–149.
- [12] F. Nielsen and K. Sun, “Clustering in hilbert’s projective geometry: The case studies of the probability simplex and the elliptope of correlation matrices,” in Geometric Structures of Information. Springer, 2019, pp. 297–331.
- [13] R. O. Duda, P. E. Hart, and D. G. Stork, Pattern Classification, 2nd ed. John Wiley and Sons, 2000.
- [14] D. Aloise, A. Deshpande, P. Hansen, and P. Popat, “Np-hardness of euclidean sum-of-squares clustering,” Machine Learning, vol. 75, no. 2, pp. 245–248, 2009.
- [15] P. S. Bradley, O. L. Mangasarian, and W. N. Street, “Clustering via concave minimization,” Advances in neural information processing systems, pp. 368–374, 1997.
- [16] M. A. Carreira-Perpinán and W. Wang, “The k-modes algorithm for clustering,” arXiv preprint arXiv:1304.6478, 2013.
- [17] I. M. Ziko, E. Granger, and I. B. Ayed, “Scalable laplacian k-modes,” in Neural Information Processing Systems (NeurIPS), 2018, pp. 10 062–10 072.
- [18] A. Krause, P. Perona, and R. Gomes, “Discriminative clustering by regularized information maximization,” Advances in neural information processing systems, vol. 23, 2010.
- [19] D. J. Rezende, S. Mohamed, and D. Wierstra, “Stochastic backpropagation and approximate inference in deep generative models,” in International conference on machine learning. PMLR, 2014, pp. 1278–1286.
- [20] W. Hu, T. Miyato, S. Tokui, E. Matsumoto, and M. Sugiyama, “Learning discrete representations via information maximizing self-augmented training,” in International conference on machine learning. PMLR, 2017, pp. 1558–1567.
- [21] S. Claici, M. Yurochkin, S. Ghosh, and J. Solomon, “Model fusion with kullback-leibler divergence,” in International Conference on Machine Learning. PMLR, 2020, pp. 2038–2047.
- [22] T. F. Gonzalez, “Clustering to minimize the maximum intercluster distance,” Theoretical computer science, vol. 38, pp. 293–306, 1985.
- [23] R. Panigrahy and S. Vishwanathan, “Ano (log* n) approximation algorithm for the asymmetricp-center problem,” Journal of Algorithms, vol. 27, no. 2, pp. 259–268, 1998.
- [24] Y. Boykov, H. Isack, C. Olsson, and I. Ben Ayed, “Volumetric bias in segmentation and reconstruction: Secrets and solutions,” in IEEE International Conference on Computer Vision (ICCV), 2015, pp. 1769–1777.
- [25] M. Tang, D. Marin, I. B. Ayed, and Y. Boykov, “Kernel cuts: Kernel and spectral clustering meet regularization,” International Journal of Computer Vision, vol. 127, no. 5, pp. 477–511, 2019.
- [26] M. Kearns, Y. Mansour, and A. Y. Ng, “An information-theoretic analysis of hard and soft assignment methods for clustering,” in Learning in graphical models. Springer, 1998, pp. 495–520.
- [27] T. Minka, “Estimating a dirichlet distribution,” 2000.
- [28] M. J. Wainwright, M. I. Jordan et al., “Graphical models, exponential families, and variational inference,” Foundations and Trends® in Machine Learning, vol. 1, no. 1–2, pp. 1–305, 2008.
- [29] J. Kruschke, Doing Bayesian data analysis: A tutorial with R, JAGS, and Stan. Academic Press, 2014.
- [30] D. Arthur and S. Vassilvitskii, “k-means++: The advantages of careful seeding,” Stanford, Tech. Rep., 2006.
- [31] O. Bachem, M. Lucic, S. H. Hassani, and A. Krause, “Approximate k-means++ in sublinear time,” in Thirtieth AAAI conference on artificial intelligence, 2016.
- [32] H. W. Kuhn, “The hungarian method for the assignment problem,” Naval research logistics quarterly, vol. 2, no. 1-2, pp. 83–97, 1955.
- [33] C. Biernacki, G. Celeux, and G. Govaert, “Assessing a mixture model for clustering with the integrated completed likelihood,” IEEE transactions on pattern analysis and machine intelligence, vol. 22, no. 7, pp. 719–725, 2000.
- [34] L. Kaufman and P. J. Rousseeuw, “Partitioning around medoids (program pam),” Finding groups in data: an introduction to cluster analysis, vol. 344, pp. 68–125, 1990.
- [35] Y. Cheng, “Mean shift, mode seeking, and clustering,” IEEE transactions on pattern analysis and machine intelligence, vol. 17, no. 8, pp. 790–799, 1995.
- [36] Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, and A. Y. Ng, “Reading digits in natural images with unsupervised feature learning,” in NIPS Workshop, 2011.
- [37] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner, “Gradient-based learning applied to document recognition,” Proceedings of the IEEE, vol. 86, no. 11, pp. 2278–2324, 1998.
- [38] X. Peng, B. Usman, N. Kaushik, D. Wang, J. Hoffman, and K. Saenko, “Visda: A synthetic-to-real benchmark for visual domain adaptation,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops, 2018, pp. 2021–2026.
- [39] J. Liang, D. Hu, and J. Feng, “Do we really need to access the source data? source hypothesis transfer for unsupervised domain adaptation,” in International Conference on Machine Learning. PMLR, 2020, pp. 6028–6039.
- [40] S. Martin, M. Boudiaf, É. Chouzenoux, J. Pesquet, and I. B. Ayed, “Towards practical few-shot query sets: Transductive minimum description length inference,” in Neural Information Processing Systems (NeurIPS), 2022.
- [41] K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” in Proceedings of the IEEE conference on computer vision and pattern recognition, 2016, pp. 770–778.
- [42] A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner, M. Dehghani, M. Minderer, G. Heigold, S. Gelly, J. Uszkoreit, and N. Houlsby, “An image is worth 16x16 words: Transformers for image recognition at scale,” in International Conference on Learning Representations, 2021. [Online]. Available: https://openreview.net/forum?id=YicbFdNTTy
- [43] A. Krizhevsky, G. Hinton et al., “Learning multiple layers of features from tiny images,” 2009.
- [44] J. Krause, M. Stark, J. Deng, and L. Fei-Fei, “3d object representations for fine-grained categorization,” in Proceedings of the IEEE international conference on computer vision workshops, 2013, pp. 554–561.
- [45] S. Maji, E. Rahtu, J. Kannala, M. Blaschko, and A. Vedaldi, “Fine-grained visual classification of aircraft,” arXiv preprint arXiv:1306.5151, 2013.
- [46] L. Fei-Fei, R. Fergus, and P. Perona, “Learning generative visual models from few training examples: An incremental bayesian approach tested on 101 object categories,” in Conference on Computer Vision and Pattern Recognition Workshop, 2004, pp. 178–178.
- [47] L. Bossard, M. Guillaumin, and L. Van Gool, “Food-101–mining discriminative components with random forests,” in Computer Vision–ECCV 2014: 13th European Conference, Zurich, Switzerland, September 6-12, 2014, Proceedings, Part VI 13. Springer, 2014, pp. 446–461.
- [48] M.-E. Nilsback and A. Zisserman, “Automated flower classification over a large number of classes,” in 2008 Sixth Indian conference on computer vision, graphics & image processing. IEEE, 2008, pp. 722–729.
- [49] J. Xiao, J. Hays, K. A. Ehinger, A. Oliva, and A. Torralba, “Sun database: Large-scale scene recognition from abbey to zoo,” in 2010 IEEE computer society conference on computer vision and pattern recognition. IEEE, 2010, pp. 3485–3492.
- [50] B. Recht, R. Roelofs, L. Schmidt, and V. Shankar, “Do imagenet classifiers generalize to imagenet?” in International conference on machine learning. PMLR, 2019, pp. 5389–5400.
- [51] J. Liu, L. Song, and Y. Qin, “Prototype rectification for few-shot learning,” in Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16. Springer, 2020, pp. 741–756.
- [52] I. Ziko, J. Dolz, E. Granger, and I. B. Ayed, “Laplacian regularized few-shot learning,” in International conference on machine learning. PMLR, 2020, pp. 11 660–11 670.
- [53] M. Boudiaf, I. Ziko, J. Rony, J. Dolz, P. Piantanida, and I. Ben Ayed, “Information maximization for few-shot learning,” Advances in Neural Information Processing Systems, vol. 33, pp. 2445–2457, 2020.
- [54] M. Jabi, M. Pedersoli, A. Mitiche, and I. B. Ayed, “Deep clustering: On the link between discriminative models and k-means,” IEEE transactions on pattern analysis and machine intelligence, vol. 43, no. 6, pp. 1887–1896, 2019, publisher: IEEE.
- [55] S. R. Richter, V. Vineet, S. Roth, and V. Koltun, “Playing for data: Ground truth from computer games,” in European conference on computer vision. Springer, 2016, pp. 102–118.
- [56] M. Cordts, M. Omran, S. Ramos, T. Rehfeld, M. Enzweiler, R. Benenson, U. Franke, S. Roth, and B. Schiele, “The cityscapes dataset for semantic urban scene understanding,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2016.
- [57] T.-H. Vu, H. Jain, M. Bucher, M. Cord, and P. Pérez, “Advent: Adversarial entropy minimization for domain adaptation in semantic segmentation,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019, pp. 2517–2526.
- [58] L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, and A. L. Yuille, “Deeplab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs,” IEEE transactions on pattern analysis and machine intelligence, vol. 40, no. 4, pp. 834–848, 2017.
- [59] V. Badrinarayanan, A. Kendall, and R. Cipolla, “Segnet: A deep convolutional encoder-decoder architecture for image segmentation,” IEEE transactions on pattern analysis and machine intelligence, vol. 39, no. 12, pp. 2481–2495, 2017.
- [60] W. Chen, D. Grangier, and M. Auli, “Strategies for training large vocabulary neural language models,” in Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). Berlin, Germany: Association for Computational Linguistics, Aug. 2016, pp. 1975–1985.
- [61] F. Murtagh and P. Contreras, “Algorithms for hierarchical clustering: an overview,” Wiley Interdisciplinary Reviews: Data Mining and Knowledge Discovery, vol. 2, no. 1, pp. 86–97, 2012.
- [62] G. Ahalya and H. M. Pandey, “Data clustering approaches survey and analysis,” in 2015 International Conference on Futuristic Trends on Computational Analysis and Knowledge Management (ABLAZE). IEEE, 2015, pp. 532–537.
- [63] O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra et al., “Matching networks for one shot learning,” Advances in neural information processing systems, vol. 29, pp. 3630–3638, 2016.
- [64] M. Ren, E. Triantafillou, S. Ravi, J. Snell, K. Swersky, J. B. Tenenbaum, H. Larochelle, and R. S. Zemel, “Meta-learning for semi-supervised few-shot classification,” in International Conference on Learning Representations ICLR, 2018.
- [65] S. Ravi and H. Larochelle, “Optimization as a model for few-shot learning,” in International Conference on Learning Representations ICLR, 2017.
- [66] S. Zagoruyko and N. Komodakis, “Wide residual networks,” arXiv preprint arXiv:1605.07146, 2016.
- [67] Y. Wang, W.-L. Chao, K. Q. Weinberger, and L. van der Maaten, “Simpleshot: Revisiting nearest-neighbor classification for few-shot learning,” arXiv preprint arXiv:1911.04623, 2019.
![]() |
Florent Chiaroni received his Dipl-Ing degree in computer science and electronics from ESTIA, France, in 2016, and his MSc in robotics and embedded systems from the University of Salford Manchester, United Kingdom, in the same year. He earned his Ph.D. in signal and image processing from the University of Paris Saclay in 2020, with VEDECOM Institute, Versailles, France, and Université Paris-Saclay, Centre national de la recherche scientifique (CNRS), CentraleSupélec, Gif-Sur-Yvette, France. He during the course of this research, was a Postdoctoral Fellow at ÉTS Montreal, Canada and Institut National de la Recherche Scientifique (INRS), Montreal, Canada. He is currently a Research Scientist at Thales Research and Technology (TRT) Canada, Thales Digital Solutions (TDS), CortAIx, Montreal, Canada. His current research interests include civil unmanned aerial vehicles (UAVs) and efficient weakly-supervised learning for visual pattern analysis. |
![]() |
Malik Boudiaf obtained his MSc in Aeronautics & Astronautics from Stanford University in 2019, and his M.Eng in Aerospace Engineering from ISAE-Supaero, France in 2017, and his PhD degree from ÉTS Montréal, Canada, supervised by Prof. Ismail Ben Ayed and Prof. Pablo Piantanida. His research lies between Computer Vision, Information Theory, Optimization, and their application to few-shot/unsupervised learning. He is currently a Senior Machine Learning Engineer in Computer Vision and Natural Language Processing at Stealth Startup, Montreal, Canada. |
![]() |
Amar Mitiche holds the Licence Es Sciences degree in mathematics from the University of Algiers and the Ph.D. degree in computer science from the University of Texas at Austin. He is currently a Professor with the Department of Telecommunications (INRS-EMT), Institut National de la Recherche Scientifique (INRS), Montreal, QC, Canada. His research is in computer vision and pattern recognition. He has written several articles on the subjects, as well as three books: Computational Analysis of Visual Motion (Plenum Press, 1994), Variational and Level Set Methods in Image Segmentation (Springer, 2011), with Ismail Ben Ayed, and Computer Vision Analysis of Image Motion by Variational Methods (Springer, 2014), with J. K. Aggarwal. His current interests include image segmentation, image motion analysis, and pattern classification by neural networks. |
![]() |
Ismail Ben Ayed is currently a Full Professor at ÉTS Montreal. He is also affiliated with the CRCHUM. His interests are in computer vision, optimization, machine learning, and medical image analysis algorithms. Ismail authored over 100 fully peer-reviewed papers, mostly published in the top venues of those areas, along with 2 books and 7 US patents. In recent years, he gave over 30 invited talks, including 4 tutorials at flagship conferences (Miccai’14, Isbi’16, Miccai’19 and Miccai’20). His research has been covered in several visible media outlets, such as Radio Canada (CBC), Quebec Science Magazine and Canal du Savoir. His research team received several recent distinctions, such as the Midl’19 best paper runner-up award and several top-ranking positions in internationally visible contests. Ismail served as Program Committee for Miccai’15, Miccai’17 and Miccai’19, and as Program Chair for Midl’20. Also, he regularly serves as a reviewer for the main scientific journals of his field and was selected several times among the top reviewers of prestigious conferences (such as CVPR’15 and NeurIPS’20). |
Appendix A sBeta complementary details
We provide in this section the demonstrations for the mean, the variance, the mode, and parameter estimation of the presented density function. Table X summarizes the resulting properties.
A.1 Mean and Variance
The mean of the probability density function is defined as . Meanwhile, the variance of can be defined as with . Based on these two statements, we propose to estimate the mean and the variance of the proposed density function .
sBeta mean: To express the first moment of , as a function its parameters, we propose to replace with as follow:
(23) | ||||
with . Then, with respect to the property , we have:
(24) | ||||
sBeta variance: We first estimate the second moment by replacing with as follow:
(25) | ||||
The term from Eq. (25) can be developed as:
(26) | ||||
Similarly, the term from Eq. (25) can be developed as:
(27) | ||||
Finally, we can estimate the variance (i.e., the second central moment) , as a function of density parameters, using Eq. (23) and Eq. (25) as follow:
(28) | ||||
Linear property: With respect to the expectation linearity, it is worth noting that we have and
(29) | ||||
A.2 Mode
The mode of a density function corresponds to the value of at which achieves its maximum.
sBeta mode: To find the mode of depending on and parameters, we estimate at which value of the derivative of is equal to . We first estimate the derivative as follows:
(30) | ||||
Then, we can solve as follow:
(31) | ||||
Thus, the mode of with is defined as .
A.3 Solving the necessary conditions for minimizing the objective w.r.t sBeta parameters
Let us compute the partial derivatives of our objective (14) for and :
(32) |
where stands for the di-gamma function. Setting partial derivatives in (32) to 0 leads to the following coupled system:
(33) |
While (33) does not have any analytical solution, we approximate updates in (33) by fixing the right-hand side of each equation, leading the following vectorized updates of parameters and for each cluster :
(34) | ||||
(35) | ||||
(36) | ||||
We empirically validated such optimization procedure converged to the MLE estimate on the synthetically generated dataset.
Appendix B Supplementary experiments
![Refer to caption](extracted/5701331/illustrations/only_pdf_dirichlet_plots.png)
Dataset size | 100 000 | 10 000 | 1 000 | 100 |
---|---|---|---|---|
Approach | (NMI) | (NMI) | (NMI) | (NMI) |
argmax | 60.1 0.5 | 59.8 0.7 | 60.4 3.9 | 62.017.9 |
k-means | 76.6 0.1 | 76.7 0.5 | 76.9 6.2 | 78.2 21.8 |
KL k-means | 76.2 0.2 | 76.3 0.7 | 76.5 6.2 | 78.1 22.8 |
GMM | 75.8 0.3 | 75.8 1.2 | 75.9 6.8 | 77.3 28.4 |
k-medians | 76.8 0.1 | 77.1 0.7 | 77.1 5.6 | 78.7 22.2 |
k-medoids | 60.8 14.4 | 64.2 12.1 | 65.3 19.4 | 71.8 24.2 |
k-modes | 76.2 0.2 | 76.3 1.0 | 76.5 5.3 | 77.5 24.9 |
HSC | 9.3 1.9 | 15.0 2.7 | 30.6 16.3 | 41.2 27.9 |
k-sBetas | 79.50.1 | 79.50.8 | 79.86.4 | 80.025.4 |
Dataset size | 100 000 | 10 000 | 1 000 |
---|---|---|---|
Approach | (NMI) | (NMI) | (NMI) |
argmax | 55.5 25.5 | 55.6 26.8 | 55.6 31.5 |
k-means | 62.3 22.6 | 62.2 23.4 | 62.4 28.9 |
KL k-means | 59.9 25.2 | 59.9 26.0 | 60.2 31.5 |
GMM | 60.6 23.9 | 63.8 29.3 | 63.9 35.0 |
k-medians | 60.4 24.8 | 60.3 25.6 | 60.3 32.2 |
k-medoids | 47.2 33.0 | 55.4 30.0 | 57.8 35.9 |
k-modes | 55.1 30.9 | 54.9 32.4 | 54.8 36.5 |
HSC | 17.7 12.1 | 13.6 11.7 | 29.1 31.8 |
k-sBetas | 72.417.2 | 72.220.1 | 73.329.2 |
Running time (seconds) | Full size Scores | |||
---|---|---|---|---|
Subset size | (Clustering) | (Prediction) | (NMI) | (mIoU) |
Full size | 0.2784 | 0.0056 | 35.8 | 65.7 |
(1024*2048) | ||||
Modulo 2 | 0.0812 | 0.0019 | 35.8 | 65.7 |
(512*1024) | ||||
Modulo 4 | 0.0282 | 0.0009 | 35.8 | 65.7 |
(256*512) | ||||
Modulo 8 | 0.0165 | 0.0004 | 35.8 | 65.7 |
(128*256) |
Dataset | SVHN MNIST | VISDA-C | ||
---|---|---|---|---|
Initialization | k-means++ | vertex init | k-means++ | vertex init |
Approach | (Acc) | (Acc) | (Acc) | (Acc) |
k-means | 66.6 5.4 | 68.9 0.0 | 44.5 3.3 | 47.9 0.0 |
KL k-means | 72.6 7.8 | 75.5 0.0 | 50.0 2.7 | 51.2 0.0 |
GMM | 60.6 9.6 | 69.2 0.0 | 43.8 5.2 | 49.4 0.0 |
k-medians | 67.4 9.2 | 68.8 0.0 | 38.5 4.0 | 40.0 0.0 |
k-medoids | 51.9 4.9 | 71.3 0.0 | 40.6 8.6 | 46.8 0.0 |
k-modes | 60.6 13.0 | 71.3 0.0 | 30.2 6.3 | 31.1 0.0 |
k-sBetas | 69.8 8.1 | 76.6 0.0 | 47.2 5.2 | 56.0 0.0 |
B.0.1 One-Shot Learning
Method | Network | miniImageNet | tieredImageNet | ||
---|---|---|---|---|---|
(NMI) | (Acc) | (NMI) | (Acc) | ||
SimpleSHOT | RN-18 | 49.1 | 62.7 | 57.5 | 69.2 |
SimpleSHOT + k-sBetas | RN-18 | 52.2 | 64.4 | 60.5 | 71.0 |
BD-CSPN | RN-18 | 58.2 | 68.9 | 67.2 | 76.0 |
BD-CSPN + k-sBetas | RN-18 | 60.5 | 69.8 | 68.9 | 76.6 |
SimpleSHOT | WRN-28-10 | 53.6 | 65.7 | 59.1 | 70.4 |
SimpleSHOT + k-sBetas | WRN-28-10 | 56.8 | 67.3 | 61.8 | 72.4 |
BD-CSPN | WRN-28-10 | 62.4 | 72.1 | 69.0 | 77.5 |
BD-CSPN + k-sBetas | WRN-28-10 | 64.0 | 72.4 | 70.7 | 78.3 |
We consider the One-Shot classification problem, in which a model is evaluated based on its ability to generalize to new classes from a single labeled example per class. Typically, One-Shot methods use the labeled support set of each task to build a classifier and obtain soft predictions for unlabelled query samples. Once soft predictions have been obtained, proposed k-sBetas can be used to further refine predictions by clustering soft predictions of the entire query set.
Setup. We use two standard benchmarks for One-Shot classification: mini-Imagenet [63] and tiered-Imagenet [64]. The mini-Imagenet benchmark is composed of 60,000 color images [63] equally split among 100 classes, themselves split between train, val, and test following [65]. The tiered-Imagenet benchmark is a larger dataset with 779,165 images and 608 classes, split following [64]. All images are resized to . Regarding the networks, we use the pre-trained RN-18 [41] and WRN28-10 [66] provided by [53]. Only 15 unlabeled query data points per class are available during each separate task, so we use the biased version of k-sBetas. As for the methods, we select one inductive method: SimpleSHOT [67] and one transductive method: BD-CSPN [51]. Each One-Shot method is reproduced using the set of hyperparameters suggested in the original papers.
Results. Table XV shows that all along these experiments, k-sBetas consistently improves in terms of NMI and Accuracy scores the output predictions of SimpleSHOT and BD-CSPN.
Classes | plane | bcycl | bus | car | horse | knife | mcycl | person | plant | sktbrd | train | truck |
---|---|---|---|---|---|---|---|---|---|---|---|---|
VISDA-C proportions | 0.0658 | 0.0627 | 0.0847 | 0.1878 | 0.0847 | 0.0375 | 0.1046 | 0.0722 | 0.0821 | 0.0412 | 0.0765 | 0.1002 |
Number of examples | 3646 | 3475 | 4690 | 10401 | 4691 | 2075 | 5796 | 4000 | 4549 | 2281 | 4236 | 5548 |
Proportions set 1 | 0.0406 | 0.2315 | 0.0953 | 0.0257 | 0.2367 | 0.0443 | 0.0930 | 0.0252 | 0.1347 | 0.0601 | 0.0012 | 0.0116 |
Number of examples | 275 | 1568 | 645 | 173 | 1603 | 300 | 630 | 170 | 912 | 407 | 8 | 78 |
Proportions set 2 | 0.0996 | 0.0234 | 0.0086 | 0.0112 | 0.1490 | 0.0064 | 0.1900 | 0.1223 | 0.1287 | 0.0387 | 0.0876 | 0.1346 |
Number of examples | 675 | 158 | 58 | 75 | 1009 | 43 | 1286 | 828 | 871 | 262 | 593 | 912 |
Proportions set 3 | 0.0511 | 0.1170 | 0.0779 | 0.0661 | 0.0127 | 0.0575 | 0.0166 | 0.1958 | 0.0047 | 0.0795 | 0.2147 | 0.1064 |
Number of examples | 346 | 792 | 527 | 448 | 85 | 389 | 112 | 1326 | 31 | 538 | 1454 | 720 |
Proportions set 4 | 0.1382 | 0.0014 | 0.1726 | 0.0017 | 0.0030 | 0.1600 | 0.2470 | 0.0726 | 0.0419 | 0.0524 | 0.0558 | 0.0534 |
Number of examples | 936 | 9 | 1169 | 11 | 20 | 1084 | 1673 | 492 | 284 | 354 | 378 | 361 |
Proportions set 5 | 0.0013 | 0.0227 | 0.1451 | 0.1049 | 0.2725 | 0.0511 | 0.0196 | 0.0500 | 0.1401 | 0.0193 | 0.0322 | 0.1411 |
Number of examples | 8 | 154 | 983 | 710 | 1846 | 346 | 132 | 338 | 948 | 130 | 218 | 956 |
Proportions set 6 | 0.1450 | 0.0202 | 0.0343 | 0.1604 | 0.0574 | 0.0183 | 0.0318 | 0.0368 | 0.2668 | 0.0356 | 0.1361 | 0.0570 |
Number of examples | 982 | 137 | 232 | 1086 | 388 | 123 | 215 | 249 | 1807 | 241 | 922 | 386 |
Proportions set 7 | 0.1200 | 0.1304 | 0.0310 | 0.0451 | 0.0308 | 0.0071 | 0.1194 | 0.2701 | 0.0122 | 0.1311 | 0.0911 | 0.0116 |
Number of examples | 812 | 883 | 209 | 305 | 208 | 48 | 809 | 1830 | 82 | 888 | 617 | 78 |
Proportions set 8 | 0.0166 | 0.0521 | 0.0457 | 0.0358 | 0.1175 | 0.1917 | 0.0258 | 0.2303 | 0.1503 | 0.0379 | 0.0600 | 0.0362 |
Number of examples | 112 | 352 | 309 | 242 | 796 | 1298 | 174 | 1560 | 1018 | 256 | 406 | 245 |
Proportions set 9 | 0.1523 | 0.1316 | 0.0196 | 0.0427 | 0.0380 | 0.0806 | 0.0423 | 0.2075 | 0.0383 | 0.0644 | 0.1074 | 0.0754 |
Number of examples | 1031 | 891 | 132 | 289 | 257 | 546 | 286 | 1405 | 259 | 436 | 727 | 510 |
Proportions set 10 | 0.0102 | 0.1336 | 0.0063 | 0.0131 | 0.1291 | 0.0812 | 0.0019 | 0.0967 | 0.3063 | 0.0770 | 0.1107 | 0.0337 |
Number of examples | 69 | 905 | 42 | 88 | 874 | 549 | 12 | 655 | 2074 | 521 | 750 | 228 |
![Refer to caption](x14.png)
![Refer to caption](x15.png)