Skip to content

Commit

Permalink
Merge pull request facebookresearch#191 from philipmorrisintl/master
Browse files Browse the repository at this point in the history
Elasticsearch integration
  • Loading branch information
ajfisch committed Mar 8, 2019
2 parents fdbdf1c + e387a26 commit 1f811de
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 3 deletions.
12 changes: 9 additions & 3 deletions drqa/pipeline/drqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,15 @@ def __init__(
annotators = tokenizers.get_annotators_for_model(self.reader)
tok_opts = {'annotators': annotators}

db_config = db_config or {}
db_class = db_config.get('class', DEFAULTS['db'])
db_opts = db_config.get('options', {})
# ElasticSearch is also used as backend if used as ranker
if hasattr(self.ranker, 'es'):
db_config = ranker_config
db_class = ranker_class
db_opts = ranker_opts
else:
db_config = db_config or {}
db_class = db_config.get('class', DEFAULTS['db'])
db_opts = db_config.get('options', {})

logger.info('Initializing tokenizers and document retrievers...')
self.num_workers = num_workers
Expand Down
4 changes: 4 additions & 0 deletions drqa/retriever/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DATA_DIR,
'wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz'
),
'elastic_url': 'localhost:9200'
}


Expand All @@ -27,8 +28,11 @@ def get_class(name):
return TfidfDocRanker
if name == 'sqlite':
return DocDB
if name == 'elasticsearch':
return ElasticDocRanker
raise RuntimeError('Invalid retriever class: %s' % name)


from .doc_db import DocDB
from .tfidf_doc_ranker import TfidfDocRanker
from .elastic_doc_ranker import ElasticDocRanker
109 changes: 109 additions & 0 deletions drqa/retriever/elastic_doc_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Rank documents with an ElasticSearch index"""

import logging
import scipy.sparse as sp

from multiprocessing.pool import ThreadPool
from functools import partial
from elasticsearch import Elasticsearch

from . import utils
from . import DEFAULTS
from .. import tokenizers

logger = logging.getLogger(__name__)


class ElasticDocRanker(object):
""" Connect to an ElasticSearch index.
Score pairs based on Elasticsearch
"""

def __init__(self, elastic_url=None, elastic_index=None, elastic_fields=None, elastic_field_doc_name=None, strict=True, elastic_field_content=None):
"""
Args:
elastic_url: URL of the ElasticSearch server containing port
elastic_index: Index name of ElasticSearch
elastic_fields: Fields of the Elasticsearch index to search in
elastic_field_doc_name: Field containing the name of the document (index)
strict: fail on empty queries or continue (and return empty result)
elastic_field_content: Field containing the content of document in plaint text
"""
# Load from disk
elastic_url = elastic_url or DEFAULTS['elastic_url']
logger.info('Connecting to %s' % elastic_url)
self.es = Elasticsearch(hosts=elastic_url)
self.elastic_index = elastic_index
self.elastic_fields = elastic_fields
self.elastic_field_doc_name = elastic_field_doc_name
self.elastic_field_content = elastic_field_content
self.strict = strict

# Elastic Ranker

def get_doc_index(self, doc_id):
"""Convert doc_id --> doc_index"""
field_index = self.elastic_field_doc_name
if isinstance(field_index, list):
field_index = '.'.join(field_index)
result = self.es.search(index=self.elastic_index, body={'query':{'match':
{field_index: doc_id}}})
return result['hits']['hits'][0]['_id']


def get_doc_id(self, doc_index):
"""Convert doc_index --> doc_id"""
result = self.es.search(index=self.elastic_index, body={'query': { 'match': {"_id": doc_index}}})
source = result['hits']['hits'][0]['_source']
return utils.get_field(source, self.elastic_field_doc_name)

def closest_docs(self, query, k=1):
"""Closest docs by using ElasticSearch
"""
results = self.es.search(index=self.elastic_index, body={'size':k ,'query':
{'multi_match': {
'query': query,
'type': 'most_fields',
'fields': self.elastic_fields}}})
hits = results['hits']['hits']
doc_ids = [utils.get_field(row['_source'], self.elastic_field_doc_name) for row in hits]
doc_scores = [row['_score'] for row in hits]
return doc_ids, doc_scores

def batch_closest_docs(self, queries, k=1, num_workers=None):
"""Process a batch of closest_docs requests multithreaded.
Note: we can use plain threads here as scipy is outside of the GIL.
"""
with ThreadPool(num_workers) as threads:
closest_docs = partial(self.closest_docs, k=k)
results = threads.map(closest_docs, queries)
return results

# Elastic DB

def __enter__(self):
return self

def close(self):
"""Close the connection to the database."""
self.es = None

def get_doc_ids(self):
"""Fetch all ids of docs stored in the db."""
results = self.es.search(index= self.elastic_index, body={
"query": {"match_all": {}}})
doc_ids = [utils.get_field(result['_source'], self.elastic_field_doc_name) for result in results['hits']['hits']]
return doc_ids

def get_doc_text(self, doc_id):
"""Fetch the raw text of the doc for 'doc_id'."""
idx = self.get_doc_index(doc_id)
result = self.es.get(index=self.elastic_index, doc_type='_doc', id=idx)
return result if result is None else result['_source'][self.elastic_field_content]

12 changes: 12 additions & 0 deletions drqa/retriever/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,15 @@ def filter_ngram(gram, mode='any'):
return filtered[0] or filtered[-1]
else:
raise ValueError('Invalid mode: %s' % mode)

def get_field(d, field_list):
"""get the subfield associated to a list of elastic fields
E.g. ['file', 'filename'] to d['file']['filename']
"""
if isinstance(field_list, str):
return d[field_list]
else:
idx = d.copy()
for field in field_list:
idx = idx[field]
return idx

0 comments on commit 1f811de

Please sign in to comment.