[1]\fnmChristof \surFehrman

[1]\orgdivPsychology Department, \orgnameUniversity of Virginia, \cityCharlottesville, \stateVirginia, \countryUnited States of America

2]\orgdivNeuroscience Graduate Program, \orgnameUniversity of Virginia, \cityCharlottesville, \stateVirginia, \countryUnited States of America

Model Predictive Control of the Neural Manifold

[email protected]    \fnmC. Daniel \surMeliza [email protected] * [
Abstract

Neural manifolds are an attractive theoretical framework for characterizing the complex behaviors of neural populations. However, many of the tools for identifying these low-dimensional subspaces are correlational and provide limited insight into the underlying dynamics. The ability to precisely control this latent activity would allow researchers to investigate the structure and function of neural manifolds. Employing techniques from the field of optimal control, we simulate controlling the latent dynamics of a neural population using closed-loop, dynamically generated sensory inputs. Using a spiking neural network (SNN) as a model of a neural circuit, we find low-dimensional representations of both the network activity (the neural manifold) and a set of salient visual stimuli. With a data-driven latent dynamics model, we apply model predictive control (MPC) to provide anticipatory, optimal control over the trajectory of the circuit in a latent space. We are able to control the latent dynamics of the SNN to follow several reference trajectories despite observing only a subset of neurons and with a substantial amount of unknown noise injected into the network. These results provide a framework to experimentally test for causal relationships between manifold dynamics and other variables of interest such as organismal behavior and BCI performance.

keywords:
Neural Manifold, Model Predictive Control, Data-Driven Modeling, Optimal Control, Spiking Neural Network

1 Introduction

Neural circuits are composed of large numbers of interconnected neurons whose activity depends on other neurons in the circuit. Because of these dependencies, simultaneously recorded populations of neurons usually exhibit high levels of correlation. Equivalently, most of the variance within the high-dimensional space corresponding to the firing rates of individual neurons is confined to a lower-dimensional subspace, or neural manifold [1]. Due to their relative simplicity, manifolds have become a popular framework for understanding the complex dynamics of large neural populations. In many different systems, activity on neural manifolds has been shown to correlate with salient features of stimuli, physical position, and internal cognitive states [2, 3, 4, 5]. However, the insight that manifolds can provide into the underlying dynamics and computational principles of the circuits remains a contentious question.

Broadly speaking, neural manifolds may be seen either as descriptive tools for dealing with the inherently correlated nature of neural data arising from highly interconnected circuits, or as a method of revealing a more fundamental dynamics that exists within a latent space [6]. Using linear subspaces with autonomous dynamics as an illustrative example, the descriptive perspective can be seen as classic dimensionality reduction with

𝐳𝐭=𝐆𝐱𝐭,subscript𝐳𝐭subscript𝐆𝐱𝐭\mathbf{z_{t}}=\mathbf{Gx_{t}},bold_z start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT = bold_Gx start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT , (1)

where 𝐱𝐭subscript𝐱𝐭\mathbf{x_{t}}bold_x start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT is a column vector of the activity of n𝑛nitalic_n neurons at time t𝑡titalic_t and 𝐳𝐭subscript𝐳𝐭\mathbf{z_{t}}bold_z start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT is a reduced dimension representation of 𝐱𝐭subscript𝐱𝐭\mathbf{x_{t}}bold_x start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT given by the linear transformation 𝐆𝐆\mathbf{G}bold_G (which parameterizes the neural manifold). This model is descriptive because the latent trajectories 𝐳𝐭subscript𝐳𝐭\mathbf{z_{t}}bold_z start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT are seen as a convenient representation of high-dimensional trajectories that result from the highly coupled dynamics that exist among the individual neurons 𝐱𝐭subscript𝐱𝐭\mathbf{x_{t}}bold_x start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT, and thus provide limited information for inference or mechanistic understanding [6]. In contrast, the generative perspective can be modeled as a latent factor model of the form

𝐱𝐭=𝐅𝐳𝐭+ϵ𝐭,subscript𝐱𝐭subscript𝐅𝐳𝐭subscriptitalic-ϵ𝐭\mathbf{x_{t}}=\mathbf{Fz_{t}}+\mathbf{\epsilon_{t}},bold_x start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT = bold_Fz start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT + italic_ϵ start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT , (2)

where 𝐱𝐭subscript𝐱𝐭\mathbf{x_{t}}bold_x start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT is a column vector of the activity of n𝑛nitalic_n neurons at time t𝑡titalic_t, 𝐳𝐭subscript𝐳𝐭\mathbf{z_{t}}bold_z start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT are the latent factors that span the neural manifold with some smaller dimension k𝑘kitalic_k, 𝐅𝐅\mathbf{F}bold_F are the factor loadings, and ϵ𝐭subscriptitalic-ϵ𝐭\mathbf{\epsilon_{t}}italic_ϵ start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT is a sample from some distribution (often Gaussian). This perspective views the measured neural activity as a function of the latent dynamics of 𝐳𝐭subscript𝐳𝐭\mathbf{z_{t}}bold_z start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT, which emerge from but are simpler than the dynamics of 𝐱𝐭subscript𝐱𝐭\mathbf{x_{t}}bold_x start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT. The descriptive and generative perspectives may also be seen as bottom-up and top-down approaches, respectively, for addressing the question of how computations are implemented by neural circuits [7, 6]. Early work on neural manifolds used linear methods such as principal components analysis (as in equation 1) and factor analysis (as in equation 2), but there is now a broad consensus that neural manifolds are often nonlinearly embedded in the full state space, requiring more sophisticated methods to identify [8].

A potential weakness in the generative approach to understanding neural manifolds is that most methods of dimensional reduction are static: they produce a time series of snapshots from an informative angle in the neural state space [9], but the dynamics have to be inferred through other means. In contrast to the bottom-up approach where there is a robust foundation of biophysics on which to build models of circuit dynamics at the level of cells and synapses, the question of how best to model dynamics in the latent space remains an active area of research [10, 11, 12, 13]. Testing these models and their causal relationship to behavior would benefit from methods for experimentally controlling the activity on the neural manifold.

In this study, we develop a framework for controlling latent dynamics in the context of a sensory system. We express activity on the neural manifold with a general state-space model

𝐳𝐭+𝟏subscript𝐳𝐭1\displaystyle\mathbf{z_{t+1}}bold_z start_POSTSUBSCRIPT bold_t + bold_1 end_POSTSUBSCRIPT =g(𝐳𝐭,𝐮𝐭,ϵ𝐭)absent𝑔subscript𝐳𝐭subscript𝐮𝐭subscriptitalic-ϵ𝐭\displaystyle=g(\mathbf{z_{t}},\mathbf{u_{t}},\mathbf{\epsilon_{t}})= italic_g ( bold_z start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT , bold_u start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT , italic_ϵ start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT ) (3)
𝐱𝐭subscript𝐱𝐭\displaystyle\mathbf{x_{t}}bold_x start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT =f(𝐳𝐭),absent𝑓subscript𝐳𝐭\displaystyle=f(\mathbf{z_{t}}),= italic_f ( bold_z start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT ) , (4)

where the latent dynamics on the manifold are defined by a function of the current state of the system 𝐳𝐭subscript𝐳𝐭\mathbf{z_{t}}bold_z start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT, an external, time-varying stimulus 𝐮𝐭subscript𝐮𝐭\mathbf{u_{t}}bold_u start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT that will be used for control, and an intrinsic, uncontrolled source of noise ϵ𝐭subscriptitalic-ϵ𝐭\mathbf{\epsilon_{t}}italic_ϵ start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT. The high dimensional measured neural activity 𝐱𝐭subscript𝐱𝐭\mathbf{x_{t}}bold_x start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT is obtained with the observation function in equation (4). Interestingly, we show that the data-generating process (i.e., the dynamics of 𝐱𝐭subscript𝐱𝐭\mathbf{x_{t}}bold_x start_POSTSUBSCRIPT bold_t end_POSTSUBSCRIPT) does not need to be the same as the latent dynamics model for the framework to be useful (see Methods) and that control of a highly nonlinear neural system is possible even when agnostic to the true structure of the neural manifold.

Given the latent dynamics specified in equation (3), the control problem is to find an external stimulus 𝐮𝐮\mathbf{u}bold_u such that the time evolution follows a specified trajectory 𝐳superscript𝐳\mathbf{z^{*}}bold_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. The field of control theory provides a rich mathematical background to find system inputs to achieve desired system outputs. Many techniques exist with feedback (closed-loop) methods being particularly attractive due to their ability to correct for unknown perturbations to the system. Broadly speaking, feedback controllers can be categorized as being either reactive or anticipatory. Reactive controllers use present and past errors in state tracking (the difference between 𝐳superscript𝐳\mathbf{z^{*}}bold_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and 𝐳𝐳\mathbf{z}bold_z) to find a control signal 𝐮𝐮\mathbf{u}bold_u. The classic PID controller is a popular implementation of a reactive controller due to its computational simplicity and strong performance. A notable issue with reactive controllers however is that one can only correct for errors in state tracking once they have been made. This can be unacceptable behavior if certain errors in state tracking are associated with pathological states. For example, suppose that a particular region of neural state space corresponded to epileptic firing. A reactive controller would only be able to respond after the system entered this region, at which point it may be much more difficult to re-establish control. In contrast, anticipatory controllers are designed to predict future errors, which can allow them to prevent the system from ever entering undesirable regions of state space. Model predictive control (MPC) is an anticipatory controller that uses a model of the system dynamics to make predictions on how present and future inputs will affect errors in state-tracking [14]. Additionally, MPC is a type of optimal controller because it attempts to find a control input sequence 𝐮𝟏:𝐓subscript𝐮:1𝐓\mathbf{u_{1:T}}bold_u start_POSTSUBSCRIPT bold_1 : bold_T end_POSTSUBSCRIPT that minimizes a loss function of the form

J(𝐱𝟎)=i=0T(𝐱𝐢,𝐮𝐢),𝐽subscript𝐱0superscriptsubscript𝑖0𝑇subscript𝐱𝐢subscript𝐮𝐢J(\mathbf{x_{0}})=\sum_{i=0}^{T}\ell(\mathbf{x_{i}},\mathbf{u_{i}}),\\ italic_J ( bold_x start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_ℓ ( bold_x start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT , bold_u start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT ) , (5)

with constraints

𝐱𝐧+𝟏subscript𝐱𝐧1\displaystyle\mathbf{x_{n+1}}bold_x start_POSTSUBSCRIPT bold_n + bold_1 end_POSTSUBSCRIPT =f(𝐱𝐧,𝐮𝐧)absent𝑓subscript𝐱𝐧subscript𝐮𝐧\displaystyle=f(\mathbf{x_{n}},\mathbf{u_{n}})= italic_f ( bold_x start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT , bold_u start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT )
𝐱𝐋𝐁subscript𝐱𝐋𝐁\displaystyle\mathbf{x_{LB}}bold_x start_POSTSUBSCRIPT bold_LB end_POSTSUBSCRIPT 𝐱𝐱𝐔𝐁absent𝐱subscript𝐱𝐔𝐁\displaystyle\leq\mathbf{x}\leq\mathbf{x_{UB}}≤ bold_x ≤ bold_x start_POSTSUBSCRIPT bold_UB end_POSTSUBSCRIPT
𝐮𝐋𝐁subscript𝐮𝐋𝐁\displaystyle\mathbf{u_{LB}}bold_u start_POSTSUBSCRIPT bold_LB end_POSTSUBSCRIPT 𝐮𝐮𝐔𝐁,absent𝐮subscript𝐮𝐔𝐁\displaystyle\leq\mathbf{u}\leq\mathbf{u_{UB}},≤ bold_u ≤ bold_u start_POSTSUBSCRIPT bold_UB end_POSTSUBSCRIPT ,

where 𝐱𝟎subscript𝐱0\mathbf{x_{0}}bold_x start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT is the value of the state at the current time step and (𝐱𝐢,𝐮𝐢)subscript𝐱𝐢subscript𝐮𝐢\ell(\mathbf{x_{i}},\mathbf{u_{i}})roman_ℓ ( bold_x start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT , bold_u start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT ) is the loss associated with i𝑖iitalic_ith time step (relative to 𝐱𝟎subscript𝐱0\mathbf{x_{0}}bold_x start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT), which is a function of the state variable(s) 𝐱𝐱\mathbf{x}bold_x and input(s) 𝐮𝐮\mathbf{u}bold_u. The controller uses the dynamical model to simulate T𝑇Titalic_T time steps into the future. Many types of loss functions are possible, but typically involve the state error and energy cost of the control signal. The constraints allow one to specify the dynamics of the system and to give lower and upper bounds for the state variables and inputs. More sophisticated versions of MPC allow for additional constraints where knowledge of any measurement or process noise can be incorporated [15]. Although MPC is only guaranteed to be globally optimal for linear systems with convex loss functions, it can also be used in many nonlinear systems [14, 16], in part because the controller can correct for model errors.

Only the first value of the sequence 𝐮𝟏:𝐓subscript𝐮:1𝐓\mathbf{u_{1:T}}bold_u start_POSTSUBSCRIPT bold_1 : bold_T end_POSTSUBSCRIPT is used as input into the system, with the optimization performed again at the next time step. By repeatedly solving this optimization problem and only using the first value, the controller can correct for errors in system modeling and anticipate future changes in the desired state trajectory. This anticipation of dynamics can result in better controller performance compared to traditional reactive controllers such as PID [16].

There are two major issues when using MPC in practice that are relevant to control of neural dynamics. The first is that the optimization procedure is computationally expensive and can result in poor controller performance if the time steps between measurements are small [14]. Neural recordings are often of high dimension (e.g. extracellular recordings with high-density silicon probes) and evolve at fast time scales. The stimuli (corresponding to the input 𝐮𝐮\mathbf{u}bold_u) may also be of a high dimension, which will result in even more complexity in optimizing equation (5). We can now see the utility of modeling activity on the neural manifold as latent generative process. By doing so, we can solve the optimization problem in a low-dimensional state 𝐳𝐳\mathbf{z}bold_z and reduce the computational complexity. If we used the descriptive approach as in equation (1), we would need to optimize in the original measurement dimension 𝐱𝐱\mathbf{x}bold_x to force the activity on the manifold to follow a specified reference trajectory.

The second issue with applying MPC to neural systems that it requires a dynamical model of the system to be controlled [17]. Although there are many dynamical models rooted in biology for individual neurons, the putative latent dynamics of a neural manifold are an emergent property that is difficult to model from first principles. This requires us to take a data-driven approach, where unknown parts of the system can be modeled via function approximation and used to predict the time-evolution of the system in response to various inputs.

Fitting these models is achieved by observing a temporal sequence of the state and input variables with some sampling period ΔtΔ𝑡\Delta troman_Δ italic_t,

𝐙=[𝐳𝟎,𝐳𝟏,,𝐳𝐍],𝐔=[𝐮𝟎,𝐮𝟏,,𝐮𝐍],formulae-sequence𝐙subscript𝐳0subscript𝐳1subscript𝐳𝐍𝐔subscript𝐮0subscript𝐮1subscript𝐮𝐍\mathbf{Z}=[\mathbf{z_{0}},\mathbf{z_{1}},...,\mathbf{z_{N}}],\mathbf{U}=[% \mathbf{u_{0}},\mathbf{u_{1}},...,\mathbf{u_{N}}],bold_Z = [ bold_z start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT , bold_z start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT bold_N end_POSTSUBSCRIPT ] , bold_U = [ bold_u start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT , bold_u start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT , … , bold_u start_POSTSUBSCRIPT bold_N end_POSTSUBSCRIPT ] , (6)

where 𝐳𝐧=𝐳(nΔt)subscript𝐳𝐧𝐳𝑛Δ𝑡\mathbf{z_{n}}=\mathbf{z}(n\Delta t)bold_z start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT = bold_z ( italic_n roman_Δ italic_t ). A discrete-time model can be parameterized such that

𝐳^𝐧+𝟏=fθ(𝐳𝐧,𝐮𝐧)subscript^𝐳𝐧1subscript𝑓𝜃subscript𝐳𝐧subscript𝐮𝐧\mathbf{\hat{z}_{n+1}}=f_{\theta}(\mathbf{z_{n}},\mathbf{u_{n}})over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT bold_n + bold_1 end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT , bold_u start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT ) (7)

These type of models are often referred to as forecasting models since the model predicts how the system will change across time. The goal is to find a set of parameters θ𝜃\thetaitalic_θ such that given some initial state value 𝐳𝟎subscript𝐳0\mathbf{z_{0}}bold_z start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT and some known input 𝐔𝐔\mathbf{U}bold_U,

[𝐳𝟎,𝐳^𝟏,,𝐳^𝐍][𝐳𝟎,𝐳𝟏,,𝐳𝐍]subscript𝐳0subscript^𝐳1subscript^𝐳𝐍subscript𝐳0subscript𝐳1subscript𝐳𝐍[\mathbf{z_{0}},\mathbf{\hat{z}_{1}},...,\mathbf{\hat{z}_{N}}]\approx[\mathbf{% z_{0}},\mathbf{z_{1}},...,\mathbf{z_{N}}][ bold_z start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT , … , over^ start_ARG bold_z end_ARG start_POSTSUBSCRIPT bold_N end_POSTSUBSCRIPT ] ≈ [ bold_z start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT , bold_z start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT , … , bold_z start_POSTSUBSCRIPT bold_N end_POSTSUBSCRIPT ] (8)

for any general temporal data sequence produced from the true dynamical system.

Data-driven approaches have been successfully applied to MPC problems in diverse fields [18, 19, 15, 20, 21], and the field is rapidly growing. In order for data-driven models to be useful for MPC applications in neuroscience, these models must be able to accurately predict the states to be controlled based only on observable state measurements, be agnostic to the number of hidden states, and generalize to a control scheme where control signals may be outside the training set. These challenges are not unique to neuroscience, but are still important to consider when selecting a data-driven approach to model the system dynamics. MPC has already been successfully applied in multiple areas of neuroscience research. At the individual cell level, simulated Hodgkin-Huxley neurons have been shown to be controllable via MPC using both biophysical [22, 23] and data-driven dynamics models [24, 25]. Additionally, both simulated and in vivo systems of neurons have been controlled with MPC using optogenetic stimulation as the control signal [26]. However, to the best of our knowledge, there has been no explicit attempt to apply these methods to activity on the neural manifold.

Our goal for this study was to provide a framework for controlling the latent dynamics on a neural manifold with MPC. An artificial neural circuit was simulated using a spiking neural network (SNN) that was driven by images of handwritten digits 𝐮𝐮\mathbf{u}bold_u. The activity of the network was measured in accordance with an extracellular recording experiment where the spike times of the SNN served as neural states 𝐱𝐱\mathbf{x}bold_x. Modern extracellular probes are able to record from dozens to hundreds of neurons simultaneously in both anesthetized and awake animals, but this is only a small subset of the neurons in a typical local circuit. Thus, in this simulation we only used a subset of the simulated neurons to fit a latent dynamics model of the whole network. We also added a substantial amount of random synaptic noise to simulate the influence of uncontrolled spontaneous activity arising from unobserved neurons within and outside the local circuit.

Using variational autoencoders (VAEs), we found low-dimensional representations of both the neural states and stimulus images. A latent dynamics model for the controller was fit using the lower-dimensional states and inputs (referred to as 𝐳𝐳\mathbf{z}bold_z and 𝐯𝐯\mathbf{v}bold_v respectively). We then found optimal latent inputs via MPC to force the latent states on the manifold to follow specified reference trajectories 𝐳superscript𝐳\mathbf{z^{*}}bold_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. We first controlled the latent dynamics of the SNN to stay at specified set points in the presence of intrinsic system noise. This result provided a proof of principle that MPC can control activity of the neural manifold in a simulation of a experiment using extracellular recording. We then show that the controller is able to force the latent states to follow multiple time-varying reference trajectories, with the optimized visual stimuli showing striking differences. The ability to control the system to follow distinct trajectories would allow for experimenters to see if there is a causal relationship between activity on the manifold and macro-scale behaviors. Finally, we examined the relationship between the proportion of measurable neurons and MPC performance.

Refer to caption
Figure 1: MPC Control Loop of Artificial Circuit. Exponentially filtered spike trains are encoded into the latent state 𝐳𝐳\mathbf{z}bold_z and is compared to a reference trajectory 𝐳superscript𝐳\mathbf{z^{*}}bold_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT to produce an error signal. The controller uses a model of the latent dynamics to find an optimal input that minimizes a loss function of the state error. This input is then projected back into the original stimulus dimension using the sVAE decoder which stimulates the artificial circuit.

2 Methods

2.1 Artificial Circuit

2.1.1 Architecture

The activity of an artificial circuit evoked by an external visual stimulus was simulated with a SNN composed of three layers: sensory, reservoir, and output. Each neuron in the SNN was modeled with a recurrent leaky integrate-and-fire (rLIF) model, with the discrete time approximation

Vn+1={βVn+wTXn+1+rTSn,if Vn<Θ0,if VnΘsubscript𝑉𝑛1cases𝛽subscript𝑉𝑛superscript𝑤𝑇subscript𝑋𝑛1superscript𝑟𝑇subscript𝑆𝑛if subscript𝑉𝑛Θ0if subscript𝑉𝑛ΘV_{n+1}=\begin{cases}\beta V_{n}+w^{T}X_{n+1}+r^{T}S_{n},&\text{if }V_{n}<% \Theta\\ 0,&\text{if }V_{n}\geq\Theta\end{cases}italic_V start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT = { start_ROW start_CELL italic_β italic_V start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT + italic_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_X start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT + italic_r start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , end_CELL start_CELL if italic_V start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT < roman_Θ end_CELL end_ROW start_ROW start_CELL 0 , end_CELL start_CELL if italic_V start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ≥ roman_Θ end_CELL end_ROW (9)

where

Vn:membrane voltage at the nth time step:subscript𝑉𝑛membrane voltage at the nth time step\displaystyle V_{n}:\text{membrane voltage at the {n}th time step}italic_V start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT : membrane voltage at the italic_n th time step
Θ:spiking threshold:Θspiking threshold\displaystyle\Theta:\text{spiking threshold}roman_Θ : spiking threshold
β:decay parameter:𝛽decay parameter\displaystyle\beta:\text{decay parameter}italic_β : decay parameter
Xn:feedforward input vector at the nth time step:subscript𝑋𝑛feedforward input vector at the nth time step\displaystyle X_{n}:\text{feedforward input vector at the {n}th time step}italic_X start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT : feedforward input vector at the italic_n th time step
w:feedforward weights:𝑤feedforward weights\displaystyle w:\text{feedforward weights}italic_w : feedforward weights
Sn:layer spiking vector at the nth time step:subscript𝑆𝑛layer spiking vector at the nth time step\displaystyle S_{n}:\text{layer spiking vector at the {n}th time step}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT : layer spiking vector at the italic_n th time step
r:recurrent weights:𝑟recurrent weights\displaystyle r:\text{recurrent weights}italic_r : recurrent weights

Whenever the V𝑉Vitalic_V variable was reset to 00, a spike was recorded at that time step. For a given layer with N𝑁Nitalic_N neurons, this produced a binary vector Sn=[s1,s2,,sN]Tsubscript𝑆𝑛superscriptsubscript𝑠1subscript𝑠2subscript𝑠𝑁𝑇S_{n}=[s_{1},s_{2},...,s_{N}]^{T}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = [ italic_s start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_s start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT where the value of each element was either a 00 or 1111 indicating if the corresponding neuron had fired at that time step. This allowed the activity of each neuron in a layer to be affected not only by its own firing (i.e., spiking inhibition/facilitation), but also to receive inputs from the other neurons in that layer.

Each neuron in the sensory layer received feedforward input in the form of a grayscale image reshaped into a 784 dimensional vector. This input served as the external stimulus that was the primary driver of SNN activity. Gaussian noise was also added to the feedforward inputs of each neuron in every layer to produce stochasticity in activity. This noise modeled the effects of natural variability in neural firing and the effects of unknown exogenous inputs. Thus the subthreshold activity of the three layers was given by the equations

Sensory::Sensoryabsent\displaystyle\text{{Sensory}}:Sensory : Vn+1sen=βVnsen+wsenT(In+ϵ)+rsenTSnsensuperscriptsubscript𝑉𝑛1𝑠𝑒𝑛𝛽superscriptsubscript𝑉𝑛𝑠𝑒𝑛superscriptsubscript𝑤𝑠𝑒𝑛𝑇subscript𝐼𝑛italic-ϵsuperscriptsubscript𝑟𝑠𝑒𝑛𝑇superscriptsubscript𝑆𝑛𝑠𝑒𝑛\displaystyle V_{n+1}^{sen}=\beta V_{n}^{sen}+w_{sen}^{T}(I_{n}+\epsilon)+r_{% sen}^{T}S_{n}^{sen}italic_V start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s italic_e italic_n end_POSTSUPERSCRIPT = italic_β italic_V start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s italic_e italic_n end_POSTSUPERSCRIPT + italic_w start_POSTSUBSCRIPT italic_s italic_e italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT + italic_ϵ ) + italic_r start_POSTSUBSCRIPT italic_s italic_e italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s italic_e italic_n end_POSTSUPERSCRIPT (10)
Reservoir::Reservoirabsent\displaystyle\text{{Reservoir}}:Reservoir : Vn+1res=βVnres+wresT(Snsen+ϵ)+rresTSnressuperscriptsubscript𝑉𝑛1𝑟𝑒𝑠𝛽superscriptsubscript𝑉𝑛𝑟𝑒𝑠superscriptsubscript𝑤𝑟𝑒𝑠𝑇superscriptsubscript𝑆𝑛𝑠𝑒𝑛italic-ϵsuperscriptsubscript𝑟𝑟𝑒𝑠𝑇superscriptsubscript𝑆𝑛𝑟𝑒𝑠\displaystyle V_{n+1}^{res}=\beta V_{n}^{res}+w_{res}^{T}(S_{n}^{sen}+\epsilon% )+r_{res}^{T}S_{n}^{res}italic_V start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r italic_e italic_s end_POSTSUPERSCRIPT = italic_β italic_V start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r italic_e italic_s end_POSTSUPERSCRIPT + italic_w start_POSTSUBSCRIPT italic_r italic_e italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s italic_e italic_n end_POSTSUPERSCRIPT + italic_ϵ ) + italic_r start_POSTSUBSCRIPT italic_r italic_e italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r italic_e italic_s end_POSTSUPERSCRIPT (11)
Output::Outputabsent\displaystyle\text{{Output}}:Output : Vn+1out=βVnout+woutT(Snres+ϵ)+routTSnoutsuperscriptsubscript𝑉𝑛1𝑜𝑢𝑡𝛽superscriptsubscript𝑉𝑛𝑜𝑢𝑡superscriptsubscript𝑤𝑜𝑢𝑡𝑇superscriptsubscript𝑆𝑛𝑟𝑒𝑠italic-ϵsuperscriptsubscript𝑟𝑜𝑢𝑡𝑇superscriptsubscript𝑆𝑛𝑜𝑢𝑡\displaystyle V_{n+1}^{out}=\beta V_{n}^{out}+w_{out}^{T}(S_{n}^{res}+\epsilon% )+r_{out}^{T}S_{n}^{out}italic_V start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT = italic_β italic_V start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT + italic_w start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r italic_e italic_s end_POSTSUPERSCRIPT + italic_ϵ ) + italic_r start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT (12)

where ϵN(0,η2)similar-toitalic-ϵ𝑁0superscript𝜂2\epsilon\sim N(0,\eta^{2})italic_ϵ ∼ italic_N ( 0 , italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) for each element in the feedforward input and Insubscript𝐼𝑛I_{n}italic_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is the stimulus image presented at the n𝑛nitalic_nth time step. Note that β𝛽\betaitalic_β incorporates both the membrane time constant and the time step of the discretization of the continuous LIF [27]. The time units for the simulation can be arbitrarily chosen by modifying this parameter. For simplicity, we use the units of milliseconds (ms) for each time step in all simulations which implicitly imposes a spiking refractory period of 1 ms. See Supplementary Information T1 for SNN training and rLIF hyperparameters.

2.1.2 Training the Network

In principle the feedforward and recurrent weights for each neuron could be randomly distributed, such as in reservoir computing [28]. However, to give the network dynamics that implement a useful computation [29], the weights were trained to perform a classification task. Using the Python package snntorch [27], the SNN was trained to accurately classify digits from the MNIST data set [30]. Training SNNs requires additional considerations compared to traditional artificial neural networks. The resetting of a neuron’s membrane voltage when it reaches the threshold ΘΘ\Thetaroman_Θ produces a non-differentiable function [27] making training with gradient descent impossible. One solution to this problem is to use surrogate gradient descent [31], where the non-differentiable function is preserved in the forward pass of the network but is replaced with a sigmoid function during the backward pass. This results in a function differentiable everywhere and allows the network to be trained with valid gradients.

The MNIST data set is composed of handwritten images of the digits 0 through 9. Classification of the images was performed by associating the label of the image with a corresponding neuron in the output layer of the SNN (e.g., label 4 with neuron 4). For time step n𝑛nitalic_n, the cross-entropy nsubscript𝑛\ell_{n}roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT was given by

pni=exp(Vni)k=09exp(Vnk)superscriptsubscript𝑝𝑛𝑖superscriptsubscript𝑉𝑛𝑖superscriptsubscript𝑘09superscriptsubscript𝑉𝑛𝑘p_{n}^{i}=\frac{\exp(V_{n}^{i})}{\sum_{k=0}^{9}\exp(V_{n}^{k})}italic_p start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = divide start_ARG roman_exp ( italic_V start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT roman_exp ( italic_V start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) end_ARG (13)
n=i=09yilog(pni),subscript𝑛superscriptsubscript𝑖09subscript𝑦𝑖𝑙𝑜𝑔superscriptsubscript𝑝𝑛𝑖\ell_{n}=-\sum_{i=0}^{9}y_{i}log(p_{n}^{i}),roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = - ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 9 end_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_l italic_o italic_g ( italic_p start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) , (14)

where Vnisuperscriptsubscript𝑉𝑛𝑖V_{n}^{i}italic_V start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT is the membrane voltage of the i𝑖iitalic_ith neuron which corresponds to the prediction of label i𝑖iitalic_i, and y𝑦yitalic_y is a one-hot encoded vector of the true label. Due to the inherently temporal nature of SNNs, one must specify how many time steps a stimulus is be presented before class prediction takes place. This can be interpreted as a combination of reaction time and evidence accumulation. Thus the loss function to be minimized is given by

CE=n=0Tn,subscript𝐶𝐸superscriptsubscript𝑛0𝑇subscript𝑛\mathcal{L}_{CE}=\sum_{n=0}^{T}\ell_{n},caligraphic_L start_POSTSUBSCRIPT italic_C italic_E end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_n = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , (15)

where the cross-entropy loss at each time step is summed for some trial time length T𝑇Titalic_T. This forces the neuron of the associated predicted class to have the highest firing rate compared to the other neurons in the output layer. For the given time window that the image is presented, the feedforward and recurrent weights in the SNN are updated using backpropagation through time (BPTT).

The first half of the MNIST data set (n=35,000𝑛35000n=35,000italic_n = 35 , 000) was used for the training and validation of the SNN with an 80/20 split. Due to the stochastic nature of the noise added to every neuron in the AC, performance after training was assessed by presenting the stimuli 30 times to obtain an accuracy distribution. Accuracy on the training (n=28,000, M=43.7𝑀43.7M=43.7italic_M = 43.7%, s=8.4𝑠8.4s=8.4italic_s = 8.4%) and validation (n=7,000𝑛7000n=7,000italic_n = 7 , 000, M=43.9𝑀43.9M=43.9italic_M = 43.9%, s=8.4𝑠8.4s=8.4italic_s = 8.4%) set were largely similar, indicating there was no overfitting to the training data. Although the accuracies were far below what would be considered competitive performance on a classification task, the purpose of training the SNN was to ensure the connections between the neurons were not random.

2.2 Dimensionality Reduction of Stimuli and Neural States

The dimensionalities of the MNIST digit stimuli and SNN neural activity were reduced using variational autoencoders (VAEs). The parameters for encoders fθ(.)f_{\theta}(.)italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( . ) and decoders hϕ(.)h_{\phi}(.)italic_h start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( . ) were found by minimizing the loss function

VAE=recon+αKLsubscript𝑉𝐴𝐸subscript𝑟𝑒𝑐𝑜𝑛𝛼subscript𝐾𝐿\ell_{VAE}=\ell_{recon}+\alpha\ell_{KL}roman_ℓ start_POSTSUBSCRIPT italic_V italic_A italic_E end_POSTSUBSCRIPT = roman_ℓ start_POSTSUBSCRIPT italic_r italic_e italic_c italic_o italic_n end_POSTSUBSCRIPT + italic_α roman_ℓ start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT (16)
recon=i=1B𝐱𝐢hϕ(𝐳𝐢)22subscript𝑟𝑒𝑐𝑜𝑛superscriptsubscript𝑖1𝐵superscriptsubscriptnormsubscript𝐱𝐢subscriptitalic-ϕsubscript𝐳𝐢22\ell_{recon}=\sum_{i=1}^{B}||\mathbf{x_{i}}-h_{\phi}(\mathbf{z_{i}})||_{2}^{2}roman_ℓ start_POSTSUBSCRIPT italic_r italic_e italic_c italic_o italic_n end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT | | bold_x start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT - italic_h start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT ) | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (17)
KL=i=1Bj=1k(1+logσij2μij2σij2)subscript𝐾𝐿superscriptsubscript𝑖1𝐵superscriptsubscript𝑗1𝑘1𝑙𝑜𝑔subscriptsuperscript𝜎2𝑖𝑗subscriptsuperscript𝜇2𝑖𝑗subscriptsuperscript𝜎2𝑖𝑗\ell_{KL}=\sum_{i=1}^{B}\sum_{j=1}^{k}(1+log\>\sigma^{2}_{ij}-\mu^{2}_{ij}-% \sigma^{2}_{ij})roman_ℓ start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( 1 + italic_l italic_o italic_g italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT - italic_μ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT - italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) (18)

where α𝛼\alphaitalic_α is a hyperparameter that scales the importance the of KL-divergence term in the loss function and B𝐵Bitalic_B is the batch size used for gradient descent. Because the latent variable 𝐳𝐳\mathbf{z}bold_z is parameterized by a Gaussian, minimizing this loss function is equivalent to maximizing the evidence lower bound (ELBO) [32].

The VAE framework was chosen for two reasons. First, VAEs embed high-dimensional data into nonlinear low-dimensional subspaces. This allows for a flexible approach for finding the latent dynamics of the SNN and obtaining a more generalizable compression of the data compared to linear transformations [33]. Second, the use of KL-divergence promotes nearby values of 𝐳𝐳\mathbf{z}bold_z in the latent space to be decoded into similar values in the original space of 𝐱𝐱\mathbf{x}bold_x [34]. By doing so, this allows for easier interpolation between a set of training data points in latent space when constructing the dynamics model and finding the optimal latent inputs with MPC. If a basic autoencoder (i.e., α=0𝛼0\alpha=0italic_α = 0) was used instead of a VAE, nearby points in latent space would not be guaranteed to be similar in the original measurement space, which would obviously be detrimental to modeling dynamics in the latent space.

The control problem can now be expressed as given a latent embedding of the neural states 𝐱𝐱\mathbf{x}bold_x and stimulus states 𝐮𝐮\mathbf{u}bold_u,

𝐳𝐧=𝔼(fθneural(𝐱𝐧)),𝐯𝐧=𝔼(fθstimulus(𝐮𝐧))formulae-sequencesubscript𝐳𝐧𝔼superscriptsubscript𝑓𝜃neuralsubscript𝐱𝐧subscript𝐯𝐧𝔼superscriptsubscript𝑓𝜃stimulussubscript𝐮𝐧\mathbf{z_{n}}=\mathbb{E}(f_{\theta}^{\text{neural}}(\mathbf{x_{n}})),\mathbf{% v_{n}}=\mathbb{E}(f_{\theta}^{\text{stimulus}}(\mathbf{u_{n}}))bold_z start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT = blackboard_E ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT neural end_POSTSUPERSCRIPT ( bold_x start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT ) ) , bold_v start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT = blackboard_E ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT stimulus end_POSTSUPERSCRIPT ( bold_u start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT ) ) (19)

