Skip to content
This repository has been archived by the owner on Feb 19, 2022. It is now read-only.

Commit

Permalink
feat: fit sklearn model on new data if its not provided
Browse files Browse the repository at this point in the history
  • Loading branch information
jerpint authored Jun 23, 2020
1 parent d83080c commit 2d19950
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 69 deletions.
15 changes: 11 additions & 4 deletions covidfaq/bert_en_model/config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
# train_file: '/home/jeremy/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/train_overlapping.json'
# train_file: 'covidfaq/data/covidfaq_data/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/train_overlapping.json'
# dev_files:
# dev_file_1: '/home/jeremy/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/val_overlapping.json'
# test_file: '/home/jeremy/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/test.json'
# dev_file_1: 'covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/val_overlapping.json'
# test_file: 'covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3_newformat/test.json'
# keep_ood: false

ckpt_to_resume: covidfaq/bert_en_model/output/bert_model.ckpt
outlier_model_pickle: covidfaq/bert_en_model/output/sklearn_outlier_model.pkl

outlier:
ood_filename: 'covidfaq/bert_en_model/en_ood_model.pkl'
model_name: 'local_outlier_factor'
training_data_files:
- covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3/train.json
- covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_en_cleaned_collection1_2_3/validation.json
- covidfaq/data/covidfaq_data/crowdsource/20200522_quebec_faq_en_cleaned_collection4/test.json

accumulate_grad_batches: 1
batch_size: 64
Expand Down
3 changes: 0 additions & 3 deletions covidfaq/bert_en_model/output/sklearn_outlier_model.pkl

This file was deleted.

11 changes: 9 additions & 2 deletions covidfaq/bert_fr_model/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@ dev_files:
test_file: '../../covidfaq_data/crowdsource/20200416_quebec_faq_fr_cleaned_collection1_2_3_newformat/validation.json'
keep_ood: false

ckpt_to_resume: covidfaq/bert_fr_model/output/bert_model.ckpt
outlier_model_pickle: covidfaq/bert_fr_model/output/sklearn_outlier_model.pkl
ckpt_to_resume: covidfaq/bert_fr_model/output/bert_model.ckpt

outlier:
ood_filename: 'covidfaq/bert_fr_model/fr_ood_model.pkl'
model_name: 'local_outlier_factor'
training_data_files:
- covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_fr_cleaned_collection1_2_3/train.json
- covidfaq/data/covidfaq_data/crowdsource/20200416_quebec_faq_fr_cleaned_collection1_2_3/validation.json
- covidfaq/data/covidfaq_data/crowdsource/20200522_quebec_faq_fr_cleaned_collection4/test.json

accumulate_grad_batches: 1
batch_size: 64
Expand Down
3 changes: 0 additions & 3 deletions covidfaq/bert_fr_model/output/sklearn_outlier_model.pkl

This file was deleted.

6 changes: 4 additions & 2 deletions covidfaq/evaluating/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import logging
import time

from bert_reranker.data.data_loader import get_passages_by_source
from tqdm import tqdm

from bert_reranker.data.data_loader import get_passages_by_source
from covidfaq.evaluating.model.cheating_model import CheatingModel
from covidfaq.evaluating.model.embedding_based_reranker import EmbeddingBasedReRanker
from covidfaq.evaluating.model.embedding_based_reranker_plus_ood_detector import (
Expand Down Expand Up @@ -125,7 +125,9 @@ def main():
elif args.model_type == "embedding_based_reranker_plus_ood":
if args.config is None:
raise ValueError("model embedding_based_reranker requires --config")
model_to_evaluate = EmbeddingBasedReRankerPlusOODDetector(args.config)
model_to_evaluate = EmbeddingBasedReRankerPlusOODDetector(
args.config, lang="en"
)
elif args.model_type == "cheating_model":
_, _, passage_id2index = get_passages_by_source(test_data)
model_to_evaluate = CheatingModel(test_data, passage_id2index)
Expand Down
18 changes: 9 additions & 9 deletions covidfaq/evaluating/model/bert_plus_ood.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json

from bert_reranker.data.data_loader import get_passages_by_source
from structlog import get_logger

from bert_reranker.data.data_loader import get_passages_by_source
from covidfaq.evaluating.model.embedding_based_reranker_plus_ood_detector import (
EmbeddingBasedReRankerPlusOODDetector,
)
Expand All @@ -16,10 +16,10 @@ class BertPlusOODEn:
class __BertPlusOODEn:
def __init__(self):
self.model = EmbeddingBasedReRankerPlusOODDetector(
"covidfaq/bert_en_model/config.yaml"
"covidfaq/bert_en_model/config.yaml", lang="en"
)

test_data = get_latest_scrape(lang="en")
test_data, _ = get_latest_scrape(lang="en")

(
self.source2passages,
Expand All @@ -28,7 +28,7 @@ def __init__(self):
) = get_passages_by_source(test_data, keep_ood=False)
self.model.collect_answers(self.source2passages)

self.get_answer("what are the symptoms of covid")
self.get_answer("what are the symptoms of covid-19")

def get_answer(self, question):
idx = self.model.answer_question(question, SOURCE)
Expand Down Expand Up @@ -69,10 +69,10 @@ class BertPlusOODFr:
class __BertPlusOODFr:
def __init__(self):
self.model = EmbeddingBasedReRankerPlusOODDetector(
"covidfaq/bert_fr_model/config.yaml"
"covidfaq/bert_fr_model/config.yaml", lang="fr"
)

test_data = get_latest_scrape(lang="fr")
test_data, _ = get_latest_scrape(lang="fr")

(
self.source2passages,
Expand Down Expand Up @@ -120,9 +120,9 @@ def __getattr__(self, name):

def get_latest_scrape(lang="en"):

latest_scrape = "covidfaq/scrape/source_" + lang + "_faq_passages.json"
latest_scrape_fname = "covidfaq/scrape/source_" + lang + "_faq_passages.json"

with open(latest_scrape) as in_stream:
with open(latest_scrape_fname) as in_stream:
test_data = json.load(in_stream)

return test_data
return test_data, latest_scrape_fname
Original file line number Diff line number Diff line change
@@ -1,23 +1,84 @@
import os
import pickle

import numpy as np
import yaml
from structlog import get_logger
from yaml import load

from covidfaq.evaluating.model.embedding_based_reranker import EmbeddingBasedReRanker

log = get_logger()


def fit_OOD_detector(ret_trainee, hyper_params, lang="en", faq_json_file=None):
"""
Prepare the best possible dataset for fitting the OOD model
and fit on that dataset.
The dataset consists of all crowdsourced questions that still
exist in the new FAQ scrape.
Here, we extract all the crowdsourced questions that align with the
latest scrape, compute their embeddings using BERT
and then fit the OOD model on the result.
"""
from covidfaq.evaluating.model.bert_plus_ood import get_latest_scrape
from bert_reranker.data.predict import generate_embeddings
from bert_reranker.models.sklearn_outliers_model import fit_sklearn_model
from bert_reranker.scripts.filter_user_questions import filter_user_questions

if not faq_json_file:
faq_data, faq_json_file = get_latest_scrape(lang=lang)

all_question_embs = []
faq_questions_set = set(
[passage["reference"]["section_headers"][0] for passage in faq_data["passages"]]
)

# get the crowdsourced questions that align with the new scrape
for user_question_file in hyper_params["outlier"]["training_data_files"]:
user_json_data = filter_user_questions(user_question_file, faq_questions_set)

embeddings_dict = generate_embeddings(
ret_trainee, json_data=user_json_data, embed_passages=False
)
all_question_embs.extend(embeddings_dict["question_embs"])

# get the questions directly from the scrape
embeddings_dict = generate_embeddings(
ret_trainee, json_data=faq_data, embed_passages=True
)
all_question_embs.extend(embeddings_dict["passage_header_embs"])

# Fit the new OOD model on all the questions
clf = fit_sklearn_model(
all_question_embs,
model_name=hyper_params["outlier"]["model_name"],
output_filename="covidfaq/bert_" + lang + "_model/" + lang + "_ood_model.pkl",
n_neighbors=4,
)
return clf


class EmbeddingBasedReRankerPlusOODDetector(EmbeddingBasedReRanker):
def __init__(self, config):
def __init__(self, config, lang):
super(EmbeddingBasedReRankerPlusOODDetector, self).__init__(config)
with open(config, "r") as stream:
hyper_params = load(stream, Loader=yaml.FullLoader)
outlier_model_pickle = hyper_params["outlier_model_pickle"]

with open(outlier_model_pickle, "rb") as file:
outlier_detector_model = pickle.load(file)
# predictor = PredictorWithOutlierDetector(self.ret_trainee, sklearn_model)
self.outlier_detector_model = outlier_detector_model
self.lang = lang

# If a model exists, load it, otherwise fit it on the new scrape
ood_filename = hyper_params["outlier"]["ood_filename"]
if os.path.isfile(ood_filename):
log.info("Loading pretrained sklearn OOD model, not fitting on newest data")
with open(ood_filename, "rb") as file:
outlier_detector_model = pickle.load(file)
self.outlier_detector_model = outlier_detector_model
else:
log.info("Fitting the sklearn OOD model on the latest data...")
self.outlier_detector_model = fit_OOD_detector(
self.ret_trainee, hyper_params, self.lang
)

def collect_answers(self, source2passages):
super(EmbeddingBasedReRankerPlusOODDetector, self).collect_answers(
Expand Down
3 changes: 2 additions & 1 deletion covidfaq/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from covidfaq import config, routers
from covidfaq.evaluating.model.bert_plus_ood import BertPlusOODEn, BertPlusOODFr
from covidfaq.scrape.scrape import load_latest_source_data
from covidfaq.scrape.scrape import load_latest_source_data, download_OOD_model

app = FastAPI()
app.include_router(routers.health.router)
Expand All @@ -23,5 +23,6 @@ def on_startup():

load_latest_source_data()

download_OOD_model()
BertPlusOODEn()
BertPlusOODFr()
52 changes: 48 additions & 4 deletions covidfaq/scrape/scrape.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime
from unicodedata import normalize
from urllib.parse import urljoin
from zipfile import ZipFile

import boto3
import bs4
Expand Down Expand Up @@ -47,13 +48,19 @@ def load_latest_source_data():

client = boto3.client("s3")
objs = client.list_objects_v2(Bucket=BUCKET_NAME)["Contents"]

last_added_en = [
obj["Key"] for obj in sorted(objs, key=get_last_modified) if "en" in obj["Key"]
][0]
obj["Key"]
for obj in sorted(objs, key=get_last_modified)
if "source_en_faq_passages" in obj["Key"]
][-1]
last_added_fr = [
obj["Key"] for obj in sorted(objs, key=get_last_modified) if "fr" in obj["Key"]
][0]
obj["Key"]
for obj in sorted(objs, key=get_last_modified)
if "source_fr_faq_passages" in obj["Key"]
][-1]

log.info("Downloading latest scrape")
s3 = boto3.resource("s3")
s3.Bucket(BUCKET_NAME).download_file(
last_added_en, "covidfaq/scrape/source_en_faq_passages.json"
Expand All @@ -62,6 +69,43 @@ def load_latest_source_data():
s3.Bucket(BUCKET_NAME).download_file(
last_added_fr, "covidfaq/scrape/source_fr_faq_passages.json"
)
log.info("data downloaded")


def download_crowdsourced_data():
BUCKET_NAME = os.environ.get("BUCKET_NAME")
# Download the covidfaq_data folder
log.info("Downloading crowdsourced questions from s3")
data_dir = "covidfaq/data"
file_name = f"{data_dir}/covidfaq_data.zip"
s3 = boto3.resource("s3")
s3.Bucket(BUCKET_NAME).download_file(
"covidfaq_data.zip", file_name,
)
log.info("extracting data")
with ZipFile(file_name, "r") as zip:
zip.extractall(path=data_dir)
log.info("data extracted")


def download_OOD_model():
BUCKET_NAME = os.environ.get("BUCKET_NAME")
# Download the OOD model
log.info("Downloading OOD models from s3")
s3 = boto3.resource("s3")

file_name_en = "covidfaq/bert_en_model/en_ood_model.pkl"
file_name_fr = "covidfaq/bert_fr_model/fr_ood_model.pkl"

s3.Bucket(BUCKET_NAME).download_file(
"en_ood_model.pkl", file_name_en,
)

s3.Bucket(BUCKET_NAME).download_file(
"fr_ood_model.pkl", file_name_fr,
)

log.info("OOD models retrieved from s3")


def remove_html_tags(data):
Expand Down
1 change: 1 addition & 0 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 0 additions & 34 deletions scripts/fetch_model.py

This file was deleted.

33 changes: 32 additions & 1 deletion scripts/scraper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
from covidfaq.k8s import rollout_restart
import os

import boto3
import structlog

from covidfaq.scrape import scrape
from covidfaq.k8s import rollout_restart
from covidfaq.evaluating.model.bert_plus_ood import BertPlusOODEn, BertPlusOODFr

log = structlog.get_logger()


def upload_OOD_to_s3():
client = boto3.client("s3")
BUCKET_NAME = os.environ.get("BUCKET_NAME")
file_to_upload_en = "covidfaq/bert_en_model/en_ood_model.pkl"
file_to_upload_fr = "covidfaq/bert_fr_model/fr_ood_model.pkl"
client.upload_file(file_to_upload_en, BUCKET_NAME, "en_ood_model.pkl")
client.upload_file(file_to_upload_fr, BUCKET_NAME, "fr_ood_model.pkl")
log.info("OOD model uploaded to s3 bucket")


def instantiate_OOD():
# instantiating the model will train and save the OOD detector
scrape.load_latest_source_data()
scrape.download_crowdsourced_data()
BertPlusOODEn()
BertPlusOODFr()
upload_OOD_to_s3()


if __name__ == "__main__":
# Scrape the quebec sites, upload the results to aws
scrape.run("covidfaq/scrape/quebec-sites.yaml", "covidfaq/scrape/")
# This will train the OOD detector
instantiate_OOD()
rollout_restart()

0 comments on commit 2d19950

Please sign in to comment.