Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: Experimental/Whisper notebook (speedup.ipynb) is not working #331

Open
1 of 2 tasks
Artyom17 opened this issue Jul 25, 2023 · 0 comments
Open
1 of 2 tasks

bug: Experimental/Whisper notebook (speedup.ipynb) is not working #331

Artyom17 opened this issue Jul 25, 2023 · 0 comments

Comments

@Artyom17
Copy link

Artyom17 commented Jul 25, 2023

Description

Experimental/Whisper notebook (speedup.ipynb) is not working.

When run the unmodified notebook on RTX 4080/4090 (i.e., it is using the large-v2 model), it takes a lot of time to 'optimize', but at some point it starts to print the following messages:

[2023-07-25 16:24:47,436] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
   function: 'forward' (kernl/.venv/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:564)
   reasons:  ___check_obj_id(past_key_value, 94111367005152)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.

and then many of these follow:

kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()

Tried to increase the cache_size_limit up to 1024, the only result is that it waits much longer to print the error about the cache size and the final outcome is the same.

The final outcome: there is no any speedup, moreover, the 'optimized' variant is usually slower (well, obviously it wasn't fully optimized).

Steps to reproduce

Repro steps are pretty much the same as in the experimental/whisper/README.md:

DOCKER_BUILDKIT=1 docker build -t kernl .
docker run --rm -it --gpus all -v $(pwd):/kernl kernl
apt install libsndfile1-dev # used by a Python audio dependency
pip install datasets soundfile librosa jupyter notebook
jupyter nbconvert --execute --clear-output experimental/whisper/speedup.ipynb --log-level=10

Or, this script could be used:

import time

import torch
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor

from kernl.model_optimization import optimize_model


torch.set_float32_matmul_precision("high")
torch._dynamo.config.cache_size_limit = 64 # 1024
#torch._dynamo.config.dynamic_shapes = True
max_len = 50  # we do not expect more than 50 tokens per audio.
num_beams = 5
model_name = "openai/whisper-large-v2"  # "openai/whisper-tiny"

# audio_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")  # small dataset for tests
audio_dataset = load_dataset("librispeech_asr", "clean", split="test")


def get_tokens(item: dict[str, dict]) -> torch.Tensor:
    tensor = processor(item["audio"]["array"], return_tensors="pt", sampling_rate=16_000).input_features
    return tensor.cuda()


processor = WhisperProcessor.from_pretrained(model_name)
inputs_warmup = get_tokens(audio_dataset[0])

model = WhisperForConditionalGeneration.from_pretrained(model_name).to("cuda").eval()

MAX_ITER = 100

timings_original = list()
transcriptions = list()
with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"):
    # warmup
    model.generate(inputs_warmup, min_length=max_len, max_length=max_len, num_beams=num_beams, do_sample=False)
    torch.cuda.synchronize()
    i = 0
    for audio in audio_dataset:
        inputs = get_tokens(audio)
        torch.cuda.synchronize()
        start = time.time()
        predicted_ids = model.generate(inputs, min_length=1, max_length=max_len, num_beams=num_beams, do_sample=False)
        torch.cuda.synchronize()
        timings_original.append(time.time() - start)
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
        print(f'{i}: {transcription}')
        transcriptions.append(transcription)
        i = i + 1
        if i == MAX_ITER:
            break
    len_audio_dataset = i

assert len_audio_dataset == len(transcriptions) == len(timings_original)

@staticmethod
def fix_reorder_cache(past, beam_idx):
    reordered_past = ()
    for layer_past in past:
        reordered_past += (
            tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
        )
    return reordered_past


WhisperForConditionalGeneration._reorder_cache = fix_reorder_cache

print('###################################')
# uncomment 2 following lines and comment the third one to use vanilla torch.compile instead of Kernl
# model.model.decoder.forward_original = model.model.decoder.forward
# model.model.decoder.forward = torch.compile(model.model.decoder.forward_original, mode="reduce-overhead")
optimize_model(model.model.decoder)
nb_diff = 0
timings_optimized = list()
with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"):
    start = time.time()
    model.generate(inputs_warmup, min_length=max_len, max_length=max_len, num_beams=num_beams, do_sample=False)
    torch.cuda.synchronize()
    print(f"time to warmup: {(time.time() - start)/60:.2f}min")
    i = 0
    for original_modem_transcription, audio in zip(transcriptions, audio_dataset):
        inputs = get_tokens(audio)
        torch.cuda.synchronize()
        start = time.time()
        predicted_ids = model.generate(inputs, min_length=1, max_length=max_len, num_beams=num_beams, do_sample=False)
        torch.cuda.synchronize()
        timings_optimized.append(time.time() - start)
        optimized_transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
        print(f'{i}: {optimized_transcription}')
        nb_diff += original_modem_transcription != optimized_transcription
        i = i + 1
        if i == MAX_ITER:
            break

original_mins = sum(timings_original) / 60
optimized_mins = sum(timings_optimized) / 60
speedup = original_mins / optimized_mins
print(f"Kernl speedup: {speedup:.1f}X ({optimized_mins:.1f} VS {original_mins:.1f} min)")
print(f"# different outputs: {nb_diff}/{len_audio_dataset} ({nb_diff / len_audio_dataset * 100:.2f}%)")

print("\nmemory footprint:")
print(f"* allocated: {torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024:.1f}GB")
print(f"* reserved: {torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024:.1f}GB")
print(f"* max reserved: {torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024:.1f}GB")


Expected Behavior

The hope is to see the promised speed up in action.

Full log of running the script above (with just 100 samples):

Found cached dataset librispeech_asr (/home/artyom/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/cff5df6e7955c80a67f80e27e7e655de71c689e2d2364bece785b972acb37fe7)
[2023-07-25 16:39:39,587] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
   function: 'forward' (/home/artyom/kernl/.venv/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:564)
   reasons:  ___check_obj_id(past_key_value, 94076869202912)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
[2023-07-25 16:40:23,679] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
   function: 'forward' (/home/artyom/kernl/.venv/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:346)
   reasons:  tensor 'past_key_value[0]' strides mismatch at index 0. expected 2560, actual 3840
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
/home/artyom/kernl/.venv/lib/python3.10/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super().capture_end()
[2023-07-25 16:41:03,368] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
   function: '_shape' (/home/artyom/kernl/.venv/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:342)
   reasons:  ___check_obj_id(self, 140211086954640)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
[2023-07-25 16:41:37,820] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
   function: '__setitem__' (/home/artyom/kernl/.venv/lib/python3.10/site-packages/transformers/utils/generic.py:328)
   reasons:  tensor 'self.past_key_values[0][0]' strides mismatch at index 0. expected 1280, actual 42240
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
###################################
time to warmup: 3.43min
Kernl speedup: 1.1X (1.0 VS 1.1 min)
# different outputs: 0/100 (0.00%)

memory footprint:
* allocated: 6.4GB
* reserved: 10.1GB
* max reserved: 12.1GB
* 
* 

Actual Behavior

Prints errors about cache_max_size overflow, about empty CUDA graphs and no speed up is reported (often, the 'optimized' version is slower that the original).

Your environment

  • Ununtu 22.04 WSL
  • Python 3.10.6
  • Python package manager - pip 22.0.2
Package                  Version      Editable project location
------------------------ ------------ -------------------------
aiohttp                  3.8.5
aiosignal                1.3.1
appdirs                  1.4.4
asttokens                2.2.1
async-timeout            4.0.2
attrs                    23.1.0
audioread                3.0.0
backcall                 0.2.0
black                    23.7.0
certifi                  2023.5.7
cffi                     1.15.1
charset-normalizer       3.2.0
click                    8.1.6
cmake                    3.27.0
datasets                 2.13.1
decorator                5.1.1
dill                     0.3.6
exceptiongroup           1.1.2
executing                1.2.0
filelock                 3.12.2
flake8                   6.0.0
frozenlist               1.4.0
fsspec                   2023.6.0
huggingface-hub          0.16.4
idna                     3.4
iniconfig                2.0.0
ipython                  8.14.0
isort                    5.12.0
jedi                     0.18.2
Jinja2                   3.1.2
joblib                   1.3.1
kernl                    0.2.2        /home/artyom/kernl/src
lazy_loader              0.3
librosa                  0.10.0.post2
lit                      16.0.6
llvmlite                 0.40.1
MarkupSafe               2.1.3
matplotlib-inline        0.1.6
mccabe                   0.7.0
more-itertools           9.1.0
mpmath                   1.3.0
msgpack                  1.0.5
multidict                6.0.4
multiprocess             0.70.14
mypy-extensions          1.0.0
networkx                 3.1
numba                    0.57.1
numpy                    1.24.4
nvidia-cublas-cu11       11.10.3.66
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11        8.5.0.96
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.2.10.91
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusparse-cu11     11.7.4.91
nvidia-nccl-cu11         2.14.3
nvidia-nvtx-cu11         11.7.91
packaging                23.1
pandas                   2.0.3
parso                    0.8.3
pathspec                 0.11.1
pexpect                  4.8.0
pickleshare              0.7.5
pip                      22.0.2
platformdirs             3.9.1
pluggy                   1.2.0
pooch                    1.6.0
prompt-toolkit           3.0.39
ptyprocess               0.7.0
pure-eval                0.2.2
pyarrow                  12.0.1
pycodestyle              2.10.0
pycparser                2.21
pyflakes                 3.0.1
Pygments                 2.15.1
pytest                   7.4.0
python-dateutil          2.8.2
pytz                     2023.3
PyYAML                   6.0.1
regex                    2023.6.3
requests                 2.31.0
safetensors              0.3.1
scikit-learn             1.3.0
scipy                    1.11.1
setuptools               59.6.0
six                      1.16.0
soundfile                0.12.1
soxr                     0.3.5
stack-data               0.6.2
sympy                    1.12
tabulate                 0.9.0
termcolor                2.3.0
threadpoolctl            3.2.0
tokenize-rt              5.1.0
tokenizers               0.13.3
tomli                    2.0.1
torch                    2.0.0
tqdm                     4.65.0
traitlets                5.9.0
transformers             4.31.0
triton                   2.0.0
typing_extensions        4.7.1
tzdata                   2023.3
urllib3                  2.0.4
wcwidth                  0.2.6
wheel                    0.40.0
xxhash                   3.2.0
yarl                     1.9.2

Self-service

  • I would be willing to help fix this bug myself.

Code of Conduct

  • I agree to follow this project's Code of Conduct
@Artyom17 Artyom17 changed the title bug: bug: Experimental/Whisper notebook (speedup.ipynb) is not working Jul 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant