Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flax whisper implementation #20479

Merged
merged 125 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
125 commits
Select commit Hold shift + click to select a range
7d3b6ef
add flax whisper implementation
andyehrenberg Nov 28, 2022
a9bed4c
rever change to setup
andyehrenberg Nov 28, 2022
0312993
remove unused imports
andyehrenberg Nov 28, 2022
c71fe4f
revert generation changes
andyehrenberg Nov 29, 2022
828d800
flax whisper docs
andyehrenberg Nov 29, 2022
baafb1c
docs
andyehrenberg Dec 1, 2022
7dba8b5
Merge branch 'huggingface:main' into flax_whisper
andyehrenberg Dec 1, 2022
2da5a58
import order
andyehrenberg Dec 1, 2022
5ee9c1f
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Dec 1, 2022
00f695f
import sorting
andyehrenberg Dec 1, 2022
0ecc03b
isort
andyehrenberg Dec 1, 2022
f66a005
add dummy objects
andyehrenberg Dec 1, 2022
175f344
doc formatting
andyehrenberg Dec 1, 2022
3329e6c
formatting
andyehrenberg Dec 1, 2022
c05089b
remove trailing whitespaces
andyehrenberg Dec 1, 2022
7551181
fix flax whisper docs
andyehrenberg Dec 1, 2022
153f2cb
Merge branch 'huggingface:main' into flax_whisper
andyehrenberg Dec 1, 2022
e255a97
add generation logic to unlock flax whisper
andyehrenberg Dec 2, 2022
f8009d7
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Dec 2, 2022
d003074
remove scans
andyehrenberg Dec 2, 2022
ba8a358
give credits to Flax Bart implementation
andyehrenberg Dec 2, 2022
9f4578d
remove unused imports
andyehrenberg Dec 2, 2022
be33fbd
add license
andyehrenberg Dec 2, 2022
8b1338b
remove assert
andyehrenberg Dec 2, 2022
c567f79
more credits to Bart
andyehrenberg Dec 2, 2022
fbe4e25
fix style
andyehrenberg Dec 2, 2022
cde5afd
formatting
andyehrenberg Dec 2, 2022
6aeb8c8
support left padding
andyehrenberg Dec 2, 2022
ec9ca19
add flax whisper generation test
andyehrenberg Dec 5, 2022
8bce923
Merge branch 'huggingface:main' into flax_whisper
andyehrenberg Dec 6, 2022
3f902f6
remove copied from comments whenever not a full copy
andyehrenberg Dec 7, 2022
3fd0a7c
fix docstrings for logits processors
andyehrenberg Dec 7, 2022
abc14a1
revert change to FlaxForceTokensLogitsProcessor
andyehrenberg Dec 7, 2022
d784a23
revert doc changes
andyehrenberg Dec 7, 2022
3dd8282
improve generation docs
andyehrenberg Dec 7, 2022
77fce32
reorganize
andyehrenberg Dec 7, 2022
fefefde
formatting
andyehrenberg Dec 7, 2022
04ad651
cleanup docs
andyehrenberg Dec 7, 2022
14e19c0
add tests
andyehrenberg Dec 7, 2022
cf67b38
handle empty list case
andyehrenberg Dec 7, 2022
3de7509
fix forced decoder ids in flax tests
andyehrenberg Dec 8, 2022
1077588
Merge branch 'huggingface:main' into flax_whisper
andyehrenberg Dec 9, 2022
5e2256a
add flax whisper to inits
andyehrenberg Dec 12, 2022
ada32b8
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Dec 12, 2022
669db4e
upate dummy objects
andyehrenberg Dec 12, 2022
bea6cf0
docs for FlaxAutoModelForSpeechSeq2Seq
andyehrenberg Dec 12, 2022
e4270b4
fix decoder_position_ids computation in pretrained model decode/__cal…
andyehrenberg Dec 14, 2022
135b634
add Copied from statements as necessary
andyehrenberg Dec 15, 2022
21fe767
compute position_ids only in __call__ and decode methods of pretraine…
andyehrenberg Dec 16, 2022
a901674
improve readabilityof compute positional embeddings
andyehrenberg Dec 16, 2022
f8d4686
check dimensionality of input_features instead of hidden_states
andyehrenberg Dec 16, 2022
b407611
copied from statement for init_cache
andyehrenberg Dec 16, 2022
8e78c86
formatting
andyehrenberg Dec 16, 2022
810358c
fix copies
andyehrenberg Dec 16, 2022
b06a6ba
fix copies
andyehrenberg Dec 16, 2022
45efd60
pass attention mask to encoder layers
andyehrenberg Dec 21, 2022
718f53b
fix decoder module outputs
andyehrenberg Dec 21, 2022
07a24a8
set dtype
andyehrenberg Dec 22, 2022
43c4ed8
smaller flax model for whisper test
andyehrenberg Dec 22, 2022
ecaac58
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Dec 22, 2022
7b35907
Update src/transformers/generation/flax_utils.py
andyehrenberg Dec 31, 2022
8a4d990
Update src/transformers/models/whisper/modeling_flax_whisper.py
andyehrenberg Dec 31, 2022
17c22fe
Update tests/models/whisper/test_modeling_flax_whisper.py
andyehrenberg Dec 31, 2022
8c021ae
cleanup
andyehrenberg Dec 31, 2022
2aed9af
Update src/transformers/models/whisper/modeling_flax_whisper.py
andyehrenberg Dec 31, 2022
64da8fa
bias cleanup
andyehrenberg Dec 31, 2022
6fc7404
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Dec 31, 2022
618f85b
doc fix
andyehrenberg Dec 31, 2022
8b56bf4
align style for force tokens processor
andyehrenberg Jan 2, 2023
209834d
readability
andyehrenberg Jan 3, 2023
fac30a0
fix input shape in tests
andyehrenberg Jan 3, 2023
aa87c98
revert FlaxGenerationMixin docstring
andyehrenberg Jan 3, 2023
23af05b
formatting
andyehrenberg Jan 3, 2023
b8086b6
fix tests
andyehrenberg Jan 3, 2023
acef3e0
fix imports
andyehrenberg Jan 3, 2023
da1df33
consistent encoder hidden states
andyehrenberg Jan 3, 2023
4cdba95
consistent hidden states
andyehrenberg Jan 3, 2023
dd7473b
input shapes
andyehrenberg Jan 3, 2023
c5621f7
typo
andyehrenberg Jan 3, 2023
46aec12
partial class trick
andyehrenberg Jan 3, 2023
a003616
partial class for input shape
andyehrenberg Jan 3, 2023
a9604a5
base_class with correct input shape
andyehrenberg Jan 3, 2023
5120afe
partial base classes
andyehrenberg Jan 3, 2023
c6b1ae4
match by name
andyehrenberg Jan 3, 2023
4c239fc
set main_input_name
andyehrenberg Jan 4, 2023
279ceb6
compare on names
andyehrenberg Jan 4, 2023
b81630e
Merge branch 'main' into flax_whisper
andyehrenberg Jan 9, 2023
797fab1
formatting
andyehrenberg Jan 9, 2023
f3173d8
remove unused import
andyehrenberg Jan 9, 2023
b4696ca
safer position ids computation
andyehrenberg Jan 10, 2023
1c11ca6
safer position id computation
andyehrenberg Jan 10, 2023
c128fd8
Update src/transformers/models/whisper/modeling_flax_whisper.py
andyehrenberg Jan 18, 2023
2ae5b08
Update src/transformers/models/whisper/modeling_flax_whisper.py
andyehrenberg Jan 18, 2023
48583bd
remove identical inherited tests
andyehrenberg Jan 18, 2023
c93232f
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Jan 18, 2023
1c18f61
fix prompt ids in tests
andyehrenberg Jan 18, 2023
c3b1d34
use generation config
andyehrenberg Jan 18, 2023
bf15d5f
use jnp array
andyehrenberg Jan 18, 2023
c5fc14b
better var names
andyehrenberg Jan 18, 2023
161cb8a
more explicit bias use
andyehrenberg Jan 18, 2023
d9cedb9
Merge branch 'main' into flax_whisper
andyehrenberg Jan 18, 2023
bb9d0af
import transformers
andyehrenberg Jan 18, 2023
f1d90d2
formatting
andyehrenberg Jan 18, 2023
733ae2b
test formatting
andyehrenberg Jan 18, 2023
6295691
remove unused imports
andyehrenberg Jan 18, 2023
902555e
remove unused imports
andyehrenberg Jan 18, 2023
cba4942
formatting
andyehrenberg Jan 18, 2023
0173945
isort
andyehrenberg Jan 18, 2023
48640e5
docs
andyehrenberg Jan 18, 2023
1daee2b
fix ln orders for encoder hidden states
andyehrenberg Jan 26, 2023
fdb0a61
Merge branch 'main' into flax_whisper
andyehrenberg Feb 3, 2023
632c4be
whisper unique generation stuff
andyehrenberg Feb 3, 2023
95403d6
Merge branch 'flax_whisper' of github.com:andyehrenberg/transformers …
andyehrenberg Feb 3, 2023
c5c3ac1
flake
andyehrenberg Feb 3, 2023
907905f
use finfo for attention bias
andyehrenberg Feb 3, 2023
9dbcda8
docs
andyehrenberg Feb 3, 2023
d36cd2c
Update src/transformers/generation/flax_utils.py
andyehrenberg Feb 14, 2023
ab01cfc
docs
andyehrenberg Feb 14, 2023
62d172a
add timestamp flax test
andyehrenberg Feb 14, 2023
455b8bf
jit for timestamps
andyehrenberg Feb 14, 2023
89658d0
formatting
andyehrenberg Feb 14, 2023
a75fd03
clean up timestamps processor
andyehrenberg Feb 15, 2023
758d56c
formatting
andyehrenberg Feb 15, 2023
f9ac652
remove if_true
andyehrenberg Feb 17, 2023
94a526e
cleanup
andyehrenberg Feb 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add flax whisper implementation
  • Loading branch information
