From 4597dcba90c5762808b8a6727505291469011864 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Tue, 10 Dec 2019 08:07:07 -0500 Subject: [PATCH] Convert PIL image objects to data uri strings in JSON serialization. This conversion is already done by the layout.image.source validator, this this way it will also happen when serializing from a dict without validation, and for images that show up elsewhere in the figure (as mapbox layers for example) --- .../plotly/_plotly_utils/basevalidators.py | 23 +++++++++++-------- packages/python/plotly/_plotly_utils/utils.py | 11 +++++++++ .../test_optional/test_utils/test_utils.py | 18 +++++++++++++++ 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/packages/python/plotly/_plotly_utils/basevalidators.py b/packages/python/plotly/_plotly_utils/basevalidators.py index 9bdf86adce9..32e06fd63ea 100644 --- a/packages/python/plotly/_plotly_utils/basevalidators.py +++ b/packages/python/plotly/_plotly_utils/basevalidators.py @@ -2347,20 +2347,25 @@ def validate_coerce(self, v): pass elif self._PIL and isinstance(v, self._PIL.Image.Image): # Convert PIL image to png data uri string - in_mem_file = io.BytesIO() - v.save(in_mem_file, format="PNG") - in_mem_file.seek(0) - img_bytes = in_mem_file.read() - base64_encoded_result_bytes = base64.b64encode(img_bytes) - base64_encoded_result_str = base64_encoded_result_bytes.decode("ascii") - v = "data:image/png;base64,{base64_encoded_result_str}".format( - base64_encoded_result_str=base64_encoded_result_str - ) + v = self.pil_image_to_uri(v) else: self.raise_invalid_val(v) return v + @staticmethod + def pil_image_to_uri(v): + in_mem_file = io.BytesIO() + v.save(in_mem_file, format="PNG") + in_mem_file.seek(0) + img_bytes = in_mem_file.read() + base64_encoded_result_bytes = base64.b64encode(img_bytes) + base64_encoded_result_str = base64_encoded_result_bytes.decode("ascii") + v = "data:image/png;base64,{base64_encoded_result_str}".format( + base64_encoded_result_str=base64_encoded_result_str + ) + return v + class CompoundValidator(BaseValidator): def __init__(self, plotly_name, parent_name, data_class_str, data_docs, **kwargs): diff --git a/packages/python/plotly/_plotly_utils/utils.py b/packages/python/plotly/_plotly_utils/utils.py index 26face3c25f..a2275374ea0 100644 --- a/packages/python/plotly/_plotly_utils/utils.py +++ b/packages/python/plotly/_plotly_utils/utils.py @@ -4,6 +4,7 @@ import re from _plotly_utils.optional_imports import get_module +from _plotly_utils.basevalidators import ImageUriValidator PY36_OR_LATER = sys.version_info.major == 3 and sys.version_info.minor >= 6 @@ -104,6 +105,7 @@ def default(self, obj): self.encode_as_date, self.encode_as_list, # because some values have `tolist` do last. self.encode_as_decimal, + self.encode_as_pil, ) for encoding_method in encoding_methods: try: @@ -192,6 +194,15 @@ def encode_as_decimal(obj): else: raise NotEncodable + @staticmethod + def encode_as_pil(obj): + """Attempt to convert PIL.Image.Image to base64 data uri""" + pil = get_module("PIL") + if isinstance(obj, pil.Image.Image): + return ImageUriValidator.pil_image_to_uri(obj) + else: + raise NotEncodable + class NotEncodable(Exception): pass diff --git a/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py b/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py index d23f567a6ba..d26e7ac6cea 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py @@ -16,9 +16,12 @@ from nose.plugins.attrib import attr from pandas.util.testing import assert_series_equal import json as _json +import os +import base64 from plotly import optional_imports, utils from plotly.graph_objs import Scatter, Scatter3d, Figure, Data +from PIL import Image matplotlylib = optional_imports.get_module("plotly.matplotlylib") @@ -274,6 +277,21 @@ def test_datetime_dot_date(self): j1 = _json.dumps(a, cls=utils.PlotlyJSONEncoder) assert j1 == '["2014-01-01", "2014-01-02"]' + def test_pil_image_encoding(self): + import _plotly_utils + + img_path = os.path.join( + _plotly_utils.__path__[0], "tests", "resources", "1x1-black.png" + ) + + with open(img_path, "rb") as f: + hex_bytes = base64.b64encode(f.read()).decode("ascii") + expected_uri = "data:image/png;base64," + hex_bytes + + img = Image.open(img_path) + j1 = _json.dumps({"source": img}, cls=utils.PlotlyJSONEncoder) + assert j1 == '{"source": "%s"}' % expected_uri + if matplotlylib: