VarteX: Enhancing Weather Forecast through Distributed Variable Representation
Abstract
Weather forecasting is essential for various human activities. Recent data-driven models have outperformed numerical weather prediction by utilizing deep learning in forecasting performance. However, challenges remain in efficiently handling multiple meteorological variables. This study proposes a new variable aggregation scheme and an efficient learning framework for that challenge. Experiments show that VarteX outperforms the conventional model in forecast performance, requiring significantly fewer parameters and resources. The effectiveness of learning through multiple aggregations and regional split training is demonstrated, enabling more efficient and accurate deep learning-based weather forecasting.
1 Introduction
From strategies addressing extreme weather to daily societal activities, weather forecasting plays an indispensable role in human activities (Bauer et al., 2015). Recently, there has been increasing interest in applying data-driven models utilizing deep learning for weather forecasting (Scher & Messori, 2019; Weyn et al., 2019; Rasp et al., 2020; Weyn et al., 2021; Keisler, 2022; Lam et al., 2023). With intensive training on meteorological data, such models can generate forecasts within seconds (Lynch, 2008), whereas numerical weather prediction needs to solve complex partial differential equations, leading to significantly longer forecasting time. Several recent studies have reported that data-driven models outperform numerical weather prediction models even in the foretasting ability (Bi et al., 2023; Lam et al., 2023; Chen et al., 2023b).
Many data-driven weather forecasting models (Pathak et al., 2022; Bi et al., 2023; Chen et al., 2023a, b; Man et al., 2023; Nguyen et al., 2023; Ni, 2023; Nguyen et al., 2024; Ramavajjala, 2024) are based on Vision Transformer (ViT; Dosovitskiy et al. (2021)), a powerful attention-based model in computer vision. This is because meteorological data closely resemble image data in their structure, having height, width, and channel dimensions. A critical difference lies in the channel dimension. Image data only has RGB channels, which share similar information about the entire image. In contrast, meteorological data have many more channels for meteorological variables, such as temperature and humidity, with unique characteristics. The large number of meteorological variables increases the computational costs, and their diversity makes learning challenging. ClimaX (Nguyen et al., 2023) addressed this challenge with minimal modifications in ViT. Particularly, it equips a variable aggregation model that sums up meteorological variables into one representative variable with attention weights. Their experiments show that such an input-dependent variable aggregation leads to more successful training than the input-agnostic convolution of variables, a standard method in the image domain to summarize RGB channels.
In this study, we propose a new variable aggregation scheme and training method for efficient learning from meteorological data with ViT-based weather forecasting model. Our variable aggregation scheme is based on the hypothesis that the representative variable (or its -dimensional embedding vector) obtained by ClimaX variable aggregation may internally contain several components because, otherwise, it is too restrictive. If so, it is better to explicitly model them as representative variables with -dimensional embedding vectors. During the encoding process by our model, these embedding vectors are separately processed by Transformer encoders and then mixed by a mixing layer. The smaller embedding dimension leads to a smaller Transformer encoder, reducing the number of parameters by for each. While this reduces the model size, the memory cost at forward pass remains unchanged because the size of the attention map is determined by the number of tokens (i.e., image patches). To address this, we introduce regional split training, where a model is trained only on a cropped region at a single forward. This training decreases the final accuracy, but with a proper choice of crop size ratio , the decrease is moderate, and the training time and spacial cost are reduced drastically in return. Specifically, the space complexity reduction follows . In experiments, we trained our model, VarteX, as well as ClimaX, on the WeatherBench dataset from scratch. The results show that VarteX forecasting accuracy is 50% higher on average than that of ClimaX, and the gap is even larger for wind speed forecasting. Regarding latitude-weighted root mean squared error (RMSE) and latitude-weighted anomaly correlation coefficient (ACC) for all target variables, learning effectiveness is highlighted through multiple representative variable aggregations. VarteX achieved these results with 55% model size, 50% training time, and 35% memory usage than ClimaX.
2 Problem setting
Meteorological data is a time-series data of meteorological variables, such as temperature, geopotential, and wind speed, at each gird of the world. Suppose we have grids and variables, then we have at time step . The learning-based weather forecasting aims to find a forecasting function with a predesignated lead time through the training of deep learning model with parameters in the regression task.
An important characteristic of meteorological data is its many variables (e.g., ). Using attention-based models such as Vision Transformer means we have to handle as many as tokens, and the attention computation grows quadratically concerning to capture the spacial and inter-variable interactions. As each meteorological variable has its own time- and space-dynamical characteristics, we cannot resort to a simple concatenation of them as we do on the R, G, and B variables in image data.
The focus of our study lies in the aggregation of meteorological variables. Particularly, we are interested in aggregation function , which reduces variables from to while achieving successful learning. ClimaX offers such an aggregation function with . We suspect that reducing a single variable may be too aggressive and extend its idea to general . However, as the number of increases, the number of matrices for the attention mechanism also increases, leading to a rise in computational cost. Therefore, we should also address the reduction in computational cost.
3 Methodology
We propose a new variable aggregation scheme and efficient training framework. As demonstrated in Section 4, the former reduces the model size and significantly improves the prediction performance, and the latter realizes the training in half or even less time and memory cost.
3.1 Variable aggregation in ClimaX
First, we review the cross-attention-based variable aggregation of ClimaX. Let be an input data embedded to -dimensional space. In the following, the same operations are spatially uniformly applied, so we focus on position for notional simplicity, and re-define by . With a trainable query vector , the cross-attention is computed to aggregate the variables as follows.
(1) |
where and with trainable weights and is the softmax operation. Namely, this cross-attention computes a weighted sum of a linear transformation of the input, with the attention weights , thereby aggregating variables into one representative variable.
3.2 Model architecture
Lead | Model | U10 | T2m | Z500 | T850 | Parameters | ||||
---|---|---|---|---|---|---|---|---|---|---|
time | ACC | RMSE | ACC | RMSE | ACC | RMSE | ACC | RMSE | (M) | |
6h | ClimaX | 0.59 | 3.35 | 0.66 | 4.19 | 0.76 | 667.34 | 0.71 | 3.53 | 108.08 |
VarteX | 0.89 | 1.89 | 0.79 | 3.24 | 0.97 | 247.33 | 0.92 | 1.90 | 59.86 | |
VarteX | 0.19 | 4.54 | 0.32 | 7.43 | 0.53 | 966.76 | 0.51 | 4.65 | 20.76 | |
VarteX | 0.76 | 2.66 | 0.56 | 5.13 | 0.81 | 599.42 | 0.80 | 2.93 | 80.47 | |
24h | ClimaX | 0.37 | 3.86 | 0.66 | 4.20 | 0.69 | 734.47 | 0.70 | 3.58 | 108.08 |
VarteX ( | 0.64 | 3.16 | 0.83 | 2.99 | 0.88 | 478.42 | 0.85 | 2.59 | 59.86 | |
VarteX | 0.10 | 4.76 | 0.31 | 7.33 | 0.48 | 986.85 | 0.47 | 4.77 | 20.76 | |
VarteX | 0.52 | 3.52 | 0.57 | 4.99 | 0.75 | 660.40 | 0.75 | 3.25 | 80.47 |
While ClimaX variable aggregation reduces variables into one, we hypothesize that the representative variable (and its embedding vector) may internally consist of several components because it is hard to believe that (typically ) variables can be represented by a single variable, even using the input-dependent attention weights. If this is the case, we should be able to split the -dimensional space into spaces, where is the potential number of representative variables. As ViT has many square matrices of the size of the embedding dimension , this split directly reduces the model size from .
We thus propose to use more than one representative variable. Hence, the embedded input is split to
(2) |
where for . We prepare trainable query vectors and apply ClimaX variable aggregation for each pair to obtain .
The embedding vectors are then repeatedly and alternating processed by two types of transformer blocks. The first transformer block is the standard encoder layer, which consists of self-attention and feed-forward networks to extract cross-tokens and token-wise features. The second transformer block is also the standard encoder layer but introduced to allow the representative variables to interact. Specifically, the concatenation of representative variables is input to small encoder layer. The self-attention layer computes an attention map of size and mixes the representative variables.
3.3 Regional split training
Data-driven models have problems with high computational costs during training, which can be attributed to spatial resolution. This is because higher spatial resolution increases the number of tokens, which affects the memory cost of attention computation quadratically. However, weather forecasting at a particular point should be mainly affected the local region. Thus, we can naturally expect that training on regional input (i.e., spatially cropped input), if only the cropped regions cover the entire space as a whole, leads to a descent training, if not as successful as the global training. Here, we examine how much time and memory reduction can be obtained from this simple strategy and how much it affects forecasting performance. In the training, an input is cropped into sub-region , and the loss of model’s output is only measured on this region. The region to crop can be randomly determined or canonically selected without overlap. This makes the cost of attention computation -time smaller.
Crop size | U10 | T2m | Z500 | T850 | Train time | Memory | ||||
---|---|---|---|---|---|---|---|---|---|---|
ACC | RMSE | ACC | RMSE | ACC | RMSE | ACC | RMSE | (h) | (GB) | |
Global | 0.89 | 1.89 | 0.79 | 3.24 | 0.97 | 247.33 | 0.92 | 1.90 | 14.60 | 33.02 |
16 32 | 0.88 | 1.97 | 0.78 | 3.39 | 0.96 | 291.19 | 0.91 | 2.08 | 6.10 | 7.70 |
8 16 | 0.76 | 2.74 | 0.63 | 5.23 | 0.80 | 628.71 | 0.77 | 3.34 | 4.18 | 6.23 |
48 | 0.81 | 2.48 | 0.68 | 4.57 | 0.90 | 458.73 | 0.82 | 2.96 | 4.01 | 6.11 |
ClimaX | 0.59 | 3.35 | 0.66 | 4.19 | 0.76 | 667.34 | 0.71 | 3.53 | 12.28 | 21.73 |
4 Experiments
We now evaluate the forecasting ability and the efficiency of the proposed model, VarteX, and ClimaX as a baseline.
4.1 Dataset and training setup
We train VarteX and ClimaX on ERA5 (Hersbach et al., 2020) from scratch following the training setup given in (Nguyen et al., 2023).
Dataset.
ERA5 is a publicly accessible atmospheric reanalysis dataset provided by the ECMWF. The full spatial resolution is 0.25° ( grids). As in (Nguyen et al., 2023), we use ERA5 data downsampled to a spatial resolution of 5.625° ( grids) provided by WeatherBench (Rasp et al., 2020); 48 meteorological variables are used in training, and four are the target of forecast, i.e., geopotential at 500 hPa (Z500), temperature at 850 hPa (T850), temperature at 2 meters from the ground (T2m), and zonal wind speed at 10 meters from the ground (U10). Each channel is standardized to have a mean of 0 and a standard deviation of 1. The dataset spans hourly data from 2006 to 2018, with 2006 to 2015 used for training, 2016 for validation, and 2017 to 2018 for testing. This division allows for comprehensive training and robust validation and testing of the predictive capabilities of the models involved.
Training.
Both VarteX and ClimaX are trained with latitude-weighted mean squared error (MSE) loss to predict from meteorological sample at time to that after steps. We trained two models for lead time and hours. The embedding dimension of Climax is , and that of VarteX is . Other architecture parameters, such as the number of attention heads, are all common between VarteX and ClimaX. Other detailed experimental settings follow those of ClimaX (cf. Section A.1).
4.2 Effect of representative variables
Table 1 compares the predictive performance of VarteX and ClimaX with 6-hour and 24-hour lead times, using RMSE and ACC metrics. VarteX, with two representative variables, significantly outperforms ClimaX regarding RMSE and ACC for all target variables with approximately a 45% reduction in the model size. This justifies our hypothesis that explicit handling of multiple representative variables, rather than implicitly having them in a single representative variable, improves learning efficiency. Increasing the number of representative variables in VarteX further reduces the model size; however, a drastic performance drop is observed. We consider that the embedding dimension per representative variable limits the network capacity. To examine this, we tested the case of again by increasing the embedding dimension (per representative variable) from to . This makes the embedding dimension the same as that in the case. We observed a sharp performance improvement, but this does not outperform the case of in both prediction ability.
4.3 Effect of Crop Size on Predictive Performance
We next compare standard training (referred to as global training) and regional split training. Given the results in Section 4.2, we focus on VarteX with . During the training of VarteX with regional split training with , equivalent to , , and grids per each cropped region. At the training, the loss is computed for each grid independently, and at the inference, an input is split and fed to the model and then reconstructed from the output. Note that the region split is done canonically; if , the input is spacially divided into top left, top right, bottom left, and bottom right. We also tested a random selection of crop** regions. However, this was not as successful as the canonical division.
Table 2 compares VarteX with two representative variables using regional split training and global training for a 6-hour lead time. The results indicate that a crop size of achieves the best trade-off between the performance and training cost.111Note that the optimal number of splits can depend on the resolution. The smaller crop size deteriorates the forecasting performance with limited improvement in the training time and memory consumption. This is because the reduction in the number of tokens from to , which quadratically reduces the size of the attention map, is large in an absolute sense, but from to is rather marginal, and other factors become a bottleneck. To summarize, our experiments suggest that VarteX with and regional split training with (equivalent to ) is the current best practice. Compared to the original ClimaX with global training, we significantly improved the ACC and RMSE with 55% model size, 40% training hours, and 25% memory consumption.
5 Conclusions
In this study, we addressed efficient learning over many diverse meteorological variables using the ViT-based model. Inspired by ClimaX, we propose a new variable aggregation scheme explicitly modeled to extract several representative variables, significantly improving the forecasting performance and reducing the model size. Further, we examined region-wise training, which reduces training time and memory cost by a large margin at a subtle cost in the forecasting scores. While this paper focuses on training ERA5 from scratch, we may apply our results to build a foundation model by following the large-scale training of ClimaX.
6 Acknowledgments
This work was supported by JST Moonshot R&D Program Grant Number JPMJMS2389.
References
- Bauer et al. (2015) Bauer, P., Thorpe, A., and Brunet, G. The quiet revolution of numerical weather prediction. Nature, 525(7567):47–55, 2015.
- Bi et al. (2023) Bi, K., Xie, L., Zhang, H., Chen, X., Gu, X., and Tian, Q. Accurate medium-range global weather forecasting with 3D neural networks. Nature, 619(7970):533–538, 2023.
- Chen et al. (2023a) Chen, K., Han, T., Gong, J., Bai, L., Ling, F., Luo, J.-J., Chen, X., Ma, L., Zhang, T., Su, R., et al. FengWu: Pushing the skillful global medium-range weather forecast beyond 10 days lead. arXiv preprint arXiv:2304.02948, 2023a.
- Chen et al. (2023b) Chen, L., Zhong, X., Zhang, F., Cheng, Y., Xu, Y., Qi, Y., and Li, H. FuXi: a cascade machine learning forecasting system for 15-day global weather forecast. npj Climate and Atmospheric Science, 6(1):190, 2023b.
- Dosovitskiy et al. (2021) Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., and Houlsby, N. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021.
- Hersbach et al. (2020) Hersbach, H., Bell, B., Berrisford, P., Hirahara, S., Horányi, A., Muñoz-Sabater, J., Nicolas, J., Peubey, C., Radu, R., Schepers, D., et al. The era5 global reanalysis. Quarterly Journal of the Royal Meteorological Society, 146(730):1999–2049, 2020.
- Keisler (2022) Keisler, R. Forecasting global weather with graph neural networks. arXiv preprint arXiv:2202.07575, 2022.
- Lam et al. (2023) Lam, R., Sanchez-Gonzalez, A., Willson, M., Wirnsberger, P., Fortunato, M., Alet, F., Ravuri, S., Ewalds, T., Eaton-Rosen, Z., Hu, W., et al. Learning skillful medium-range global weather forecasting. Science, 382(6677):1416–1421, 2023.
- Lynch (2008) Lynch, P. The origins of computer weather prediction and climate modeling. Journal of computational physics, 227(7):3431–3444, 2008.
- Man et al. (2023) Man, X., Zhang, C., Feng, J., Li, C., and Shao, J. W-mae: Pre-trained weather model with masked autoencoder for multi-variable weather forecasting. arXiv preprint arXiv:2304.08754, 2023.
- Nguyen et al. (2023) Nguyen, T., Brandstetter, J., Kapoor, A., Gupta, J., and Grover, A. ClimaX: A foundation model for weather and climate. In ICLR 2023 Workshop on Tackling Climate Change with Machine Learning, 2023.
- Nguyen et al. (2024) Nguyen, T., Shah, R., Bansal, H., Arcomano, T., Madireddy, S., Maulik, R., Kotamarthi, V., Foster, I., and Grover, A. Scaling Transformers for Skillful and Reliable Medium-range Weather Forecasting. In ICLR 2024 Workshop on AI4DifferentialEquations In Science, 2024.
- Ni (2023) Ni, Z. Kunyu: A High-Performing Global Weather Model Beyond Regression Losses. arXiv preprint arXiv:2312.08264, 2023.
- Pathak et al. (2022) Pathak, J., Subramanian, S., Harrington, P., Raja, S., Chattopadhyay, A., Mardani, M., Kurth, T., Hall, D., Li, Z., Azizzadenesheli, K., et al. Fourcastnet: A global data-driven high-resolution weather model using adaptive fourier neural operators. arXiv preprint arXiv:2202.11214, 2022.
- Ramavajjala (2024) Ramavajjala, V. HEAL-ViT: Vision Transformers on a spherical mesh for medium-range weather forecasting. arXiv preprint arXiv:2403.17016, 2024.
- Rasp et al. (2020) Rasp, S., Dueben, P. D., Scher, S., Weyn, J. A., Mouatadid, S., and Thuerey, N. WeatherBench: a benchmark data set for data-driven weather forecasting. Journal of Advances in Modeling Earth Systems, 12(11):e2020MS002203, 2020.
- Scher & Messori (2019) Scher, S. and Messori, G. Weather and climate forecasting with neural networks: using general circulation models (gcms) with different complexity as a study ground. Geoscientific Model Development, 12(7):2797–2809, 2019.
- Weyn et al. (2019) Weyn, J. A., Durran, D. R., and Caruana, R. Can machines learn to predict weather? using deep learning to predict gridded 500-hpa geopotential height from historical weather data. Journal of Advances in Modeling Earth Systems, 11(8):2680–2693, 2019.
- Weyn et al. (2021) Weyn, J. A., Durran, D. R., Caruana, R., and Cresswell-Clay, N. Sub-seasonal forecasting with a large ensemble of deep-learning weather prediction models. Journal of Advances in Modeling Earth Systems, 13(7):e2021MS002502, 2021.
Appendix A Experiment details
A.1 Training details
In this study, the experimental conditions were set up according to the paper under ClimaX. The two models, VarteX and ClimaX, are trained using an AdamW optimizer with a learning rate of and a weight decay of . The training schedule includes a linear warmup over five epochs followed by 45 epochs of cosine annealing. The total training period spans 50 epochs with a batch size of 128, utilizing a gradient accumulation strategy of 4 steps for every 32 batch sizes.
A.2 Software and Hardware
We use the ClimaX repository from GitHub (https://github.com/microsoft/ClimaX). All experiments use an NVIDIA RTX A6000 with 48GB of memory.
A.3 Hyperparameters
Table 3 and Table 4 present the hyperparameters of the models utilized in this experiment. The hyperparameters for ClimaX are adapted from the original paper, while VarteX shares the same values for common parameters with ClimaX. Additionally, Table 5 illustrates the meteorological variables contained within the input data.
Hyperparameter | Value |
Default variables | All variables in Table 5 |
Image size | [32, 64] |
Patch size | 2 |
Embedding dimension | 1024 |
Number of ViT blocks | 8 |
Number of attention heads | 16 |
Number of representative variables | 2 |
MLP ratio | 4 |
Prediction depth | 2 |
Hidden dimension in prediction head | 1024 |
Drop path | 0.1 |
Drop rate | 0.1 |
Hyperparameter | Value |
---|---|
Default variables | All variables in Table 5 |
Image size | [32, 64] |
Patch size | 2 |
Embedding dimension | 1024 |
Number of ViT blocks | 8 |
Number of attention heads | 16 |
MLP ratio | 4 |
Prediction depth | 2 |
Hidden dimension in prediction head | 1024 |
Drop path | 0.1 |
Drop rate | 0.1 |
Variable name | Abbrev. | Pressure levels |
Land-sea mask | LSM | |
Orography | ||
2 metre temperature | T2m | |
10 meter U wind component | U10 | |
10 meter V wind component | V10 | |
Geopotential | Z | 50, 250, 500, 600, 700, 850, 925 |
U wind component | U | 50, 250, 500, 600, 700, 850, 925 |
V wind component | V | 50, 250, 500, 600, 700, 850, 925 |
Temperature | T | 50, 250, 500, 600, 700, 850, 925 |
Specific humidity | Q | 50, 250, 500, 600, 700, 850, 925 |
Relative humidity | R | 50, 250, 500, 600, 700, 850, 925 |
A.4 Loss function and Metrics
This section presents the evaluation metrics used in the experiment. and represent the forecast and ground truth, respectively, while K denotes the total number of test data points. Additionally, C represents climatology, defined as the time average over the entire test data set, .
Latitude weighting factor
(3) |
Latitude-weighted mean square error (MSE)
(4) |
Latitude-weighted root mean square error (RMSE)
(5) |
Latitude-weighted anomaly correlation coefficient (ACC)
(6) |
(7) |
Appendix B Qualitative evaluation
The qualitative evaluation of the predictive performance of VarteX for all target variables with a 6-hour lead time is shown. The results include VarteX with both global training and regional split training. The first column shows the initial state, the second column presents the ground truth, the third column displays the predicted results, and the fourth column indicates the bias between the ground truth and the predictions. Note that these visualizations are for reference only and that this experiment was not pre-trained in the same way as ClimaX. We believe that better visualization can be obtained if pre-training is used as in the ClimaX paper (Nguyen et al., 2023).