Skip to content

Commit

Permalink
Merge branch 'main' into llm-prompt-learning-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
vadam5 committed Sep 26, 2022
2 parents d23bf6c + e3ac280 commit 770967a
Show file tree
Hide file tree
Showing 27 changed files with 114 additions and 116 deletions.
2 changes: 1 addition & 1 deletion ci.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ spec:
path: /vol/scratch1/scratch.okuchaiev_blossom
containers:
- name: cuda
image: nvcr.io/nvidia/pytorch:22.05-py3
image: nvcr.io/nvidia/pytorch:22.08-py3
command:
- cat
volumeMounts:
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/metrics/multi_binary_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def validation_epoch_end(self, outputs):
F1 score calculated from the predicted value and binarized target values.
"""

full_state_update = False

def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.total_correct_counts = 0
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/common/metrics/classification_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def validation_epoch_end(self, outputs):
accuracy, compute acc=correct_count/total_count
"""

full_state_update = True

def __init__(self, top_k=None, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)

Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/common/metrics/global_average_loss_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class GlobalAverageLossMetric(Metric):
values of :meth:`update` method ``loss`` argument has to be a sum of losses. default: ``True``
"""

full_state_update = True

def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, take_avg_loss=True):
super().__init__(
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/common/metrics/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class Perplexity(Metric):
``probs`` last dim has to be valid probability distribution.
"""

full_state_update = True

def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, validate_args=True):
super().__init__(
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/common/tokenizers/column_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(
if transform == 'yeo-johnson':
self.scaler = PowerTransformer(standardize=True)
elif transform == 'quantile':
self.scaler = QuantileTransformer(output_distribution='uniform')
self.scaler = QuantileTransformer(output_distribution='uniform', n_quantiles=100)
elif transform == 'robust':
self.scaler = RobustScaler()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,23 @@ def __init__(
self.g2p = g2p

def encode(self, text):
"""See base class."""
ps, space, tokens = [], self.tokens[self.space], set(self.tokens)
"""See base class for more information."""

text = self.text_preprocessing_func(text)
g2p_text = self.g2p(text) # TODO: handle infer
return self.encode_from_g2p(g2p_text, text)

def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None):
"""
Encodes text that has already been run through G2P.
Called for encoding to tokens after text preprocessing and G2P.
Args:
g2p_text: G2P's output, could be a mixture of phonemes and graphemes,
e.g. "see OOV" -> ['S', 'IY1', ' ', 'O', 'O', 'V']
raw_text: original raw input
"""
ps, space, tokens = [], self.tokens[self.space], set(self.tokens)
for p in g2p_text: # noqa
# Remove stress
if p.isalnum() and len(p) == 3 and not self.stresses:
Expand All @@ -355,9 +366,10 @@ def encode(self, text):
ps.append(p)
# Warn about unknown char/phoneme
elif p != space:
logging.warning(
f"Text: [{''.join(g2p_text)}] contains unknown char/phoneme: [{p}]. Original text: [{text}]. Symbol will be skipped."
)
message = f"Text: [{''.join(g2p_text)}] contains unknown char/phoneme: [{p}]."
if raw_text is not None:
message += f"Original text: [{raw_text}]. Symbol will be skipped."
logging.warning(message)

# Remove trailing spaces
if ps:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))


dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16}
dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float64, 7: np.double, 8: np.uint16}


def code(dtype):
Expand Down Expand Up @@ -293,7 +293,7 @@ def __getitem__(self, idx):


class IndexedDatasetBuilder(object):
element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, np.float: 4, np.double: 8}
element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, np.float64: 4, np.double: 8}

def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
__all__ = ["KNNIndex", "MMapRetrievalIndexedDataset", "MMapRetrievalIndexedDatasetBuilder"]


dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16}
dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float64, 7: np.double, 8: np.uint16}