we seek to find an optimal set of latent inputs

𝐯1:T=argmin𝐯1:Tn=0T(𝐳𝐧,𝐯𝐧)superscriptsubscript𝐯:1𝑇subscriptargminsubscript𝐯:1𝑇superscriptsubscript𝑛0𝑇subscript𝐳𝐧subscript𝐯𝐧\mathbf{v}_{1:T}^{*}=\operatorname*{arg\,min}_{\mathbf{v}_{1:T}}\sum_{n=0}^{T}% \ell(\mathbf{z_{n}},\mathbf{v_{n}})bold_v start_POSTSUBSCRIPT 1 : italic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_v start_POSTSUBSCRIPT 1 : italic_T end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_n = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_ℓ ( bold_z start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT , bold_v start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT ) (20)

with the dynamics model

𝐳𝐧+𝟏=g(𝐳𝐧,𝐯𝐧).subscript𝐳𝐧1𝑔subscript𝐳𝐧subscript𝐯𝐧\mathbf{z_{n+1}}=g(\mathbf{z_{n}},\mathbf{v_{n}}).bold_z start_POSTSUBSCRIPT bold_n + bold_1 end_POSTSUBSCRIPT = italic_g ( bold_z start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT , bold_v start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT ) . (21)

The SNN can then be stimulated with the decoded latent inputs

