Skip to content

Commit

Permalink
address reviewer comment
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Dong <[email protected]>
  • Loading branch information
yidong72 committed Dec 8, 2022
1 parent 09d5854 commit 5f33b3b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 7 deletions.
64 changes: 60 additions & 4 deletions nemo/collections/nlp/modules/common/megatron/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def request_data(data, port=PORT_NUM):


class RetrievalService:
"""
Abstract class for Retrieval Service.
"""

@abc.abstractmethod
def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors: int):
pass
Expand All @@ -65,6 +69,10 @@ def add_docs_to_index(self, docs: List[str], add_eos: bool = True):


class ChunkStore:
"""
ChunkStore maps chunk id to tokens. It is used as an in memory storage for dynamic retrieval DB.
"""

def __init__(self, chunk_size, pad_id):
self.store = {}
self._count = 0
Expand All @@ -85,6 +93,11 @@ def reset(self):


class SentenceBertResource(Resource):
"""
SentenceBERT Flask resource.
The PUT method is to get token/str embedding.
"""

def __init__(
self, bert_model, tokenizer, pool, sentence_bert_batch,
):
Expand All @@ -96,8 +109,6 @@ def __init__(
self.embedding_dim = self.bert_model.get_sentence_embedding_dimension()

def put(self):
# logging.info("request IP: " + str(request.remote_addr))
# logging.info(json.dumps(request.get_json()))
data = request.get_json()
if isinstance(data, dict):
return jsonify({'dim': self.embedding_dim})
Expand All @@ -122,6 +133,10 @@ def get_emb(self, query: Union[List[str], str, torch.Tensor]):


class SentenceBertServer(object):
"""
Flask SentenceBERT server, which helps to calculate str/token embeddings
"""

def __init__(
self,
devices: str,
Expand Down Expand Up @@ -154,13 +169,24 @@ def run(self, url, port=PORT_NUM_BERT):
def start_sentence_bert_server(
devices: str, tokenizer: TokenizerSpec, sentence_bert: str = 'all-mpnet-base-v2', sentence_bert_batch: int = 4
):
"""
Start the sentence bert server method.
It only starts the server at rank 0 worker.
Doesn't support multiple nodes yet.
"""

if torch.distributed.get_rank() == 0:
server = SentenceBertServer(devices, tokenizer, sentence_bert, sentence_bert_batch,)
server.run("0.0.0.0")
torch.distributed.barrier()


class FaissRetrievalResource(Resource):
"""
Static Faiss Retrieval Flask resource.
The PUT method is to get KNN tokens.
"""

def __init__(
self, index, tokenizer, ds,
):
Expand All @@ -170,8 +196,6 @@ def __init__(
self.ds = ds

def put(self):
# logging.info("request IP: " + str(request.remote_addr))
# logging.info(json.dumps(request.get_json()))
data = request.get_json()
sentences = data['sentences']
num_neighbors = data['neighbors']
Expand Down Expand Up @@ -210,6 +234,10 @@ def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors: int):


class RetrievalServer(object):
"""
Flask Retrieval server, which helps to get the KNN tokens given the query chunk
"""

def __init__(
self, faiss_index: str, faiss_devices: str, nprobe: int, retrieval_index: str, tokenizer: TokenizerSpec,
):
Expand Down Expand Up @@ -245,6 +273,11 @@ def run(self, url, port=PORT_NUM):


class DynamicRetrievalResource(FaissRetrievalResource):
"""
Dynamic Faiss Retrieval Flask resource.
The PUT method is to get KNN tokens, add new chunks, reset index.
"""

def __init__(
self, index, tokenizer, chunk_size, stride, store,
):
Expand Down Expand Up @@ -310,6 +343,10 @@ def add_docs_to_index(self, docs: List[str], add_eos: bool = True):


class DynamicRetrievalServer(object):
"""
Flask Dynamic Retrieval server, which helps to build dynamic retrieval index.
"""

def __init__(
self, faiss_devices: str, tokenizer: TokenizerSpec, chunk_size: int = 64, stride: int = 32,
):
Expand Down Expand Up @@ -351,6 +388,12 @@ def run(self, url, port=PORT_NUM_DYN):


class FaissRetrievalService(RetrievalService):
"""
Top level static retrieval service class.
It starts the server at rank 0 worker, currently doesn't support multiple nodes yet.
It implements the retrieval services interface, has a simple client to do KNN queries.
"""

def __init__(
self, faiss_index: str, faiss_devices: str, nprobe: int, retrieval_index: str, tokenizer: TokenizerSpec,
):
Expand Down Expand Up @@ -384,6 +427,13 @@ def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors):


class DynamicFaissRetrievalService(RetrievalService):
"""
Top level dynamic retrieval service class.
It starts the server at rank 0 worker, currently doesn't support multiple nodes yet.
It implements the retrieval services interface, has a simple client to add, reset and query
the dynamic retrieval index.
"""

def __init__(
self, faiss_devices: str, tokenizer: TokenizerSpec, chunk_size: int, stride: int,
):
Expand Down Expand Up @@ -432,6 +482,12 @@ def add_docs_to_index(self, query: List[str], add_eos: bool = True):


class ComboRetrievalService(RetrievalService):
"""
Top level retrieval service class.
It combines other retrieval services as a combo retrieval service.
It uses `weights` to determine the number of neighbors for each of the retrieval service members.
"""

def __init__(self, retrieval_services, weights):
self.retrieval_services = retrieval_services
self.updatable = any([service.updatable for service in retrieval_services])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def tokenize_batch(self, sentences, max_len, add_BOS):
"""
tokenizer = self.model.tokenizer
if add_BOS:
context_tokens = [[tokenizer.eos_id] + tokenizer.text_to_ids(s) for s in sentences]
context_tokens = [[tokenizer.bos_id] + tokenizer.text_to_ids(s) for s in sentences]
else:
context_tokens = [tokenizer.text_to_ids(s) for s in sentences]
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_id, max_len)
Expand Down Expand Up @@ -330,7 +330,7 @@ def tokenize_batch(self, sentences, max_len, add_BOS):
"""
tokenizer = self.model.tokenizer
if add_BOS:
context_tokens = [[tokenizer.eos_id] + tokenizer.text_to_ids(s) for s in sentences]
context_tokens = [[tokenizer.bos_id] + tokenizer.text_to_ids(s) for s in sentences]
else:
context_tokens = [tokenizer.text_to_ids(s) for s in sentences]
if self.pad_token_for_retrieval:
Expand Down Expand Up @@ -360,7 +360,7 @@ def tokenize_batch_with_context_and_completion(self, sentences, max_len, add_BOS
tokenizer = self.model.tokenizer
if add_BOS:
context_tokens = [
[[tokenizer.eos_id] + tokenizer.text_to_ids(s[0]), tokenizer.text_to_ids(s[1])] for s in sentences
[[tokenizer.bos_id] + tokenizer.text_to_ids(s[0]), tokenizer.text_to_ids(s[1])] for s in sentences
]
else:
context_tokens = [[tokenizer.text_to_ids(s[0]), tokenizer.text_to_ids(s[1])] for s in sentences]
Expand Down

0 comments on commit 5f33b3b

Please sign in to comment.