def code(dtype):
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/nlp/metrics/classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def validation_epoch_end(self, outputs):
aggregated precision, recall, f1, report
"""

full_state_update = True

def __init__(
self,
num_classes: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,10 @@ def setup_test_data(self, cfg):
)
self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples)

def on_pretrain_routine_start(self) -> None:
def on_fit_start(self) -> None:
# keep a copy of init_global_step
self.init_global_step = self.trainer.global_step
return super().on_pretrain_routine_start()
return super().on_fit_start()

def compute_consumed_samples(self, steps_since_resume=0):
app_state = AppState()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,10 @@ def id_func(output_tensor):

return fwd_output_only_func

def on_pretrain_routine_start(self) -> None:
def on_fit_start(self) -> None:
# keep a copy of init_global_step
self.init_global_step = self.trainer.global_step
return super().on_pretrain_routine_start()
return super().on_fit_start()

def validation_step(self, batch, batch_idx):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -881,10 +881,10 @@ def setup_training_data(self, cfg):
self._train_ds, consumed_samples, num_workers=self._cfg.data.num_workers
)

def on_pretrain_routine_start(self) -> None:
def on_fit_start(self) -> None:
# keep a copy of init_global_step
self.init_global_step = self.trainer.global_step
return super().on_pretrain_routine_start()
return super().on_fit_start()

def setup_validation_data(self, cfg):
if hasattr(self, '_validation_ds'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,10 @@ def forward(
)
return output_tensor

def on_pretrain_routine_start(self) -> None:
def on_fit_start(self) -> None:
# keep a copy of init_global_step
self.init_global_step = self.trainer.global_step
return super().on_pretrain_routine_start()
return super().on_fit_start()

def training_step(self, batch, batch_idx):
input_tokens_id = batch['tokens']
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(self, c):
self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, bias=False)

# Sample a random orthonormal matrix to initialize weights
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0]

# Ensure determinant is 1.0 not -1.0
if torch.det(W) < 0:
Expand Down
1 change: 0 additions & 1 deletion nemo/core/config/pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class TrainerConfig:
accelerator: Optional[str] = None
sync_batchnorm: bool = False
precision: Any = 32
weights_save_path: Optional[str] = None
num_sanity_val_steps: int = 2
resume_from_checkpoint: Optional[str] = None
profiler: Optional[Any] = None
Expand Down
7 changes: 1 addition & 6 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def configure_checkpointing(
trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig'
):
""" Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint
callback or if trainer.weights_save_path was passed to Trainer.
callback
"""
for callback in trainer.callbacks:
if isinstance(callback, ModelCheckpoint):
Expand All @@ -876,11 +876,6 @@ def configure_checkpointing(
"and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback "
"to False, or remove ModelCheckpoint from the lightning trainer"
)
if Path(trainer.weights_save_path) != Path.cwd():
raise CheckpointMisconfigurationError(
"The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager"
)

# Create the callback and attach it to trainer
if "filepath" in params:
if params.filepath is not None:
Expand Down
2 changes: 1 addition & 1 deletion nemo_text_processing/text_normalization/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def split_text_into_sentences(self, text: str) -> List[str]:
upper_case_unicode = '\u0410-\u042F'

# Read and split transcript by utterance (roughly, sentences)
split_pattern = f"(?<!\w\.\w.)(?<![A-Z{upper_case_unicode}][a-z{lower_case_unicode}]+\.)(?<![A-Z{upper_case_unicode}]\.)(?<=\.|\?|\!|\.”|\?”\!”)\s(?![0-9]+[a-z]*\.)"
split_pattern = rf"(?<!\w\.\w.)(?<![A-Z{upper_case_unicode}][a-z{lower_case_unicode}]+\.)(?<![A-Z{upper_case_unicode}]\.)(?<=\.|\?|\!|\.”|\?”\!”)\s(?![0-9]+[a-z]*\.)"

sentences = regex.split(split_pattern, text)
return sentences
Expand Down
125 changes: 49 additions & 76 deletions scripts/dataset_processing/get_ami_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,105 +20,78 @@

from nemo.collections.asr.parts.utils.manifest_utils import create_manifest

# todo: once https://github.com/tango4j/diarization_annotation/pull/1 merged, we can use the same repo
test_rttm_url = (
"https://raw.githubusercontent.com/tango4j/diarization_annotation/main/AMI_corpus/test/split_rttms.tar.gz"
)
dev_rttm_url = (
"https://raw.githubusercontent.com/SeanNaren/diarization_annotation/dev/AMI_corpus/dev/split_rttms.tar.gz"
)
rttm_url = "https://raw.githubusercontent.com/BUTSpeechFIT/AMI-diarization-setup/main/only_words/rttms/{}/{}.rttm"
uem_url = "https://raw.githubusercontent.com/BUTSpeechFIT/AMI-diarization-setup/main/uems/{}/{}.uem"
list_url = "https://raw.githubusercontent.com/BUTSpeechFIT/AMI-diarization-setup/main/lists/{}.meetings.txt"

test_set_ids = [
"EN2002a",
"EN2002b",
"EN2002c",
"EN2002d",
"ES2004a",
"ES2004b",
"ES2004c",
"ES2004d",
"ES2014a",
"ES2014b",
"ES2014c",
"ES2014d",
"IS1009a",
"IS1009b",
"IS1009c",
"IS1009d",
"TS3003a",
"TS3003b",
"TS3003c",
"TS3003d",
"TS3007a",
"TS3007b",
"TS3007c",
"TS3007d",
]

dev_set_ids = [
"IS1008a",
"IS1008b",
"IS1008c",
"IS1008d",
"ES2011a",
"ES2011b",
"ES2011c",
"ES2011d",
"TS3004a",
"TS3004b",
"TS3004c",
"TS3004d",
"IB4001",
"IB4002",
"IB4003",
"IB4004",
"IB4010",
"IB4011",
]
audio_types = ['Mix-Headset', 'Array1-01']

# these two IDs in the train set are missing download links for Array1-01.
# We exclude them as a result.
not_found_ids = ['IS1007d', 'IS1003b']

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download the AMI Test Corpus Dataset for Speaker Diarization")
parser = argparse.ArgumentParser(description="Download the AMI Corpus Dataset for Speaker Diarization")
parser.add_argument(
"--test_manifest_filepath",
help="path to output test manifest file",
type=str,
default='AMItest_input_manifest.json',
default='AMI_test_manifest.json',
)
parser.add_argument(
"--dev_manifest_filepath",
help="path to output test manifest file",
"--dev_manifest_filepath", help="path to output dev manifest file", type=str, default='AMI_dev_manifest.json',
)
parser.add_argument(
"--train_manifest_filepath",
help="path to output train manifest file",
type=str,
default='AMIdev_input_manifest.json',
default='AMI_train_manifest.json',
)
parser.add_argument("--data_root", help="path to output data directory", type=str, default="ami_dataset")
args = parser.parse_args()

data_path = os.path.abspath(args.data_root)
os.makedirs(data_path, exist_ok=True)

for ids, manifest_path, split, rttm_url in (
(test_set_ids, args.test_manifest_filepath, 'test', test_rttm_url),
(dev_set_ids, args.dev_manifest_filepath, 'dev', dev_rttm_url),
for manifest_path, split in (
(args.test_manifest_filepath, 'test'),
(args.dev_manifest_filepath, 'dev'),
(args.train_manifest_filepath, 'train'),
):
split_path = os.path.join(data_path, split)
audio_path = os.path.join(split_path, "audio")
os.makedirs(split_path, exist_ok=True)
rttm_path = os.path.join(split_path, "split_rttms")
rttm_path = os.path.join(split_path, "rttm")
uem_path = os.path.join(split_path, "uem")

for id in ids:
os.system(
f"wget -P {audio_path} https://groups.inf.ed.ac.uk/ami/AMICorpusMirror//amicorpus/{id}/audio/{id}.Mix-Headset.wav"
)

if not os.path.exists(f"{split_path}/split_rttms.tar.gz"):
os.system(f"wget -P {split_path} {rttm_url}")
os.system(f"tar -xzvf {split_path}/split_rttms.tar.gz -C {split_path}")
os.system(f"wget -P {split_path} {list_url.format(split)}")
with open(os.path.join(split_path, f"{split}.meetings.txt")) as f:
ids = f.read().strip().split('\n')
for id in [file_id for file_id in ids if file_id not in not_found_ids]:
for audio_type in audio_types:
audio_type_path = os.path.join(audio_path, audio_type)
os.makedirs(audio_type_path, exist_ok=True)
os.system(
f"wget -P {audio_type_path} https://groups.inf.ed.ac.uk/ami/AMICorpusMirror//amicorpus/{id}/audio/{id}.{audio_type}.wav"
)
rttm_download = rttm_url.format(split, id)
os.system(f"wget -P {rttm_path} {rttm_download}")
uem_download = uem_url.format(split, id)
os.system(f"wget -P {uem_path} {uem_download}")

audio_files_path = os.path.join(split_path, 'audio_files.txt')
rttm_files_path = os.path.join(split_path, 'rttm_files.txt')
with open(audio_files_path, 'w') as f:
f.write('\n'.join(os.path.join(audio_path, p) for p in os.listdir(audio_path)))
with open(rttm_files_path, 'w') as f:
f.write('\n'.join(os.path.join(rttm_path, p) for p in os.listdir(rttm_path)))

create_manifest(audio_files_path, manifest_path, rttm_path=rttm_files_path)
uem_files_path = os.path.join(split_path, 'uem_files.txt')
with open(uem_files_path, 'w') as f:
f.write('\n'.join(os.path.join(uem_path, p) for p in os.listdir(uem_path)))
for audio_type in audio_types:
audio_type_path = os.path.join(audio_path, audio_type)
audio_files_path = os.path.join(split_path, f'audio_files_{audio_type}.txt')
with open(audio_files_path, 'w') as f:
f.write('\n'.join(os.path.join(audio_type_path, p) for p in os.listdir(audio_type_path)))
audio_type_manifest_path = manifest_path.replace('.json', f'.{audio_type}.json')
create_manifest(
audio_files_path, audio_type_manifest_path, rttm_path=rttm_files_path, uem_path=uem_files_path
)
Loading

0 comments on commit 770967a

Please sign in to comment.