𝐮𝐧=hϕstimulus(𝐯𝐧)subscript𝐮𝐧superscriptsubscriptitalic-ϕstimulussubscriptsuperscript𝐯𝐧\mathbf{u_{n}}=h_{\phi}^{\text{stimulus}}(\mathbf{v^{*}_{n}})bold_u start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT = italic_h start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT stimulus end_POSTSUPERSCRIPT ( bold_v start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT ) (22)

to produce the neural states 𝐱𝐱\mathbf{x}bold_x at the next time step.

2.3 Stimulus VAE

A VAE was trained to find a low dimensional representation of MNIST digit stimuli (sVAE). The sVAE encoder was composed of two single-channel convolutional layers and a single feedforward layer. Each of the three layers used a ReLU activation function. The size of the latent space 𝐯𝐯\mathbf{v}bold_v was chosen to be 2. The sVAE decoder was symmetric to the architecture of the encoder with one key difference; a sigmoid activation function was used on the final output layer to ensure that the values of the reconstructed stimuli were between 0 and 1 (the bounds of all pixel values in the training images). See Supplementary Information T2 for layer hyperparameters.

The second half of the MNIST data set (n=35,000𝑛35000n=35,000italic_n = 35 , 000) was used for training and validation of the sVAE. Eighty percent of this data (n=28,000𝑛28000n=28,000italic_n = 28 , 000) was used for training and the remaining 20% for validation (n=7,000𝑛7000n=7,000italic_n = 7 , 000). The sVAE was trained via gradient descent in Pytorch using the Adam optimizer.

A sequence of latent inputs for latent dynamics model identification was constructed using the validation set of images. Discrete points in the latent sVAE space were obtained by running k-means clustering (n=100) on the low dimensional embedding of the validation images. Half of the centers were used for the dynamics model training (Vtrainsubscript𝑉𝑡𝑟𝑎𝑖𝑛V_{train}italic_V start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT) and half for testing (Vtestsubscript𝑉𝑡𝑒𝑠𝑡V_{test}italic_V start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT). These discrete points were converted to a continuous time series by one of three methods: step function, fast interpolation, slow interpolation. The step function method took a sequence of centers and held each center constant for 500 time steps (ms). The fast and slow methods took the sequence and linearly interpolated values between each element in the sequence but at different time scales (200 ms for the fast and 1000 ms for the slow). Both the Vtrainsubscript𝑉𝑡𝑟𝑎𝑖𝑛V_{train}italic_V start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT and Vtestsubscript𝑉𝑡𝑒𝑠𝑡V_{test}italic_V start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT sequences were 46.3 seconds long (46,300 time steps). The use of these three methods was to have a input sequence that could show how the latent states of the SNN responded to inputs changing at different time scales and frequencies. Figure 2 shows the latent input sequences and their sVAE decoded values.

Refer to caption
Figure 2: Latent Embedding of MNIST Digits. A) Each colored dot is one of the MNIST digits embedded in the 2-dimensional nonlinear subspace of the sVAE encoder. Notice the clustering of the digits by label (color), indicating that digits with identical labels were often embedded in nearby latent space. In order to generate a latent sequence of inputs used to stimulate the artificial circuit, points from this latent space where sampled using k-means clustering (100 centers). B) Training and testing latent sequences where each generated using half of the centers from A. C) Three points from the training inputs decoded into the original stimulus dimension.

