Skip to content

tk-rusch/linoss

Repository files navigation

Oscillatory State-Space Models (ICLR2025 Oral)

This repository contains the official implementation for the paper Oscillatory State-Space Models by T. Konstantin Rusch and Daniela Rus.

NEW For a complete JAX-based SSM library including LinOSS and other state-of-the-art SSM architectures, please check out our linax library.


We propose Linear Oscillatory State-Space models (LinOSS) for efficiently learning on long sequences. Inspired by cortical dynamics of biological neural networks, we base our proposed LinOSS model on a system of forced harmonic oscillators. A stable discretization, integrated over time using fast associative parallel scans, yields the proposed state-space model.

linoss_animation

Requirements

This repository is implemented in python 3.10 and uses Jax as their machine learning framework. This is an extension of https://github.com/Benjamin-Walker/log-neural-cdes.

Environment

The code for preprocessing the datasets, training LinOSS, S5, LRU, NCDE, NRDE, and Log-NCDE uses the following packages:

  • jax and jaxlib for automatic differentiation.
  • equinox for constructing neural networks.
  • optax for neural network optimisers.
  • diffrax for differential equation solvers.
  • signax for calculating the signature.
  • sktime for handling time series data in ARFF format.
  • tqdm for progress bars.
  • matplotlib for plotting.
  • pre-commit for code formatting.
conda create -n LinOSS python=3.10
conda activate LinOSS
conda install pre-commit=3.7.1 sktime=0.30.1 tqdm=4.66.4 matplotlib=3.8.4 -c conda-forge
# Substitue for correct Jax pip install: https://jax.readthedocs.io/en/latest/installation.html
pip install -U "jax[cuda12]" "jaxlib[cuda12]" equinox==0.11.4 optax==0.2.2 diffrax==0.5.1 signax==0.1.1

If running data_dir/process_uea.py throws this error: No module named 'packaging' Then run: pip install packaging

After installing the requirements, run pre-commit install to install the pre-commit hooks.


Data

The folder data_dir contains the scripts for downloading data, preprocessing the data, and creating dataloaders and datasets. Raw data should be downloaded into the data_dir/raw folder. Processed data should be saved into the data_dir/processed folder in the following format:

processed/{collection}/{dataset_name}/data.pkl, 
processed/{collection}/{dataset_name}/labels.pkl,
processed/{collection}/{dataset_name}/original_idxs.pkl (if the dataset has original data splits)

where data.pkl and labels.pkl are jnp.arrays with shape (n_samples, n_timesteps, n_features) and (n_samples, n_classes) respectively. If the dataset had original_idxs then those should be saved as a list of jnp.arrays with shape [(n_train,), (n_val,), (n_test,)].

The UEA Datasets

The UEA datasets are a collection of multivariate time series classification benchmarks. They can be downloaded by running data_dir/download_uea.py and preprocessed by running data_dir/process_uea.py.

The PPG-DaLiA Dataset

The PPG-DaLiA dataset is a multivariate time series regression dataset, where the aim is to predict a person’s heart rate using data collected from a wrist-worn device. The dataset can be downloaded from the UCI Machine Learning Repository. The data should be unzipped and saved in the data_dir/raw folder in the following format PPG_FieldStudy/S{i}/S{i}.pkl. The data can be preprocessed by running the process_ppg.py script.


Experiments

The code for training and evaluating the models is contained in train.py. Experiments can be run using the run_experiment.py script. This script requires you to specify the names of the models you want to train, the names of the datasets you want to train on, and a directory which contains configuration files. By default, it will run the LinOSS experiments. The configuration files should be organised as config_dir/{model_name}/{dataset_name}.json and contain the following fields:

  • seeds: A list of seeds to use for training.
  • data_dir: The directory containing the data.
  • output_parent_dir: The directory to save the output.
  • lr_scheduler: A function which takes the learning rate and returns the new learning rate.
  • num_steps: The number of steps to train for.
  • print_steps: The number of steps between printing the loss.
  • batch_size: The batch size.
  • metric: The metric to use for evaluation.
  • classification: Whether the task is a classification task.
  • linoss_discretization: ONLY for LinoSS -- which discretization to use. Choices are ['IM','IMEX']
  • lr: The initial learning rate.
  • time: Whether to include time as a channel.
  • Any further specific model parameters.

See experiment_configs/repeats for examples.


Reproducing the Results

The configuration files for all the experiments with fixed hyperparameters can be found in the experiment_configs folder and run_experiment.py is currently configured to run the repeat experiments on the UEA datasets. The outputs folder contains a zip file of the output files from the UEA, and PPG experiments.


Citation

If you found our work useful in your research, please cite our paper at:

@inproceedings{rusch2025linoss,
  title={Oscillatory State-Space Models},
  author={Rusch, T Konstantin and Rus, Daniela},
  booktitle={International Conference on Learning Representations},
  year={2025}
}

(Also consider starring the project on GitHub.)

About

Oscillatory State-Space Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages