Skip to content

Commit dfd7fac

Browse files
Merge pull request #117 from x-tabdeveloping/optimization
Auto n_components for multiple topic models
2 parents e3c981d + 96e47a4 commit dfd7fac

File tree

10 files changed

+757
-79
lines changed

10 files changed

+757
-79
lines changed

docs/Topeax.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Topeax
2+
3+
Topeax is a probabilistic topic model based on the Peax clustering model, which finds topics based on peaks in point density in the embedding space.
4+
It can recover the number of topics automatically.
5+
6+
<br>
7+
<figure>
8+
<img src="../images/peax.png" width="100%" style="margin-left: auto;margin-right: auto;">
9+
<figcaption>Schematic overview of the steps of the Peax clustering algorithm</figcaption>
10+
</figure>
11+
12+
:warning: **This part of the documentation is still under construction, as more details and a paper are on their way.** :warning:
13+
14+
## API Reference
15+
16+
::: turftopic.models.topeax.Topeax
17+
18+
::: turftopic.models.topeax.Peax

docs/images/peax.png

260 KB
Loading

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ profile = "black"
99

1010
[project]
1111
name = "turftopic"
12-
version = "0.19.1"
12+
version = "0.20.0"
1313
description = "Topic modeling with contextual representations from sentence transformers."
1414
authors = [
1515
{ name = "Márton Kardos <[email protected]>", email = "[email protected]" }

turftopic/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from turftopic.models.fastopic import FASTopic
88
from turftopic.models.gmm import GMM
99
from turftopic.models.keynmf import KeyNMF
10+
from turftopic.models.topeax import Topeax
1011
from turftopic.serialization import load_model
1112

1213
try:
@@ -20,6 +21,7 @@
2021
"ClusteringTopicModel",
2122
"SemanticSignalSeparation",
2223
"GMM",
24+
"Topeax",
2325
"KeyNMF",
2426
"AutoEncodingTopicModel",
2527
"ContextualModel",

turftopic/encoders/utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import itertools
22
from typing import Iterable, List
33

4+
import numpy as np
5+
import torch
6+
from tqdm import trange
7+
48

59
def batched(iterable, n: int) -> Iterable[List[str]]:
610
"Batch data into tuples of length n. The last batch may be shorter."
@@ -10,3 +14,47 @@ def batched(iterable, n: int) -> Iterable[List[str]]:
1014
it = iter(iterable)
1115
while batch := list(itertools.islice(it, n)):
1216
yield batch
17+
18+
19+
def encode_chunks(
20+
encoder,
21+
sentences,
22+
batch_size=64,
23+
window_size=50,
24+
step_size=40,
25+
return_chunks=False,
26+
show_progress_bar=False,
27+
):
28+
chunks = []
29+
chunk_embeddings = []
30+
for start_index in trange(
31+
0,
32+
len(sentences),
33+
batch_size,
34+
desc="Encoding batches...",
35+
disable=not show_progress_bar,
36+
):
37+
batch = sentences[start_index : start_index + batch_size]
38+
features = encoder.tokenize(batch)
39+
with torch.no_grad():
40+
output_features = encoder.forward(features)
41+
n_tokens = output_features["attention_mask"].sum(axis=1)
42+
for i_doc in range(len(batch)):
43+
for chunk_start in range(0, n_tokens[i_doc], step_size):
44+
chunk_end = min(chunk_start + window_size, n_tokens[i_doc])
45+
_emb = output_features["token_embeddings"][
46+
i_doc, chunk_start:chunk_end, :
47+
].mean(axis=0)
48+
chunk_embeddings.append(_emb)
49+
if return_chunks:
50+
chunks.append(
51+
encoder.tokenizer.decode(
52+
features["input_ids"][i_doc, chunk_start:chunk_end]
53+
)
54+
.replace("[CLS]", "")
55+
.replace("[SEP]", "")
56+
)
57+
if not return_chunks:
58+
chunks = None
59+
chunk_embeddings = np.stack(chunk_embeddings)
60+
return chunk_embeddings, chunks

turftopic/models/_keynmf.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import warnings
33
from collections import defaultdict
44
from datetime import datetime
5-
from typing import Iterable, Literal, Optional
5+
from functools import partial
6+
from typing import Iterable, Literal, Optional, Union
67

78
import igraph as ig
89
import numpy as np
@@ -21,6 +22,10 @@
2122
from sklearn.utils.validation import check_non_negative
2223

2324
from turftopic.base import Encoder
25+
from turftopic.optimization import (
26+
decomposition_gaussian_bic,
27+
optimize_n_components,
28+
)
2429

2530
NOT_MATCHING_ERROR = (
2631
"Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). "
@@ -242,7 +247,7 @@ def batch_extract_keywords(
242247
class KeywordNMF:
243248
def __init__(
244249
self,
245-
n_components: int,
250+
n_components: Union[int, Literal["auto"]],
246251
seed: Optional[int] = None,
247252
top_n: Optional[int] = None,
248253
):
@@ -318,6 +323,15 @@ def vectorize(
318323

319324
def fit_transform(self, keywords: list[dict[str, float]]) -> np.ndarray:
320325
X = self.vectorize(keywords, fitting=True)
326+
if self.n_components == "auto":
327+
# Finding N components with BIC
328+
bic_fn = partial(
329+
decomposition_gaussian_bic,
330+
decomp_class=NMF,
331+
X=X,
332+
)
333+
n_components = optimize_n_components(bic_fn, min_n=1, verbose=True)
334+
self.n_components = n_components
321335
check_non_negative(X, "NMF (input X)")
322336
W, H = _initialize_nmf(X, self.n_components, random_state=self.seed)
323337
W, H, self.n_iter = NMF(
@@ -339,6 +353,10 @@ def transform(self, keywords: list[dict[str, float]]):
339353
return W.astype(X.dtype)
340354

341355
def partial_fit(self, keyword_batch: list[dict[str, float]]):
356+
if self.n_components == "auto":
357+
raise ValueError(
358+
"Cannot infer number of components with BIC when online fitting the model."
359+
)
342360
X = self.vectorize(keyword_batch, fitting=True)
343361
try:
344362
check_non_negative(X, "NMF (input X)")
@@ -365,6 +383,15 @@ def fit_transform_dynamic(
365383
n_bins = len(time_bin_edges) - 1
366384
document_term_matrix = self.vectorize(keywords, fitting=True)
367385
check_non_negative(document_term_matrix, "NMF (input X)")
386+
if self.n_components == "auto":
387+
# Finding N components with BIC
388+
bic_fn = partial(
389+
decomposition_gaussian_bic,
390+
decomp_class=NMF,
391+
X=X,
392+
)
393+
n_components = optimize_n_components(bic_fn, verbose=True)
394+
self.n_components = n_components
368395
document_topic_matrix, H = _initialize_nmf(
369396
document_term_matrix,
370397
self.n_components,

0 commit comments

Comments
 (0)