VarteX: Enhancing Weather Forecast through Distributed Variable Representation

Ayumu Ueyama    Kazuhiko Kawamoto    Hiroshi Kera
Abstract

Weather forecasting is essential for various human activities. Recent data-driven models have outperformed numerical weather prediction by utilizing deep learning in forecasting performance. However, challenges remain in efficiently handling multiple meteorological variables. This study proposes a new variable aggregation scheme and an efficient learning framework for that challenge. Experiments show that VarteX outperforms the conventional model in forecast performance, requiring significantly fewer parameters and resources. The effectiveness of learning through multiple aggregations and regional split training is demonstrated, enabling more efficient and accurate deep learning-based weather forecasting.

Machine Learning, ICML

1 Introduction

From strategies addressing extreme weather to daily societal activities, weather forecasting plays an indispensable role in human activities (Bauer et al., 2015). Recently, there has been increasing interest in applying data-driven models utilizing deep learning for weather forecasting  (Scher & Messori, 2019; Weyn et al., 2019; Rasp et al., 2020; Weyn et al., 2021; Keisler, 2022; Lam et al., 2023). With intensive training on meteorological data, such models can generate forecasts within seconds (Lynch, 2008), whereas numerical weather prediction needs to solve complex partial differential equations, leading to significantly longer forecasting time. Several recent studies have reported that data-driven models outperform numerical weather prediction models even in the foretasting ability (Bi et al., 2023; Lam et al., 2023; Chen et al., 2023b).

Many data-driven weather forecasting models (Pathak et al., 2022; Bi et al., 2023; Chen et al., 2023a, b; Man et al., 2023; Nguyen et al., 2023; Ni, 2023; Nguyen et al., 2024; Ramavajjala, 2024) are based on Vision Transformer (ViT; Dosovitskiy et al. (2021)), a powerful attention-based model in computer vision. This is because meteorological data closely resemble image data in their structure, having height, width, and channel dimensions. A critical difference lies in the channel dimension. Image data only has RGB channels, which share similar information about the entire image. In contrast, meteorological data have many more channels for meteorological variables, such as temperature and humidity, with unique characteristics. The large number of meteorological variables increases the computational costs, and their diversity makes learning challenging. ClimaX (Nguyen et al., 2023) addressed this challenge with minimal modifications in ViT. Particularly, it equips a variable aggregation model that sums up meteorological variables into one representative variable with attention weights. Their experiments show that such an input-dependent variable aggregation leads to more successful training than the input-agnostic convolution of variables, a standard method in the image domain to summarize RGB channels.

In this study, we propose a new variable aggregation scheme and training method for efficient learning from meteorological data with ViT-based weather forecasting model. Our variable aggregation scheme is based on the hypothesis that the representative variable (or its D𝐷Ditalic_D-dimensional embedding vector) obtained by ClimaX variable aggregation may internally contain several components because, otherwise, it is too restrictive. If so, it is better to explicitly model them as R>1𝑅1R>1italic_R > 1 representative variables with D/R𝐷𝑅D/Ritalic_D / italic_R-dimensional embedding vectors. During the encoding process by our model, these embedding vectors are separately processed by Transformer encoders and then mixed by a mixing layer. The smaller embedding dimension leads to a smaller Transformer encoder, reducing the number of parameters by 1/R21superscript𝑅21/R^{2}1 / italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for each. While this reduces the model size, the memory cost at forward pass remains unchanged because the size of the attention map is determined by the number of tokens (i.e., image patches). To address this, we introduce regional split training, where a model is trained only on a cropped region at a single forward. This training decreases the final accuracy, but with a proper choice of crop size ratio S𝑆Sitalic_S, the decrease is moderate, and the training time and spacial cost are reduced drastically in return. Specifically, the space complexity reduction follows 𝒪(1/S2)𝒪1superscript𝑆2\mathcal{O}(1/S^{2})caligraphic_O ( 1 / italic_S start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). In experiments, we trained our model, VarteX, as well as ClimaX, on the WeatherBench dataset from scratch. The results show that VarteX forecasting accuracy is 50% higher on average than that of ClimaX, and the gap is even larger for wind speed forecasting. Regarding latitude-weighted root mean squared error (RMSE) and latitude-weighted anomaly correlation coefficient (ACC) for all target variables, learning effectiveness is highlighted through multiple representative variable aggregations. VarteX achieved these results with 55% model size, 50% training time, and 35% memory usage than ClimaX.

2 Problem setting

Meteorological data is a time-series data of meteorological variables, such as temperature, geopotential, and wind speed, at each gird of the world. Suppose we have H×W𝐻𝑊H\times Witalic_H × italic_W grids and V𝑉Vitalic_V variables, then we have XtH×W×Vsubscript𝑋𝑡superscript𝐻𝑊𝑉X_{t}\in\mathbb{R}^{H\times W\times V}italic_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_V end_POSTSUPERSCRIPT at time step t𝑡titalic_t. The learning-based weather forecasting aims to find a forecasting function fθ:XtXt+Δt:subscript𝑓𝜃maps-tosubscript𝑋𝑡subscript𝑋𝑡Δ𝑡f_{\theta}:X_{t}\mapsto X_{t+\Delta t}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT : italic_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ↦ italic_X start_POSTSUBSCRIPT italic_t + roman_Δ italic_t end_POSTSUBSCRIPT with a predesignated lead time ΔtΔ𝑡\Delta troman_Δ italic_t through the training of deep learning model with parameters θ𝜃\thetaitalic_θ in the regression task.

