Skip to content

Commit

Permalink
Implement 3D support for scatterplot (jbesomi#126)
Browse files Browse the repository at this point in the history
* Implement 3D support for scatterplot.

Additionally:
- clean up the function
- add `hover_name` as argument
- improve docstring
- don't show figure when `return_figure` is set to True

Co-authored-by: Maximilian Krahn <[email protected]>

* Implement suggested changes.

Co-authored-by: Henri Froese <[email protected]>
Co-authored-by: Maximilian Krahn <[email protected]>
  • Loading branch information
3 people authored Jul 30, 2020
1 parent 6106035 commit f66f23c
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 17 deletions.
20 changes: 20 additions & 0 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,26 @@ def load_tests(loader, tests, ignore):


class TestVisualization(PandasTestCase):
"""
Test scatterplot.
"""

def test_scatterplot_dimension_too_high(self):
s = pd.Series([[1, 2, 3, 4], [1, 2, 3, 4]])
df = pd.DataFrame(s)
self.assertRaises(ValueError, visualization.scatterplot, df, col=0)

def test_scatterplot_dimension_too_low(self):
s = pd.Series([[1], [1]])
df = pd.DataFrame(s)
self.assertRaises(ValueError, visualization.scatterplot, df, col=0)

def test_scatterplot_return_figure(self):
s = pd.Series([[1, 2, 3], [1, 2, 3]])
df = pd.DataFrame(s)
ret = visualization.scatterplot(df, col=0, return_figure=True)
self.assertIsNotNone(ret)

"""
Test top_words.
"""
Expand Down
69 changes: 52 additions & 17 deletions texthero/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import pandas as pd
import numpy as np
import plotly.express as px

from wordcloud import WordCloud
Expand All @@ -20,58 +21,92 @@ def scatterplot(
df: pd.DataFrame,
col: str,
color: str = None,
hover_name: str = None,
hover_data: [] = None,
title="",
return_figure=False,
):
"""
Show scatterplot of DataFrame column using python plotly scatter.
Plot the values in column col. For example, if every cell in df[col]
is a list of three values (e.g. from doing PCA with 3 components),
a 3D-Plot is created and every cell entry [x, y, z] is visualized
as the point (x, y, z).
Parameters
----------
df: DataFrame with a column to be visualized.
col: str
The name of the column of the DataFrame to use for x and y axis.
The name of the column of the DataFrame to use for x and y (and z) axis.
color: str, default to None.
Name of the column to use for coloring (rows with same value get same color).
title: str, default to "".
Title of the plot.
return_figure: optional, default to False.
Function returns the figure if set to True.
hover_name: str, default to None
Name of the column to supply title of hover data when hovering over a point.
hover_data: List[str], default to [].
List of column names to supply data when hovering over a point.
hover_name: str, default to None
Name of the column to supply title of hover data when hovering over a point.
title: str, default to "".
Title of the plot.
return_figure: optional, default to False.
Function returns the figure instead of showing it if set to True.
Examples
--------
>>> import texthero as hero
>>> import pandas as pd
>>> df = pd.DataFrame(["Football, Sports, Soccer", "music, violin, orchestra", "football, fun, sports"], columns=["texts"])
>>> df = pd.DataFrame(["Football, Sports, Soccer", "music, violin, orchestra", "football, fun, sports", "music, fun, guitar"], columns=["texts"])
>>> df["texts"] = hero.clean(df["texts"]).pipe(hero.tokenize)
>>> df["pca"] = hero.tfidf(df["texts"]).pipe(hero.pca)
>>> df["pca"] = hero.tfidf(df["texts"]).pipe(hero.pca, n_components=3)
>>> df["topics"] = hero.tfidf(df["texts"]).pipe(hero.kmeans, n_clusters=2)
>>> hero.scatterplot(df, col="pca", color="topics", hover_data=["texts"]) # doctest: +SKIP
"""

pca0 = df[col].apply(lambda x: x[0])
pca1 = df[col].apply(lambda x: x[1])
plot_values = np.stack(df[col], axis=1)
dimension = len(plot_values)

fig = px.scatter(
df, x=pca0, y=pca1, color=color, hover_data=hover_data, title=title
)
# fig.show(config={'displayModeBar': False})
fig.show()
if dimension < 2 or dimension > 3:
raise ValueError(
"The column you want to visualize has dimension < 2 or dimension > 3."
" The function can only visualize 2- and 3-dimensional data."
)

if dimension == 2:
x, y = plot_values[0], plot_values[1]

fig = px.scatter(
df,
x=x,
y=y,
color=color,
hover_data=hover_data,
title=title,
hover_name=hover_name,
)

else:
x, y, z = plot_values[0], plot_values[1], plot_values[2]

fig = px.scatter_3d(
df,
x=x,
y=y,
z=z,
color=color,
hover_data=hover_data,
title=title,
hover_name=hover_name,
)

if return_figure:
return fig
else:
fig.show()


"""
Expand Down

0 comments on commit f66f23c

Please sign in to comment.