Skip to content

Commit

Permalink
resolve_is_lower_case fix (#1204)
Browse files Browse the repository at this point in the history
* Albert check
  • Loading branch information
zphang committed Oct 19, 2020
1 parent 5724fee commit b20f30a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 3 additions & 1 deletion jiant/shared/model_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,10 @@ def resolve_tokenizer_class(model_type):


def resolve_is_lower_case(tokenizer):
if isinstance(tokenizer, (transformers.BertTokenizer, transformers.AlbertTokenizer)):
if isinstance(tokenizer, transformers.BertTokenizer):
return tokenizer.basic_tokenizer.do_lower_case
if isinstance(tokenizer, transformers.AlbertTokenizer):
return tokenizer.do_lower_case
else:
return False

Expand Down
6 changes: 4 additions & 2 deletions jiant/tasks/lib/templates/span_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from typing import List, Tuple

from jiant.shared.model_resolution import resolve_is_lower_case
import jiant.shared.model_resolution as model_resolution
from jiant.tasks.core import (
Task,
TaskTypes,
Expand Down Expand Up @@ -33,7 +33,9 @@ class Example(BaseExample):

def tokenize(self, tokenizer):
passage = (
self.passage.lower() if resolve_is_lower_case(tokenizer=tokenizer) else self.passage
self.passage.lower()
if model_resolution.resolve_is_lower_case(tokenizer=tokenizer)
else self.passage
)
passage_tokens = tokenizer.tokenize(passage)
token_aligner = TokenAligner(source=passage, target=passage_tokens)
Expand Down

0 comments on commit b20f30a

Please sign in to comment.