Skip to content

Commit

Permalink
Pass the label_decoder to the base class
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Sep 18, 2024
1 parent e474ee2 commit 3dfb5ad
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 7 deletions.
3 changes: 2 additions & 1 deletion stanza/models/lemma_classifier/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
logger = logging.getLogger('stanza.lemmaclassifier')

class LemmaClassifier(ABC, nn.Module):
def __init__(self, *args, **kwargs):
def __init__(self, label_decoder, *args, **kwargs):
super().__init__(*args, **kwargs)

self.label_decoder = label_decoder
self.unsaved_modules = []

def add_unsaved_module(self, name, module):
Expand Down
5 changes: 1 addition & 4 deletions stanza/models/lemma_classifier/lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_
Raises:
FileNotFoundError: if the forward or backward charlm file cannot be found.
"""
super(LemmaClassifierLSTM, self).__init__()
super(LemmaClassifierLSTM, self).__init__(label_decoder)
self.model_args = model_args

self.hidden_dim = model_args['hidden_dim']
Expand All @@ -63,9 +63,6 @@ def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_

self.input_size += self.embedding_dim

# TODO: pass this up to the parent class
self.label_decoder = label_decoder

# Optionally, include charlm embeddings
self.use_charlm = use_charlm

Expand Down
3 changes: 1 addition & 2 deletions stanza/models/lemma_classifier/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, model_args: dict, output_dim: int, transformer_name: str, lab
transformer_name (str): name of the HF transformer to use
label_decoder (dict): a map of the labels available to the model
"""
super(LemmaClassifierWithTransformer, self).__init__()
super(LemmaClassifierWithTransformer, self).__init__(label_decoder)
self.model_args = model_args

# Choose transformer
Expand All @@ -45,7 +45,6 @@ def __init__(self, model_args: dict, output_dim: int, transformer_name: str, lab
nn.ReLU(),
nn.Linear(64, output_dim)
)
self.label_decoder = label_decoder

def get_save_dict(self):
save_dict = {
Expand Down

0 comments on commit 3dfb5ad

Please sign in to comment.