Skip to content

Commit

Permalink
importing and integration
Browse files Browse the repository at this point in the history
  • Loading branch information
raminmh committed Jun 14, 2021
1 parent 2512766 commit 55b15e0
Show file tree
Hide file tree
Showing 29 changed files with 3,847 additions and 2,450 deletions.
Binary file added .DS_Store
Binary file not shown.
21 changes: 0 additions & 21 deletions LICENSE

This file was deleted.

55 changes: 54 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,56 @@
# Closed Form Solution of Neural ODEs
# Closed-form Continuous-depth Models

This archive is the code supplementary materials of the paper Closed-form Continuous-depth Models

## Requirements

- Python3.6 or newer
- Tensorflow 2.4 or newer
- PyTorch 1.8 or newer
- pytorch-lightning 1.3.0 or newer
- scikit-learn 0.24.2 or newer

## Module description

- ```tf_cfc.py``` Implementation of the CfC (various versions) in Tensorflow 2.x
- ```torch_cfc.py``` Implementation of the CfC (various versions) in PyTorch
- ```train_physio.py``` Trains the CfC models on the Physionet 2012 dataset in PyTorch (code adapted from Rubanova et al. 2019)
- ```train_xor.py``` Trains the CfC models on the XOR dataset in Tensorflow (code adapted from Lechner & Hasani, 2020)
- ```train_imdb.py``` Trains the CfC models on the IMDB dataset in Tensorflow (code adapted from Keras examples website)
- ```train_walker.py``` Trains the CfC models on the Walker2d dataset in Tensorflow (code adapted from Lechner & Hasani, 2020)
- ```irregular_sampled_datasets.py``` Datasets (same splits) from Lechner & Hasani (2020)
- ```duv_physionet.py``` and ```duv_utils.py``` Physionet dataset (same split) from Rubanova et al. (2019)

## Usage

All training scripts except the following three flags

- ```no_gate``` Runs the CfC without the (1-sigmoid) part
- ```minimal``` Runs the CfC direct solution
- ```use_ltc``` Runs an LTC with a semi-implicit ODE solver instead of a CfC
- ```use_mixed``` Mixes the CfC's RNN-state with a LSTM to avoid vanishing gradients

If none of these flags are provided, the full CfC model is used

For instance

```bash
python3 train_physio.py
```

train the full CfC model on the Physionet dataset.

Similarly

```bash
train_walker.py --minimal
```

runs the direct CfC solution on the walker2d dataset.

For downloading the Walker2d dataset of Lechner & Hasani 2020, run

```bash
source download_dataset.sh
```

8 changes: 8 additions & 0 deletions download_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

mkdir data
mkdir data/person

wget https://pub.ist.ac.at/~mlechner/datasets/walker.zip
unzip walker.zip -d data/
rm walker.zip
37 changes: 0 additions & 37 deletions download_datasets.sh

This file was deleted.

Loading

0 comments on commit 55b15e0

Please sign in to comment.