2.4 Neural VAE

Refer to caption
Figure 3: Latent Embedding of Neural Activity. A) Results of the VAE on training SNN neural activity (nVAE). The exponentially filtered spikes are embedded in a latent 2D space through the nVAE encoder. Activity in this latent space can be projected back into the original dimension with the nVAE decoder. B) Results of the nVAE on the testing SNN neural activity.

The SNN was stimulated using the sVAE-decoded Vtrainsubscript𝑉𝑡𝑟𝑎𝑖𝑛V_{train}italic_V start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT and Vtestsubscript𝑉𝑡𝑒𝑠𝑡V_{test}italic_V start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT input sequences with the resulting spiking activity used to fit a VAE for the neural states (nVAE). A random sample of 20% of the neurons (n=122𝑛122n=122italic_n = 122) in the SNN was used to build the model. These neurons were the only units in the SNN that were measured for the entirety of experiments I and II. Recall that in a typical extracellular recording, only a subset of the neurons are observable. The purpose of this random sampling was to mimic the incomplete information that would be obtained in a real recording.

The measured binary spiking states were converted to continuous states with an exponential filter, where the smoothed state 𝐱𝐧subscript𝐱𝐧\mathbf{x_{n}}bold_x start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT of spiking state 𝐲𝐧subscript𝐲𝐧\mathbf{y_{n}}bold_y start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT is given by

𝐱𝐧+𝟏=ω𝐲𝐧+(1ω)𝐱𝐧subscript𝐱𝐧1𝜔subscript𝐲𝐧1𝜔subscript𝐱𝐧\mathbf{x_{n+1}}=\omega\mathbf{y_{n}}+(1-\omega)\mathbf{x_{n}}bold_x start_POSTSUBSCRIPT bold_n + bold_1 end_POSTSUBSCRIPT = italic_ω bold_y start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT + ( 1 - italic_ω ) bold_x start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT (23)

where ω𝜔\omegaitalic_ω was chosen to be 0.50.50.50.5 and 𝐱𝟎=𝐲𝟎subscript𝐱0subscript𝐲0\mathbf{x_{0}}=\mathbf{y_{0}}bold_x start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT = bold_y start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT.

The nVAE encoder and decoder were symmetric, with each having five feedforward layers. The smoothed state 𝐱𝐧subscript𝐱𝐧\mathbf{x_{n}}bold_x start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT was z-scored when entering the first layer of the encoder. A small value ϵ=1×105italic-ϵ1superscript105\epsilon=1\times 10^{-5}italic_ϵ = 1 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT was added to the denominator of the standardization since some time steps had near identical values. Each hidden layer used a ReLU activation function followed by a batch normalization layer. The size of the latent dimension 𝐳𝐳\mathbf{z}bold_z was chosen to be 2. Because the addition of noise to the rLIF neurons resulted in temporal spiking jitter to the same input stimulus, the variance of the latent representation 𝐳𝐳\mathbf{z}bold_z was constrained to be above 1. This was achieved by having the encoder learn the log of the variance instead of the variance directly and then applying the softplus function to the estimate. By using this minimum variance, it acted as a regularizing parameter to the noise present in the training data. See Supplementary information T3 for layer and training hyperparameters.

2.5 Latent Dynamics Model

Refer to caption
Figure 4: Forecasting Performance of Latent Dynaimcs Model in Experiment I. (Left) Latent state 𝐳𝐳\mathbf{z}bold_z (red) forecasting on the training data. On top is the forecast for 𝐳𝟏subscript𝐳1\mathbf{z_{1}}bold_z start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT and on bottom 𝐳𝟐subscript𝐳2\mathbf{z_{2}}bold_z start_POSTSUBSCRIPT bold_2 end_POSTSUBSCRIPT. In black is the actual latent trajectory of the training data. The product-moment correlation between the actual and predicted latent states is shown above each figure. (Right) Same as on the left, but with the testing data. The fits between the training and testing forecast are largely identical.

A linear latent dynamics model of the form

𝐳𝐧+𝟏=A𝐳𝐧+B𝐯𝐧subscript𝐳𝐧1𝐴subscript𝐳𝐧𝐵subscript𝐯𝐧\mathbf{z_{n+1}}=A\mathbf{z_{n}}+B\mathbf{v_{n}}bold_z start_POSTSUBSCRIPT bold_n + bold_1 end_POSTSUBSCRIPT = italic_A bold_z start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT + italic_B bold_v start_POSTSUBSCRIPT bold_n end_POSTSUBSCRIPT (24)

was estimated using ridge regression with the L2-hyperparameter chosen via leave-one-out cross-validation. All model fitting was performed using the sklearn Python package. The dynamics model was fit with the data produced by the Vtrainsubscript𝑉𝑡𝑟𝑎𝑖𝑛V_{train}italic_V start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT input sequence. Model performance was assessed by using an initial value of 𝐳𝐳\mathbf{z}bold_z and forecasting for the entire training sequence length. At every time step, the known value of Vtrainsubscript𝑉𝑡𝑟𝑎𝑖𝑛V_{train}italic_V start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT and the model’s previous prediction of 𝐳𝐳\mathbf{z}bold_z was used to forecast the next value. An inadequate model would produce a time-series that was a poor approximation of the actual latent state trajectory produced by Vtrainsubscript𝑉𝑡𝑟𝑎𝑖𝑛V_{train}italic_V start_POSTSUBSCRIPT italic_t italic_r italic_a italic_i italic_n end_POSTSUBSCRIPT and could possibly even diverge due to compounding errors.

As an example, the forecasted values of the training data in Experiment I were a close fit to the actual 𝐳𝐳\mathbf{z}bold_z training trajectories as measured with the product-moment correlation (Rz1=0.87subscript𝑅subscript𝑧10.87R_{z_{1}}=0.87italic_R start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = 0.87, Rz2=0.73subscript𝑅subscript𝑧20.73R_{z_{2}}=0.73italic_R start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = 0.73). The possibility of overfitting was assessed by forecasting with the latent states elicited from Vtestsubscript𝑉𝑡𝑒𝑠𝑡V_{test}italic_V start_POSTSUBSCRIPT italic_t italic_e italic_s italic_t end_POSTSUBSCRIPT. The resulting predicted trajectory was also similar to the actual latent trajectory (Rz1=0.87subscript𝑅subscript𝑧10.87R_{z_{1}}=0.87italic_R start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = 0.87, Rz2=0.72subscript𝑅subscript𝑧20.72R_{z_{2}}=0.72italic_R start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = 0.72) indicating that the dynamics model would be useful for control with MPC. See Figure 4 for the forecasted latent trajectories of the training and testing data in Experiment I.

3 Results

3.1 Experiment I: MPC Can Force the Latent Dynamics on the Neural Manifold to Stay at set points

As a simple first experiment, the latent dynamics were controlled to follow a step function. This reference trajectory was composed of two set points in latent space, each held constant for 500 ms. The values were chosen by running k-means clustering on the latent state training data and using the resulting centroids as the set points. The controller had a predictive time horizon T𝑇Titalic_T of 30 time steps and optimized the loss function

J(𝐳𝟎)=𝐳𝐓𝐒𝐳𝐓+i=0T1𝐳𝐢𝐐𝐳𝐢+Δ𝐯𝐢𝐑Δ𝐯𝐢𝐽subscript𝐳0subscriptsuperscript𝐳𝐓subscript𝐒𝐳𝐓superscriptsubscript𝑖0𝑇1subscriptsuperscript𝐳𝐢subscript𝐐𝐳𝐢Δsubscriptsuperscript𝐯𝐢𝐑Δsubscript𝐯𝐢J(\mathbf{z_{0}})=\mathbf{z}^{\intercal}_{\mathbf{T}}\mathbf{S}\mathbf{z_{T}}+% \sum_{i=0}^{T-1}\mathbf{z}^{\intercal}_{\mathbf{i}}\mathbf{Q}\mathbf{z_{i}}+% \Delta\mathbf{v}^{\intercal}_{\mathbf{i}}\mathbf{R}\Delta\mathbf{v_{i}}italic_J ( bold_z start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT ) = bold_z start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_T end_POSTSUBSCRIPT bold_Sz start_POSTSUBSCRIPT bold_T end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T - 1 end_POSTSUPERSCRIPT bold_z start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT bold_Qz start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT + roman_Δ bold_v start_POSTSUPERSCRIPT ⊺ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT bold_R roman_Δ bold_v start_POSTSUBSCRIPT bold_i end_POSTSUBSCRIPT (25)

