Skip to content

Commit 66861e7

Browse files
Merge pull request #42 from x-tabdeveloping/color_scheme
Added option to use custom color scheme in figures API
2 parents 4066f5e + 0e18be2 commit 66861e7

File tree

6 files changed

+42
-13
lines changed

6 files changed

+42
-13
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "topic-wizard"
3-
version = "1.1.1"
3+
version = "1.1.2"
44
description = "Pretty and opinionated topic model visualization in Python."
55
authors = ["Márton Kardos <[email protected]>"]
66
license = "MIT"
@@ -13,7 +13,7 @@ dash = "^2.7.1"
1313
dash-extensions = "^1.0.4"
1414
dash-mantine-components = "~0.12.1"
1515
dash-iconify = "~0.1.2"
16-
joblib = "~1.2.0"
16+
joblib = "^1.2.0"
1717
scikit-learn = "^1.2.0"
1818
scipy = ">=1.8.0"
1919
umap-learn = ">=0.5.3"

topicwizard/figures/documents.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""External API for creating self-contained figures for documents."""
2+
23
from typing import List, Optional, Union
34

45
import numpy as np
@@ -56,7 +57,10 @@ def document_map(
5657

5758

5859
def document_topic_distribution(
59-
topic_data: TopicData, documents: Union[List[str], str], top_n: int = 8
60+
topic_data: TopicData,
61+
documents: Union[List[str], str],
62+
top_n: int = 8,
63+
color_scheme: str = "Portland",
6064
) -> go.Figure:
6165
"""Displays topic distribution on a bar plot for a document
6266
or a set of documents.
@@ -69,6 +73,8 @@ def document_topic_distribution(
6973
Documents to display topic distribution for.
7074
top_n: int, default 8
7175
Number of topics to display at most.
76+
color_scheme: str, default 'Portland'
77+
Name of the Plotly color scheme to use for the plot.
7278
"""
7379
transform = topic_data["transform"]
7480
if transform is None:
@@ -80,7 +86,7 @@ def document_topic_distribution(
8086
topic_importances = prepare.document_topic_importances(transform(documents))
8187
topic_importances = topic_importances.groupby(["topic_id"]).sum().reset_index()
8288
n_topics = topic_data["document_topic_matrix"].shape[-1]
83-
twilight = colors.get_colorscale("Portland")
89+
twilight = colors.get_colorscale(color_scheme)
8490
topic_colors = colors.sample_colorscale(twilight, np.arange(n_topics) / n_topics)
8591
topic_colors = np.array(topic_colors)
8692
return plots.document_topic_barplot(
@@ -89,7 +95,11 @@ def document_topic_distribution(
8995

9096

9197
def document_topic_timeline(
92-
topic_data: TopicData, document: str, window_size: int = 10, step_size: int = 1
98+
topic_data: TopicData,
99+
document: str,
100+
window_size: int = 10,
101+
step_size: int = 1,
102+
color_scheme: str = "Portland",
93103
) -> go.Figure:
94104
"""Projects documents into 2d space and displays them on a scatter plot.
95105
@@ -103,6 +113,8 @@ def document_topic_timeline(
103113
The windows over which topic inference should be run.
104114
step_size: int, default 1
105115
Size of the steps for the rolling window.
116+
color_scheme: str, default 'Portland'
117+
Name of Plotly color scheme to use for the plot.
106118
"""
107119
timeline = prepare.calculate_timeline(
108120
doc_id=0,
@@ -113,7 +125,7 @@ def document_topic_timeline(
113125
)
114126
topic_names = topic_data["topic_names"]
115127
n_topics = len(topic_names)
116-
twilight = colors.get_colorscale("Portland")
128+
twilight = colors.get_colorscale(color_scheme)
117129
topic_colors = colors.sample_colorscale(twilight, np.arange(n_topics) / n_topics)
118130
topic_colors = np.array(topic_colors)
119131
return plots.document_timeline(timeline, topic_names, topic_colors)

topicwizard/figures/groups.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""External API for creating self-contained figures for groups."""
2+
23
from typing import List
34

45
import numpy as np
@@ -68,7 +69,11 @@ def group_map(topic_data: TopicData, group_labels: List[str]) -> go.Figure:
6869

6970

7071
def group_topic_barcharts(
71-
topic_data: TopicData, group_labels: List[str], top_n: int = 5, n_columns: int = 4
72+
topic_data: TopicData,
73+
group_labels: List[str],
74+
top_n: int = 5,
75+
n_columns: int = 4,
76+
color_scheme: str = "Portland",
7277
):
7378
"""Displays the most important topics for each group.
7479
@@ -82,6 +87,8 @@ def group_topic_barcharts(
8287
Maximum number of topics to display for each group.
8388
n_columns: int, default 4
8489
Indicates how many columns the faceted plot should have.
90+
color_scheme: str, default 'Portland'
91+
Name of the plotly color scheme to use for the figure.
8592
"""
8693
# Factorizing group labels
8794
group_id_labels, group_names = pd.factorize(group_labels)
@@ -105,7 +112,7 @@ def group_topic_barcharts(
105112
horizontal_spacing=0.01,
106113
)
107114
n_topics = len(topic_data["topic_names"])
108-
color_scheme = colors.get_colorscale("Portland")
115+
color_scheme = colors.get_colorscale(color_scheme)
109116
topic_colors = colors.sample_colorscale(
110117
color_scheme, np.arange(n_topics) / n_topics, low=0.25, high=1.0
111118
)

topicwizard/figures/topics.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def topic_wordclouds(
110110
topic_data: TopicData,
111111
top_n: int = 30,
112112
n_columns: int = 4,
113+
color_scheme: str = "copper",
113114
) -> go.Figure:
114115
"""Plots most relevant words as word clouds for every topic.
115116
@@ -121,6 +122,8 @@ def topic_wordclouds(
121122
Specifies the number of words to show for each topic.
122123
n_columns: int, default 4
123124
Number of columns in the subplot grid.
125+
color_scheme: str, default 'copper'
126+
Matplotlib color scheme to use for the wordcloud.
124127
"""
125128
n_topics = topic_data["topic_term_matrix"].shape[0]
126129
(
@@ -147,7 +150,7 @@ def topic_wordclouds(
147150
components=topic_term_importances,
148151
vocab=topic_data["vocab"],
149152
)
150-
subfig = plots.wordcloud(top_words)
153+
subfig = plots.wordcloud(top_words, color_scheme=color_scheme)
151154
row, column = (topic_id // n_columns) + 1, (topic_id % n_columns) + 1
152155
fig.add_trace(subfig.data[0], row=row, col=column)
153156
fig.update_layout(

topicwizard/figures/words.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def word_map(
1818
topic_data: TopicData,
1919
z_threshold: float = 2.0,
2020
topic_axes: Optional[Tuple[Union[str, int], Union[str, int]]] = None,
21+
color_scheme: str = "tempo",
2122
) -> go.Figure:
2223
"""Plots words on a scatter plot based on UMAP projections
2324
of their importances in topics into 2D space or by two topic axes.
@@ -38,6 +39,8 @@ def word_map(
3839
The topic axes along which the words should be displayed.
3940
If not specified, the axes on the graph are going to be
4041
UMAP projections' dimensions.
42+
color_scheme: str, default 'tempo'
43+
Name of the Plotly color scheme to use for the plot.
4144
"""
4245
topic_names = topic_data["topic_names"]
4346
if topic_axes is None:
@@ -58,7 +61,7 @@ def word_map(
5861
freq_z = zscore(word_frequencies)
5962
dominant_topic = prepare.dominant_topic(topic_data["topic_term_matrix"])
6063
dominant_topic = np.array(topic_data["topic_names"])[dominant_topic]
61-
tempo = colors.get_colorscale("tempo")
64+
tempo = colors.get_colorscale(color_scheme)
6265
n_topics = len(topic_data["topic_names"])
6366
topic_colors = colors.sample_colorscale(tempo, np.arange(n_topics) / n_topics)
6467
topic_colors = np.array(topic_colors)
@@ -99,6 +102,7 @@ def word_association_barchart(
99102
words: Union[List[str], str],
100103
n_association: int = 0,
101104
top_n: int = 20,
105+
color_scheme: str = "Rainbow",
102106
):
103107
"""Plots bar chart of most important topics for the given words and their closest
104108
associations in topic space.
@@ -114,6 +118,8 @@ def word_association_barchart(
114118
None get displayed by default.
115119
top_n: int = 20
116120
Top N topics to display.
121+
color_scheme: str, default 'Rainbow'
122+
Name of the Plotly color scheme to use for the plot.
117123
"""
118124
if isinstance(words, str):
119125
words = [words]
@@ -128,7 +134,7 @@ def word_association_barchart(
128134
word_ids, topic_data["topic_term_matrix"], n_association
129135
)
130136
n_topics = topic_data["topic_term_matrix"].shape[0]
131-
tempo = colors.get_colorscale("Rainbow")
137+
tempo = colors.get_colorscale(color_scheme)
132138
topic_colors = colors.sample_colorscale(tempo, np.arange(n_topics) / n_topics)
133139
topic_colors = np.array(topic_colors)
134140
top_topics = prepare.top_topics(

topicwizard/plots/topics.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Module containing plotting utilities for topics."""
2+
23
from typing import List
34

45
import numpy as np
@@ -139,7 +140,7 @@ def topic_plot(top_words: pd.DataFrame):
139140
return fig
140141

141142

142-
def wordcloud(top_words: pd.DataFrame) -> go.Figure:
143+
def wordcloud(top_words: pd.DataFrame, color_scheme: str = "copper") -> go.Figure:
143144
"""Plots most relevant words for current topic as a worcloud."""
144145
top_dict = {
145146
word: importance
@@ -151,7 +152,7 @@ def wordcloud(top_words: pd.DataFrame) -> go.Figure:
151152
width=800,
152153
height=1060,
153154
background_color="white",
154-
colormap="copper",
155+
colormap=color_scheme,
155156
scale=4,
156157
).generate_from_frequencies(top_dict)
157158
image = cloud.to_image()

0 commit comments

Comments
 (0)