Skip to content

xuesongwang/Neural-Process-with-Position-Relevant-Only-Variances

Repository files navigation

NP-PROV: Neural Processes with Position-Relevant-Only Variances

This repository is the official implementation of NP-PROV: Neural Processes with Position-Relevant-Only Variances.

Requirements

  • Python 3.6 or higher.

  • gcc and gfortran: On OS X, these are both installed with brew install gcc. On Linux, gcc is most likely already available, and gfortran can be installed with apt-get install gfortran.

Install the requirements and You should now be ready to go!

pip install -r requirements.txt

Training

Off-the-grid datasets

To train the model(s) for off-the-grid datasets, run this command:

python train_1d.py --name EQ --epochs 200 --learning_rate 3e-4 --weight_decay 1e-5

The first argument, name(default = EQ), specifies the data that the model will be trained on, and should be one of the following:

  • EQ: samples from a GP with an exponentiated quadratic (EQ) kernel;
  • matern: samples from a GP with a Matern-5/2 kernel;
  • period: samples from a GP with a weakly-periodic kernel
  • smart_meter: This dataset is referred from: https://github.com/3springs/attentive-neural-processes/tree/RANPfSD/data To train on smart_meter, you need to change the argument indir in the function of get_smartmeter_df in data/smart_meter.py to your own data path.

On-the-grid datasets

To train the models for on-the-grid datasets, run this command:

python train_2d.py --dataset mnist --batch-size 16 --learning-rate 5e-4 --epochs 100

The first argument, dataset(default = mnist), specifies the data that the model will be trained on, and should be one of the following:

  • mnist:This dataset can be downloaded using torchvision.datasets. Change the path MNIST('./MNIST/mnist_data'...) of the function load_dataset in data/image_data.py to the location of your datasets;
  • svhn: This dataset can also be downloaded using torchvision;
  • celebA: This dataset can be downloaded from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

We split the downloaded training datasets into training and validating sets with the proportion: 7:3 and use additional testing sets.

Evaluation

Off-the-grid datasets

To evaluate my model on off-the-grid datasets, run:

python eval_1d.py --name EQ

The argument is the same as in train_1d.py. A model called name + _model.pt will be loaded from the folder saved_model

On-the-grid datasets

To evaluate my model on On-the-grid datasets, run:

python eval_2d.py --dataset mnist --batch-size 16

The argument is the same as in train_2d.py. A model called dataset + _model.pth.gz will be loaded from the folder saved_model

Results

Our model achieves the following log-likelihood (displayed in mean (variance)) on the on/off-the-grid datasets:

Model name EQ Matern Period Smart Meter
NP-PROV 2.20 (0.02) 0.90 (0.03) -1.00 (0.02) 2.32 (0.05)
Model name MNIST SVHN celebA miniImageNet
NP-PROV 2.66 (3e-2) 8.24 (5e-2) 5.11 (1e-2) 4.39 (2e-1)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages