This repository contains the official implementation for the paper Oscillatory State-Space Models by T. Konstantin Rusch and Daniela Rus.
For a complete JAX-based SSM library including LinOSS and other state-of-the-art SSM architectures, please check out our linax library.
We propose Linear Oscillatory State-Space models (LinOSS) for efficiently learning on long sequences. Inspired by cortical dynamics of biological neural networks, we base our proposed LinOSS model on a system of forced harmonic oscillators. A stable discretization, integrated over time using fast associative parallel scans, yields the proposed state-space model.
This repository is implemented in python 3.10 and uses Jax as their machine learning framework. This is an extension of https://github.com/Benjamin-Walker/log-neural-cdes.
The code for preprocessing the datasets, training LinOSS, S5, LRU, NCDE, NRDE, and Log-NCDE uses the following packages:
jaxandjaxlibfor automatic differentiation.equinoxfor constructing neural networks.optaxfor neural network optimisers.diffraxfor differential equation solvers.signaxfor calculating the signature.sktimefor handling time series data in ARFF format.tqdmfor progress bars.matplotlibfor plotting.pre-commitfor code formatting.
conda create -n LinOSS python=3.10
conda activate LinOSS
conda install pre-commit=3.7.1 sktime=0.30.1 tqdm=4.66.4 matplotlib=3.8.4 -c conda-forge
# Substitue for correct Jax pip install: https://jax.readthedocs.io/en/latest/installation.html
pip install -U "jax[cuda12]" "jaxlib[cuda12]" equinox==0.11.4 optax==0.2.2 diffrax==0.5.1 signax==0.1.1
If running data_dir/process_uea.py throws this error: No module named 'packaging'
Then run: pip install packaging
After installing the requirements, run pre-commit install to install the pre-commit hooks.
The folder data_dir contains the scripts for downloading data, preprocessing the data, and creating dataloaders and
datasets. Raw data should be downloaded into the data_dir/raw folder. Processed data should be saved into the data_dir/processed
folder in the following format:
processed/{collection}/{dataset_name}/data.pkl,
processed/{collection}/{dataset_name}/labels.pkl,
processed/{collection}/{dataset_name}/original_idxs.pkl (if the dataset has original data splits)
where data.pkl and labels.pkl are jnp.arrays with shape (n_samples, n_timesteps, n_features) and (n_samples, n_classes) respectively. If the dataset had original_idxs then those should be saved as a list of jnp.arrays with shape [(n_train,), (n_val,), (n_test,)].
The UEA datasets are a collection of multivariate time series classification benchmarks. They can be downloaded by
running data_dir/download_uea.py and preprocessed by running data_dir/process_uea.py.
The PPG-DaLiA dataset is a multivariate time series regression dataset,
where the aim is to predict a person’s heart rate using data
collected from a wrist-worn device. The dataset can be downloaded from the
UCI Machine Learning Repository. The data should be
unzipped and saved in the data_dir/raw folder in the following format PPG_FieldStudy/S{i}/S{i}.pkl. The data can be
preprocessed by running the process_ppg.py script.
The code for training and evaluating the models is contained in train.py. Experiments can be run using the run_experiment.py script.
This script requires you to specify the names of the models you want to train,
the names of the datasets you want to train on, and a directory which contains configuration files. By default,
it will run the LinOSS experiments. The configuration files should be organised as config_dir/{model_name}/{dataset_name}.json and contain the
following fields:
seeds: A list of seeds to use for training.data_dir: The directory containing the data.output_parent_dir: The directory to save the output.lr_scheduler: A function which takes the learning rate and returns the new learning rate.num_steps: The number of steps to train for.print_steps: The number of steps between printing the loss.batch_size: The batch size.metric: The metric to use for evaluation.classification: Whether the task is a classification task.linoss_discretization: ONLY for LinoSS -- which discretization to use. Choices are ['IM','IMEX']lr: The initial learning rate.time: Whether to include time as a channel.- Any further specific model parameters.
See experiment_configs/repeats for examples.
The configuration files for all the experiments with fixed hyperparameters can be found in the experiment_configs folder and
run_experiment.py is currently configured to run the repeat experiments on the UEA datasets.
The outputs folder contains a zip file of the output files from the UEA, and PPG experiments.
If you found our work useful in your research, please cite our paper at:
@inproceedings{rusch2025linoss,
title={Oscillatory State-Space Models},
author={Rusch, T Konstantin and Rus, Daniela},
booktitle={International Conference on Learning Representations},
year={2025}
}(Also consider starring the project on GitHub.)