An important characteristic of meteorological data is its many variables (e.g., V=48𝑉48V=48italic_V = 48). Using attention-based models such as Vision Transformer means we have to handle as many as N=HWV𝑁𝐻𝑊𝑉N=HWVitalic_N = italic_H italic_W italic_V tokens, and the attention computation grows quadratically concerning N𝑁Nitalic_N to capture the spacial and inter-variable interactions. As each meteorological variable has its own time- and space-dynamical characteristics, we cannot resort to a simple concatenation of them as we do on the R, G, and B variables in image data.

The focus of our study lies in the aggregation of meteorological variables. Particularly, we are interested in aggregation function 𝒜:XtH×W×VZtH×W×R:𝒜subscript𝑋𝑡superscript𝐻𝑊𝑉maps-tosubscript𝑍𝑡superscript𝐻𝑊𝑅\mathcal{A}:X_{t}\in\mathbb{R}^{H\times W\times V}\mapsto Z_{t}\in\mathbb{R}^{% H\times W\times R}caligraphic_A : italic_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_V end_POSTSUPERSCRIPT ↦ italic_Z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_R end_POSTSUPERSCRIPT, which reduces variables from V𝑉Vitalic_V to R𝑅Ritalic_R while achieving successful learning. ClimaX offers such an aggregation function with R=1𝑅1R=1italic_R = 1. We suspect that reducing a single variable may be too aggressive and extend its idea to general R𝑅Ritalic_R. However, as the number of R𝑅Ritalic_R increases, the number of matrices for the attention mechanism also increases, leading to a rise in computational cost. Therefore, we should also address the reduction in computational cost.

3 Methodology

We propose a new variable aggregation scheme and efficient training framework. As demonstrated in Section 4, the former reduces the model size and significantly improves the prediction performance, and the latter realizes the training in half or even less time and memory cost.

3.1 Variable aggregation in ClimaX

First, we review the cross-attention-based variable aggregation of ClimaX. Let X~H×W×V×D~𝑋superscript𝐻𝑊𝑉𝐷\widetilde{X}\in\mathbb{R}^{H\times W\times V\times D}over~ start_ARG italic_X end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_V × italic_D end_POSTSUPERSCRIPT be an input data embedded to D𝐷Ditalic_D-dimensional space. In the following, the same operations are spatially uniformly applied, so we focus on position (h,w){1,,H}×{1,,W}𝑤1𝐻1𝑊(h,w)\in\{1,\ldots,H\}\times\{1,\ldots,W\}( italic_h , italic_w ) ∈ { 1 , … , italic_H } × { 1 , … , italic_W } for notional simplicity, and re-define X~~𝑋\widetilde{X}over~ start_ARG italic_X end_ARG by X~hwV×Dsubscript~𝑋𝑤superscript𝑉𝐷\widetilde{X}_{hw}\in\mathbb{R}^{V\times D}over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_h italic_w end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_V × italic_D end_POSTSUPERSCRIPT. With a trainable query vector 𝒒D𝒒superscript𝐷\bm{q}\in\mathbb{R}^{D}bold_italic_q ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, the cross-attention is computed to aggregate the variables as follows.

𝒛=softmax(𝒒K~D)V~1×D,superscript𝒛topsoftmaxsuperscript𝒒topsuperscript~𝐾top𝐷~𝑉superscript1𝐷\displaystyle\bm{z}^{\top}=\mathrm{softmax}\left(\frac{\bm{q}^{\top}\widetilde% {K}^{\top}}{\sqrt{D}}\right)\widetilde{V}\in\mathbb{R}^{1\times D},bold_italic_z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = roman_softmax ( divide start_ARG bold_italic_q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over~ start_ARG italic_K end_ARG start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D end_ARG end_ARG ) over~ start_ARG italic_V end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_D end_POSTSUPERSCRIPT , (1)

