Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b0417c9

Browse filesBrowse files
authored
Add samples for the Cloud ML Engine (GoogleCloudPlatform#824)
Samples for using online prediction with JSON and TFRecord inputs.
1 parent 8123cdd commit b0417c9
Copy full SHA for b0417c9

File tree

Expand file treeCollapse file tree

4 files changed

+274
-0
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+274
-0
lines changed

‎ml_engine/online_prediction/README.md

Copy file name to clipboard
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
https://cloud.google.com/ml-engine/docs/concepts/prediction-overview
+198Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#!/bin/python
2+
# Copyright 2017 Google Inc. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Examples of using the Cloud ML Engine's online prediction service."""
17+
import argparse
18+
import base64
19+
import json
20+
21+
# [START import_libraries]
22+
import googleapiclient.discovery
23+
# [END import_libraries]
24+
import six
25+
26+
27+
# [START predict_json]
28+
def predict_json(project, model, instances, version=None):
29+
"""Send json data to a deployed model for prediction.
30+
31+
Args:
32+
project (str): project where the Cloud ML Engine Model is deployed.
33+
model (str): model name.
34+
instances ([Mapping[str: Any]]): Keys should be the names of Tensors
35+
your deployed model expects as inputs. Values should be datatypes
36+
convertible to Tensors, or (potentially nested) lists of datatypes
37+
convertible to tensors.
38+
version: str, version of the model to target.
39+
Returns:
40+
Mapping[str: any]: dictionary of prediction results defined by the
41+
model.
42+
"""
43+
# Create the ML Engine service object.
44+
# To authenticate set the environment variable
45+
# GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
46+
service = googleapiclient.discovery.build('ml', 'v1beta1')
47+
name = 'projects/{}/models/{}'.format(project, model)
48+
49+
if version is not None:
50+
name += '/versions/{}'.format(version)
51+
52+
response = service.projects().predict(
53+
name=name,
54+
body={'instances': instances}
55+
).execute()
56+
57+
if 'error' in response:
58+
raise RuntimeError(response['error'])
59+
60+
return response['predictions']
61+
# [END predict_json]
62+
63+
64+
# [START predict_tf_records]
65+
def predict_tf_records(project,
66+
model,
67+
example_bytes_list,
68+
version=None):
69+
"""Send protocol buffer data to a deployed model for prediction.
70+
71+
Args:
72+
project (str): project where the Cloud ML Engine Model is deployed.
73+
model (str): model name.
74+
example_bytes_list ([str]): A list of bytestrings representing
75+
serialized tf.train.Example protocol buffers. The contents of this
76+
protocol buffer will change depending on the signature of your
77+
deployed model.
78+
version: str, version of the model to target.
79+
Returns:
80+
Mapping[str: any]: dictionary of prediction results defined by the
81+
model.
82+
"""
83+
service = googleapiclient.discovery.build('ml', 'v1beta1')
84+
name = 'projects/{}/models/{}'.format(project, model)
85+
86+
if version is not None:
87+
name += '/versions/{}'.format(version)
88+
89+
response = service.projects().predict(
90+
name=name,
91+
body={'instances': [
92+
{'b64': base64.b64encode(example_bytes)}
93+
for example_bytes in example_bytes_list
94+
]}
95+
).execute()
96+
97+
if 'error' in response:
98+
raise RuntimeError(response['error'])
99+
100+
return response['predictions']
101+
# [END predict_tf_records]
102+
103+
104+
# [START census_to_example_bytes]
105+
def census_to_example_bytes(json_instance):
106+
"""Serialize a JSON example to the bytes of a tf.train.Example.
107+
This method is specific to the signature of the Census example.
108+
See: https://cloud.google.com/ml-engine/docs/concepts/prediction-overview
109+
for details.
110+
111+
Args:
112+
json_instance (Mapping[str: Any]): Keys should be the names of Tensors
113+
your deployed model expects to parse using it's tf.FeatureSpec.
114+
Values should be datatypes convertible to Tensors, or (potentially
115+
nested) lists of datatypes convertible to tensors.
116+
Returns:
117+
str: A string as a container for the serialized bytes of
118+
tf.train.Example protocol buffer.
119+
"""
120+
import tensorflow as tf
121+
feature_dict = {}
122+
for key, data in json_instance.iteritems():
123+
if isinstance(data, six.string_types):
124+
feature_dict[key] = tf.train.Feature(
125+
bytes_list=tf.train.BytesList(value=[str(data)]))
126+
elif isinstance(data, float):
127+
feature_dict[key] = tf.train.Feature(
128+
float_list=tf.train.FloatList(value=[data]))
129+
elif isinstance(data, int):
130+
feature_dict[key] = tf.train.Feature(
131+
int64_list=tf.train.Int64List(value=[data]))
132+
return tf.train.Example(
133+
features=tf.train.Features(
134+
feature=feature_dict
135+
)
136+
).SerializeToString()
137+
# [END census_to_example_bytes]
138+
139+
140+
def main(project, model, version=None, force_tfrecord=False):
141+
"""Send user input to the prediction service."""
142+
while True:
143+
try:
144+
user_input = json.loads(raw_input("Valid JSON >>>"))
145+
except KeyboardInterrupt:
146+
return
147+
148+
if not isinstance(user_input, list):
149+
user_input = [user_input]
150+
try:
151+
if force_tfrecord:
152+
example_bytes_list = [
153+
census_to_example_bytes(e)
154+
for e in user_input
155+
]
156+
result = predict_tf_records(
157+
project, model, example_bytes_list, version=version)
158+
else:
159+
result = predict_json(
160+
project, model, user_input, version=version)
161+
except RuntimeError as err:
162+
print(str(err))
163+
else:
164+
print(result)
165+
166+
167+
if __name__ == '__main__':
168+
parser = argparse.ArgumentParser()
169+
parser.add_argument(
170+
'--project',
171+
help='Project in which the model is deployed',
172+
type=str,
173+
required=True
174+
)
175+
parser.add_argument(
176+
'--model',
177+
help='Model name',
178+
type=str,
179+
required=True
180+
)
181+
parser.add_argument(
182+
'--version',
183+
help='Name of the version.',
184+
type=str
185+
)
186+
parser.add_argument(
187+
'--force-tfrecord',
188+
help='Send predictions as TFRecords rather than raw JSON',
189+
action='store_true',
190+
default=False
191+
)
192+
args = parser.parse_args()
193+
main(
194+
args.project,
195+
args.model,
196+
version=args.version,
197+
force_tfrecord=args.force_tfrecord
198+
)
+74Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2017 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for predict.py ."""
16+
17+
import base64
18+
19+
import pytest
20+
21+
import predict
22+
23+
24+
MODEL = 'census'
25+
VERSION = 'v1'
26+
TF_RECORDS_VERSION = 'v1tfrecord'
27+
PROJECT = 'python-docs-samples-tests'
28+
JSON = {
29+
'age': 25,
30+
'workclass': ' Private',
31+
'education': ' 11th',
32+
'education_num': 7,
33+
'marital_status': ' Never-married',
34+
'occupation': ' Machine-op-inspct',
35+
'relationship': ' Own-child',
36+
'race': ' Black',
37+
'gender': ' Male',
38+
'capital_gain': 0,
39+
'capital_loss': 0,
40+
'hours_per_week': 40,
41+
'native_country': ' United-States'
42+
}
43+
EXPECTED_OUTPUT = {
44+
u'probabilities': [0.9942260384559631, 0.005774002522230148],
45+
u'logits': [-5.148599147796631],
46+
u'classes': 0,
47+
u'logistic': [0.005774001590907574]
48+
}
49+
50+
51+
def test_predict_json():
52+
result = predict.predict_json(
53+
PROJECT, MODEL, [JSON, JSON], version=VERSION)
54+
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result
55+
56+
57+
def test_predict_json_error():
58+
with pytest.raises(RuntimeError):
59+
predict.predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION)
60+
61+
62+
@pytest.mark.slow
63+
def test_census_example_to_bytes():
64+
b = predict.census_to_example_bytes(JSON)
65+
assert base64.b64encode(b) is not None
66+
67+
68+
@pytest.mark.slow
69+
@pytest.mark.xfail('Single placeholder inputs broken in service b/35778449')
70+
def test_predict_tfrecords():
71+
b = predict.census_to_example_bytes(JSON)
72+
result = predict.predict_tfrecords(
73+
PROJECT, MODEL, [b, b], version=TF_RECORDS_VERSION)
74+
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tensorflow==1.0.0

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.