Skip to content

Commit

Permalink
Add training and batched inference test for DDPM vs DDIM (huggingface…
Browse files Browse the repository at this point in the history
…#140)

* Add torch_device to the VE pipeline

* Mark the training test with slow
  • Loading branch information
anton-l authored Jul 27, 2022
1 parent c24c3b3 commit 32ab6fc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
2 changes: 1 addition & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_scheduler,
)
Expand Down
36 changes: 36 additions & 0 deletions training_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,44 @@
import copy
import os
import random

import numpy as np
import torch


def enable_full_determinism(seed: int):
"""
Helper function for reproducible behavior during distributed training. See
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
"""
# set seed first
set_seed(seed)

# Enable PyTorch deterministic mode. This potentially requires either the environment
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
# depending on the CUDA version, so we set them both here
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)

# Enable CUDNN deterministic mode
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
Args:
seed (`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available


class EMAModel:
"""
Exponential Moving Average of models weights
Expand Down

0 comments on commit 32ab6fc

Please sign in to comment.