Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 86e57cb

Browse filesBrowse files
fix(spanner_dbapi): replace insecure pickle with json for partition deserialization (#17014)
This PR resolves a critical Insecure Deserialization vulnerability (potential Remote Code Execution) in the `spanner_dbapi` module [b/510871112](b/510871112) . Previously, the module utilized `pickle.loads()` to decode partition IDs provided by users via the `RUN PARTITION` statement, creating a possibility for arbitrary code execution attack payloads. We have fully eliminated `pickle` usage in this module and migrated to standard `json` serialization. --------- Co-authored-by: Knut Olav Løite <koloite@gmail.com>
1 parent 6b62cb6 commit 86e57cb
Copy full SHA for 86e57cb

5 files changed

+616-7Lines changed: 616 additions & 7 deletions

File tree

Expand file treeCollapse file tree
Open diff view settings
Filter options
Expand file treeCollapse file tree
Open diff view settings
Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_dbapi/partition_helper.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_dbapi/partition_helper.py
+126-4Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,145 @@
1313
# limitations under the License.
1414

1515
import base64
16+
import copy
17+
import datetime
1618
import gzip
17-
import pickle
19+
import json
1820
from dataclasses import dataclass
1921
from typing import Any
2022

23+
from google.protobuf.json_format import MessageToDict, ParseDict
24+
from google.protobuf.message import Message
25+
from google.protobuf.struct_pb2 import Struct
26+
2127
from google.cloud.spanner_v1 import BatchTransactionId
28+
from google.cloud.spanner_v1._helpers import _make_value_pb
29+
from google.cloud.spanner_v1.types import DirectedReadOptions, ExecuteSqlRequest, Type
30+
31+
_PROTO_CLASS_MAP = {
32+
"QueryOptions": ExecuteSqlRequest.QueryOptions,
33+
"DirectedReadOptions": DirectedReadOptions,
34+
"Struct": Struct,
35+
"Type": Type,
36+
}
37+
38+
39+
def _serialize_value(val: Any) -> Any:
40+
if isinstance(val, bytes):
41+
return {"__type__": "bytes", "value": base64.b64encode(val).decode("utf-8")}
42+
elif isinstance(val, datetime.datetime):
43+
return {"__type__": "datetime", "value": val.isoformat()}
44+
elif hasattr(val, "_pb"):
45+
return {
46+
"__type__": "protobuf",
47+
"class": val.__class__.__name__,
48+
"value": MessageToDict(val._pb, preserving_proto_field_name=True),
49+
}
50+
elif isinstance(val, Message):
51+
return {
52+
"__type__": "protobuf",
53+
"class": val.__class__.__name__,
54+
"value": MessageToDict(val, preserving_proto_field_name=True),
55+
}
56+
elif isinstance(val, dict):
57+
return {k: _serialize_value(v) for k, v in val.items()}
58+
elif isinstance(val, list):
59+
return [_serialize_value(v) for v in val]
60+
elif isinstance(val, tuple):
61+
return {"__type__": "tuple", "value": [_serialize_value(v) for v in val]}
62+
return val
63+
64+
65+
def _deserialize_value(val: Any) -> Any:
66+
if isinstance(val, dict):
67+
if "__type__" in val:
68+
t = val["__type__"]
69+
if t == "bytes":
70+
return base64.b64decode(val["value"])
71+
elif t == "datetime":
72+
dt_str = val["value"]
73+
if dt_str.endswith("Z"):
74+
dt_str = dt_str[:-1] + "+00:00"
75+
return datetime.datetime.fromisoformat(dt_str)
76+
elif t == "tuple":
77+
return tuple(_deserialize_value(x) for x in val["value"])
78+
elif t == "protobuf":
79+
cls_name = val.get("class")
80+
dict_val = val["value"]
81+
if cls_name in _PROTO_CLASS_MAP:
82+
cls = _PROTO_CLASS_MAP[cls_name]
83+
msg = cls()._pb if hasattr(cls(), "_pb") else cls()
84+
ParseDict(dict_val, msg)
85+
return cls(msg) if hasattr(cls(), "_pb") else msg
86+
return _deserialize_value(dict_val)
87+
return {k: _deserialize_value(v) for k, v in val.items()}
88+
elif isinstance(val, list):
89+
return [_deserialize_value(v) for v in val]
90+
return val
91+
92+
93+
def _unpack_value_pb(value):
94+
which = value.WhichOneof("kind")
95+
if which == "null_value":
96+
return None
97+
elif which == "number_value":
98+
return value.number_value
99+
elif which == "string_value":
100+
return value.string_value
101+
elif which == "bool_value":
102+
return value.bool_value
103+
elif which == "struct_value":
104+
return {k: _unpack_value_pb(v) for k, v in value.struct_value.fields.items()}
105+
elif which == "list_value":
106+
return [_unpack_value_pb(v) for v in value.list_value.values]
107+
return None
22108

23109

24110
def decode_from_string(encoded_partition_id):
25111
gzip_bytes = base64.b64decode(bytes(encoded_partition_id, "utf-8"))
26112
partition_id_bytes = gzip.decompress(gzip_bytes)
27-
return pickle.loads(partition_id_bytes)
113+
114+
data = json.loads(partition_id_bytes.decode("utf-8"))
115+
btid_data = data["batch_transaction_id"]
116+
btid = BatchTransactionId(
117+
transaction_id=_deserialize_value(btid_data["transaction_id"]),
118+
session_id=btid_data["session_id"],
119+
read_timestamp=_deserialize_value(btid_data["read_timestamp"]),
120+
)
121+
partition_result = _deserialize_value(data["partition_result"])
122+
123+
# Post-process query params back from Protobuf Struct to Python primitives
124+
if "query" in partition_result and "params" in partition_result["query"]:
125+
params_pb = partition_result["query"]["params"]
126+
if params_pb:
127+
partition_result["query"]["params"] = {
128+
k: _unpack_value_pb(v) for k, v in params_pb.fields.items()
129+
}
130+
131+
return PartitionId(btid, partition_result)
28132

29133

30134
def encode_to_string(batch_transaction_id, partition_result):
31-
partition_id = PartitionId(batch_transaction_id, partition_result)
32-
partition_id_bytes = pickle.dumps(partition_id)
135+
# Copy to avoid modifying the caller's dictionary in connection.py
136+
partition_result = copy.deepcopy(partition_result)
137+
138+
# Pre-process query params into a Protobuf Struct
139+
if "query" in partition_result and "params" in partition_result["query"]:
140+
params = partition_result["query"]["params"]
141+
if params:
142+
params_pb = Struct(fields={k: _make_value_pb(v) for k, v in params.items()})
143+
partition_result["query"]["params"] = params_pb
144+
145+
data = {
146+
"batch_transaction_id": {
147+
"transaction_id": _serialize_value(batch_transaction_id.transaction_id),
148+
"session_id": batch_transaction_id.session_id,
149+
"read_timestamp": _serialize_value(batch_transaction_id.read_timestamp),
150+
},
151+
"partition_result": _serialize_value(partition_result),
152+
}
153+
154+
partition_id_bytes = json.dumps(data).encode("utf-8")
33155
gzip_bytes = gzip.compress(partition_id_bytes)
34156
return str(base64.b64encode(gzip_bytes), "utf-8")
35157

Collapse file

‎packages/google-cloud-spanner/google/cloud/spanner_v1/testing/mock_spanner.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/google/cloud/spanner_v1/testing/mock_spanner.py
+14-2Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,21 @@ class MockSpanner:
3636
def __init__(self):
3737
self.results = {}
3838
self.execute_streaming_sql_results = {}
39+
self.partition_results = {}
3940
self.errors = {}
4041

4142
def clear_results(self):
4243
self.results = {}
4344
self.execute_streaming_sql_results = {}
45+
self.partition_results = {}
4446
self.errors = {}
4547

4648
def add_result(self, sql: str, result: result_set.ResultSet):
4749
self.results[sql.lower().strip()] = result
4850

51+
def add_partition_result(self, sql: str, result: spanner.PartitionResponse):
52+
self.partition_results[sql.lower().strip()] = result
53+
4954
def add_execute_streaming_sql_results(
5055
self, sql: str, partial_result_sets: list[result_set.PartialResultSet]
5156
):
@@ -57,6 +62,12 @@ def get_result(self, sql: str) -> result_set.ResultSet:
5762
raise ValueError(f"No result found for {sql}")
5863
return result
5964

65+
def get_partition_result(self, sql: str) -> spanner.PartitionResponse:
66+
result = self.partition_results.get(sql.lower().strip())
67+
if result is None:
68+
return spanner.PartitionResponse()
69+
return result
70+
6071
def add_error(self, method: str, error: _Status):
6172
if not hasattr(self, "_errors_list"):
6273
self._errors_list = {}
@@ -300,11 +311,12 @@ def Rollback(self, request, context):
300311

301312
def PartitionQuery(self, request, context):
302313
self._requests.append(request)
303-
return spanner.PartitionResponse()
314+
return self.mock_spanner.get_partition_result(request.sql)
304315

305316
def PartitionRead(self, request, context):
306317
self._requests.append(request)
307-
return spanner.PartitionResponse()
318+
# For reads, look up by target table name
319+
return self.mock_spanner.get_partition_result(request.table)
308320

309321
def BatchWrite(self, request, context):
310322
self._requests.append(request)
Collapse file

‎packages/google-cloud-spanner/tests/_helpers.py‎

Copy file name to clipboardExpand all lines: packages/google-cloud-spanner/tests/_helpers.py
+4-1Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from os import getenv
22
from unittest import IsolatedAsyncioTestCase
33

4-
import mock
4+
try:
5+
import mock
6+
except ImportError:
7+
import unittest.mock as mock
58

69
from google.cloud.spanner_v1 import gapic_version
710
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
Collapse file
+134Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from google.cloud.spanner_dbapi.connection import Connection
17+
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
18+
from google.cloud.spanner_v1 import TypeCode
19+
from google.cloud.spanner_v1.types import spanner as spanner_types
20+
from tests.mockserver_tests.mock_server_test_base import (
21+
MockServerTestBase,
22+
add_single_result,
23+
)
24+
25+
26+
class TestDbapiPartitionQuery(MockServerTestBase):
27+
def test_partition_query_and_run_partition(self):
28+
sql = "SELECT name FROM users WHERE active = true"
29+
30+
# 1. Set up mock results for PartitionQuery RPC in the mock servicer
31+
partition_response = spanner_types.PartitionResponse()
32+
partition_response.partitions.extend(
33+
[
34+
spanner_types.Partition(partition_token=b"mock-token-1"),
35+
spanner_types.Partition(partition_token=b"mock-token-2"),
36+
]
37+
)
38+
self.spanner_service.mock_spanner.add_partition_result(sql, partition_response)
39+
40+
# 2. Set up mock results for ExecuteSql when executing the partitions
41+
add_single_result(sql, "name", TypeCode.STRING, [("Alice",), ("Bob",)])
42+
43+
# 3. Connect via DB-API and mark connection as read-only (required for partitioning)
44+
connection = Connection(self.instance, self.database)
45+
connection._read_only = True
46+
47+
# Define partitioning parameters inside DB-API Statement
48+
from google.cloud.spanner_dbapi.parsed_statement import (
49+
ClientSideStatementType,
50+
StatementType,
51+
)
52+
53+
parsed = ParsedStatement(
54+
statement_type=StatementType.CLIENT_SIDE,
55+
statement=Statement(sql),
56+
client_side_statement_type=ClientSideStatementType.PARTITION_QUERY,
57+
client_side_statement_params=["SELECT name FROM users WHERE active = true"],
58+
)
59+
60+
# Generate serialized token strings (Base64 + GZip JSON)
61+
partition_ids = connection.partition_query(parsed)
62+
self.assertEqual(2, len(partition_ids))
63+
64+
# 4. Reconstruct & Execute the partitions by deserializing their tokens
65+
all_names = []
66+
for token in partition_ids:
67+
result_stream = connection.run_partition(token)
68+
for row in result_stream:
69+
all_names.append(row[0])
70+
71+
# Verify results are successfully round-tripped and parsed
72+
self.assertIn("Alice", all_names)
73+
self.assertIn("Bob", all_names)
74+
75+
def test_partition_query_with_complex_parameters(self):
76+
import datetime
77+
import decimal
78+
79+
sql = "SELECT name FROM users WHERE active = @active AND salary > @salary AND signup_time = @signup_time"
80+
81+
# Set up complex parameter values (bool, Decimal, datetime)
82+
params = {
83+
"active": True,
84+
"salary": decimal.Decimal("75000.50"),
85+
"signup_time": datetime.datetime(
86+
2026, 5, 10, 12, 34, 56, tzinfo=datetime.timezone.utc
87+
),
88+
}
89+
from google.cloud.spanner_v1 import Type
90+
91+
param_types = {
92+
"active": Type(code=TypeCode.BOOL),
93+
"salary": Type(code=TypeCode.NUMERIC),
94+
"signup_time": Type(code=TypeCode.TIMESTAMP),
95+
}
96+
97+
# 1. Mock results for the partition generation RPC
98+
partition_response = spanner_types.PartitionResponse()
99+
partition_response.partitions.extend(
100+
[spanner_types.Partition(partition_token=b"complex-mock-token-1")]
101+
)
102+
self.spanner_service.mock_spanner.add_partition_result(sql, partition_response)
103+
104+
# 2. Mock results for execution of partition streaming SQL
105+
add_single_result(sql, "name", TypeCode.STRING, [("Charlie",)])
106+
107+
# 3. Establish Connection
108+
connection = Connection(self.instance, self.database)
109+
connection._read_only = True
110+
111+
from google.cloud.spanner_dbapi.parsed_statement import (
112+
ClientSideStatementType,
113+
StatementType,
114+
)
115+
116+
parsed = ParsedStatement(
117+
statement_type=StatementType.CLIENT_SIDE,
118+
statement=Statement(sql, params=params, param_types=param_types),
119+
client_side_statement_type=ClientSideStatementType.PARTITION_QUERY,
120+
client_side_statement_params=[sql],
121+
)
122+
123+
# Execute partition generation - this serializes query parameters!
124+
partition_ids = connection.partition_query(parsed)
125+
self.assertEqual(1, len(partition_ids))
126+
127+
# 4. Reconstruct and run the partition E2E
128+
all_names = []
129+
for token in partition_ids:
130+
result_stream = connection.run_partition(token)
131+
for row in result_stream:
132+
all_names.append(row[0])
133+
134+
self.assertEqual(["Charlie"], all_names)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.