where

𝐐,𝐒=[15000150],𝐑=[10,0000010,000].formulae-sequence𝐐𝐒matrix15000150𝐑matrix100000010000\mathbf{Q,S}=\begin{bmatrix}150&0\\ 0&150\end{bmatrix},\mathbf{R}=\begin{bmatrix}10,000&0\\ 0&10,000\end{bmatrix}.bold_Q , bold_S = [ start_ARG start_ROW start_CELL 150 end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 150 end_CELL end_ROW end_ARG ] , bold_R = [ start_ARG start_ROW start_CELL 10 , 000 end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 10 , 000 end_CELL end_ROW end_ARG ] . (26)

No constraints other than the dynamics model were included. Due to the stochastic nature of the SNN responses, 50 independent trials of MPC were performed. The normalized mean square error (nMSE) was used to quantify the controller performance of the latent state, which normalizes mean square error by the difference between the maximum and minimum 𝐳superscript𝐳\mathbf{z^{*}}bold_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT values. This was done for ease of comparing controller performance across dimensions with different scales. All MPC optimizations and implementations were performed using the do-mpc python package [35]. This package utilizes CasADi [36] and IPOPT [37] for interior-point optimization and automatic differentiation methods.

The controller was able to achieve good performance, especially when considering the noise in the system and that only 20% of the neurons of the SNN were observable. Even though the noise that was added to all neuron inputs was Gaussian, the nonlinearities in the rLIF models propagate highly complex noise structures throughout the network. At the beginning of control and after the set point was switched, the system converged to the desired location in the latent space within 20–50 time steps and then remained there despite continual noise inputs that produced large perturbations (Figure 5A). When the set point was changed, there was a large excursion of the control stimulus within its latent space, and there were large fluctuations in the input throughout the experiment as the controller corrected for noise-induced perturbations.

Because the latent space has only two dimensions, we can visualize the dynamics of the forecasting model as a vector field and the current and desired states of the system as points within this field. The dynamics can be decomposed into an autonomous component Azn𝐴subscript𝑧𝑛Az_{n}italic_A italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and the forcing from the stimulus Bvn𝐵subscript𝑣𝑛Bv_{n}italic_B italic_v start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, and as the combined system is linear, there is at most a single equilibrium. As seen in Figure 5B, at time points when the set point was constant, the controller chose stimuli such that the dynamics had a stable equilibrium close to the set point. During the transition to a new set point, the controller initially moved the equilibrium to a location beyond the set point but then settled to an input that created a fixed point at the desired location. This overshoot may have allowed the controller to more rapidly move the system to its desired location.

As seen in Figure 5C, there was a qualitative change in the activity of the population when the reference point changed at 500 ms. Interestingly, while there was an initial drop in the average firing rate, it returned to a similar value as the system converged on the new set point. This indicates that the latent states are not just functions of the population firing rate, but are functions of particular firing patterns and activity from specific neurons. We can also examine the stimuli that were actually presented to the SNN during control. As seen in Figure 5D, the control stimuli all had digit-like characteristics, which undoubtedly reflects the fact that the sVAE was trained only on digits. Notice, however, that the image presented during the transition to the second reference point is a larger and more intense ‘2’ than the image presented after the system had converged.

Refer to caption
Figure 5: Results of Experiment I. A) (Left) The results of MPC of the SNN across 50 trials. The reference trajectories (black) were composed of two set points that changed at 500 ms. In light red are the controlled latent trajectories across the 50 trials and the average of these trajectories are shown in dark red. (Right) The latent inputs produced by the optimizer. In light blue are the inputs used across the 50 trials and the average input shown in dark blue. B) At each time step, the latent dynamics model predicts a vector field on the latent states. Three snapshots of this vector field are shown at 200, 500, and 800 ms. The desired reference points are shown by a black X and the controlled latent states are indicated by red dots. Notice the magnitude of the vectors when the reference point changes at 500 ms. The blue dots are the fixed points of the dynamics model given the value of the latent inputs 𝐯𝐯\mathbf{v}bold_v. C) (Above) The smoothed spike trains of the measured SNN neurons across the 50 trials are shown in light blue with the average shown in black. The smoothed spikes of a sample trial is shown in dark blue. (Below) The measured spike trains of the sampled trial corresponding to the dark blue curve above. D) Decoded latent inputs from the optimizer across three time points of the sampled trial. These are the visual stimuli that drove activity in the network. An animated visualization of the latent state and the stimuli used for control can be found in Supplementary Information V1. Link: https://doi.org/10.6084/m9.figshare.26072803.v1

3.2 Experiment II: Distinct Reference Trajectories Produce Different Levels of Controller Performance

The previous experiment revealed how the controller successfully used the stimulus-dependent dynamics of the forecasting model to drive the network to specific locations in the latent space. However, the forecasting model is only a linear approximation of the simulated network’s dynamics, which are nonlinear and of a much higher dimension. To examine how control could be used to probe the underlying dynamics of the network, it is not sufficient to characterize the start and end points of a trajectory, but instead what specific path was taken. For our next experiment, the reference trajectory was replaced with two time-varying functions (reference trajectory 1 and 2). Each of these reference trajectories had the same initial (𝐳𝟎superscriptsubscript𝐳0\mathbf{z_{0}^{*}}bold_z start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT) and final values (𝐳𝐟superscriptsubscript𝐳𝐟\mathbf{z_{f}^{*}}bold_z start_POSTSUBSCRIPT bold_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT), but took different paths through latent state space. Using a parameterized function of a circle passing through points 𝐳𝟎superscriptsubscript𝐳0\mathbf{z_{0}^{*}}bold_z start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and 𝐳𝐟superscriptsubscript𝐳𝐟\mathbf{z_{f}^{*}}bold_z start_POSTSUBSCRIPT bold_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, trajectories 1 and 2 were the opposite arcs of the resulting circle. This ensured that both trajectories were of equal length through the latent state space. The observation model from Experiment I was used for this task (20% observability) and the values of 𝐳𝟎superscriptsubscript𝐳0\mathbf{z_{0}^{*}}bold_z start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and 𝐳𝐟superscriptsubscript𝐳𝐟\mathbf{z_{f}^{*}}bold_z start_POSTSUBSCRIPT bold_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT were the set points from the same experiment. Fifty control trials were performed for each of the reference trajectories with each trial having the same MPC hyperparameters from Experiment I.

Refer to caption
Figure 6: Latent States and Inputs for Different Reference Trajectories. A) The controlled paths (green and indigo) for the two reference trajectories (black) in Experiment II. Each of the reference trajectories had the same initial and final values (𝐳𝟎superscriptsubscript𝐳0\mathbf{z_{0}^{*}}bold_z start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and 𝐳𝐟superscriptsubscript𝐳𝐟\mathbf{z_{f}^{*}}bold_z start_POSTSUBSCRIPT bold_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT respectively). The paths for 50 trials with trajectory 1 as the reference are shown in green and in indigo for trajectory 2. For comparison, the paths from Experiment I 200 ms before and after the change in set point are shown in red. The set points were identical to the 𝐳𝟎superscriptsubscript𝐳0\mathbf{z_{0}^{*}}bold_z start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and 𝐳𝐟superscriptsubscript𝐳𝐟\mathbf{z_{f}^{*}}bold_z start_POSTSUBSCRIPT bold_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT used here, but the path between them was not constrained to follow certain values. The dark colors show the average of the respective path types. B) Latent states 𝐙𝟏subscript𝐙1\mathbf{Z_{1}}bold_Z start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT (teal) and 𝐙𝟐subscript𝐙2\mathbf{Z_{2}}bold_Z start_POSTSUBSCRIPT bold_2 end_POSTSUBSCRIPT (magenta) show differing regions of poor control. While the average paths are near the reference values (dark color), each of the fifty trials (light color) may oscillate around the reference trajectory. Below are the latent inputs used to control the system across the fifty trials (light color) with the average values (dark color). C) Differences in nMSE of latent states and RMS of latent inputs across trajectories. Notice the difference in relative errors between the latent states across conditions. This indicates that errors in control along a particular dimension are not independent of errors in the other dimension. The RMS of the latent states are higher in trajectory 2 for both latent inputs. Since the same sVAE was used, the numerical values of the error are directly comparable. See Supplementary Information V2, V3 for animated visualization of the controlled trajectories.

As seen in Figure 6A, the controller successfully forced the network to take two different trajectories through the latent space, both of which were different from the trajectories the system used in Experiment I. However, the paths the system took often deviated substantially from the reference trajectory. For both trajectory 1 and 2, it appeared as though the system was attracted to specific regions of the state space and that there were other regions that the controller was unable to force the system to enter. These regions likely correspond to attractors and separatrices in the true dynamics of the SNN. Interestingly, individual trajectories appeared to show large oscillations in some regions of the state space, which could reflect rapid switching between stable equilibria or a limit cycle. It remains to be seen if a nonlinear forecasting model would enable the controller to force the system through repulsive regions of the state space better (see Discussion). There were differences between the two trajectories in the average state error (Figure 6B) and the variance of the latent inputs produced by the controller (Figure 6C). This indicated that more “power” in the latent space was needed to control the system to follow trajectory 2, even though the lengths of the two trajectories were the same.

Examining the activity in the measurement space, we see obvious differences in the spiking behavior of the SNN (Figure 7A). The visual stimuli produced from decoding the latent inputs were initially similar between the two reference trajectories but then diverged as the trajectories separated (Figure 7B). This is consistent with the behavior of the SNN, which was trained to produce distinct responses to different digits; thus, different regions of the state space are likely to correspond to specific digits. One implication of this is that MPC of the latent dynamics could reveal if specific kinds of stimuli correspond to particular latent trajectories or if the activity in the latent space is driven by the differences in the stimulus at every time step (e.g. prediction error [38]).

Refer to caption
Figure 7: Recorded Spikes and Control Inputs for Experiment II.A) Spike trains from a single trial for the reference trajectory 1 and 2 conditions. Note that the high densities of spiking in the Trajectory 2 condition correspond to the same time interval as the high oscillations in the latent space (500800similar-toabsent500800\sim 500-800∼ 500 - 800 ms). B) The decoded latent inputs that were used to stimulate the SNN during specific time points of the single trials in A.

3.3 Experiment III: The Proportion of Observable Neurons Affects Controller Performance

To investigate how robust the control strategy was to partial observation of the neural population, the procedure from Experiment I was performed on different subsets of neurons randomly sampled from the SNN. For each level of neuron observability, 10 independent ensembles were sampled from the SNN (n=80𝑛80n=80italic_n = 80). An nVAE was trained for each ensemble, but in order to ensure a fair comparison, the architecture was kept constant (except for the size of the input layer). As shown in Figure 8A, there was a tradeoff between the performance of the nVAE and the performance of the forecasting model. With a small number of neurons, (1% observability, 6 neurons), the nVAEs achieved almost perfect reconstruction for both training and testing data. However, the forecasting models performed very poorly, indicating that the stimulus-dependent dynamics of the full SNN could not be inferred from the behavior of just a few neurons. With larger proportions of the population observed, the performance of the nVAEs decreased, indicating that a latent space with only two dimensions was not sufficient to capture all the variance in the neural activity. There was a floor to this degradation of performance, and the reconstruction errors plateaued around 30% observability. This is consistent with a high level of correlation within the SNN. The performance of the forecasting models increased with the number of observed neurons. With 20% or more of the neurons included in the model, the dynamics of the SNN within the latent space could be predicted with a high level of accuracy.

The latent dynamics inferred at each level of observability were controlled to follow a step function of two set points obtained from k-means clustering, as in Experiment I. Ten trials of MPC were run for each of the models using the same previous loss function and hyperparameters. In practice, the hyperparameters of the controller would be tuned after fitting the forecasting model, but kee** them the same provides a fairer basis for comparison. Overall, controller performance was better with larger proportions of neurons observed (Figure 8B,C). Additionally, the variability of average nMSE across ensembles was typically higher when fewer neurons from the population were observed. This is unsurprising since there are many more possible ensembles to sample in the lower observation conditions which may differ in their controllability. However, it was still possible to sample bad ensembles when observing a large proportion of the neurons. The three highest levels of neuron observability all had an ensemble that produced higher average nMSE for a single latent state than the average of the nMSE distribution of the three lowest levels (Figure 8B).

Refer to caption
Figure 8: Effects of Percentage of Observable Neurons on Control.A). The performance of the nVAE and forecasting models plotted as a function of the proportion of the neurons observed. For each observation condition, ten independent ensembles were sampled from the SNN. The average performance across ensembles is plotted for each of the observation percentages. One ensemble in the 1% observation condition produced divergent control trajectories and was removed from analysis. B) The average nMSE of the latent states across ensembles for each level of observation percentage. As the percentage goes up, the nMSE trends toward smaller values. However, as seen clearly in the 40% observation condition, bad samples are still possible.C) Reference (black) and controlled (red) trajectories for examplar ensembles for each observation percentage. Light traces show trajectories for individual trials (n=10𝑛10n=10italic_n = 10) and dark red shows the average.

4 Discussion

In this study, we demonstrated that MPC can be used to control the latent states of an artificial spiking neural network using low-dimensional, data-driven models of the dynamics. Despite the highly nonlinear dynamics of the SNN, the latent states could be fixed to certain set points or forced to follow arbitrary, time-varying reference trajectories. Control was possible even when only a limited proportion of the neurons was observed and there when there were unknown sources of noise in the network. Reducing the dimensionality of the neural activity and the stimuli used to control the network can therefore make MPC computationally tractable for high-dimensional neural systems. Based on this approach and this proof of principle, it should be possible to use MPC for real-time control of latent dynamics in sensory-driven biological networks using extracellular recordings. Achieving this level of control over neural networks would enable experimenters to probe the dynamics of neural circuits in ways that we have begun to outline here.

In Experiment II, we found that the state of the network could be forced to take different trajectories by providing distinct sequences of latent inputs, but there were regions of the latent space where control was better or worse. One possibility is the latent dynamics model made better predictions for certain regions of latent state space, where the dynamics of the full network were more linear. The use of a nonlinear dynamics model may result in a better approximation of the latent vector field, but at the cost of complexity in both model fitting and MPC optimization. Another possibility is that the VAEs may have learned a dimensional transformation that preserved the distributional information in the high-dimensional space but that distorted the temporal relationships needed to infer the latent dynamics in certain regions [11]. By fitting the VAEs and latent dynamics models in two distinct phases, this may have produced dynamics models that were not optimal for the subspaces found by the VAEs. An alternative approach would be to simultaneously learn the dimensionality reduction and latent dynamics model in a single step, for example by extending a recurrent switching linear dynamical systems approach [13] to nonlinear dimensional reduction or by using structured variational autoencoders [39]. Although this would likely improve the performance of the controller, it would also introduce greater complexity when fitting the model, as additional hyperparameters would be needed to scale the influences of the VAE reconstruction, latent forecasting, and measurement space forecasting. Extensive work may be needed to identify the best approaches for balancing forecasting accuracy with computational efficiency for a given system and set of scientific questions.

Recent years have seen the development of high-density silicon electrodes that can record extracellular spikes from hundreds to thousands of neurons simultaneously [40, 41], but this is still only a tiny fraction of the number of cells that participate in local neural circuits. In Experiment III, we found that the average performance of the forecasting model and the controller decreased as the proportion of neurons observed in the SNN decreased. However, these lower observation conditions also had higher variance in the distribution of average nMSE, indicating that good control over small ensembles is possible. It is also well-known that when observing a subset of states of a dynamical system, time-delay embedding the measurements gives information on the full dynamics [42]. In each of the experiments here, time-delay embedding was not used to fit the VAEs or latent dynamics models for simplicity. Although this did not appear to impact the controller performance for the higher observation percentage models, it may have resulted in poorer performance as the percentage decreased in value. It would be interesting to examine if using time delays would result in increased performance even when the number of observed neurons is very low. Other work has shown the success of using time-lagged autoencoders to find latent dynamics models [43], and methods exist to find optimal time delays and the number of embedding dimensions [44]. This would introduce additional complexity in practice however, because both the dimensionality reduction and dynamics model need to be fit quickly when recording from living neurons, otherwise the experiment may not be feasible in a laboratory setting due to cell death or electrode drift.

The divide between the descriptive and generative perspectives on neural manifolds is related to the question of whether there is a low-dimensional dynamical system that emerges from the much larger and more complex dynamical system defined by the biophysics of intrinsic and synaptic currents in large populations of neurons. In this study, the true dynamics of the SNN were defined by a large, nonlinear dynamical system composed of hundreds of interconnected artificial neurons. The activity of this network could be represented in a low-dimensional neural manifold (Figure 3) with a clustered structure that is suggestive of attractor basins. The MPC framework developed in this study provided a method for experimentally probing the dynamics on this manifold to better understand its structure and function. Our results are consistent with the generative perspective in that a simple linear approximation of the dynamics in the latent space was sufficient to achieve a high level of control over the network. However, we also found that that there were regions of the state space where the latent dynamical model was not a good enough model of the underlying system to provide strong control, a result that can be interpreted as support of the descriptive perspective or as a source of insight for how to improve the latent model. If MPC can be applied in biological systems, it could provide a strong test of whether manifold activity is causal by enabling experimenters to see if specific organismal behaviors can be produced by controlling the latent neural dynamics. This would be an powerful tool in furthering the understanding of how complex behaviors and computations emerge from the structure of neural circuits and the dynamics of their activity.

It may be important that there was a topological alignment between the latent spaces for neural activity and the stimulus set. The sVAE successfully discovered a dimensional map** that separated the digits and their variants into 10 distinct and well-separated clusters (Figure 2A). There was also clear evidence of clustering in the neural latent space that mapped in an orderly way to which digit the stimulus was (Figure 3). Though it is beyond the scope of the present study, it would be interesting to explore how the stimulus latent spaces discovered by other dimensionality reduction methods impacts how the latent dynamics model performs in forecasting and control. It is also interesting to consider that the dynamics of the SNN and the latent map** of the sVAE were both learned from the distribution of the stimulus, but through radically different methods. If topological alignment is required for control in this classification task and in other computational problems such as the ones explored by Susillo and Barak [29], it may speak to a simple but profound theory that derives from the ideas of James [45] and Hebb [46]: that learning is a process of aligning the latent dynamics of neural circuits to the latent dynamics of the physical world. Optimal control, both as theory and as a method for more precise experimental manipulation, may be of benefit in testing this theory in biological systems.