andyehrenberg committed Nov 28, 2022
commit 7d3b6ef3ac10feecb29a7a4a4f26325856f2d782
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ Flax), PyTorch, and/or TensorFlow.
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
| Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
| Whisper | ✅ | ❌ | ✅ | ✅ | |
| Whisper | ✅ | ❌ | ✅ | ✅ | |
| X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ |
| XGLM | ✅ | ✅ | ✅ | ✅ | ✅ |
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
Expand Down
11 changes: 11 additions & 0 deletions docs/source/en/model_doc/whisper.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,14 @@ The original code can be found [here](https://github.com/openai/whisper).

[[autodoc]] TFWhisperForConditionalGeneration
- call


## FlaxWhisperModel

[[autodoc]] FlaxWhisperModel
- __call__

## FlaxWhisperForConditionalGeneration

[[autodoc]] FlaxWhisperForConditionalGeneration
- __call__
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
"starlette",
"tensorflow-cpu>=2.4,<2.11",
"tensorflow>=2.4,<2.11",
"tensorflow-text",
#"tensorflow-text",
"tf2onnx",
"timeout-decorator",
"timm",
Expand Down Expand Up @@ -247,8 +247,8 @@ def run(self):
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "pyknp")
extras["sklearn"] = deps_list("scikit-learn")

extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text")
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text")
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx")
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx")

extras["torch"] = deps_list("torch")
extras["accelerate"] = deps_list("accelerate")
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3223,6 +3223,13 @@
_import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
)
_import_structure["models.whisper"].extend(
[
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
]
)
_import_structure["models.xglm"].extend(
[
"FlaxXGLMForCausalLM",
Expand Down Expand Up @@ -5872,6 +5879,11 @@
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
from .models.whisper import (
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,
)
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
from .models.xlm_roberta import (
FlaxXLMRobertaForMaskedLM,
Expand Down
45 changes: 44 additions & 1 deletion src/transformers/generation/flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,53 @@ def __init__(self, min_length: int, eos_token_id: int):
self.eos_token_id = eos_token_id

def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:

andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
# create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)

scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores)

return scores


class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxSuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts
generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not
sampled at the begining of the generation.
"""

def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index

def __call__(self, input_ids, scores, cur_len: int):
if input_ids.shape[1] == self.begin_index:
scores = scores.at[:, self.begin_suppress_tokens].set(-float("inf"))

return scores


class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor):
def __init__(self, suppress_tokens: list):
self.suppress_tokens = list(suppress_tokens)

def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
scores = scores.at[..., self.suppress_tokens].set(-float("inf"))

return scores


class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor):
andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
r"""This processor can be used to force a list of tokens. The processor will set their log probs to `inf` so that they
are sampled at their corresponding index."""

def __init__(self, force_token_map):
self.force_token_map = dict(force_token_map)

def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int):
generation_idx = input_ids.shape[-1]
current_token = self.force_token_map.get(generation_idx, None)
if current_token is not None:
scores = scores.at[:, :].set(-float("inf"))
scores = scores.at[:, current_token].set(0)
return scores
99 changes: 94 additions & 5 deletions src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import inspect
import warnings
from functools import partial
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import numpy as np

Expand All @@ -36,8 +36,11 @@
from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxForceTokensLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
Expand Down Expand Up @@ -155,6 +158,35 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, mode
model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
return model_kwargs

def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
) -> jnp.ndarray:
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
# Only use this arg if not None, otherwise just remove from model_kwargs
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
if decoder_input_ids is not None:
return decoder_input_ids
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return jnp.array(decoder_start_token_id).reshape(1, -1).repeat(batch_size, axis=0)
andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved

def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
)
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id

if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)

@staticmethod
def _expand_to_num_beams(tensor, num_beams):
return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
Expand Down Expand Up @@ -227,6 +259,9 @@ def generate(
min_length: Optional[int] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[int]] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
trace: bool = True,
Expand Down Expand Up @@ -334,12 +369,19 @@ def generate(
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)

batch_size = input_ids.shape[0]

if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
if model_kwargs.get("encoder_outputs") is None:
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
# prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id,
model_kwargs=model_kwargs,
)

# Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
Expand Down Expand Up @@ -382,7 +424,16 @@ def generate(

if not do_sample and num_beams == 1:
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
no_repeat_ngram_size,
min_length,
max_length,
eos_token_id,
forced_bos_token_id,
forced_eos_token_id,
input_ids_seq_length,
suppress_tokens=suppress_tokens,
begin_suppress_tokens=begin_suppress_tokens,
andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
forced_decoder_ids=forced_decoder_ids,
)
return self._greedy_search(
input_ids,
Expand All @@ -397,7 +448,16 @@ def generate(
elif do_sample and num_beams == 1:
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
no_repeat_ngram_size,
min_length,
max_length,
eos_token_id,
forced_bos_token_id,
forced_eos_token_id,
input_ids_seq_length,
suppress_tokens=suppress_tokens,
begin_suppress_tokens=begin_suppress_tokens,
andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
forced_decoder_ids=forced_decoder_ids,
)
return self._sample(
input_ids,
Expand Down Expand Up @@ -426,7 +486,16 @@ def generate(
)

logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
no_repeat_ngram_size,
min_length,
max_length,
eos_token_id,
forced_bos_token_id,
forced_eos_token_id,
input_ids_seq_length,
suppress_tokens=suppress_tokens,
begin_suppress_tokens=begin_suppress_tokens,
andyehrenberg marked this conversation as resolved.
Show resolved Hide resolved
forced_decoder_ids=forced_decoder_ids,
)

return self._beam_search(
Expand Down Expand Up @@ -478,6 +547,10 @@ def _get_logits_processor(
eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
input_ids_seq_length: int,
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[int]] = None,
) -> FlaxLogitsProcessorList:
"""
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
Expand All @@ -496,6 +569,12 @@ def _get_logits_processor(
forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
)
suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens
begin_suppress_tokens = (
begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens
)
if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"):
forced_decoder_ids = self.config.forced_decoder_ids

# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
Expand All @@ -505,6 +584,16 @@ def _get_logits_processor(
processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
if suppress_tokens is not None:
processors.append(FlaxSuppressTokensLogitsProcessor(suppress_tokens))
if begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1
if forced_decoder_ids is not None:
begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced
processors.append(FlaxSuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index))
if forced_decoder_ids is not None:
processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
return processors

def _greedy_search(
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("vit", "FlaxViTModel"),
("wav2vec2", "FlaxWav2Vec2Model"),
("whisper", "FlaxWhisperModel"),
("xglm", "FlaxXGLMModel"),
("xlm-roberta", "FlaxXLMRobertaModel"),
]
Expand All @@ -73,6 +74,7 @@
("roformer", "FlaxRoFormerForMaskedLM"),
("t5", "FlaxT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
("whisper", "FlaxWhisperForConditionalGeneration"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
]
)
Expand Down Expand Up @@ -208,6 +210,7 @@
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
("whisper", "FlaxWhisperForConditionalGeneration"),
]
)

Expand Down
33 changes: 32 additions & 1 deletion src/transformers/models/whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)


_import_structure = {
Expand Down Expand Up @@ -54,6 +60,19 @@
"TFWhisperPreTrainedModel",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_whisper"] = [
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig, WhisperOnnxConfig
from .feature_extraction_whisper import WhisperFeatureExtractor
Expand Down Expand Up @@ -86,6 +105,18 @@
TFWhisperPreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_whisper import (
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,
)

else:
import sys

Expand Down
Loading