where K~=X~WK~𝐾~𝑋subscript𝑊K\widetilde{K}=\widetilde{X}W_{\mathrm{K}}over~ start_ARG italic_K end_ARG = over~ start_ARG italic_X end_ARG italic_W start_POSTSUBSCRIPT roman_K end_POSTSUBSCRIPT and V~=X~WV~𝑉~𝑋subscript𝑊V\widetilde{V}=\widetilde{X}W_{\mathrm{V}}over~ start_ARG italic_V end_ARG = over~ start_ARG italic_X end_ARG italic_W start_POSTSUBSCRIPT roman_V end_POSTSUBSCRIPT with trainable weights WK,WVD×Dsubscript𝑊Ksubscript𝑊Vsuperscript𝐷𝐷W_{\mathrm{K}},W_{\mathrm{V}}\in\mathbb{R}^{D\times D}italic_W start_POSTSUBSCRIPT roman_K end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT roman_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_D end_POSTSUPERSCRIPT and softmax()softmax\mathrm{softmax}(\,\cdot\,)roman_softmax ( ⋅ ) is the softmax operation. Namely, this cross-attention computes a weighted sum of a linear transformation of the input, V~=X~WV~𝑉~𝑋subscript𝑊V\widetilde{V}=\widetilde{X}W_{\mathrm{V}}over~ start_ARG italic_V end_ARG = over~ start_ARG italic_X end_ARG italic_W start_POSTSUBSCRIPT roman_V end_POSTSUBSCRIPT with the attention weights 𝒂=softmax(𝒒K~D)superscript𝒂topsoftmaxsuperscript𝒒topsuperscript~𝐾top𝐷\bm{a}^{\top}=\mathrm{softmax}\left(\frac{\bm{q}^{\top}\widetilde{K}^{\top}}{% \sqrt{D}}\right)bold_italic_a start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = roman_softmax ( divide start_ARG bold_italic_q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over~ start_ARG italic_K end_ARG start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D end_ARG end_ARG ), thereby aggregating V𝑉Vitalic_V variables into one representative variable.

3.2 Model architecture


Refer to caption