References

  • \bibcommenthead
  • Cunningham and Yu [2014] Cunningham, J.P., Yu, B.M.: Dimensionality reduction for large-scale neural recordings. Nature Neuroscience 17(11), 1500–1509 (2014)
  • Mante et al. [2013] Mante, V., Sussillo, D., Shenoy, K.V., Newsome, W.T.: Context-dependent computation by recurrent dynamics in prefrontal cortex. Nature 503(7474), 78–84 (2013)
  • Kim et al. [2017] Kim, S.S., Rouault, H., Druckmann, S., Jayaraman, V.: Ring attractor dynamics in the drosophila central brain. Science 356(6340), 849–853 (2017) https://doi.org/10.1126/science.aal4835
  • Chaudhuri et al. [2019] Chaudhuri, R., Gerçek, B., Pandey, B., Peyrache, A., Fiete, I.: The intrinsic attractor manifold and population dynamics of a canonical cognitive circuit across waking and sleep. Nature Neuroscience 22(9), 1512–1520 (2019)
  • Chung and Abbott [2021] Chung, S., Abbott, L.F.: Neural population geometry: An approach for understanding biological and artificial neural networks. Current opinion in neurobiology 70, 137–144 (2021)
  • Langdon et al. [2023] Langdon, C., Genkin, M., Engel, T.A.: A unifying perspective on neural manifolds and circuits for cognition. Nature Reviews Neuroscience, 1–15 (2023)
  • Gallego et al. [2017] Gallego, J.A., Perich, M.G., Miller, L.E., Solla, S.A.: Neural manifolds for the control of movement. Neuron 94(5), 978–984 (2017)
  • Fortunato et al. [2023] Fortunato, C., Bennasar-Vázquez, J., Park, J., Chang, J.C., Miller, L.E., Dudman, J.T., Perich, M.G., Gallego, J.A.: Nonlinear manifolds underlie neural population activity during behaviour. bioRxiv (2023)
  • Pang et al. [2016] Pang, R., Lansdell, B.J., Fairhall, A.L.: Dimensionality reduction in neuroscience. Current Biology 26(14), 656–660 (2016)
  • Florian et al. [2011] Florian, B., Sepp, K., Joshua, H., Richard, H.: Hidden markov models in the neurosciences. Hidden Markov Models, Theory and Applications, 169 (2011)
  • Lusch et al. [2018] Lusch, B., Kutz, J.N., Brunton, S.L.: Deep learning for universal linear embeddings of nonlinear dynamics. Nature Communications 9(1), 4950 (2018)
  • Sussillo et al. [2015] Sussillo, D., Churchland, M.M., Kaufman, M.T., Shenoy, K.V.: A neural network that finds a naturalistic solution for the production of muscle activity. Nature Neuroscience 18(7), 1025–1033 (2015)
  • Linderman et al. [2017] Linderman, S., Johnson, M., Miller, A., Adams, R., Blei, D., Paninski, L.: Bayesian learning and inference in recurrent switching linear dynamical systems. In: Artificial Intelligence and Statistics, pp. 914–922 (2017). PMLR
  • Raković and Levine [2019] Raković, S.V., Levine, W.S. (eds.): Handbook of Model Predictive Control. Control Engineering. Springer, Cham (2019). https://doi.org/10.1007/978-3-319-77489-3
  • Hewing et al. [2020] Hewing, L., Wabersich, K.P., Menner, M., Zeilinger, M.N.: Learning-based model predictive control: Toward safe learning in control. Annual Review of Control, Robotics, and Autonomous Systems 3(1), 269–296 (2020) https://doi.org/10.1146/annurev-control-090419-075625
  • Brunton and Kutz [2019] Brunton, S.L., Kutz, J.N.: Data-driven Science and Engineering: Machine Learning, Dynamical Systems, and Control. Cambridge University Press, Cambridge (2019)
  • Schwenzer et al. [2021] Schwenzer, M., Ay, M., Bergs, T., Abel, D.: Review on model predictive control: an engineering perspective. The International Journal of Advanced Manufacturing Technology 117(5-6), 1327–1349 (2021) https://doi.org/10.1007/s00170-021-07682-3
  • Bieker et al. [2019] Bieker, K., Peitz, S., Brunton, S.L., Kutz, J.N., Dellnitz, M.: Deep model predictive control with online learning for complex physical systems (2019) https://doi.org/10.48550/ARXIV.1905.10094 . Publisher: arXiv Version Number: 1
  • Kaiser et al. [2018] Kaiser, E., Kutz, J.N., Brunton, S.L.: Sparse identification of nonlinear dynamics for model predictive control in the low-data limit. Proceedings of the Royal Society A: Mathematical, Physical and Engineering Sciences 474(2219), 20180335 (2018) https://doi.org/10.1098/rspa.2018.0335
  • Salzmann et al. [2023] Salzmann, T., Kaufmann, E., Arrizabalaga, J., Pavone, M., Scaramuzza, D., Ryll, M.: Real-time neural MPC: Deep learning model predictive control for quadrotors and agile robotic platforms. IEEE Robotics and Automation Letters 8(4), 2397–2404 (2023) https://doi.org/10.1109/LRA.2023.3246839
  • Zheng and Wu [2023] Zheng, Y., Wu, Z.: Physics-informed online machine learning and predictive control of nonlinear processes with parameter uncertainty. Industrial & Engineering Chemistry Research 62(6), 2804–2818 (2023) https://doi.org/10.1021/acs.iecr.2c03691
  • Fröhlich and Jezernik [2005] Fröhlich, F., Jezernik, S.: Feedback control of Hodgkin–Huxley nerve cell dynamics. Control Engineering Practice 13(9), 1195–1206 (2005) https://doi.org/10.1016/j.conengprac.2004.10.008
  • Yue et al. [2022] Yue, R., Tomastik, R., Dutta, A.: Non-linear model-based control of neural cell dynamics. preprint, In Review (May 2022). https://doi.org/10.21203/rs.3.rs-580874/v2
  • Senthilvelmurugan and Subbian [2023] Senthilvelmurugan, N.N., Subbian, S.: Active fault tolerant deep brain stimulator for epilepsy using deep neural network. Biomedical Engineering / Biomedizinische Technik 68(4), 373–392 (2023) https://doi.org/10.1515/bmt-2021-0302
  • Fehrman and Meliza [2023] Fehrman, C., Meliza, C.D.: Nonlinear model predictive control of a conductance-based neuron model via data-driven forecasting. arXiv preprint arXiv:2312.14274 (2023)
  • Bolus et al. [2021] Bolus, M.F., Willats, A.A., Rozell, C.J., Stanley, G.B.: State-space optimal feedback control of optogenetically driven neural activity. Journal of Neural Engineering 18(3), 036006 (2021) https://doi.org/10.1088/1741-2552/abb89c
  • Eshraghian et al. [2023] Eshraghian, J.K., Ward, M., Neftci, E.O., Wang, X., Lenz, G., Dwivedi, G., Bennamoun, M., Jeong, D.S., Lu, W.D.: Training spiking neural networks using lessons from deep learning. Proceedings of the IEEE 111(9), 1016–1054 (2023) https://doi.org/10.1109/JPROC.2023.3308088
  • Maass [2011] Maass, W.: Liquid state machines: motivation, theory, and applications. Computability in context: computation and logic in the real world, 275–296 (2011)
  • Sussillo and Barak [2013] Sussillo, D., Barak, O.: Opening the black box: low-dimensional dynamics in high-dimensional recurrent neural networks. Neural computation 25(3), 626–649 (2013)
  • LeCun et al. [1998] LeCun, Y., Bottou, L., Bengio, Y., Haffner, P.: Gradient-based learning applied to document recognition. Proceedings of the IEEE 86(11), 2278–2324 (1998)
  • Neftci et al. [2019] Neftci, E.O., Mostafa, H., Zenke, F.: Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks. IEEE Signal Processing Magazine 36(6), 51–63 (2019) https://doi.org/10.1109/MSP.2019.2931595
  • Odaibo [2019] Odaibo, S.: Tutorial: Deriving the standard variational autoencoder (vae) loss function. arXiv preprint arXiv:1907.08956 (2019)
  • Gomari et al. [2022] Gomari, D.P., Schweickart, A., Cerchietti, L., Paietta, E., Fernandez, H., Al-Amin, H., Suhre, K., Krumsiek, J.: Variational autoencoders learn transferrable representations of metabolomics data. Communications Biology 5(1), 645 (2022)
  • Kingma and Welling [2013] Kingma, D.P., Welling, M.: Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114 (2013)
  • Fiedler et al. [2023] Fiedler, F., Karg, B., Lüken, L., Brandner, D., Heinlein, M., Brabender, F., Lucia, S.: do-mpc: Towards FAIR nonlinear and robust model predictive control. Control Engineering Practice 140, 105676 (2023) https://doi.org/10.1016/j.conengprac.2023.105676
  • Andersson et al. [2019] Andersson, J.A.E., Gillis, J., Horn, G., Rawlings, J.B., Diehl, M.: CasADi: a software framework for nonlinear optimization and optimal control. Mathematical Programming Computation 11(1), 1–36 (2019) https://doi.org/10.1007/s12532-018-0139-4
  • Wächter and Biegler [2006] Wächter, A., Biegler, L.T.: On the implementation of an interior-point filter line-search algorithm for large-scale nonlinear programming. Mathematical Programming 106(1), 25–57 (2006) https://doi.org/10.1007/s10107-004-0559-y
  • Egner et al. [2010] Egner, T., Monti, J.M., Summerfield, C.: Expectation and surprise determine neural population responses in the ventral visual stream. Journal of Neuroscience 30(49), 16601–16608 (2010)
  • Connor et al. [2021] Connor, M., Canal, G., Rozell, C.: Variational autoencoder with learned latent structure. In: International Conference on Artificial Intelligence and Statistics, pp. 2359–2367 (2021). PMLR
  • Yang et al. [2020] Yang, L., Lee, K., Villagracia, J., Masmanidis, S.C.: Open source silicon microprobes for high throughput neural recording. Journal of neural engineering 17(1), 016036 (2020)
  • Steinmetz et al. [2021] Steinmetz, N.A., Aydin, C., Lebedeva, A., Okun, M., Pachitariu, M., Bauza, M., Beau, M., Bhagat, J., Böhm, C., Broux, M., et al.: Neuropixels 2.0: A miniaturized high-density probe for stable, long-term brain recordings. Science 372(6539), 4588 (2021)
  • Clark et al. [2022] Clark, R., Fuller, L., Platt, J.A., Abarbanel, H.D.I.: Reduced-dimension, biophysical neuron models constructed from observed data. Neural Computation 34(7), 1545–1587 (2022) https://doi.org/10.1162/neco_a_01515
  • Wehmeyer and Noé [2018] Wehmeyer, C., Noé, F.: Time-lagged autoencoders: Deep learning of slow collective variables for molecular kinetics. The Journal ofChemical Physics 148(24) (2018)
  • Sugihara and May [1990] Sugihara, G., May, R.M.: Nonlinear forecasting as a way of distinguishing chaos from measurement error in time series. Nature 344(6268), 734–741 (1990)
  • James [1890] James, W.: The Principles of Psychology vol. 1. Dover, New York (1890)
  • Hebb [1949] Hebb, D.O.: The Organization of Behavior. John Wiley and Sons, Incorporated, New York (1949)