-
Model Predictive Simulation Using Structured Graphical Models and Transformers
Authors:
Xinghua Lou,
Meet Dave,
Shrinu Kushagra,
Miguel Lazaro-Gredilla,
Kevin Murphy
Abstract:
We propose an approach to simulating trajectories of multiple interacting agents (road users) based on transformers and probabilistic graphical models (PGMs), and apply it to the Waymo SimAgents challenge. The transformer baseline is based on the MTR model, which predicts multiple future trajectories conditioned on the past trajectories and static road layout features. We then improve upon these g…
▽ More
We propose an approach to simulating trajectories of multiple interacting agents (road users) based on transformers and probabilistic graphical models (PGMs), and apply it to the Waymo SimAgents challenge. The transformer baseline is based on the MTR model, which predicts multiple future trajectories conditioned on the past trajectories and static road layout features. We then improve upon these generated trajectories using a PGM, which contains factors which encode prior knowledge, such as a preference for smooth trajectories, and avoidance of collisions with static obstacles and other moving agents. We perform (approximate) MAP inference in this PGM using the Gauss-Newton method. Finally we sample $K=32$ trajectories for each of the $N \sim 100$ agents for the next $T=8 Δ$ time steps, where $Δ=10$ is the sampling rate per second. Following the Model Predictive Control (MPC) paradigm, we only return the first element of our forecasted trajectories at each step, and then we replan, so that the simulation can constantly adapt to its changing environment. We therefore call our approach "Model Predictive Simulation" or MPS. We show that MPS improves upon the MTR baseline, especially in safety critical metrics such as collision rate. Furthermore, our approach is compatible with any underlying forecasting model, and does not require extra training, so we believe it is a valuable contribution to the community.
△ Less
Submitted 27 June, 2024;
originally announced June 2024.
-
Graph schemas as abstractions for transfer learning, inference, and planning
Authors:
J. Swaroop Guntupalli,
Rajkumar Vasudeva Raju,
Shrinu Kushagra,
Carter Wendelken,
Danny Sawyer,
Ishan Deshpande,
Guangyao Zhou,
Miguel Lázaro-Gredilla,
Dileep George
Abstract:
Transferring latent structure from one environment or problem to another is a mechanism by which humans and animals generalize with very little data. Inspired by cognitive and neurobiological insights, we propose graph schemas as a mechanism of abstraction for transfer learning. Graph schemas start with latent graph learning where perceptually aliased observations are disambiguated in the latent s…
▽ More
Transferring latent structure from one environment or problem to another is a mechanism by which humans and animals generalize with very little data. Inspired by cognitive and neurobiological insights, we propose graph schemas as a mechanism of abstraction for transfer learning. Graph schemas start with latent graph learning where perceptually aliased observations are disambiguated in the latent space using contextual information. Latent graph learning is also emerging as a new computational model of the hippocampus to explain map learning and transitive inference. Our insight is that a latent graph can be treated as a flexible template -- a schema -- that models concepts and behaviors, with slots that bind groups of latent nodes to the specific observations or groundings. By treating learned latent graphs (schemas) as prior knowledge, new environments can be quickly learned as compositions of schemas and their newly learned bindings. We evaluate graph schemas on two previously published challenging tasks: the memory & planning game and one-shot StreetLearn, which are designed to test rapid task solving in novel environments. Graph schemas can be learned in far fewer episodes than previous baselines, and can model and plan in a few steps in novel variations of these tasks. We also demonstrate learning, matching, and reusing graph schemas in more challenging 2D and 3D environments with extensive perceptual aliasing and size variations, and show how different schemas can be composed to model larger and more complex environments. To summarize, our main contribution is a unified system, inspired and grounded in cognitive science, that facilitates rapid transfer learning of new environments using schemas via map-induction and composition that handles perceptual aliasing.
△ Less
Submitted 12 December, 2023; v1 submitted 14 February, 2023;
originally announced February 2023.
-
PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX
Authors:
Guangyao Zhou,
Antoine Dedieu,
Nishanth Kumar,
Wolfgang Lehrach,
Miguel Lázaro-Gredilla,
Shrinu Kushagra,
Dileep George
Abstract:
PGMax is an open-source Python package for (a) easily specifying discrete Probabilistic Graphical Models (PGMs) as factor graphs; and (b) automatically running efficient and scalable loopy belief propagation (LBP) in JAX. PGMax supports general factor graphs with tractable factors, and leverages modern accelerators like GPUs for inference. Compared with existing alternatives, PGMax obtains higher-…
▽ More
PGMax is an open-source Python package for (a) easily specifying discrete Probabilistic Graphical Models (PGMs) as factor graphs; and (b) automatically running efficient and scalable loopy belief propagation (LBP) in JAX. PGMax supports general factor graphs with tractable factors, and leverages modern accelerators like GPUs for inference. Compared with existing alternatives, PGMax obtains higher-quality inference results with up to three orders-of-magnitude inference time speedups. PGMax additionally interacts seamlessly with the rapidly growing JAX ecosystem, opening up new research possibilities. Our source code, examples and documentation are available at https://github.com/deepmind/PGMax.
△ Less
Submitted 24 March, 2023; v1 submitted 8 February, 2022;
originally announced February 2022.
-
On sampling from data with duplicate records
Authors:
Alireza Heidari,
Shrinu Kushagra,
Ihab F. Ilyas
Abstract:
Data deduplication is the task of detecting records in a database that correspond to the same real-world entity. Our goal is to develop a procedure that samples uniformly from the set of entities present in the database in the presence of duplicates. We accomplish this by a two-stage process. In the first step, we estimate the frequencies of all the entities in the database. In the second step, we…
▽ More
Data deduplication is the task of detecting records in a database that correspond to the same real-world entity. Our goal is to develop a procedure that samples uniformly from the set of entities present in the database in the presence of duplicates. We accomplish this by a two-stage process. In the first step, we estimate the frequencies of all the entities in the database. In the second step, we use rejection sampling to obtain a (approximately) uniform sample from the set of entities. However, efficiently estimating the frequency of all the entities is a non-trivial task and not attainable in the general case. Hence, we consider various natural properties of the data under which such frequency estimation (and consequently uniform sampling) is possible. Under each of those assumptions, we provide sampling algorithms and give proofs of the complexity (both statistical and computational) of our approach. We complement our study by conducting extensive experiments on both real and synthetic datasets.
△ Less
Submitted 24 August, 2020;
originally announced August 2020.
-
Record fusion: A learning approach
Authors:
Alireza Heidari,
George Michalopoulos,
Shrinu Kushagra,
Ihab F. Ilyas,
Theodoros Rekatsinas
Abstract:
Record fusion is the task of aggregating multiple records that correspond to the same real-world entity in a database. We can view record fusion as a machine learning problem where the goal is to predict the "correct" value for each attribute for each entity. Given a database, we use a combination of attribute-level, recordlevel, and database-level signals to construct a feature vector for each ce…
▽ More
Record fusion is the task of aggregating multiple records that correspond to the same real-world entity in a database. We can view record fusion as a machine learning problem where the goal is to predict the "correct" value for each attribute for each entity. Given a database, we use a combination of attribute-level, recordlevel, and database-level signals to construct a feature vector for each cell (or (row, col)) of that database. We use this feature vector alongwith the ground-truth information to learn a classifier for each of the attributes of the database.
Our learning algorithm uses a novel stagewise additive model. At each stage, we construct a new feature vector by combining a part of the original feature vector with features computed by the predictions from the previous stage. We then learn a softmax classifier over the new feature space. This greedy stagewise approach can be viewed as a deep model where at each stage, we are adding more complicated non-linear transformations of the original feature vector. We show that our approach fuses records with an average precision of ~98% when source information of records is available, and ~94% without source information across a diverse array of real-world datasets. We compare our approach to a comprehensive collection of data fusion and entity consolidation methods considered in the literature. We show that our approach can achieve an average precision improvement of ~20%/~45% with/without source information respectively.
△ Less
Submitted 17 June, 2020;
originally announced June 2020.
-
Adjoined Networks: A Training Paradigm with Applications to Network Compression
Authors:
Utkarsh Nath,
Shrinu Kushagra,
Yingzhen Yang
Abstract:
Compressing deep neural networks while maintaining accuracy is important when we want to deploy large, powerful models in production and/or edge devices. One common technique used to achieve this goal is knowledge distillation. Typically, the output of a static pre-defined teacher (a large base network) is used as soft labels to train and transfer information to a student (or smaller) network. In…
▽ More
Compressing deep neural networks while maintaining accuracy is important when we want to deploy large, powerful models in production and/or edge devices. One common technique used to achieve this goal is knowledge distillation. Typically, the output of a static pre-defined teacher (a large base network) is used as soft labels to train and transfer information to a student (or smaller) network. In this paper, we introduce Adjoined Networks, or AN, a learning paradigm that trains both the original base network and the smaller compressed network together. In our training approach, the parameters of the smaller network are shared across both the base and the compressed networks. Using our training paradigm, we can simultaneously compress (the student network) and regularize (the teacher network) any architecture. In this paper, we focus on popular CNN-based architectures used for computer vision tasks. We conduct an extensive experimental evaluation of our training paradigm on various large-scale datasets. Using ResNet-50 as the base network, AN achieves 71.8% top-1 accuracy with only 1.8M parameters and 1.6 GFLOPs on the ImageNet data-set. We further propose Differentiable Adjoined Networks (DAN), a training paradigm that augments AN by using neural architecture search to jointly learn both the width and the weights for each layer of the smaller network. DAN achieves ResNet-50 level accuracy on ImageNet with $3.8\times$ fewer parameters and $2.2\times$ fewer FLOPs.
△ Less
Submitted 14 April, 2022; v1 submitted 9 June, 2020;
originally announced June 2020.
-
Three-dimensional matching is NP-Hard
Authors:
Shrinu Kushagra
Abstract:
The standard proof of NP-Hardness of 3DM provides a power-$4$ reduction of 3SAT to 3DM. In this note, we provide a linear-time reduction. Under the exponential time hypothesis, this reduction improves the runtime lower bound from $2^{o(\sqrt[4]{m})}$ (under the standard reduction) to $2^{o(m)}$.
The standard proof of NP-Hardness of 3DM provides a power-$4$ reduction of 3SAT to 3DM. In this note, we provide a linear-time reduction. Under the exponential time hypothesis, this reduction improves the runtime lower bound from $2^{o(\sqrt[4]{m})}$ (under the standard reduction) to $2^{o(m)}$.
△ Less
Submitted 29 February, 2020;
originally announced March 2020.
-
Semi-supervised clustering for de-duplication
Authors:
Shrinu Kushagra,
Shai Ben-David,
Ihab Ilyas
Abstract:
Data de-duplication is the task of detecting multiple records that correspond to the same real-world entity in a database. In this work, we view de-duplication as a clustering problem where the goal is to put records corresponding to the same physical entity in the same cluster and putting records corresponding to different physical entities into different clusters.
We introduce a framework whic…
▽ More
Data de-duplication is the task of detecting multiple records that correspond to the same real-world entity in a database. In this work, we view de-duplication as a clustering problem where the goal is to put records corresponding to the same physical entity in the same cluster and putting records corresponding to different physical entities into different clusters.
We introduce a framework which we call promise correlation clustering. Given a complete graph $G$ with the edges labelled $0$ and $1$, the goal is to find a clustering that minimizes the number of $0$ edges within a cluster plus the number of $1$ edges across different clusters (or correlation loss). The optimal clustering can also be viewed as a complete graph $G^*$ with edges corresponding to points in the same cluster being labelled $0$ and other edges being labelled $1$. Under the promise that the edge difference between $G$ and $G^*$ is "small", we prove that finding the optimal clustering (or $G^*$) is still NP-Hard. [Ashtiani et. al, 2016] introduced the framework of semi-supervised clustering, where the learning algorithm has access to an oracle, which answers whether two points belong to the same or different clusters. We further prove that even with access to a same-cluster oracle, the promise version is NP-Hard as long as the number queries to the oracle is not too large ($o(n)$ where $n$ is the number of vertices).
Given these negative results, we consider a restricted version of correlation clustering. As before, the goal is to find a clustering that minimizes the correlation loss. However, we restrict ourselves to a given class $\mathcal F$ of clusterings. We offer a semi-supervised algorithmic approach to solve the restricted variant with success guarantees.
△ Less
Submitted 10 October, 2018;
originally announced October 2018.
-
Provably noise-robust, regularised $k$-means clustering
Authors:
Shrinu Kushagra,
Yaoliang Yu,
Shai Ben-David
Abstract:
We consider the problem of clustering in the presence of noise. That is, when on top of cluster structure, the data also contains a subset of \emph{unstructured} points. Our goal is to detect the clusters despite the presence of many unstructured points. Any algorithm that achieves this goal is noise-robust. We consider a regularisation method which converts any center-based clustering objective i…
▽ More
We consider the problem of clustering in the presence of noise. That is, when on top of cluster structure, the data also contains a subset of \emph{unstructured} points. Our goal is to detect the clusters despite the presence of many unstructured points. Any algorithm that achieves this goal is noise-robust. We consider a regularisation method which converts any center-based clustering objective into a noise-robust one. We focus on the $k$-means objective and we prove that the regularised version of $k$-means is NP-Hard even for $k=1$. We consider two algorithms based on the convex (sdp and lp) relaxation of the regularised objective and prove robustness guarantees for both.
The sdp and lp relaxation of the standard (non-regularised) $k$-means objective has been previously studied by [ABC+15]. Under the stochastic ball model of the data they show that the sdp-based algorithm recovers the underlying structure as long as the balls are separated by $δ> 2\sqrt{2} + ε$. We improve upon this result in two ways. First, we show recovery even for $δ> 2 + ε$. Second, our regularised algorithm recovers the balls even in the presence of noise so long as the number of noisy points is not too large. We complement our theoretical analysis with simulations and analyse the effect of various parameters like regularization constant, noise-level etc. on the performance of our algorithm. In the presence of noise, our algorithm performs better than $k$-means++ on MNIST.
△ Less
Submitted 27 August, 2018; v1 submitted 30 November, 2017;
originally announced November 2017.
-
Clustering with Same-Cluster Queries
Authors:
Hassan Ashtiani,
Shrinu Kushagra,
Shai Ben-David
Abstract:
We propose a framework for Semi-Supervised Active Clustering framework (SSAC), where the learner is allowed to interact with a domain expert, asking whether two given instances belong to the same cluster or not. We study the query and computational complexity of clustering in this framework. We consider a setting where the expert conforms to a center-based clustering with a notion of margin. We sh…
▽ More
We propose a framework for Semi-Supervised Active Clustering framework (SSAC), where the learner is allowed to interact with a domain expert, asking whether two given instances belong to the same cluster or not. We study the query and computational complexity of clustering in this framework. We consider a setting where the expert conforms to a center-based clustering with a notion of margin. We show that there is a trade off between computational complexity and query complexity; We prove that for the case of $k$-means clustering (i.e., when the expert conforms to a solution of $k$-means), having access to relatively few such queries allows efficient solutions to otherwise NP hard problems.
In particular, we provide a probabilistic polynomial-time (BPP) algorithm for clustering in this setting that asks $O\big(k^2\log k + k\log n)$ same-cluster queries and runs with time complexity $O\big(kn\log n)$ (where $k$ is the number of clusters and $n$ is the number of instances). The algorithm succeeds with high probability for data satisfying margin conditions under which, without queries, we show that the problem is NP hard. We also prove a lower bound on the number of queries needed to have a computationally efficient clustering algorithm in this setting.
△ Less
Submitted 22 November, 2016; v1 submitted 8 June, 2016;
originally announced June 2016.