diff --git a/python/pyspark/pandas/plot/plotly.py b/python/pyspark/pandas/plot/plotly.py index dfcc13931d4bb..ebf23416344d4 100644 --- a/python/pyspark/pandas/plot/plotly.py +++ b/python/pyspark/pandas/plot/plotly.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import inspect from typing import TYPE_CHECKING, Union import pandas as pd @@ -109,7 +110,11 @@ def plot_histogram(data: Union["ps.DataFrame", "ps.Series"], **kwargs): ) ) - fig = go.Figure(data=bars, layout=go.Layout(barmode="stack")) + layout_keys = inspect.signature(go.Layout).parameters.keys() + layout_kwargs = {k: v for k, v in kwargs.items() if k in layout_keys} + + fig = go.Figure(data=bars, layout=go.Layout(**layout_kwargs)) + fig["layout"]["barmode"] = "stack" fig["layout"]["xaxis"]["title"] = "value" fig["layout"]["yaxis"]["title"] = "count" return fig diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py index 7be00d593ee36..2937ef1813f74 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py @@ -186,6 +186,13 @@ def check_pie_plot(psdf): # ) # check_pie_plot(psdf1) + def test_hist_layout_kwargs(self): + s = ps.Series([1, 3, 2]) + plt = s.plot.hist(title="Title", foo="xxx") + self.assertEqual(plt.layout.barmode, "stack") + self.assertEqual(plt.layout.title.text, "Title") + self.assertFalse(hasattr(plt.layout, "foo")) + def test_hist_plot(self): def check_hist_plot(psdf): bins = np.array([1.0, 5.9, 10.8, 15.7, 20.6, 25.5, 30.4, 35.3, 40.2, 45.1, 50.0])