diff --git a/README.md b/README.md index 3bd23df5..64982dd1 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,12 @@ BERTopic is a topic modeling technique that leverages 🤗 transformers and c-TF-IDF to create dense clusters -allowing for easily interpretable topics whilst keeping important words in the topic descriptions. It even supports -visualizations similar to LDAvis! +allowing for easily interpretable topics whilst keeping important words in the topic descriptions. + +BERTopic supports +[**guided**](https://maartengr.github.io/BERTopic/tutorial/guided/guided.html), +(semi-) [**supervised**](https://maartengr.github.io/BERTopic/tutorial/supervised/supervised.html), +and [**dynamic**](https://maartengr.github.io/BERTopic/tutorial/topicsovertime/topicsovertime.html) topic modeling. It even supports visualizations similar to LDAvis! Corresponding medium posts can be found [here](https://towardsdatascience.com/topic-modeling-with-bert-779f7db187e6?source=friends_link&sk=0b5a470c006d1842ad4c8a3057063a99) and [here](https://towardsdatascience.com/interactive-topic-modeling-with-bertopic-1ea55e7d73d8?sk=03c2168e9e74b6bda2a1f3ed953427e4). @@ -66,7 +70,7 @@ from sklearn.datasets import fetch_20newsgroups docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] topic_model = BERTopic() -topics, _ = topic_model.fit_transform(docs) +topics, probs = topic_model.fit_transform(docs) ``` After generating topics, we can access the frequent topics that were generated: @@ -114,7 +118,6 @@ topic_model.visualize_topics() - ## Embedding Models BERTopic supports many embedding models that can be used to embed the documents and words: * Sentence-Transformers @@ -123,29 +126,16 @@ BERTopic supports many embedding models that can be used to embed the documents * Gensim * USE -Click [here](https://maartengr.github.io/BERTopic/tutorial/embeddings/embeddings.html) -for a full overview of all supported embedding models. - -### Sentence-Transformers -You can select any model from sentence-transformers [here](https://www.sbert.net/docs/pretrained_models.html) -and pass it to BERTopic: +[**Sentence-Transformers**]() is typically used as it has shown great results embedding documents +meant for semantic similarity. Simply select any from their documentation +[here](https://www.sbert.net/docs/pretrained_models.html) and pass it to BERTopic: ```python topic_model = BERTopic(embedding_model="paraphrase-MiniLM-L6-v2") ``` -Or select a SentenceTransformer model with your own parameters: - -```python -from sentence_transformers import SentenceTransformer - -sentence_model = SentenceTransformer("paraphrase-MiniLM-L6-v2") -topic_model = BERTopic(embedding_model=sentence_model) -``` - -### Flair -[Flair](https://github.com/flairNLP/flair) allows you to choose almost any embedding model that -is publicly available. Flair can be used as follows: +[**Flair**](https://github.com/flairNLP/flair) allows you to choose almost any 🤗 transformers model. Simply +select any from [here](https://huggingface.co/models) and pass it to BERTopic: ```python from flair.embeddings import TransformerDocumentEmbeddings @@ -154,20 +144,13 @@ roberta = TransformerDocumentEmbeddings('roberta-base') topic_model = BERTopic(embedding_model=roberta) ``` -You can select any 🤗 transformers model [here](https://huggingface.co/models). - -**Custom Embeddings** -You can also use previously generated embeddings by passing it to `fit_transform()`: - -```python -topic_model = BERTopic() -topics, _ = topic_model.fit_transform(docs, embeddings) -``` +Click [here](https://maartengr.github.io/BERTopic/tutorial/embeddings/embeddings.html) +for a full overview of all supported embedding models. ## Dynamic Topic Modeling Dynamic topic modeling (DTM) is a collection of techniques aimed at analyzing the evolution of topics -over time. These methods allow you to understand how a topic is represented across different times. -Here, we will be using all of Donald Trump's tweet so see how he talked over certain topics over time: +over time. These methods allow you to understand how a topic is represented over time. +Here, we will be using all of Donald Trump's tweet to see how he talked over certain topics over time: ```python import re @@ -186,7 +169,7 @@ Then, we need to extract the global topic representations by simply creating and ```python topic_model = BERTopic(verbose=True) -topics, _ = topic_model.fit_transform(tweets) +topics, probs = topic_model.fit_transform(tweets) ``` From these topics, we are going to generate the topic representations at each timestamp for each topic. We do this diff --git a/bertopic/__init__.py b/bertopic/__init__.py index d24f0fb6..9786c65a 100644 --- a/bertopic/__init__.py +++ b/bertopic/__init__.py @@ -1,6 +1,6 @@ from bertopic._bertopic import BERTopic -__version__ = "0.8.1" +__version__ = "0.9.0" __all__ = [ "BERTopic", diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 1d330387..6f54a265 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -43,7 +43,7 @@ class BERTopic: from sklearn.datasets import fetch_20newsgroups docs = fetch_20newsgroups(subset='all')['data'] - topic_model = BERTopic(calculate_probabilities=True) + topic_model = BERTopic() topics, probabilities = topic_model.fit_transform(docs) ``` @@ -72,7 +72,8 @@ def __init__(self, nr_topics: Union[int, str] = None, low_memory: bool = False, calculate_probabilities: bool = False, - embedding_model = None, + seed_topic_list: List[List[str]] = None, + embedding_model=None, umap_model: UMAP = None, hdbscan_model: hdbscan.HDBSCAN = None, vectorizer_model: CountVectorizer = None, @@ -101,13 +102,15 @@ def __init__(self, "auto" to automatically reduce topics that have a similarity of at least 0.9, do not maps all others. low_memory: Sets UMAP low memory to True to make sure less memory is used. - calculate_probabilities: Whether to calculate the topic probabilities. This could - slow down the extraction of topics if you have many - documents (> 100_000). Set this only to True if you - have a low amount of documents or if you do not mind - more computation time. - NOTE: since probabilities are not calculated, you cannot - use the corresponding visualization `visualize_probabilities`. + calculate_probabilities: Whether to calculate the probabilities of all topics + per document instead of the probability of the assigned + topic per document. This could slow down the extraction + of topics if you have many documents (> 100_000). Set this + only to True if you have a low amount of documents or if + you do not mind more computation time. + NOTE: If false you cannot use the corresponding + visualization method `visualize_probabilities`. + seed_topic_list: A list of seed words per topic to converge around verbose: Changes the verbosity of the model, Set to True if you want to track the stages of the model. embedding_model: Use a custom embedding model. @@ -133,6 +136,7 @@ def __init__(self, self.low_memory = low_memory self.calculate_probabilities = calculate_probabilities self.verbose = verbose + self.seed_topic_list = seed_topic_list # Embedding model self.language = language if not embedding_model else None @@ -158,8 +162,10 @@ def __init__(self, self.topics = None self.topic_sizes = None self.mapped_topics = None + self.merged_topics = None self.topic_embeddings = None self.topic_sim_matrix = None + self.representative_docs = None if verbose: logger.set_level("DEBUG") @@ -222,10 +228,11 @@ def fit_transform(self, Returns: predictions: Topic predictions for each documents - probabilities: The topic probability distribution which is returned by default. - If `calculate_probabilities` in BERTopic is set to False, then the - probabilities are not calculated to speed up computation and - decrease memory usage. + probabilities: The probability of the assigned topic per document. + If `calculate_probabilities` in BERTopic is set to True, then + it calculates the probabilities of all topics across all documents + instead of only the assigned topic. This, however, slows down + computation and may increase memory usage. Usage: @@ -234,7 +241,7 @@ def fit_transform(self, from sklearn.datasets import fetch_20newsgroups docs = fetch_20newsgroups(subset='all')['data'] - topic_model = BERTopic(calculate_probabilities=True) + topic_model = BERTopic() topics, probs = topic_model.fit_transform(docs) ``` @@ -251,7 +258,7 @@ def fit_transform(self, embeddings = sentence_model.encode(docs, show_progress_bar=True) # Create topic model - topic_model = BERTopic(calculate_probabilities=True) + topic_model = BERTopic() topics, probs = topic_model.fit_transform(docs, embeddings) ``` """ @@ -276,6 +283,8 @@ def fit_transform(self, language=self.language) # Reduce dimensionality with UMAP + if self.seed_topic_list is not None and self.embedding_model is not None: + y, embeddings = self._guided_topic_modeling(embeddings) umap_embeddings = self._reduce_dimensionality(embeddings, y) # Cluster UMAP embeddings with HDBSCAN @@ -291,7 +300,9 @@ def fit_transform(self, # Reduce topics if self.nr_topics: documents = self._reduce_topics(documents) - probabilities = self._map_probabilities(probabilities) + + self._map_representative_docs() + probabilities = self._map_probabilities(probabilities) predictions = documents.Topic.to_list() return predictions, probabilities @@ -321,7 +332,7 @@ def transform(self, docs = fetch_20newsgroups(subset='all')['data'] topic_model = BERTopic().fit(docs) - topics, _ = topic_model.transform(docs) + topics, probs = topic_model.transform(docs) ``` If you want to use your own embeddings: @@ -338,7 +349,7 @@ def transform(self, # Create topic model topic_model = BERTopic().fit(docs, embeddings) - topics, _ = topic_model.transform(docs, embeddings) + topics, probs = topic_model.transform(docs, embeddings) ``` """ check_is_fitted(self) @@ -353,7 +364,7 @@ def transform(self, verbose=self.verbose) umap_embeddings = self.umap_model.transform(embeddings) - predictions, _ = hdbscan.approximate_predict(self.hdbscan_model, umap_embeddings) + predictions, probabilities = hdbscan.approximate_predict(self.hdbscan_model, umap_embeddings) if self.calculate_probabilities: probabilities = hdbscan.membership_vector(self.hdbscan_model, umap_embeddings) @@ -423,7 +434,7 @@ def topics_over_time(self, ```python from bertopic import BERTopic topic_model = BERTopic() - topics, _ = topic_model.fit_transform(docs) + topics, probs = topic_model.fit_transform(docs) topics_over_time = topic_model.topics_over_time(docs, topics, timestamps, nr_bins=20) ``` """ @@ -539,7 +550,7 @@ def topics_per_class(self, ```python from bertopic import BERTopic topic_model = BERTopic() - topics, _ = topic_model.fit_transform(docs) + topics, probs = topic_model.fit_transform(docs) topics_per_class = topic_model.topics_per_class(docs, topics, classes) ``` """ @@ -768,6 +779,36 @@ def get_topic_freq(self, topic: int = None) -> Union[pd.DataFrame, int]: return pd.DataFrame(self.topic_sizes.items(), columns=['Topic', 'Count']).sort_values("Count", ascending=False) + def get_representative_docs(self, topic: int) -> List[str]: + """ Extract representative documents per topic + + Arguments: + topic: A specific topic for which you want + the representative documents + + Returns: + Representative documents of the chosen topic + + Usage: + + To extract the representative docs of all topics: + + ```python + representative_docs = topic_model.get_representative_docs() + ``` + + To get the representative docs of a single topic: + + ```python + representative_docs = topic_model.get_representative_docs(12) + ``` + """ + check_is_fitted(self) + if isinstance(topic, int): + return self.representative_docs[topic] + else: + return self.representative_docs + def reduce_topics(self, docs: List[str], topics: List[int], @@ -807,7 +848,7 @@ def reduce_topics(self, If probabilities were not calculated simply run the function without them: ```python - new_topics, _= topic_model.reduce_topics(docs, topics, nr_topics=30) + new_topics, new_probs = topic_model.reduce_topics(docs, topics, nr_topics=30) ``` """ check_is_fitted(self) @@ -1231,7 +1272,7 @@ def load(cls, with open(path, 'rb') as file: if embedding_model: topic_model = joblib.load(file) - topic_model.embedding_model = embedding_model + topic_model.embedding_model = select_backend(embedding_model) else: topic_model = joblib.load(file) return topic_model @@ -1337,16 +1378,54 @@ def _cluster_embeddings(self, """ self.hdbscan_model.fit(umap_embeddings) documents['Topic'] = self.hdbscan_model.labels_ + probabilities = self.hdbscan_model.probabilities_ if self.calculate_probabilities: probabilities = hdbscan.all_points_membership_vectors(self.hdbscan_model) - else: - probabilities = None self._update_topic_size(documents) + self._save_representative_docs(documents) logger.info("Clustered UMAP embeddings with HDBSCAN") return documents, probabilities + def _guided_topic_modeling(self, embeddings: np.ndarray) -> Tuple[List[int], np.array]: + """ Apply Guided Topic Modeling + + We transform the seeded topics to embeddings using the + same embedder as used for generating document embeddings. + + Then, we apply cosine similarity between the embeddings + and set labels for documents that are more similar to + one of the topics, then the average document. + + If a document is more similar to the average document + than any of the topics, it gets the -1 label and is + thereby not included in UMAP. + + Arguments: + embeddings: The document embeddings + + Returns + y: The labels for each seeded topic + embeddings: Updated embeddings + """ + # Create embeddings from the seeded topics + seed_topic_list = [" ".join(seed_topic) for seed_topic in self.seed_topic_list] + seed_topic_embeddings = self._extract_embeddings(seed_topic_list, verbose=self.verbose) + seed_topic_embeddings = np.vstack([seed_topic_embeddings, embeddings.mean(axis=0)]) + + # Label documents that are most similar to one of the seeded topics + sim_matrix = cosine_similarity(embeddings, seed_topic_embeddings) + y = [np.argmax(sim_matrix[index]) for index in range(sim_matrix.shape[0])] + y = [val if val != len(seed_topic_list) else -1 for val in y] + + # Average the document embeddings related to the seeded topics with the + # embedding of the seeded topic to force the documents in a cluster + for seed_topic in range(len(seed_topic_list)): + indices = [index for index, topic in enumerate(y) if topic == seed_topic] + embeddings[indices] = np.average([embeddings[indices], seed_topic_embeddings[seed_topic]], weights=[3, 1]) + return y, embeddings + def _extract_topics(self, documents: pd.DataFrame): """ Extract topics from the clusters using a class-based TF-IDF @@ -1364,6 +1443,68 @@ def _extract_topics(self, documents: pd.DataFrame): for key, values in self.topics.items()} + def _save_representative_docs(self, documents: pd.DataFrame): + """ Save the most representative docs (3) per topic + + The most representative docs are extracted by taking + the exemplars from the HDBSCAN-generated clusters. + + Full instructions can be found here: + https://hdbscan.readthedocs.io/en/latest/soft_clustering_explanation.html + + Arguments: + documents: Dataframe with documents and their corresponding IDs + """ + # Prepare the condensed tree and luf clusters beneath a given cluster + condensed_tree = self.hdbscan_model.condensed_tree_ + raw_tree = condensed_tree._raw_tree + clusters = condensed_tree._select_clusters() + cluster_tree = raw_tree[raw_tree['child_size'] > 1] + + # Find the points with maximum lambda value in each leaf + representative_docs = {} + for topic in documents['Topic'].unique(): + if topic != -1: + leaves = hdbscan.plots._recurse_leaf_dfs(cluster_tree, clusters[topic]) + + result = np.array([]) + for leaf in leaves: + max_lambda = raw_tree['lambda_val'][raw_tree['parent'] == leaf].max() + points = raw_tree['child'][(raw_tree['parent'] == leaf) & (raw_tree['lambda_val'] == max_lambda)] + result = np.hstack((result, points)) + + representative_docs[topic] = list(np.random.choice(result, 3, replace=False).astype(int)) + + # Convert indices to documents + self.representative_docs = {topic: [documents.iloc[doc_id].Document for doc_id in doc_ids] + for topic, doc_ids in + representative_docs.items()} + + def _map_representative_docs(self): + """ Map the representative docs per topic to the correct topics + + If topics were reduced, remove documents from topics that were + merged into larger topics as we assume that the documents from + larger topics are better representative of the entire merged + topic. + """ + representative_docs = self.representative_docs.copy() + + # Remove topics that were merged as the most frequent + # topic or the topics they were merged into contain as they contain + # better representative documents + if self.merged_topics: + for topic_to_remove in self.merged_topics: + del representative_docs[topic_to_remove] + + # Update the representative documents + updated_representative_docs = {} + for old_topic, docs in representative_docs.items(): + new_topic = self.mapped_topics[old_topic] + updated_representative_docs[new_topic] = docs + + self.representative_docs = updated_representative_docs + def _create_topic_vectors(self): """ Creates embeddings per topics based on their topic representation @@ -1422,8 +1563,14 @@ def _c_tf_idf(self, documents_per_topic: pd.DataFrame, m: int, fit: bool = True) words = self.vectorizer_model.get_feature_names() X = self.vectorizer_model.transform(documents) + if self.seed_topic_list: + seed_topic_list = [seed for seeds in self.seed_topic_list for seed in seeds] + multiplier = np.array([1.2 if word in seed_topic_list else 1 for word in words]) + else: + multiplier = None + if fit: - self.transformer = ClassTFIDF().fit(X, n_samples=m) + self.transformer = ClassTFIDF().fit(X, n_samples=m, multiplier=multiplier) c_tf_idf = self.transformer.transform(X) @@ -1524,9 +1671,14 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame) -> pd.DataFrame: Returns: documents: Updated dataframe with documents and the reduced number of Topics """ + # Track the mapping of topics if not self.mapped_topics: self.mapped_topics = {topic: topic for topic in set(self.hdbscan_model.labels_)} + # Track which topics where originally merged + if not self.merged_topics: + self.merged_topics = [] + # Create topic similarity matrix if self.topic_embeddings is not None: similarities = cosine_similarity(np.array(self.topic_embeddings)) @@ -1540,6 +1692,7 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame) -> pd.DataFrame: topic_to_merge = self.get_topic_freq().iloc[-1].Topic topic_to_merge_into = np.argmax(similarities[topic_to_merge + 1]) - 1 similarities[:, topic_to_merge + 1] = -1 + self.merged_topics.append(topic_to_merge) # Update Topic labels documents.loc[documents.Topic == topic_to_merge, "Topic"] = topic_to_merge_into @@ -1566,7 +1719,7 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame) -> pd.DataFrame: return documents def _auto_reduce_topics(self, documents: pd.DataFrame) -> pd.DataFrame: - """ Reduce the number of topics as long as it exceeds a minimum similarity of 0.915 + """ Reduce the number of topics automatically using HDBSCAN Arguments: documents: Dataframe with documents and their corresponding IDs and Topics @@ -1574,9 +1727,14 @@ def _auto_reduce_topics(self, documents: pd.DataFrame) -> pd.DataFrame: Returns: documents: Updated dataframe with documents and the reduced number of Topics """ + # Track the mapping of topics if not self.mapped_topics: self.mapped_topics = {topic: topic for topic in set(self.hdbscan_model.labels_)} + # Track which topics where originally merged + if not self.merged_topics: + self.merged_topics = [] + unique_topics = sorted(list(documents.Topic.unique()))[1:] max_topic = unique_topics[-1] @@ -1597,6 +1755,13 @@ def _auto_reduce_topics(self, documents: pd.DataFrame) -> pd.DataFrame: if prediction != -1} documents.Topic = documents.Topic.map(mapped_topics).fillna(documents.Topic).astype(int) + # Track merged topic + df = pd.DataFrame({"Topic": mapped_topics.keys(), "Group": mapped_topics.values()}) + df["Size"] = df["Topic"].map(self.topic_sizes) + mask = df.groupby(['Topic'])['Size'].transform('max') + df = df[~(df['Size'] == mask)] + self.merged_topic = df.Topic.values.tolist() + # Update mapped topics with new clusters self.mapped_topics = {og_topic: mapped_topics[topic] if topic in mapped_topics @@ -1631,15 +1796,15 @@ def _sort_mappings_by_frequency(self, documents: pd.DataFrame) -> pd.DataFrame: documents: Updated dataframe with documents and the mapped and re-ordered topic ids """ - self._update_topic_size(documents) if not self.mapped_topics: self.mapped_topics = {topic: topic for topic in set(self.hdbscan_model.labels_)} # Map topics based on frequency - sorted_topics = {topic: index - 1 for index, topic - in enumerate(self.topic_sizes.keys())} + df = pd.DataFrame(self.topic_sizes.items(), columns=["Old_Topic", "Size"]).sort_values("Size", ascending=False) + df = df[df.Old_Topic != -1] + sorted_topics = {**{-1: -1}, **dict(zip(df.Old_Topic, range(len(df))))} self.mapped_topics = {og_topic: sorted_topics[topic] if topic in sorted_topics else topic @@ -1659,21 +1824,19 @@ def _map_probabilities(self, probabilities: Union[np.ndarray, None]) -> Union[np Arguments: probabilities: An array containing probabilities - Returns: - new_probabilities: Updated probabilities - + mapped_probabilities: Updated probabilities """ - if isinstance(probabilities, np.ndarray): - new_probabilities = probabilities.copy() + # Map array of probabilities (probability for assigned topic per document) + if len(probabilities.shape) == 2 and self.get_topic(-1): + mapped_probabilities = np.zeros((probabilities.shape[0], + len(set(self.mapped_topics.values()))-1)) for from_topic, to_topic in self.mapped_topics.items(): if to_topic != -1 and from_topic != -1: - new_probabilities[:, to_topic] += new_probabilities[:, from_topic] - new_probabilities[:, from_topic] = 0 + mapped_probabilities[:, to_topic] += probabilities[:, from_topic] + return mapped_probabilities - return new_probabilities.round(3) - else: - return None + return probabilities def _preprocess_text(self, documents: np.ndarray) -> List[str]: """ Basic preprocessing of text diff --git a/bertopic/_ctfidf.py b/bertopic/_ctfidf.py index d55d6d6f..9953ed6b 100644 --- a/bertopic/_ctfidf.py +++ b/bertopic/_ctfidf.py @@ -21,7 +21,7 @@ class ClassTFIDF(TfidfTransformer): def __init__(self, *args, **kwargs): super(ClassTFIDF, self).__init__(*args, **kwargs) - def fit(self, X: sp.csr_matrix, n_samples: int): + def fit(self, X: sp.csr_matrix, n_samples: int, multiplier: np.ndarray = None): """Learn the idf vector (global term weights). Arguments: @@ -38,6 +38,8 @@ def fit(self, X: sp.csr_matrix, n_samples: int): df = np.squeeze(np.asarray(X.sum(axis=0))) avg_nr_samples = int(X.sum(axis=1).mean()) idf = np.log(avg_nr_samples / df) + if multiplier is not None: + idf = idf * multiplier self._idf_diag = sp.diags(idf, offsets=0, shape=(n_features, n_features), format='csr', diff --git a/bertopic/plotting/_distribution.py b/bertopic/plotting/_distribution.py index e0e04ceb..84445053 100644 --- a/bertopic/plotting/_distribution.py +++ b/bertopic/plotting/_distribution.py @@ -35,6 +35,9 @@ def visualize_distribution(topic_model, """ + if len(probabilities.shape) != 2: + raise ValueError("This visualization cannot be used if you have set `calculate_probabilities` to False " + "as it uses the topic probabilities of all topics. ") if len(probabilities[probabilities > min_probability]) == 0: raise ValueError("There are no values where `min_probability` is higher than the " "probabilities that were supplied. Lower `min_probability` to prevent this error.") diff --git a/bertopic/plotting/_topics_over_time.py b/bertopic/plotting/_topics_over_time.py index d8718e37..e5b94111 100644 --- a/bertopic/plotting/_topics_over_time.py +++ b/bertopic/plotting/_topics_over_time.py @@ -1,12 +1,14 @@ import pandas as pd from typing import List import plotly.graph_objects as go +from sklearn.preprocessing import normalize def visualize_topics_over_time(topic_model, topics_over_time: pd.DataFrame, top_n_topics: int = None, topics: List[int] = None, + normalize_frequency: bool = False, width: int = 1250, height: int = 450) -> go.Figure: """ Visualize topics over time @@ -17,6 +19,7 @@ def visualize_topics_over_time(topic_model, corresponding topic representation top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized + normalize_frequency: Whether to normalize each topic's frequency individually width: The width of the figure. height: The height of the figure. @@ -63,7 +66,11 @@ def visualize_topics_over_time(topic_model, trace_data = data.loc[data.Topic == topic, :] topic_name = trace_data.Name.values[0] words = trace_data.Words.values - fig.add_trace(go.Scatter(x=trace_data.Timestamp, y=trace_data.Frequency, + if normalize_frequency: + y = normalize(trace_data.Frequency.values.reshape(1, -1))[0] + else: + y = trace_data.Frequency + fig.add_trace(go.Scatter(x=trace_data.Timestamp, y=y, mode='lines', marker_color=colors[index % 7], hoverinfo="text", @@ -74,7 +81,7 @@ def visualize_topics_over_time(topic_model, fig.update_xaxes(showgrid=True) fig.update_yaxes(showgrid=True) fig.update_layout( - yaxis_title="Frequency", + yaxis_title="Normalized Frequency" if normalize_frequency else "Frequency", title={ 'text': "Topics over Time", 'y': .95, diff --git a/bertopic/plotting/_topics_per_class.py b/bertopic/plotting/_topics_per_class.py index 71fe0721..d66170a9 100644 --- a/bertopic/plotting/_topics_per_class.py +++ b/bertopic/plotting/_topics_per_class.py @@ -1,12 +1,14 @@ import pandas as pd from typing import List import plotly.graph_objects as go +from sklearn.preprocessing import normalize def visualize_topics_per_class(topic_model, topics_per_class: pd.DataFrame, top_n_topics: int = 10, topics: List[int] = None, + normalize_frequency: bool = False, width: int = 1250, height: int = 900) -> go.Figure: """ Visualize topics per class @@ -17,6 +19,7 @@ def visualize_topics_per_class(topic_model, corresponding topic representation top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized + normalize_frequency: Whether to normalize each topic's frequency individually width: The width of the figure. height: The height of the figure. @@ -67,8 +70,12 @@ def visualize_topics_per_class(topic_model, trace_data = data.loc[data.Topic == topic, :] topic_name = trace_data.Name.values[0] words = trace_data.Words.values + if normalize_frequency: + x = normalize(trace_data.Frequency.values.reshape(1, -1))[0] + else: + x = trace_data.Frequency fig.add_trace(go.Bar(y=trace_data.Class, - x=trace_data.Frequency, + x=x, visible=visible, marker_color=colors[index % 7], hoverinfo="text", @@ -80,7 +87,7 @@ def visualize_topics_per_class(topic_model, fig.update_xaxes(showgrid=True) fig.update_yaxes(showgrid=True) fig.update_layout( - xaxis_title="Frequency", + xaxis_title="Normalized Frequency" if normalize_frequency else "Frequency", yaxis_title="Class", title={ 'text': "Topics per Class", diff --git a/docs/changelog.md b/docs/changelog.md index 6d39ed0e..8a996909 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,3 +1,52 @@ +## **Version 0.9** +*Release date: 9 August, 2021* + +**Highlights**: + +* Implemented a [**Guided BERTopic**](https://maartengr.github.io/BERTopic/tutorial/guided/guided.html) -> Use seeds to steer the Topic Modeling +* Get the most representative documents per topic: `topic_model.get_representative_docs(topic=1)` + * This allows users to see which documents are good representations of a topic and better understand the topics that were created +* Added `normalize_frequency` parameter to `visualize_topics_per_class` and `visualize_topics_over_time` in order to better compare the relative topic frequencies between topics +* Return flat probabilities as default, only calculate the probabilities of all topics per document if `calculate_probabilities` is True +* Added several FAQs + +**Fixes**: + +* Fix loading pre-trained BERTopic model +* Fix mapping of probabilities +* Fix [#190](https://github.com/MaartenGr/BERTopic/issues/190) + + +**Guided BERTopic**: + +Guided BERTopic works in two ways: + +First, we create embeddings for each seeded topics by joining them and passing them through the document embedder. +These embeddings will be compared with the existing document embeddings through cosine similarity and assigned a label. +If the document is most similar to a seeded topic, then it will get that topic's label. +If it is most similar to the average document embedding, it will get the -1 label. +These labels are then passed through UMAP to create a semi-supervised approach that should nudge the topic creation to the seeded topics. + +Second, we take all words in `seed_topic_list` and assign them a multiplier larger than 1. +Those multipliers will be used to increase the IDF values of the words across all topics thereby increasing +the likelihood that a seeded topic word will appear in a topic. This does, however, also increase the chance of an +irrelevant topic having unrelated words. In practice, this should not be an issue since the IDF value is likely to +remain low regardless of the multiplier. The multiplier is now a fixed value but may change to something more elegant, +like taking the distribution of IDF values and its position into account when defining the multiplier. + +```python +seed_topic_list = [["company", "billion", "quarter", "shrs", "earnings"], + ["acquisition", "procurement", "merge"], + ["exchange", "currency", "trading", "rate", "euro"], + ["grain", "wheat", "corn"], + ["coffee", "cocoa"], + ["natural", "gas", "oil", "fuel", "products", "petrol"]] + +topic_model = BERTopic(seed_topic_list=seed_topic_list) +topics, probs = topic_model.fit_transform(docs) +``` + + ## **Version 0.8.1** *Release date: 08 June, 2021* diff --git a/docs/faq.md b/docs/faq.md index 097caf99..d628a8d6 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -33,10 +33,26 @@ typically require a GPU and using only a CPU can slow down computation time quit However, if you do not have access to a GPU, looking into quantization might help. ## **I am facing memory issues. Help!** -To prevent any memory issues, it is advised to set `low_memory` to True. This will result in UMAP being -a bit slower, but consuming significantly less memory. Moreover, calculating the probabilities of topics -is quite computationally consuming and might impact memory. Setting `calculate_probabilities` to False -could similarly help. +There are several ways to perform computation with large datasets. +First, you can set `low_memory` to True when instantiating BERTopic. +This may prevent blowing up the memory in UMAP. + +Second, setting `calculate_probabilities` to False when instantiating BERTopic prevents a huge document-topic +probability matrix from being created. Moreover, HDBSCAN is quite slow when it tries to calculate probabilities on large datasets. + +Third, you can set the minimum frequency of words in the CountVectorizer class to reduce the size of the resulting +sparse c-TF-IDF matrix. You can do this as follows: + +```python +from bertopic import BERTopic +from sklearn.feature_extraction.text import CountVectorizer + +vectorizer_model = CountVectorizer(ngram_range=(1, 2), stop_words="english", min_df=10) +topic_model = BERTopic(vectorizer_model=vectorizer_model) +``` + +The [min_df](https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html) +parameter is used to indicate the minimum frequency of words. Setting this value larger than 1 can significantly reduce memory. If the problem persists, then this could be an issue related to your available memory. The processing of millions of documents is quite computationally expensive and sufficient RAM is necessary. @@ -56,8 +72,29 @@ Third, although this does not happen very often, there simply aren't that many t in your documents. You can often see this when you have many `-1` topics, which is actually not a topic but a category of outliers. -## **Why are the probabilities not calculated?** -Although it is possible to calculate the probabilities, the process of doing so is quite computationally +## **I have too many topics, how do I decrease them?** +If you have a large dataset, then it is possible to generate thousands of topics. Especially with large +datasets, there is a good chance they actually contain many small topics. In practice, you might want +a few hundred topics at most in order to interpret them nicely. + +There are a few ways of increasing the number of generated topics: + +First, we can set the `min_topic_size` in the BERTopic initialization much higher (e.g., 300) +to make sure that those small clusters will not be generated. This is a HDBSCAN parameter that +specifies what the minimum number of documents are needed in a cluster. More documents in a cluster +means less topics will be generated. + +Second, you can create a custom UMAP model and set `n_neighbors` much higher than the default 15 (e.g., 200). +This also prevents those micro clusters to be generated as it will needs quite a number of neighboring +documents to create a cluster. + +Third, we can set `nr_topics` to a value that seems logical to the user. Do note that topics are forced +to merge together which might result in a lower quality of topics. In practice, I would advise using +`nr_topic="auto"` as that will merge topics together that are very similar. Dissimilar topics will +therefore remain separated. + +## **How do I calculate the probabilities of all topics in a document?** +Although it is possible to calculate all the probabilities, the process of doing so is quite computationally inefficient and might significantly increase the computation time. To prevent this, the probabilities are not calculated as a default. In order to calculate, you will have to set `calculate_probabilities` to True: @@ -81,6 +118,26 @@ I would suggest doing one of the following: * Use the above step also with numpy as it is part of the issue * Install BERTopic in a fresh environment using these steps. +## **How can I run BERTopic without an internet connection?** +The great thing about using sentence-transformers is that it searches automatically for an embedding model locally. +If it cannot find one, it will download the pre-trained model from its servers. +Make sure that you set the correct path for sentence-transformers to work. You can find a bit more about that +[here](https://github.com/UKPLab/sentence-transformers/issues/888). + +You can download the corresponding model [here](https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/) +and unzip it. Then, simply use the following to create your embedding model: + +```python +from sentence_transformers import SentenceTransformer +embedding_model = SentenceTransformer('path/to/unzipped/model') +``` + +Then, pass it to BERTopic: + +```python +from bertopic import BERTopic +topic_model = BERTopic(embedding_model=embedding_model) +``` ## **Can I use the GPU to speed up the model?** Yes and no. The GPU is automatically used when you use a SentenceTransformer or Flair embedding model. Using a CPU @@ -88,12 +145,28 @@ would then definitely slow things down. However, UMAP and HDBSCAN are not GPU-ac the near future. For now, a GPU does help tremendously for extracting embeddings but does not speed up all aspects of BERtopic. -## **Should I preprocess the data?** -No. By using document embeddings there is typically no need to preprocess the data as all parts of a document -are important in understanding the general topic of the document. Although this holds true in 99% of cases, if you -have data that contains a lot of noise, for example, HTML-tags, then it would be best to remove them. HTML-tags -typically do not contribute to the meaning of a document and should therefore be removed. However, if you apply -topic modeling to HTML-code to extract topics of code, then it becomes important. +## **How can I use BERTopic with Chinese documents?** +Currently, CountVectorizer tokenizes text by splitting whitespace which does not work for Chinese. +In order to get it to work, you will have to create a custom `CountVectorizer` with `jieba`: + +```python +from sklearn.feature_extraction.text import CountVectorizer +import jieba + +def tokenize_zh(text): + words = jieba.lcut(text) + return words + +vectorizer = CountVectorizer(tokenizer=tokenize_zh) +``` + +Next, we pass our custom vectorizer to BERTopic and create our topic model: + +```python +from bertopic import BERTopic +topic_model = BERTopic(embedding_model=model, verbose=True, vectorizer_model=vectorizer) +topics, _ = topic_model.fit_transform(docs, embeddings=embeddings) +``` ## **Why does it take so long to import BERTopic?** The main culprit here seems to be UMAP. After running tests with [Tuna](https://github.com/nschloe/tuna) we @@ -103,4 +176,10 @@ can see that most of the resources when importing BERTopic can be dedicated to U Unfortunately, there currently is no fix for this issue. The most recent ticket regarding this issue can be found [here](https://github.com/lmcinnes/umap/issues/631). - + +## **Should I preprocess the data?** +No. By using document embeddings there is typically no need to preprocess the data as all parts of a document +are important in understanding the general topic of the document. Although this holds true in 99% of cases, if you +have data that contains a lot of noise, for example, HTML-tags, then it would be best to remove them. HTML-tags +typically do not contribute to the meaning of a document and should therefore be removed. However, if you apply +topic modeling to HTML-code to extract topics of code, then it becomes important. \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index b2e4d2b2..f113558a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -44,7 +44,7 @@ from sklearn.datasets import fetch_20newsgroups docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] topic_model = BERTopic() -topics, _ = topic_model.fit_transform(docs) +topics, probs = topic_model.fit_transform(docs) ``` After generating topics and their probabilities, we can access the frequent topics that were generated: diff --git a/docs/tutorial/embeddings/embeddings.md b/docs/tutorial/embeddings/embeddings.md index 3ba14755..047702d0 100644 --- a/docs/tutorial/embeddings/embeddings.md +++ b/docs/tutorial/embeddings/embeddings.md @@ -184,7 +184,7 @@ embeddings = sentence_model.encode(docs, show_progress_bar=False) # Create topic model and use the custom embeddings topic_model = BERTopic() -topics, _ = topic_model.fit_transform(docs, embeddings) +topics, probs = topic_model.fit_transform(docs, embeddings) ``` As you can see above, we used a SentenceTransformer model to create the embedding. You could also have used @@ -208,7 +208,7 @@ embeddings = vectorizer.fit_transform(docs) # topic_model = BERTopic(stop_words="english") -topics, _ = topic_model.fit_transform(docs, embeddings) +topics, probs = topic_model.fit_transform(docs, embeddings) ``` Here, you will probably notice that creating the embeddings is quite fast whereas `fit_transform` is quite slow. diff --git a/docs/tutorial/guided/guided.md b/docs/tutorial/guided/guided.md new file mode 100644 index 00000000..5a2b81bd --- /dev/null +++ b/docs/tutorial/guided/guided.md @@ -0,0 +1,52 @@ +## **Guided Topic Modeling** +Guided Topic Modeling or Seeded Topic Modeling is a collection of techniques that guides the topic modeling approach +by setting a number of seed topics in which the model will converge to. These techniques allow the user to set a +pre-defined number of topic representations that are sure to be in documents. For example, take an IT-business +that has a ticket system for the software their clients use. Those tickets may typically contain information about +a specific bug regarding login issues that the IT-business is aware off. + +To model that bug, we can create a seed topic representation containing the words `bug`, `login`, `password`, +and `username`. By defining those words, a Guided Topic Modeling approach will try to converge at least one topic to those words. + +Guided BERTopic has two main steps: + +First, we create embeddings for each seeded topics by joining them and passing them through the document embedder. +These embeddings will be compared with the existing document embeddings through cosine similarity and assigned a label. +If the document is most similar to a seeded topic, then it will get that topic's label. +If it is most similar to the average document embedding, it will get the -1 label. +These labels are then passed through UMAP to create a semi-supervised approach that should nudge +the topic creation to the seeded topics. + +Second, we take all words in seed_topic_list and assign them a multiplier larger than 1. +Those multipliers will be used to increase the IDF values of the words across all topics thereby increasing +the likelihood that a seeded topic word will appear in a topic. This does, however, also increase the chance of an +irrelevant topic having unrelated words. In practice, this should not be an issue since the IDF value is likely +to remain low regardless of the multiplier. The multiplier is now a fixed value but may change to something +more elegant, like taking the distribution of IDF values and its position into account when defining the multiplier. + +### **Example** +To demonstrate Guided BERTopic, we use the 20 Newsgroups dataset as our example. We have frequently used this +dataset in BERTopic examples and we sometimes see a topic generated about health with words as `drug` and `cancer` +being important. However, due to the stocastisch nature of UMAP this topic is not always found. + +In order to guide BERTopic to that topic, we create a seed topic list that we pass through our model. However, +there may be several other topics that we know should be in the documents. Let's also initialize those: + +```python +from bertopic import BERTopic +from sklearn.datasets import fetch_20newsgroups + +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))["data"] + +seed_topic_list = [["drug", "cancer", "drugs", "doctor"], + ["windows", "drive", "dos", "file"], + ["space", "launch", "orbit", "lunar"]] + +topic_model = BERTopic(seed_topic_list=seed_topic_list) +topics, probs = topic_model.fit_transform(docs) +``` + +AS you can see above, the `seed_topic_list` contains a list of topic representations. By defining the above topics +BERTopic is more likely to model the defined seeded topics. However, BERTopic is merely nudged towards creating those +topics. In practice, if the seeded topics do not exist or might be divided into smaller topics, then they will +not be modeled. Thus, seed topics need to be accurate in order to accurately converge towards them. \ No newline at end of file diff --git a/docs/tutorial/quickstart/quickstart.md b/docs/tutorial/quickstart/quickstart.md index 06b31310..81a78d2b 100644 --- a/docs/tutorial/quickstart/quickstart.md +++ b/docs/tutorial/quickstart/quickstart.md @@ -32,7 +32,7 @@ from sklearn.datasets import fetch_20newsgroups docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] topic_model = BERTopic() -topics, _ = topic_model.fit_transform(docs) +topics, probs = topic_model.fit_transform(docs) ``` After generating topics, we can access the frequent topics that were generated: diff --git a/docs/tutorial/search/search.md b/docs/tutorial/search/search.md index aa4a2856..a4d1a86f 100644 --- a/docs/tutorial/search/search.md +++ b/docs/tutorial/search/search.md @@ -10,7 +10,7 @@ from sklearn.datasets import fetch_20newsgroups # Create topics docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] topic_model = BERTopic() -topics, _ = topic_model.fit_transform(docs) +topics, probs = topic_model.fit_transform(docs) ``` After having trained our model, we can use `find_topics` to search for topics that are similar diff --git a/docs/tutorial/topicrepresentation/topicrepresentation.md b/docs/tutorial/topicrepresentation/topicrepresentation.md index fa16f345..f0368608 100644 --- a/docs/tutorial/topicrepresentation/topicrepresentation.md +++ b/docs/tutorial/topicrepresentation/topicrepresentation.md @@ -15,7 +15,7 @@ from sklearn.datasets import fetch_20newsgroups # Create topics docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] topic_model = BERTopic(n_gram_range=(2, 3)) -topics, _ = topic_model.fit_transform(docs) +topics, probs = topic_model.fit_transform(docs) ``` From the model created above, one of the most frequent topics is the following: diff --git a/docs/tutorial/topicsovertime/topicsovertime.md b/docs/tutorial/topicsovertime/topicsovertime.md index d2a5c644..b0c2766c 100644 --- a/docs/tutorial/topicsovertime/topicsovertime.md +++ b/docs/tutorial/topicsovertime/topicsovertime.md @@ -53,7 +53,7 @@ Then, we need to extract the global topic representations by simply creating and from bertopic import BERTopic topic_model = BERTopic(verbose=True) -topics, _ = topic_model.fit_transform(tweets) +topics, probs = topic_model.fit_transform(tweets) ``` From these topics, we are going to generate the topic representations at each timestamp for each topic. We do this diff --git a/docs/tutorial/topicsperclass/topicsperclass.md b/docs/tutorial/topicsperclass/topicsperclass.md index 7bafa3d9..d27625e7 100644 --- a/docs/tutorial/topicsperclass/topicsperclass.md +++ b/docs/tutorial/topicsperclass/topicsperclass.md @@ -25,7 +25,7 @@ Next, we want to extract the topics across all documents without taking the cate ```python topic_model = BERTopic(verbose=True) -topics, _ = topic_model.fit_transform(docs) +topics, probs = topic_model.fit_transform(docs) ``` Now that we have created our global topic model, let us calculate the topic representations across each category: diff --git a/docs/tutorial/visualization/visualization.md b/docs/tutorial/visualization/visualization.md index 9e5f715c..ba2dfac5 100644 --- a/docs/tutorial/visualization/visualization.md +++ b/docs/tutorial/visualization/visualization.md @@ -15,7 +15,7 @@ from sklearn.datasets import fetch_20newsgroups docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] topic_model = BERTopic() -topics, _ = topic_model.fit_transform(docs) +topics, probs = topic_model.fit_transform(docs) ``` Then, we simply call `topic_model.visualize_topics()` in order to visualize our topics. The resulting graph is a @@ -114,7 +114,7 @@ tweets = trump.text.to_list() # Create topics over time model = BERTopic(verbose=True) -topics, _ = model.fit_transform(tweets) +topics, probs = model.fit_transform(tweets) topics_over_time = model.topics_over_time(tweets, topics, timestamps) ``` @@ -144,7 +144,7 @@ classes = [data["target_names"][i] for i in data["target"]] # Create topic model and calculate topics per class topic_model = BERTopic() -topics, _ = topic_model.fit_transform(docs) +topics, probs = topic_model.fit_transform(docs) topics_per_class = topic_model.topics_per_class(docs, topics, classes=classes) ``` diff --git a/mkdocs.yml b/mkdocs.yml index e2aea614..9eb24f89 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -19,6 +19,7 @@ nav: - Topics per Class: tutorial/topicsperclass/topicsperclass.md - (semi)-Supervised Topic Modeling: tutorial/supervised/supervised.md - Dynamic Topic Modeling: tutorial/topicsovertime/topicsovertime.md + - Guided Topic Modeling: tutorial/guided/guided.md - API: - BERTopic: api/bertopic.md - cTFIDF: api/ctfidf.md diff --git a/setup.py b/setup.py index 5adb2872..d958e86c 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ setup( name="bertopic", packages=find_packages(exclude=["notebooks", "docs"]), - version="0.8.1", + version="0.9.0", author="Maarten P. Grootendorst", author_email="maartengrootendorst@gmail.com", description="BERTopic performs topic Modeling with state-of-the-art transformer models.",