11"""External API for creating self-contained figures for documents."""
2+
23from typing import List , Optional , Union
34
45import numpy as np
@@ -56,7 +57,10 @@ def document_map(
5657
5758
5859def document_topic_distribution (
59- topic_data : TopicData , documents : Union [List [str ], str ], top_n : int = 8
60+ topic_data : TopicData ,
61+ documents : Union [List [str ], str ],
62+ top_n : int = 8 ,
63+ color_scheme : str = "Portland" ,
6064) -> go .Figure :
6165 """Displays topic distribution on a bar plot for a document
6266 or a set of documents.
@@ -69,6 +73,8 @@ def document_topic_distribution(
6973 Documents to display topic distribution for.
7074 top_n: int, default 8
7175 Number of topics to display at most.
76+ color_scheme: str, default 'Portland'
77+ Name of the Plotly color scheme to use for the plot.
7278 """
7379 transform = topic_data ["transform" ]
7480 if transform is None :
@@ -80,7 +86,7 @@ def document_topic_distribution(
8086 topic_importances = prepare .document_topic_importances (transform (documents ))
8187 topic_importances = topic_importances .groupby (["topic_id" ]).sum ().reset_index ()
8288 n_topics = topic_data ["document_topic_matrix" ].shape [- 1 ]
83- twilight = colors .get_colorscale ("Portland" )
89+ twilight = colors .get_colorscale (color_scheme )
8490 topic_colors = colors .sample_colorscale (twilight , np .arange (n_topics ) / n_topics )
8591 topic_colors = np .array (topic_colors )
8692 return plots .document_topic_barplot (
@@ -89,7 +95,11 @@ def document_topic_distribution(
8995
9096
9197def document_topic_timeline (
92- topic_data : TopicData , document : str , window_size : int = 10 , step_size : int = 1
98+ topic_data : TopicData ,
99+ document : str ,
100+ window_size : int = 10 ,
101+ step_size : int = 1 ,
102+ color_scheme : str = "Portland" ,
93103) -> go .Figure :
94104 """Projects documents into 2d space and displays them on a scatter plot.
95105
@@ -103,6 +113,8 @@ def document_topic_timeline(
103113 The windows over which topic inference should be run.
104114 step_size: int, default 1
105115 Size of the steps for the rolling window.
116+ color_scheme: str, default 'Portland'
117+ Name of Plotly color scheme to use for the plot.
106118 """
107119 timeline = prepare .calculate_timeline (
108120 doc_id = 0 ,
@@ -113,7 +125,7 @@ def document_topic_timeline(
113125 )
114126 topic_names = topic_data ["topic_names" ]
115127 n_topics = len (topic_names )
116- twilight = colors .get_colorscale ("Portland" )
128+ twilight = colors .get_colorscale (color_scheme )
117129 topic_colors = colors .sample_colorscale (twilight , np .arange (n_topics ) / n_topics )
118130 topic_colors = np .array (topic_colors )
119131 return plots .document_timeline (timeline , topic_names , topic_colors )
0 commit comments