From 2d1995082882926349d10cc5203e85051c2410be Mon Sep 17 00:00:00 2001 From: Jeremy Pinto Date: Tue, 23 Jun 2020 09:46:39 -0400 Subject: [PATCH] feat: fit sklearn model on new data if its not provided --- covidfaq/bert_en_model/config.yaml | 15 +++- .../output/sklearn_outlier_model.pkl | 3 - covidfaq/bert_fr_model/config.yaml | 11 ++- .../output/sklearn_outlier_model.pkl | 3 - covidfaq/evaluating/evaluator.py | 6 +- covidfaq/evaluating/model/bert_plus_ood.py | 18 ++--- ...edding_based_reranker_plus_ood_detector.py | 73 +++++++++++++++++-- covidfaq/main.py | 3 +- covidfaq/scrape/scrape.py | 52 ++++++++++++- poetry.lock | 1 + scripts/fetch_model.py | 34 --------- scripts/scraper.py | 33 ++++++++- 12 files changed, 183 insertions(+), 69 deletions(-) delete mode 100644 covidfaq/bert_en_model/output/sklearn_outlier_model.pkl delete mode 100644 covidfaq/bert_fr_model/output/sklearn_outlier_model.pkl delete mode 100644 scripts/fetch_model.py diff --git a/covidfaq/bert_en_model/config.yaml b/covidfaq/bert_en_model/config.yaml index b36d79c..c1f2242 100644 --- a/covidfaq/bert_en_model/config.yaml +++ b/covidfaq/bert_en_model/config.yaml @@ -1,11 +1,18 @@ -# train_file: '/home/jeremy/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/train_overlapping.json' +# train_file: 'covidfaq/data/covidfaq_data/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/train_overlapping.json' # dev_files: - # dev_file_1: '/home/jeremy/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/val_overlapping.json' -# test_file: '/home/jeremy/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/test.json' + # dev_file_1: 'covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/val_overlapping.json' +# test_file: 'covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/test.json' # keep_ood: false ckpt_to_resume: covidfaq/bert_en_model/output/bert_model.ckpt -outlier_model_pickle: covidfaq/bert_en_model/output/sklearn_outlier_model.pkl + +outlier: + ood_filename: 'covidfaq/bert_en_model/en_ood_model.pkl' + model_name: 'local_outlier_factor' + training_data_files: + - covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3/train.json + - covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3/validation.json + - covidfaq/data/covidfaq_data/crowdsource/20200522_quebec_faq_en_cleaned_collection4/test.json accumulate_grad_batches: 1 batch_size: 64 diff --git a/covidfaq/bert_en_model/output/sklearn_outlier_model.pkl b/covidfaq/bert_en_model/output/sklearn_outlier_model.pkl deleted file mode 100644 index ba463de..0000000 --- a/covidfaq/bert_en_model/output/sklearn_outlier_model.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c2c6973bcde7ef51e498ac828a22812e7a2b44b3a505269e6ec7f0ff7fa941e7 -size 3088588 diff --git a/covidfaq/bert_fr_model/config.yaml b/covidfaq/bert_fr_model/config.yaml index 13dd55d..5748872 100644 --- a/covidfaq/bert_fr_model/config.yaml +++ b/covidfaq/bert_fr_model/config.yaml @@ -4,8 +4,15 @@ dev_files: test_file: '../../covidfaq_data/crowdsource/20200416_quebec_faq_fr_cleaned_collection1_2_3_newformat/validation.json' keep_ood: false -ckpt_to_resume: covidfaq/bert_fr_model/output/bert_model.ckpt -outlier_model_pickle: covidfaq/bert_fr_model/output/sklearn_outlier_model.pkl +ckpt_to_resume: covidfaq/bert_fr_model/output/bert_model.ckpt + +outlier: + ood_filename: 'covidfaq/bert_fr_model/fr_ood_model.pkl' + model_name: 'local_outlier_factor' + training_data_files: + - covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_fr_cleaned_collection1_2_3/train.json + - covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_fr_cleaned_collection1_2_3/validation.json + - covidfaq/data/covidfaq_data/crowdsource/20200522_quebec_faq_fr_cleaned_collection4/test.json accumulate_grad_batches: 1 batch_size: 64 diff --git a/covidfaq/bert_fr_model/output/sklearn_outlier_model.pkl b/covidfaq/bert_fr_model/output/sklearn_outlier_model.pkl deleted file mode 100644 index 63ef361..0000000 --- a/covidfaq/bert_fr_model/output/sklearn_outlier_model.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:09fe4b26c9543198e35cad6198ec7a493bc1ffed407c867a8b7ab2c46c6c2c61 -size 59498475 diff --git a/covidfaq/evaluating/evaluator.py b/covidfaq/evaluating/evaluator.py index feca826..4356ffe 100644 --- a/covidfaq/evaluating/evaluator.py +++ b/covidfaq/evaluating/evaluator.py @@ -5,9 +5,9 @@ import logging import time -from bert_reranker.data.data_loader import get_passages_by_source from tqdm import tqdm +from bert_reranker.data.data_loader import get_passages_by_source from covidfaq.evaluating.model.cheating_model import CheatingModel from covidfaq.evaluating.model.embedding_based_reranker import EmbeddingBasedReRanker from covidfaq.evaluating.model.embedding_based_reranker_plus_ood_detector import ( @@ -125,7 +125,9 @@ def main(): elif args.model_type == "embedding_based_reranker_plus_ood": if args.config is None: raise ValueError("model embedding_based_reranker requires --config") - model_to_evaluate = EmbeddingBasedReRankerPlusOODDetector(args.config) + model_to_evaluate = EmbeddingBasedReRankerPlusOODDetector( + args.config, lang="en" + ) elif args.model_type == "cheating_model": _, _, passage_id2index = get_passages_by_source(test_data) model_to_evaluate = CheatingModel(test_data, passage_id2index) diff --git a/covidfaq/evaluating/model/bert_plus_ood.py b/covidfaq/evaluating/model/bert_plus_ood.py index d8ef08c..64c37af 100644 --- a/covidfaq/evaluating/model/bert_plus_ood.py +++ b/covidfaq/evaluating/model/bert_plus_ood.py @@ -1,8 +1,8 @@ import json -from bert_reranker.data.data_loader import get_passages_by_source from structlog import get_logger +from bert_reranker.data.data_loader import get_passages_by_source from covidfaq.evaluating.model.embedding_based_reranker_plus_ood_detector import ( EmbeddingBasedReRankerPlusOODDetector, ) @@ -16,10 +16,10 @@ class BertPlusOODEn: class __BertPlusOODEn: def __init__(self): self.model = EmbeddingBasedReRankerPlusOODDetector( - "covidfaq/bert_en_model/config.yaml" + "covidfaq/bert_en_model/config.yaml", lang="en" ) - test_data = get_latest_scrape(lang="en") + test_data, _ = get_latest_scrape(lang="en") ( self.source2passages, @@ -28,7 +28,7 @@ def __init__(self): ) = get_passages_by_source(test_data, keep_ood=False) self.model.collect_answers(self.source2passages) - self.get_answer("what are the symptoms of covid") + self.get_answer("what are the symptoms of covid-19") def get_answer(self, question): idx = self.model.answer_question(question, SOURCE) @@ -69,10 +69,10 @@ class BertPlusOODFr: class __BertPlusOODFr: def __init__(self): self.model = EmbeddingBasedReRankerPlusOODDetector( - "covidfaq/bert_fr_model/config.yaml" + "covidfaq/bert_fr_model/config.yaml", lang="fr" ) - test_data = get_latest_scrape(lang="fr") + test_data, _ = get_latest_scrape(lang="fr") ( self.source2passages, @@ -120,9 +120,9 @@ def __getattr__(self, name): def get_latest_scrape(lang="en"): - latest_scrape = "covidfaq/scrape/source_" + lang + "_faq_passages.json" + latest_scrape_fname = "covidfaq/scrape/source_" + lang + "_faq_passages.json" - with open(latest_scrape) as in_stream: + with open(latest_scrape_fname) as in_stream: test_data = json.load(in_stream) - return test_data + return test_data, latest_scrape_fname diff --git a/covidfaq/evaluating/model/embedding_based_reranker_plus_ood_detector.py b/covidfaq/evaluating/model/embedding_based_reranker_plus_ood_detector.py index 12119f2..36887d8 100644 --- a/covidfaq/evaluating/model/embedding_based_reranker_plus_ood_detector.py +++ b/covidfaq/evaluating/model/embedding_based_reranker_plus_ood_detector.py @@ -1,23 +1,84 @@ +import os import pickle import numpy as np import yaml +from structlog import get_logger from yaml import load from covidfaq.evaluating.model.embedding_based_reranker import EmbeddingBasedReRanker +log = get_logger() + + +def fit_OOD_detector(ret_trainee, hyper_params, lang="en", faq_json_file=None): + """ + Prepare the best possible dataset for fitting the OOD model + and fit on that dataset. + The dataset consists of all crowdsourced questions that still + exist in the new FAQ scrape. + Here, we extract all the crowdsourced questions that align with the + latest scrape, compute their embeddings using BERT + and then fit the OOD model on the result. + """ + from covidfaq.evaluating.model.bert_plus_ood import get_latest_scrape + from bert_reranker.data.predict import generate_embeddings + from bert_reranker.models.sklearn_outliers_model import fit_sklearn_model + from bert_reranker.scripts.filter_user_questions import filter_user_questions + + if not faq_json_file: + faq_data, faq_json_file = get_latest_scrape(lang=lang) + + all_question_embs = [] + faq_questions_set = set( + [passage["reference"]["section_headers"][0] for passage in faq_data["passages"]] + ) + + # get the crowdsourced questions that align with the new scrape + for user_question_file in hyper_params["outlier"]["training_data_files"]: + user_json_data = filter_user_questions(user_question_file, faq_questions_set) + + embeddings_dict = generate_embeddings( + ret_trainee, json_data=user_json_data, embed_passages=False + ) + all_question_embs.extend(embeddings_dict["question_embs"]) + + # get the questions directly from the scrape + embeddings_dict = generate_embeddings( + ret_trainee, json_data=faq_data, embed_passages=True + ) + all_question_embs.extend(embeddings_dict["passage_header_embs"]) + + # Fit the new OOD model on all the questions + clf = fit_sklearn_model( + all_question_embs, + model_name=hyper_params["outlier"]["model_name"], + output_filename="covidfaq/bert_" + lang + "_model/" + lang + "_ood_model.pkl", + n_neighbors=4, + ) + return clf + class EmbeddingBasedReRankerPlusOODDetector(EmbeddingBasedReRanker): - def __init__(self, config): + def __init__(self, config, lang): super(EmbeddingBasedReRankerPlusOODDetector, self).__init__(config) with open(config, "r") as stream: hyper_params = load(stream, Loader=yaml.FullLoader) - outlier_model_pickle = hyper_params["outlier_model_pickle"] - with open(outlier_model_pickle, "rb") as file: - outlier_detector_model = pickle.load(file) - # predictor = PredictorWithOutlierDetector(self.ret_trainee, sklearn_model) - self.outlier_detector_model = outlier_detector_model + self.lang = lang + + # If a model exists, load it, otherwise fit it on the new scrape + ood_filename = hyper_params["outlier"]["ood_filename"] + if os.path.isfile(ood_filename): + log.info("Loading pretrained sklearn OOD model, not fitting on newest data") + with open(ood_filename, "rb") as file: + outlier_detector_model = pickle.load(file) + self.outlier_detector_model = outlier_detector_model + else: + log.info("Fitting the sklearn OOD model on the latest data...") + self.outlier_detector_model = fit_OOD_detector( + self.ret_trainee, hyper_params, self.lang + ) def collect_answers(self, source2passages): super(EmbeddingBasedReRankerPlusOODDetector, self).collect_answers( diff --git a/covidfaq/main.py b/covidfaq/main.py index 259e317..1c3c34d 100644 --- a/covidfaq/main.py +++ b/covidfaq/main.py @@ -4,7 +4,7 @@ from covidfaq import config, routers from covidfaq.evaluating.model.bert_plus_ood import BertPlusOODEn, BertPlusOODFr -from covidfaq.scrape.scrape import load_latest_source_data +from covidfaq.scrape.scrape import load_latest_source_data, download_OOD_model app = FastAPI() app.include_router(routers.health.router) @@ -23,5 +23,6 @@ def on_startup(): load_latest_source_data() + download_OOD_model() BertPlusOODEn() BertPlusOODFr() diff --git a/covidfaq/scrape/scrape.py b/covidfaq/scrape/scrape.py index e41050a..213d07c 100644 --- a/covidfaq/scrape/scrape.py +++ b/covidfaq/scrape/scrape.py @@ -7,6 +7,7 @@ from datetime import datetime from unicodedata import normalize from urllib.parse import urljoin +from zipfile import ZipFile import boto3 import bs4 @@ -47,13 +48,19 @@ def load_latest_source_data(): client = boto3.client("s3") objs = client.list_objects_v2(Bucket=BUCKET_NAME)["Contents"] + last_added_en = [ - obj["Key"] for obj in sorted(objs, key=get_last_modified) if "en" in obj["Key"] - ][0] + obj["Key"] + for obj in sorted(objs, key=get_last_modified) + if "source_en_faq_passages" in obj["Key"] + ][-1] last_added_fr = [ - obj["Key"] for obj in sorted(objs, key=get_last_modified) if "fr" in obj["Key"] - ][0] + obj["Key"] + for obj in sorted(objs, key=get_last_modified) + if "source_fr_faq_passages" in obj["Key"] + ][-1] + log.info("Downloading latest scrape") s3 = boto3.resource("s3") s3.Bucket(BUCKET_NAME).download_file( last_added_en, "covidfaq/scrape/source_en_faq_passages.json" @@ -62,6 +69,43 @@ def load_latest_source_data(): s3.Bucket(BUCKET_NAME).download_file( last_added_fr, "covidfaq/scrape/source_fr_faq_passages.json" ) + log.info("data downloaded") + + +def download_crowdsourced_data(): + BUCKET_NAME = os.environ.get("BUCKET_NAME") + # Download the covidfaq_data folder + log.info("Downloading crowdsourced questions from s3") + data_dir = "covidfaq/data" + file_name = f"{data_dir}/covidfaq_data.zip" + s3 = boto3.resource("s3") + s3.Bucket(BUCKET_NAME).download_file( + "covidfaq_data.zip", file_name, + ) + log.info("extracting data") + with ZipFile(file_name, "r") as zip: + zip.extractall(path=data_dir) + log.info("data extracted") + + +def download_OOD_model(): + BUCKET_NAME = os.environ.get("BUCKET_NAME") + # Download the OOD model + log.info("Downloading OOD models from s3") + s3 = boto3.resource("s3") + + file_name_en = "covidfaq/bert_en_model/en_ood_model.pkl" + file_name_fr = "covidfaq/bert_fr_model/fr_ood_model.pkl" + + s3.Bucket(BUCKET_NAME).download_file( + "en_ood_model.pkl", file_name_en, + ) + + s3.Bucket(BUCKET_NAME).download_file( + "fr_ood_model.pkl", file_name_fr, + ) + + log.info("OOD models retrieved from s3") def remove_html_tags(data): diff --git a/poetry.lock b/poetry.lock index 4cf21af..939a11c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -286,6 +286,7 @@ version = "0.15.2" [[package]] category = "main" + description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" name = "fastapi" optional = false diff --git a/scripts/fetch_model.py b/scripts/fetch_model.py deleted file mode 100644 index 7d72f27..0000000 --- a/scripts/fetch_model.py +++ /dev/null @@ -1,34 +0,0 @@ -import os -from zipfile import ZipFile - -import boto3 -import structlog - -log = structlog.get_logger() - - -def fetch(): - rerank_dir = "covidfaq/rerank" - model_name = f"{rerank_dir}/model.zip" - - log.info("downloading model from s3") - s3 = boto3.client("s3") - s3.download_file( - "coviddata.dialoguecorp.com", - "mirko/bert_rerank_model__for_testing.zip", - model_name, - ) - log.info("model downloaded") - - log.info("extracting model") - with ZipFile(model_name, "r") as zip: - zip.extractall(path=rerank_dir) - log.info("model extracted") - - log.info("cleanup") - os.remove(model_name) - log.info("cleanup done") - - -if __name__ == "__main__": - fetch() diff --git a/scripts/scraper.py b/scripts/scraper.py index d31354d..eeb711c 100644 --- a/scripts/scraper.py +++ b/scripts/scraper.py @@ -1,6 +1,37 @@ -from covidfaq.k8s import rollout_restart +import os + +import boto3 +import structlog + from covidfaq.scrape import scrape +from covidfaq.k8s import rollout_restart +from covidfaq.evaluating.model.bert_plus_ood import BertPlusOODEn, BertPlusOODFr + +log = structlog.get_logger() + + +def upload_OOD_to_s3(): + client = boto3.client("s3") + BUCKET_NAME = os.environ.get("BUCKET_NAME") + file_to_upload_en = "covidfaq/bert_en_model/en_ood_model.pkl" + file_to_upload_fr = "covidfaq/bert_fr_model/fr_ood_model.pkl" + client.upload_file(file_to_upload_en, BUCKET_NAME, "en_ood_model.pkl") + client.upload_file(file_to_upload_fr, BUCKET_NAME, "fr_ood_model.pkl") + log.info("OOD model uploaded to s3 bucket") + + +def instantiate_OOD(): + # instantiating the model will train and save the OOD detector + scrape.load_latest_source_data() + scrape.download_crowdsourced_data() + BertPlusOODEn() + BertPlusOODFr() + upload_OOD_to_s3() + if __name__ == "__main__": + # Scrape the quebec sites, upload the results to aws scrape.run("covidfaq/scrape/quebec-sites.yaml", "covidfaq/scrape/") + # This will train the OOD detector + instantiate_OOD() rollout_restart()