Figure 1: Comparison of ClimaX and VarteX architectures. ClimaX aggregates V meteorological variables into a single representative variable, whereas VarteX aggregates them into R representative variables. VarteX has a layer for learning each representative variable and for learning a mixture of representative variables.
Table 1: Comparison of the proposed model (VarteX) and ClimaX trained on the ERA5 dataset from scratch (no pretraining). VarteX has R𝑅Ritalic_R representative variables, each of which is embedded into (D/R)𝐷𝑅(D/R)( italic_D / italic_R )-dimensional space. The only exception is (R=4)superscript𝑅4(R=4)^{*}( italic_R = 4 ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, where the embedding dimension of a representative variable is doubled to D=2048𝐷2048D=2048italic_D = 2048.

Lead Model U10 T2m Z500 T850 Parameters
time ACC \uparrow RMSE \downarrow ACC \uparrow RMSE \downarrow ACC \uparrow RMSE \downarrow ACC \uparrow RMSE \downarrow (M)
6h ClimaX 0.59 3.35 0.66 4.19 0.76 667.34 0.71 3.53 108.08
VarteX (R=2)𝑅2(R=2)( italic_R = 2 ) 0.89 1.89 0.79 3.24 0.97 247.33 0.92 1.90 59.86
VarteX (R=4)𝑅4(R=4)( italic_R = 4 ) 0.19 4.54 0.32 7.43 0.53 966.76 0.51 4.65 20.76
VarteX (R=4)superscript𝑅4(R=4)^{*}( italic_R = 4 ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT 0.76 2.66 0.56 5.13 0.81 599.42 0.80 2.93 80.47
24h ClimaX 0.37 3.86 0.66 4.20 0.69 734.47 0.70 3.58 108.08
VarteX (R=2)R=2)italic_R = 2 ) 0.64 3.16 0.83 2.99 0.88 478.42 0.85 2.59 59.86
VarteX (R=4)𝑅4(R=4)( italic_R = 4 ) 0.10 4.76 0.31 7.33 0.48 986.85 0.47 4.77 20.76
VarteX (R=4)superscript𝑅4(R=4)^{*}( italic_R = 4 ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT 0.52 3.52 0.57 4.99 0.75 660.40 0.75 3.25 80.47

While ClimaX variable aggregation reduces V𝑉Vitalic_V variables into one, we hypothesize that the representative variable (and its embedding vector) may internally consist of several components because it is hard to believe that V𝑉Vitalic_V (typically V>40𝑉40V>40italic_V > 40) variables can be represented by a single variable, even using the input-dependent attention weights. If this is the case, we should be able to split the D𝐷Ditalic_D-dimensional space into R𝑅Ritalic_R spaces, where R𝑅Ritalic_R is the potential number of representative variables. As ViT has many square matrices of the size of the embedding dimension D𝐷Ditalic_D, this split directly reduces the model size from 𝒪(R×D2/R2)𝒪𝑅superscript𝐷2superscript𝑅2\mathcal{O}(R\times D^{2}/R^{2})caligraphic_O ( italic_R × italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ).

We thus propose to use more than one representative variable. Hence, the embedded input X~V×D~𝑋superscript𝑉𝐷\widetilde{X}\in\mathbb{R}^{V\times D}over~ start_ARG italic_X end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_V × italic_D end_POSTSUPERSCRIPT is split to

X~=[X~1X~R],~𝑋delimited-[]subscript~𝑋1subscript~𝑋𝑅\displaystyle\widetilde{X}=\left[\widetilde{X}_{1}\ \cdots\ \widetilde{X}_{R}% \right],over~ start_ARG italic_X end_ARG = [ over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⋯ over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT ] , (2)

where X~kV×(D/R)subscript~𝑋𝑘superscript𝑉𝐷𝑅\widetilde{X}_{k}\in\mathbb{R}^{V\times(D/R)}over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_V × ( italic_D / italic_R ) end_POSTSUPERSCRIPT for k=1,,R𝑘1𝑅k=1,\ldots,Ritalic_k = 1 , … , italic_R. We prepare trainable query vectors 𝒒1,,𝒒Rsubscript𝒒1subscript𝒒𝑅\bm{q}_{1},\ldots,\bm{q}_{R}bold_italic_q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_q start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT and apply ClimaX variable aggregation for each (𝒒k,X~k)subscript𝒒𝑘subscript~𝑋𝑘(\bm{q}_{k},\widetilde{X}_{k})( bold_italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , over~ start_ARG italic_X end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) pair to obtain 𝒛ksubscript𝒛𝑘\bm{z}_{k}bold_italic_z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT.

The embedding vectors {𝒛k}ksubscriptsubscript𝒛𝑘𝑘\{\bm{z}_{k}\}_{k}{ bold_italic_z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT are then repeatedly and alternating processed by two types of transformer blocks. The first transformer block is the standard encoder layer, which consists of self-attention and feed-forward networks to extract cross-tokens and token-wise features. The second transformer block is also the standard encoder layer but introduced to allow the R𝑅Ritalic_R representative variables to interact. Specifically, the concatenation of R𝑅Ritalic_R representative variables is input Z=[z1zR]R×D/R𝑍delimited-[]subscript𝑧1subscript𝑧𝑅superscript𝑅𝐷𝑅Z=[z_{1}\cdots z_{R}]\in\mathbb{R}^{R\times D/R}italic_Z = [ italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⋯ italic_z start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_R × italic_D / italic_R end_POSTSUPERSCRIPT to small encoder layer. The self-attention layer computes an attention map of size R𝑅Ritalic_R and mixes the representative variables.

3.3 Regional split training

Data-driven models have problems with high computational costs during training, which can be attributed to spatial resolution. This is because higher spatial resolution increases the number of tokens, which affects the memory cost of attention computation quadratically. However, weather forecasting at a particular point should be mainly affected the local region. Thus, we can naturally expect that training on regional input (i.e., spatially cropped input), if only the cropped regions cover the entire space as a whole, leads to a descent training, if not as successful as the global training. Here, we examine how much time and memory reduction can be obtained from this simple strategy and how much it affects forecasting performance. In the training, an input XtH×W×Vsubscript𝑋𝑡superscript𝐻𝑊𝑉X_{t}\in\mathbb{R}^{H\times W\times V}italic_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_V end_POSTSUPERSCRIPT is cropped into sub-region 𝒞(Xt)(H/S)×(W/S)×V𝒞subscript𝑋𝑡superscript𝐻𝑆𝑊𝑆𝑉\mathcal{C}(X_{t})\in\mathbb{R}^{(H/S)\times(W/S)\times V}caligraphic_C ( italic_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT ( italic_H / italic_S ) × ( italic_W / italic_S ) × italic_V end_POSTSUPERSCRIPT, and the loss of model’s output is only measured on this region. The region to crop can be randomly determined or canonically selected without overlap. This makes the cost of attention computation 𝒪(1/S2)𝒪1superscript𝑆2\mathcal{O}(1/S^{2})caligraphic_O ( 1 / italic_S start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )-time smaller.

Table 2: Comparison of global training and regional split training by VarteX with R=2𝑅2R=2italic_R = 2 and lead time Δt=6Δ𝑡6\Delta t=6roman_Δ italic_t = 6. The result of ClimaX with Global Training is provided as a reference.

Crop size U10 T2m Z500 T850 Train time Memory
ACC \uparrow RMSE \downarrow ACC \uparrow RMSE \downarrow ACC \uparrow RMSE \downarrow ACC \uparrow RMSE \downarrow (h) (GB)
Global 0.89 1.89 0.79 3.24 0.97 247.33 0.92 1.90 14.60 33.02
16×\times× 32 0.88 1.97 0.78 3.39 0.96 291.19 0.91 2.08 6.10 7.70
8×\times× 16 0.76 2.74 0.63 5.23 0.80 628.71 0.77 3.34 4.18 6.23
4×\times×8 0.81 2.48 0.68 4.57 0.90 458.73 0.82 2.96 4.01 6.11
ClimaX 0.59 3.35 0.66 4.19 0.76 667.34 0.71 3.53 12.28 21.73

4 Experiments

We now evaluate the forecasting ability and the efficiency of the proposed model, VarteX, and ClimaX as a baseline.

4.1 Dataset and training setup

We train VarteX and ClimaX on ERA5 (Hersbach et al., 2020) from scratch following the training setup given in (Nguyen et al., 2023).

Dataset.

ERA5 is a publicly accessible atmospheric reanalysis dataset provided by the ECMWF. The full spatial resolution is 0.25° (721×14407211440721\times 1440721 × 1440 grids). As in  (Nguyen et al., 2023), we use ERA5 data downsampled to a spatial resolution of 5.625° (32×64326432\times 6432 × 64 grids) provided by WeatherBench (Rasp et al., 2020); 48 meteorological variables are used in training, and four are the target of forecast, i.e., geopotential at 500 hPa (Z500), temperature at 850 hPa (T850), temperature at 2 meters from the ground (T2m), and zonal wind speed at 10 meters from the ground (U10). Each channel is standardized to have a mean of 0 and a standard deviation of 1. The dataset spans hourly data from 2006 to 2018, with 2006 to 2015 used for training, 2016 for validation, and 2017 to 2018 for testing. This division allows for comprehensive training and robust validation and testing of the predictive capabilities of the models involved.

Training.

Both VarteX and ClimaX are trained with latitude-weighted mean squared error (MSE) loss to predict from meteorological sample XtH×W×Vsubscript𝑋𝑡superscript𝐻𝑊𝑉X_{t}\in\mathbb{R}^{H\times W\times V}italic_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_V end_POSTSUPERSCRIPT at time t𝑡titalic_t to that after ΔtΔ𝑡\Delta troman_Δ italic_t steps. We trained two models for lead time Δt=6Δ𝑡6\Delta t=6roman_Δ italic_t = 6 and 24242424 hours. The embedding dimension of Climax is D=1024𝐷1024D=1024italic_D = 1024, and that of VarteX is D/R𝐷𝑅D/Ritalic_D / italic_R. Other architecture parameters, such as the number of attention heads, are all common between VarteX and ClimaX. Other detailed experimental settings follow those of ClimaX (cf. Section A.1).

4.2 Effect of representative variables

Table 1 compares the predictive performance of VarteX and ClimaX with 6-hour and 24-hour lead times, using RMSE and ACC metrics. VarteX, with two representative variables, significantly outperforms ClimaX regarding RMSE and ACC for all target variables with approximately a 45% reduction in the model size. This justifies our hypothesis that explicit handling of multiple representative variables, rather than implicitly having them in a single representative variable, improves learning efficiency. Increasing the number of representative variables in VarteX further reduces the model size; however, a drastic performance drop is observed. We consider that the embedding dimension D/R𝐷𝑅D/Ritalic_D / italic_R per representative variable limits the network capacity. To examine this, we tested the case of R=4𝑅4R=4italic_R = 4 again by increasing the embedding dimension (per representative variable) from D/R𝐷𝑅D/Ritalic_D / italic_R to 2D/R2𝐷𝑅2D/R2 italic_D / italic_R. This makes the embedding dimension the same as that in the R=2𝑅2R=2italic_R = 2 case. We observed a sharp performance improvement, but this does not outperform the case of R=2𝑅2R=2italic_R = 2 in both prediction ability.

4.3 Effect of Crop Size on Predictive Performance

We next compare standard training (referred to as global training) and regional split training. Given the results in Section 4.2, we focus on VarteX with R=2𝑅2R=2italic_R = 2. During the training of VarteX with regional split training with S=2,4,8𝑆248S=2,4,8italic_S = 2 , 4 , 8, equivalent to 16×32163216\times 3216 × 32, 8×168168\times 168 × 16, and 4×8484\times 84 × 8 grids per each cropped region. At the training, the loss is computed for each grid independently, and at the inference, an input is split and fed to the model and then reconstructed from the output. Note that the region split is done canonically; if S=2𝑆2S=2italic_S = 2, the input is spacially divided into top left, top right, bottom left, and bottom right. We also tested a random selection of crop** regions. However, this was not as successful as the canonical division.

Table 2 compares VarteX with two representative variables using regional split training and global training for a 6-hour lead time. The results indicate that a crop size of 16×32163216\times 3216 × 32 achieves the best trade-off between the performance and training cost.111Note that the optimal number of splits can depend on the resolution. The smaller crop size deteriorates the forecasting performance with limited improvement in the training time and memory consumption. This is because the reduction in the number of tokens from N𝑁Nitalic_N to N/2𝑁2N/2italic_N / 2, which quadratically reduces the size of the attention map, is large in an absolute sense, but from N/2𝑁2N/2italic_N / 2 to N/4𝑁4N/4italic_N / 4 is rather marginal, and other factors become a bottleneck. To summarize, our experiments suggest that VarteX with R=2𝑅2R=2italic_R = 2 and regional split training with S=2𝑆2S=2italic_S = 2 (equivalent to 16×32163216\times 3216 × 32) is the current best practice. Compared to the original ClimaX with global training, we significantly improved the ACC and RMSE with 55% model size, 40% training hours, and 25% memory consumption.

5 Conclusions

In this study, we addressed efficient learning over many diverse meteorological variables using the ViT-based model. Inspired by ClimaX, we propose a new variable aggregation scheme explicitly modeled to extract several representative variables, significantly improving the forecasting performance and reducing the model size. Further, we examined region-wise training, which reduces training time and memory cost by a large margin at a subtle cost in the forecasting scores. While this paper focuses on training ERA5 from scratch, we may apply our results to build a foundation model by following the large-scale training of ClimaX.

6 Acknowledgments

This work was supported by JST Moonshot R&D Program Grant Number JPMJMS2389.

References

  • Bauer et al. (2015) Bauer, P., Thorpe, A., and Brunet, G. The quiet revolution of numerical weather prediction. Nature, 525(7567):47–55, 2015.
  • Bi et al. (2023) Bi, K., Xie, L., Zhang, H., Chen, X., Gu, X., and Tian, Q. Accurate medium-range global weather forecasting with 3D neural networks. Nature, 619(7970):533–538, 2023.
  • Chen et al. (2023a) Chen, K., Han, T., Gong, J., Bai, L., Ling, F., Luo, J.-J., Chen, X., Ma, L., Zhang, T., Su, R., et al. FengWu: Pushing the skillful global medium-range weather forecast beyond 10 days lead. arXiv preprint arXiv:2304.02948, 2023a.
  • Chen et al. (2023b) Chen, L., Zhong, X., Zhang, F., Cheng, Y., Xu, Y., Qi, Y., and Li, H. FuXi: a cascade machine learning forecasting system for 15-day global weather forecast. npj Climate and Atmospheric Science, 6(1):190, 2023b.
  • Dosovitskiy et al. (2021) Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., and Houlsby, N. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021.
  • Hersbach et al. (2020) Hersbach, H., Bell, B., Berrisford, P., Hirahara, S., Horányi, A., Muñoz-Sabater, J., Nicolas, J., Peubey, C., Radu, R., Schepers, D., et al. The era5 global reanalysis. Quarterly Journal of the Royal Meteorological Society, 146(730):1999–2049, 2020.
  • Keisler (2022) Keisler, R. Forecasting global weather with graph neural networks. arXiv preprint arXiv:2202.07575, 2022.
  • Lam et al. (2023) Lam, R., Sanchez-Gonzalez, A., Willson, M., Wirnsberger, P., Fortunato, M., Alet, F., Ravuri, S., Ewalds, T., Eaton-Rosen, Z., Hu, W., et al. Learning skillful medium-range global weather forecasting. Science, 382(6677):1416–1421, 2023.
  • Lynch (2008) Lynch, P. The origins of computer weather prediction and climate modeling. Journal of computational physics, 227(7):3431–3444, 2008.
  • Man et al. (2023) Man, X., Zhang, C., Feng, J., Li, C., and Shao, J. W-mae: Pre-trained weather model with masked autoencoder for multi-variable weather forecasting. arXiv preprint arXiv:2304.08754, 2023.
  • Nguyen et al. (2023) Nguyen, T., Brandstetter, J., Kapoor, A., Gupta, J., and Grover, A. ClimaX: A foundation model for weather and climate. In ICLR 2023 Workshop on Tackling Climate Change with Machine Learning, 2023.
  • Nguyen et al. (2024) Nguyen, T., Shah, R., Bansal, H., Arcomano, T., Madireddy, S., Maulik, R., Kotamarthi, V., Foster, I., and Grover, A. Scaling Transformers for Skillful and Reliable Medium-range Weather Forecasting. In ICLR 2024 Workshop on AI4DifferentialEquations In Science, 2024.
  • Ni (2023) Ni, Z. Kunyu: A High-Performing Global Weather Model Beyond Regression Losses. arXiv preprint arXiv:2312.08264, 2023.
  • Pathak et al. (2022) Pathak, J., Subramanian, S., Harrington, P., Raja, S., Chattopadhyay, A., Mardani, M., Kurth, T., Hall, D., Li, Z., Azizzadenesheli, K., et al. Fourcastnet: A global data-driven high-resolution weather model using adaptive fourier neural operators. arXiv preprint arXiv:2202.11214, 2022.
  • Ramavajjala (2024) Ramavajjala, V. HEAL-ViT: Vision Transformers on a spherical mesh for medium-range weather forecasting. arXiv preprint arXiv:2403.17016, 2024.
  • Rasp et al. (2020) Rasp, S., Dueben, P. D., Scher, S., Weyn, J. A., Mouatadid, S., and Thuerey, N. WeatherBench: a benchmark data set for data-driven weather forecasting. Journal of Advances in Modeling Earth Systems, 12(11):e2020MS002203, 2020.
  • Scher & Messori (2019) Scher, S. and Messori, G. Weather and climate forecasting with neural networks: using general circulation models (gcms) with different complexity as a study ground. Geoscientific Model Development, 12(7):2797–2809, 2019.
  • Weyn et al. (2019) Weyn, J. A., Durran, D. R., and Caruana, R. Can machines learn to predict weather? using deep learning to predict gridded 500-hpa geopotential height from historical weather data. Journal of Advances in Modeling Earth Systems, 11(8):2680–2693, 2019.
  • Weyn et al. (2021) Weyn, J. A., Durran, D. R., Caruana, R., and Cresswell-Clay, N. Sub-seasonal forecasting with a large ensemble of deep-learning weather prediction models. Journal of Advances in Modeling Earth Systems, 13(7):e2021MS002502, 2021.

Appendix A Experiment details

A.1 Training details

In this study, the experimental conditions were set up according to the paper under ClimaX. The two models, VarteX and ClimaX, are trained using an AdamW optimizer with a learning rate of 5×1075superscript1075\times 10^{-7}5 × 10 start_POSTSUPERSCRIPT - 7 end_POSTSUPERSCRIPT and a weight decay of 1×1051superscript1051\times 10^{-5}1 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT. The training schedule includes a linear warmup over five epochs followed by 45 epochs of cosine annealing. The total training period spans 50 epochs with a batch size of 128, utilizing a gradient accumulation strategy of 4 steps for every 32 batch sizes.

A.2 Software and Hardware

We use the ClimaX repository from GitHub (https://github.com/microsoft/ClimaX). All experiments use an NVIDIA RTX A6000 with 48GB of memory.

A.3 Hyperparameters

Table 3 and Table 4 present the hyperparameters of the models utilized in this experiment. The hyperparameters for ClimaX are adapted from the original paper, while VarteX shares the same values for common parameters with ClimaX. Additionally, Table 5 illustrates the meteorological variables contained within the input data.

Table 3: Hyperparameters of VarteX.

Hyperparameter Value
Default variables All variables in Table 5
Image size [32, 64]
Patch size 2
Embedding dimension 1024
Number of ViT blocks 8
Number of attention heads 16
Number of representative variables 2
MLP ratio 4
Prediction depth 2
Hidden dimension in prediction head 1024
Drop path 0.1
Drop rate 0.1
Table 4: Hyperparameters of ClimaX (Nguyen et al., 2023).

Hyperparameter Value
Default variables All variables in Table 5
Image size [32, 64]
Patch size 2
Embedding dimension 1024
Number of ViT blocks 8
Number of attention heads 16
MLP ratio 4
Prediction depth 2
Hidden dimension in prediction head 1024
Drop path 0.1
Drop rate 0.1
Table 5: The variables from WeatherBench used in the model. The same variables are utilized as in ClimaX.

Variable name Abbrev. Pressure levels
Land-sea mask LSM
Orography
2 metre temperature T2m
10 meter U wind component U10
10 meter V wind component V10
Geopotential Z 50, 250, 500, 600, 700, 850, 925
U wind component U 50, 250, 500, 600, 700, 850, 925
V wind component V 50, 250, 500, 600, 700, 850, 925
Temperature T 50, 250, 500, 600, 700, 850, 925
Specific humidity Q 50, 250, 500, 600, 700, 850, 925
Relative humidity R 50, 250, 500, 600, 700, 850, 925

A.4 Loss function and Metrics

This section presents the evaluation metrics used in the experiment. Y~~𝑌\widetilde{Y}over~ start_ARG italic_Y end_ARG and Y𝑌Yitalic_Y represent the forecast and ground truth, respectively, while K denotes the total number of test data points. Additionally, C represents climatology, defined as the time average over the entire test data set, C=1/KkYk𝐶1𝐾subscript𝑘subscript𝑌𝑘C=1/K\sum_{k}Y_{k}italic_C = 1 / italic_K ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_Y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT.

Latitude weighting factor

L(h)=cos(h)1Hlat=1Hcos(lat)𝐿1𝐻superscriptsubscript𝑙𝑎𝑡1𝐻𝑙𝑎𝑡L(h)=\frac{\cos{(h)}}{\frac{1}{H}\sum_{lat=1}^{H}\cos{(lat)}}italic_L ( italic_h ) = divide start_ARG roman_cos ( italic_h ) end_ARG start_ARG divide start_ARG 1 end_ARG start_ARG italic_H end_ARG ∑ start_POSTSUBSCRIPT italic_l italic_a italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_H end_POSTSUPERSCRIPT roman_cos ( italic_l italic_a italic_t ) end_ARG (3)

Latitude-weighted mean square error (MSE)

MSE=𝔼[1H×Wh=1Hw=1WL(h)(Y~k,h,wYk,h,w)2].MSE𝔼delimited-[]1𝐻𝑊superscriptsubscript1𝐻superscriptsubscript𝑤1𝑊𝐿superscriptsubscript~𝑌𝑘𝑤subscript𝑌𝑘𝑤2\text{MSE}=\mathbb{E}\Big{[}\frac{1}{H\times W}\sum_{h=1}^{H}\sum_{w=1}^{W}L(h% )(\widetilde{Y}_{k,h,w}-Y_{k,h,w})^{2}\Big{]}.MSE = blackboard_E [ divide start_ARG 1 end_ARG start_ARG italic_H × italic_W end_ARG ∑ start_POSTSUBSCRIPT italic_h = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_H end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_w = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W end_POSTSUPERSCRIPT italic_L ( italic_h ) ( over~ start_ARG italic_Y end_ARG start_POSTSUBSCRIPT italic_k , italic_h , italic_w end_POSTSUBSCRIPT - italic_Y start_POSTSUBSCRIPT italic_k , italic_h , italic_w end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] . (4)

Latitude-weighted root mean square error (RMSE)

RMSE=1Kk=1K1H×Wh=1Hw=1WL(h)(Y~k,h,wYk,h,w)2.RMSE1𝐾superscriptsubscript𝑘1𝐾1𝐻𝑊superscriptsubscript1𝐻superscriptsubscript𝑤1𝑊𝐿superscriptsubscript~𝑌𝑘𝑤subscript𝑌𝑘𝑤2\text{RMSE}=\frac{1}{K}\sum_{k=1}^{K}\sqrt{\frac{1}{H\times W}\sum_{h=1}^{H}% \sum_{w=1}^{W}L(h)(\widetilde{Y}_{k,h,w}-Y_{k,h,w})^{2}}.RMSE = divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT square-root start_ARG divide start_ARG 1 end_ARG start_ARG italic_H × italic_W end_ARG ∑ start_POSTSUBSCRIPT italic_h = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_H end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_w = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W end_POSTSUPERSCRIPT italic_L ( italic_h ) ( over~ start_ARG italic_Y end_ARG start_POSTSUBSCRIPT italic_k , italic_h , italic_w end_POSTSUBSCRIPT - italic_Y start_POSTSUBSCRIPT italic_k , italic_h , italic_w end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG . (5)

Latitude-weighted anomaly correlation coefficient (ACC)

ACC=k,h,wL(h)Y~k,h,wYk,h,wk,h,wL(h)Y~k,h,w2k,h,wL(h)Yk,h,w2ACCsubscript𝑘𝑤𝐿subscriptsuperscript~𝑌𝑘𝑤subscriptsuperscript𝑌𝑘𝑤subscript𝑘𝑤𝐿subscriptsuperscript~𝑌2𝑘𝑤subscript𝑘𝑤𝐿subscriptsuperscript𝑌2𝑘𝑤\text{ACC}=\frac{\sum_{k,h,w}L(h)\widetilde{Y}^{\prime}_{k,h,w}Y^{\prime}_{k,h% ,w}}{\sqrt{\sum_{k,h,w}L(h)\widetilde{Y}^{\prime 2}_{k,h,w}\sum_{k,h,w}L(h)Y^{% \prime 2}_{k,h,w}}}ACC = divide start_ARG ∑ start_POSTSUBSCRIPT italic_k , italic_h , italic_w end_POSTSUBSCRIPT italic_L ( italic_h ) over~ start_ARG italic_Y end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k , italic_h , italic_w end_POSTSUBSCRIPT italic_Y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k , italic_h , italic_w end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_k , italic_h , italic_w end_POSTSUBSCRIPT italic_L ( italic_h ) over~ start_ARG italic_Y end_ARG start_POSTSUPERSCRIPT ′ 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k , italic_h , italic_w end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_k , italic_h , italic_w end_POSTSUBSCRIPT italic_L ( italic_h ) italic_Y start_POSTSUPERSCRIPT ′ 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k , italic_h , italic_w end_POSTSUBSCRIPT end_ARG end_ARG (6)
Y~=Y~C,Y=YCformulae-sequencesuperscript~𝑌~𝑌𝐶superscript𝑌𝑌𝐶\widetilde{Y}^{\prime}=\widetilde{Y}-C,Y^{\prime}=Y-Cover~ start_ARG italic_Y end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = over~ start_ARG italic_Y end_ARG - italic_C , italic_Y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_Y - italic_C (7)

Appendix B Qualitative evaluation

The qualitative evaluation of the predictive performance of VarteX for all target variables with a 6-hour lead time is shown. The results include VarteX with both global training and regional split training. The first column shows the initial state, the second column presents the ground truth, the third column displays the predicted results, and the fourth column indicates the bias between the ground truth and the predictions. Note that these visualizations are for reference only and that this experiment was not pre-trained in the same way as ClimaX. We believe that better visualization can be obtained if pre-training is used as in the ClimaX paper (Nguyen et al., 2023).


Refer to caption

Figure 2: An example of VarteX’s forecasting results with two representative variables and the Ground Truth for a 6-hour lead time.

Refer to caption

Figure 3: An example of VarteX’s forecasting results with four representative variables and the Ground Truth for a 6-hour lead time.

Refer to caption


Figure 4: An example of VarteX’s forecasting results with two representative variables and the Ground Truth for a 6-hour lead time, with the embedding dimension specifically set to 2048.

Refer to caption

Figure 5: An example of VarteX’s forecasting results and Ground Truth for a 6-hour lead time using a 16×32163216\times 3216 × 32 crop size with regional split training.

Refer to caption

Figure 6: An example of VarteX’s forecasting results and Ground Truth for a 6-hour lead time using a 8×168168\times 168 × 16 crop size with regional split training.

Refer to caption

Figure 7: An example of VarteX’s forecasting results and Ground Truth for a 6-hour lead time using a 4×8484\times 84 × 8 crop size with regional split training.