|
10 | 10 | # the License.
|
11 | 11 | """Tests for predict.py ."""
|
12 | 12 | import base64
|
| 13 | + |
13 | 14 | import pytest
|
14 |
| -from predict import predict_json, predict_tf_records, census_to_example_bytes |
| 15 | + |
| 16 | +from predict import census_to_example_bytes, predict_json |
15 | 17 |
|
16 | 18 |
|
17 | 19 | MODEL = 'census'
|
18 | 20 | VERSION = 'v1'
|
19 |
| -PROJECT = 'python-docs-samples-test' |
20 |
| -JSON = {'age': 25, 'workclass': ' Private', 'education': ' 11th', 'education_num': 7, 'marital_status': ' Never-married', 'occupation': ' Machine-op-inspct', 'relationship': ' Own-child', 'race': ' Black', 'gender': ' Male', 'capital_gain': 0, 'capital_loss': 0, 'hours_per_week': 40, 'native_country': ' United-States'} |
21 |
| -EXAMPLE_BYTE_STRING = 'CuoCChoKDmhvdXJzX3Blcl93ZWVrEggSBgoEAAAgQgoZCgl3b3JrY2xhc3MSDAoKCgggUHJpdmF0ZQoeCgxyZWxhdGlvbnNoaXASDgoMCgogT3duLWNoaWxkChMKBmdlbmRlchIJCgcKBSBNYWxlCg8KA2FnZRIIEgYKBAAAyEEKJAoObWFyaXRhbF9zdGF0dXMSEgoQCg4gTmV2ZXItbWFycmllZAoSCgRyYWNlEgoKCAoGIEJsYWNrChkKDWVkdWNhdGlvbl9udW0SCBIGCgQAAOBACiQKDm5hdGl2ZV9jb3VudHJ5EhIKEAoOIFVuaXRlZC1TdGF0ZXMKGAoMY2FwaXRhbF9sb3NzEggSBgoEAAAAAAoWCgllZHVjYXRpb24SCQoHCgUgMTF0aAoYCgxjYXBpdGFsX2dhaW4SCBIGCgQAAAAACiQKCm9jY3VwYXRpb24SFgoUChIgTWFjaGluZS1vcC1pbnNwY3Q=' |
22 |
| - |
23 |
| -EXPECTED_OUTPUT = {u'probabilities': [0.9942260384559631, 0.005774002522230148], u'logits': [-5.148599147796631], u'classes': 0, u'logistic': [0.005774001590907574]} |
| 21 | +PROJECT = 'python-docs-samples-tests' |
| 22 | +JSON = { |
| 23 | + 'age': 25, |
| 24 | + 'workclass': ' Private', |
| 25 | + 'education': ' 11th', |
| 26 | + 'education_num': 7, |
| 27 | + 'marital_status': ' Never-married', |
| 28 | + 'occupation': ' Machine-op-inspct', |
| 29 | + 'relationship': ' Own-child', |
| 30 | + 'race': ' Black', |
| 31 | + 'gender': ' Male', |
| 32 | + 'capital_gain': 0, |
| 33 | + 'capital_loss': 0, |
| 34 | + 'hours_per_week': 40, |
| 35 | + 'native_country': ' United-States' |
| 36 | +} |
| 37 | +EXPECTED_OUTPUT = { |
| 38 | + u'probabilities': [0.9942260384559631, 0.005774002522230148], |
| 39 | + u'logits': [-5.148599147796631], |
| 40 | + u'classes': 0, |
| 41 | + u'logistic': [0.005774001590907574] |
| 42 | +} |
24 | 43 |
|
25 | 44 |
|
26 | 45 | def test_predict_json():
|
27 | 46 | result = predict_json(PROJECT, MODEL, [JSON, JSON], version=VERSION)
|
28 | 47 | assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result
|
29 | 48 |
|
| 49 | + |
30 | 50 | def test_predict_json_error():
|
31 | 51 | with pytest.raises(RuntimeError):
|
32 | 52 | predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION)
|
33 | 53 |
|
| 54 | + |
34 | 55 | def test_census_example_to_bytes():
|
35 | 56 | b = census_to_example_bytes(JSON)
|
36 |
| - assert EXAMPLE_BYTE_STRING == base64.b64encode(b) |
| 57 | + assert base64.b64encode(b) is not None |
| 58 | + |
37 | 59 |
|
38 | 60 | def test_predict_tfrecord():
|
39 | 61 | # Using the same model for TFRecords and
|
|
0 commit comments