Skip to content

Commit

Permalink
Add option for weighting anchor word MI by anchor_strength when ranki…
Browse files Browse the repository at this point in the history
…ng words for topics
  • Loading branch information
ryanjgallagher committed Mar 22, 2021
1 parent a213671 commit a355ab8
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions corextopic/corextopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def calculate_mis(self, theta, log_p_y):
mis = self.h_x - p_y * binary_entropy(np.exp(theta[3]).T) - (1 - p_y) * binary_entropy(np.exp(theta[2]).T)
return (mis - 1. / (2. * self.n_samples)).clip(0.) # P-T bias correction

def get_topics(self, n_words=10, topic=None, print_words=True):
def get_topics(self, n_words=10, topic=None, print_words=True, weighted_rank=True):
"""
Gets the top words for topics. If `words` was not provided to CorEx,
then the indices of the top words are returned
Expand All @@ -543,6 +543,10 @@ def get_topics(self, n_words=10, topic=None, print_words=True):
the top words for all topics
print_words, bool
Whether to return string words or integer indices for the topics
weighted_rank, bool
Whether to weight the mutual information of anchor words with their
anchor strengths. If anchoring was not used, then this does not
affect the output
RETURNS
-------
Expand Down Expand Up @@ -576,12 +580,25 @@ def get_topics(self, n_words=10, topic=None, print_words=True):
# Get indices of which words belong to the topic
inds = np.where(self.alpha[n] >= 1.)[0]
# Sort topic words according to mutual information
inds = inds[np.argsort(-self.alpha[n,inds] * self.mis[n,inds])]
# Create topic to return
if print_words is True:
topic = [(self.col_index2word[ind], self.mis[n,ind], self.sign[n,ind]) for ind in inds[:n_words]]
if weighted_rank:
inds = inds[np.argsort(-self.alpha[n,inds] * self.mis[n,inds])]
else:
topic = [(ind, self.mis[n,ind], self.sign[n,ind]) for ind in inds[:n_words]]
inds = inds[np.argsort(-self.mis[n,inds])]
# Create topic to return
topic = []
for ind in inds[:n_words]:
if print_words is True:
word = self.col_index2word[ind]
else:
word = ind

if weighted_rank:
mi = self.alpha[n,inds] * self.mis[n,ind]
else:
mi = self.mis[n,ind]

topic.append((word, mi, self.sign[n,ind]))

# Add topic to list of topics if returning all topics. Otherwise, return topic
if len(topic_ns) != 1:
topics.append(topic)
Expand Down

0 comments on commit a355ab8

Please sign in to comment.