Skip to content

Commit

Permalink
Keep the target words (words whcih we trained to recognize) in a set.…
Browse files Browse the repository at this point in the history
… Will be useful for integrating with the Pipeline
  • Loading branch information
AngledLuffa committed Sep 19, 2024
1 parent aa8fc24 commit 3064b33
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 11 deletions.
13 changes: 11 additions & 2 deletions stanza/models/lemma_classifier/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
logger = logging.getLogger('stanza.lemmaclassifier')

class LemmaClassifier(ABC, nn.Module):
def __init__(self, label_decoder, *args, **kwargs):
def __init__(self, label_decoder, target_words, *args, **kwargs):
super().__init__(*args, **kwargs)

self.label_decoder = label_decoder
self.target_words = target_words
self.unsaved_modules = []

def add_unsaved_module(self, name, module):
Expand All @@ -49,6 +50,9 @@ def model_type(self):
return a ModelType
"""

def target_indices(self, sentence):
return [idx for idx, word in enumerate(sentence) if word.lower() in self.target_words]

@staticmethod
def from_checkpoint(checkpoint, args=None):
model_type = checkpoint['model_type']
Expand Down Expand Up @@ -81,6 +85,7 @@ def from_checkpoint(checkpoint, args=None):
label_decoder=checkpoint['label_decoder'],
upos_to_id=checkpoint['upos_to_id'],
known_words=checkpoint['known_words'],
target_words=checkpoint['target_words'],
use_charlm=use_charlm,
charlm_forward_file=charlm_forward_file,
charlm_backward_file=charlm_backward_file)
Expand All @@ -90,7 +95,11 @@ def from_checkpoint(checkpoint, args=None):
output_dim = len(checkpoint['label_decoder'])
saved_args = checkpoint['args']
bert_model = saved_args['bert_model']
model = LemmaClassifierWithTransformer(model_args = saved_args, output_dim=output_dim, transformer_name=bert_model, label_decoder=checkpoint['label_decoder'])
model = LemmaClassifierWithTransformer(model_args=saved_args,
output_dim=output_dim,
transformer_name=bert_model,
label_decoder=checkpoint['label_decoder'],
target_words=checkpoint['target_words'])
else:
raise ValueError("Unknown model type %s" % model_type)

Expand Down
3 changes: 2 additions & 1 deletion stanza/models/lemma_classifier/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
self.output_dim = len(label_decoder)
logger.info(f"Loaded dataset successfully from {train_file}")
logger.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}")
logger.info(f"Target words: {dataset.target_words}")

self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words)
self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words, dataset.target_words)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

self.model.to(device)
Expand Down
6 changes: 4 additions & 2 deletions stanza/models/lemma_classifier/lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class LemmaClassifierLSTM(LemmaClassifier):
From the LSTM output, we get the embedding of the specific token that we classify on. That embedding
is fed into an MLP for classification.
"""
def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words,
def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words, target_words,
use_charlm=False, charlm_forward_file=None, charlm_backward_file=None):
"""
Args:
Expand All @@ -30,6 +30,7 @@ def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_
upos_to_id (Mapping[str, int]): A dictionary mapping UPOS tag strings to their respective IDs
pt_embedding (Pretrain): pretrained embeddings
known_words (list(str)): Words which are in the training data
target_words (set(str)): a set of the words which might need lemmatization
use_charlm (bool): Whether or not to use the charlm embeddings
charlm_forward_file (str): The path to the forward pass model for the character language model
charlm_backward_file (str): The path to the forward pass model for the character language model.
Expand All @@ -41,7 +42,7 @@ def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_
Raises:
FileNotFoundError: if the forward or backward charlm file cannot be found.
"""
super(LemmaClassifierLSTM, self).__init__(label_decoder)
super(LemmaClassifierLSTM, self).__init__(label_decoder, target_words)
self.model_args = model_args

self.hidden_dim = model_args['hidden_dim']
Expand Down Expand Up @@ -113,6 +114,7 @@ def get_save_dict(self):
"args": self.model_args,
"upos_to_id": self.upos_to_id,
"known_words": self.known_words,
"target_words": self.target_words,
}
skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)]
for k in skipped:
Expand Down
4 changes: 2 additions & 2 deletions stanza/models/lemma_classifier/train_lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(self, model_args: dict, embedding_file: str, use_charlm: bool = Fal
else:
raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')")

def build_model(self, label_decoder, upos_to_id, known_words):
return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words,
def build_model(self, label_decoder, upos_to_id, known_words, target_words):
return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, target_words,
use_charlm=self.use_charlm, charlm_forward_file=self.charlm_forward_file, charlm_backward_file=self.charlm_backward_file)

def build_argparse():
Expand Down
4 changes: 2 additions & 2 deletions stanza/models/lemma_classifier/train_transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torc
])
return optimizer

def build_model(self, label_decoder, upos_to_id, known_words):
return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder)
def build_model(self, label_decoder, upos_to_id, known_words, target_words):
return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder, target_words=target_words)


def main(args=None, predefined_args=None):
Expand Down
6 changes: 4 additions & 2 deletions stanza/models/lemma_classifier/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
logger = logging.getLogger('stanza.lemmaclassifier')

class LemmaClassifierWithTransformer(LemmaClassifier):
def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping):
def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping, target_words: set):
"""
Model architecture:
Expand All @@ -27,8 +27,9 @@ def __init__(self, model_args: dict, output_dim: int, transformer_name: str, lab
output_dim (int): Dimension of the output from the MLP
transformer_name (str): name of the HF transformer to use
label_decoder (dict): a map of the labels available to the model
target_words (set(str)): a set of the words which might need lemmatization
"""
super(LemmaClassifierWithTransformer, self).__init__(label_decoder)
super(LemmaClassifierWithTransformer, self).__init__(label_decoder, target_words)
self.model_args = model_args

# Choose transformer
Expand All @@ -50,6 +51,7 @@ def get_save_dict(self):
save_dict = {
"params": self.state_dict(),
"label_decoder": self.label_decoder,
"target_words": self.target_words,
"model_type": self.model_type(),
"args": self.model_args,
}
Expand Down
6 changes: 6 additions & 0 deletions stanza/models/lemma_classifier/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_coun

logger.debug("Final label decoder: %s Should be strings to ints", label_decoder)

# words which we are analyzing
target_words = set()

# all known words in the dataset, not just target words
known_words = set()

with open(data_path, "r+", encoding="utf-8") as f:
Expand Down Expand Up @@ -78,6 +82,7 @@ def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_coun
if get_counts:
counts[label_decoder[label]] += 1

target_words.add(words[target_idx])
known_words.update(words)

self.sentences = sentences
Expand All @@ -93,6 +98,7 @@ def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_coun
self.shuffle = shuffle

self.known_words = [x.lower() for x in sorted(known_words)]
self.target_words = set(x.lower() for x in target_words)

def __len__(self):
"""
Expand Down

0 comments on commit 3064b33

Please sign in to comment.