diff --git a/bigframes/operations/_matplotlib/core.py b/bigframes/operations/_matplotlib/core.py index 2c1c2bc4ac..04534e20a9 100644 --- a/bigframes/operations/_matplotlib/core.py +++ b/bigframes/operations/_matplotlib/core.py @@ -14,7 +14,6 @@ import abc import typing -import uuid import pandas as pd @@ -115,6 +114,18 @@ def _compute_plot_data(self): if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE: sample[c] = sample[c].astype("object") + # To avoid Matplotlib's automatic conversion of `Float64` or `Int64` columns + # to `object` types (which breaks float-like behavior), this code proactively + # converts the column to a compatible format. + s = self.kwargs.get("s", None) + if pd.core.dtypes.common.is_integer(s): + s = self.data.columns[s] + if self._is_column_name(s, sample): + if sample[s].dtype == dtypes.INT_DTYPE: + sample[s] = sample[s].astype("int64") + elif sample[s].dtype == dtypes.FLOAT_DTYPE: + sample[s] = sample[s].astype("float64") + return sample def _is_sequence_arg(self, arg): @@ -130,9 +141,3 @@ def _is_column_name(self, arg, data): and pd.core.dtypes.common.is_hashable(arg) and arg in data.columns ) - - def _generate_new_column_name(self, data): - col_name = None - while col_name is None or col_name in data.columns: - col_name = f"plot_temp_{str(uuid.uuid4())[:8]}" - return col_name diff --git a/tests/system/small/operations/test_plotting.py b/tests/system/small/operations/test_plotting.py index 824125adf2..6542ce6de3 100644 --- a/tests/system/small/operations/test_plotting.py +++ b/tests/system/small/operations/test_plotting.py @@ -240,6 +240,32 @@ def test_scatter_args_c(c): ) +@pytest.mark.parametrize( + ("s"), + [ + pytest.param([10, 34, 50], id="int"), + pytest.param([1.0, 3.4, 5.0], id="float"), + pytest.param( + [True, True, False], id="bool", marks=pytest.mark.xfail(raises=ValueError) + ), + ], +) +def test_scatter_args_s(s): + data = { + "a": [1, 2, 3], + "b": [1, 2, 3], + } + data["s"] = s + df = bpd.DataFrame(data) + pd_df = pd.DataFrame(data) + + ax = df.plot.scatter(x="a", y="b", s="s") + pd_ax = pd_df.plot.scatter(x="a", y="b", s="s") + tm.assert_numpy_array_equal( + ax.collections[0].get_sizes(), pd_ax.collections[0].get_sizes() + ) + + @pytest.mark.parametrize( ("arg_name"), [