Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ffae2ed

Browse filesBrowse files
elibixbyJon Wayne Parrott
authored andcommitted
Fix ml_engine tests (GoogleCloudPlatform#850)
1 parent 87f7d24 commit ffae2ed
Copy full SHA for ffae2ed

File tree

Expand file treeCollapse file tree

2 files changed

+16
-18
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+16
-18
lines changed

‎ml_engine/online_prediction/predict.py

Copy file name to clipboardExpand all lines: ml_engine/online_prediction/predict.py
+6-6Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def predict_json(project, model, instances, version=None):
6262

6363

6464
# [START predict_tf_records]
65-
def predict_tf_records(project,
66-
model,
67-
example_bytes_list,
68-
version=None):
65+
def predict_examples(project,
66+
model,
67+
example_bytes_list,
68+
version=None):
6969
"""Send protocol buffer data to a deployed model for prediction.
7070
7171
Args:
@@ -119,7 +119,7 @@ def census_to_example_bytes(json_instance):
119119
"""
120120
import tensorflow as tf
121121
feature_dict = {}
122-
for key, data in json_instance.iteritems():
122+
for key, data in six.iteritems(json_instance):
123123
if isinstance(data, six.string_types):
124124
feature_dict[key] = tf.train.Feature(
125125
bytes_list=tf.train.BytesList(value=[str(data)]))
@@ -153,7 +153,7 @@ def main(project, model, version=None, force_tfrecord=False):
153153
census_to_example_bytes(e)
154154
for e in user_input
155155
]
156-
result = predict_tf_records(
156+
result = predict_examples(
157157
project, model, example_bytes_list, version=version)
158158
else:
159159
result = predict_json(

‎ml_engine/online_prediction/predict_test.py

Copy file name to clipboardExpand all lines: ml_engine/online_prediction/predict_test.py
+10-12Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323

2424
MODEL = 'census'
25-
VERSION = 'v1'
26-
TF_RECORDS_VERSION = 'v1tfrecord'
25+
JSON_VERSION = 'v1json'
26+
EXAMPLES_VERSION = 'v1example'
2727
PROJECT = 'python-docs-samples-tests'
2828
JSON = {
2929
'age': 25,
@@ -41,22 +41,21 @@
4141
'native_country': ' United-States'
4242
}
4343
EXPECTED_OUTPUT = {
44-
u'probabilities': [0.9942260384559631, 0.005774002522230148],
45-
u'logits': [-5.148599147796631],
46-
u'classes': 0,
47-
u'logistic': [0.005774001590907574]
44+
u'confidence': 0.7760371565818787,
45+
u'predictions': u' <=50K'
4846
}
4947

5048

5149
def test_predict_json():
5250
result = predict.predict_json(
53-
PROJECT, MODEL, [JSON, JSON], version=VERSION)
51+
PROJECT, MODEL, [JSON, JSON], version=JSON_VERSION)
5452
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result
5553

5654

5755
def test_predict_json_error():
5856
with pytest.raises(RuntimeError):
59-
predict.predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION)
57+
predict.predict_json(
58+
PROJECT, MODEL, [{"foo": "bar"}], version=JSON_VERSION)
6059

6160

6261
@pytest.mark.slow
@@ -66,9 +65,8 @@ def test_census_example_to_bytes():
6665

6766

6867
@pytest.mark.slow
69-
@pytest.mark.xfail('Single placeholder inputs broken in service b/35778449')
70-
def test_predict_tfrecords():
68+
def test_predict_examples():
7169
b = predict.census_to_example_bytes(JSON)
72-
result = predict.predict_tfrecords(
73-
PROJECT, MODEL, [b, b], version=TF_RECORDS_VERSION)
70+
result = predict.predict_examples(
71+
PROJECT, MODEL, [b, b], version=EXAMPLES_VERSION)
7472
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.