Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 78d8a59

Browse filesBrowse files
Alec Glassfordtswast
authored andcommitted
Add custom prediction routine samples for AI Platform (GoogleCloudPlatform#2121)
* Add custom prediction routine samples Change-Id: I734bebd77970a3ab627b0cbffdcb8fef320c2de4 * Ensure line limit of 79 characters Change-Id: Ic3b512b7478a1e5052baf2978ed1fbc384793e2e
1 parent 3619a77 commit 78d8a59
Copy full SHA for 78d8a59

File tree

Expand file treeCollapse file tree

6 files changed

+287
-0
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+287
-0
lines changed
+28Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Custom prediction routines (beta)
2+
3+
Read the AI Platform documentation about custom prediction routines to learn how
4+
to use these samples:
5+
6+
* [Custom prediction routines (with a TensorFlow Keras
7+
example)](https://cloud.google.com/ml-engine/docs/tensorflow/custom-prediction-routines)
8+
* [Custom prediction routines (with a scikit-learn
9+
example)](https://cloud.google.com/ml-engine/docs/scikit/custom-prediction-routines)
10+
11+
If you want to package a predictor directly from this directory, make sure to
12+
edit `setup.py`: replace the reference to `predictor.py` with either
13+
`tensorflow-predictor.py` or `scikit-predictor.py`.
14+
15+
## What's next
16+
17+
For a more complete example of how to train and deploy a custom prediction
18+
routine, check out one of the following tutorials:
19+
20+
* [Creating a custom prediction routine with
21+
Keras](https://cloud.google.com/ml-engine/docs/tensorflow/custom-prediction-routine-keras)
22+
(also available as [a Jupyter
23+
notebook](https://colab.research.google.com/github/GoogleCloudPlatform/cloudml-samples/blob/master/notebooks/tensorflow/custom-prediction-routine-keras.ipynb))
24+
25+
* [Creating a custom prediction routine with
26+
scikit-learn](https://cloud.google.com/ml-engine/docs/scikit/custom-prediction-routine-scikit-learn)
27+
(also available as [a Jupyter
28+
notebook](https://colab.research.google.com/github/GoogleCloudPlatform/cloudml-samples/blob/master/notebooks/scikit-learn/custom-prediction-routine-scikit-learn.ipynb))
+50Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2019 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
class Predictor(object):
17+
"""Interface for constructing custom predictors."""
18+
19+
def predict(self, instances, **kwargs):
20+
"""Performs custom prediction.
21+
22+
Instances are the decoded values from the request. They have already
23+
been deserialized from JSON.
24+
25+
Args:
26+
instances: A list of prediction input instances.
27+
**kwargs: A dictionary of keyword args provided as additional
28+
fields on the predict request body.
29+
30+
Returns:
31+
A list of outputs containing the prediction results. This list must
32+
be JSON serializable.
33+
"""
34+
raise NotImplementedError()
35+
36+
@classmethod
37+
def from_path(cls, model_dir):
38+
"""Creates an instance of Predictor using the given path.
39+
40+
Loading of the predictor should be done in this method.
41+
42+
Args:
43+
model_dir: The local directory that contains the exported model
44+
file along with any additional files uploaded when creating the
45+
version resource.
46+
47+
Returns:
48+
An instance implementing this Predictor class.
49+
"""
50+
raise NotImplementedError()
+43Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2019 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
17+
18+
class ZeroCenterer(object):
19+
"""Stores means of each column of a matrix and uses them for preprocessing.
20+
"""
21+
22+
def __init__(self):
23+
"""On initialization, is not tied to any distribution."""
24+
self._means = None
25+
26+
def preprocess(self, data):
27+
"""Transforms a matrix.
28+
29+
The first time this is called, it stores the means of each column of
30+
the input. Then it transforms the input so each column has mean 0. For
31+
subsequent calls, it subtracts the stored means from each column. This
32+
lets you 'center' data at prediction time based on the distribution of
33+
the original training data.
34+
35+
Args:
36+
data: A NumPy matrix of numerical data.
37+
38+
Returns:
39+
A transformed matrix with the same dimensions as the input.
40+
"""
41+
if self._means is None: # during training only
42+
self._means = np.mean(data, axis=0)
43+
return data - self._means
+73Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2019 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import pickle
17+
18+
import numpy as np
19+
from sklearn.externals import joblib
20+
21+
22+
class MyPredictor(object):
23+
"""An example Predictor for an AI Platform custom prediction routine."""
24+
25+
def __init__(self, model, preprocessor):
26+
"""Stores artifacts for prediction. Only initialized via `from_path`.
27+
"""
28+
self._model = model
29+
self._preprocessor = preprocessor
30+
31+
def predict(self, instances, **kwargs):
32+
"""Performs custom prediction.
33+
34+
Preprocesses inputs, then performs prediction using the trained
35+
scikit-learn model.
36+
37+
Args:
38+
instances: A list of prediction input instances.
39+
**kwargs: A dictionary of keyword args provided as additional
40+
fields on the predict request body.
41+
42+
Returns:
43+
A list of outputs containing the prediction results.
44+
"""
45+
inputs = np.asarray(instances)
46+
preprocessed_inputs = self._preprocessor.preprocess(inputs)
47+
outputs = self._model.predict(preprocessed_inputs)
48+
return outputs.tolist()
49+
50+
@classmethod
51+
def from_path(cls, model_dir):
52+
"""Creates an instance of MyPredictor using the given path.
53+
54+
This loads artifacts that have been copied from your model directory in
55+
Cloud Storage. MyPredictor uses them during prediction.
56+
57+
Args:
58+
model_dir: The local directory that contains the trained
59+
scikit-learn model and the pickled preprocessor instance. These
60+
are copied from the Cloud Storage model directory you provide
61+
when you deploy a version resource.
62+
63+
Returns:
64+
An instance of `MyPredictor`.
65+
"""
66+
model_path = os.path.join(model_dir, 'model.joblib')
67+
model = joblib.load(model_path)
68+
69+
preprocessor_path = os.path.join(model_dir, 'preprocessor.pkl')
70+
with open(preprocessor_path, 'rb') as f:
71+
preprocessor = pickle.load(f)
72+
73+
return cls(model, preprocessor)
+20Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2019 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from setuptools import setup
16+
17+
setup(
18+
name='my_custom_code',
19+
version='0.1',
20+
scripts=['predictor.py', 'preprocess.py'])
+73Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2019 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import pickle
17+
18+
import numpy as np
19+
from tensorflow import keras
20+
21+
22+
class MyPredictor(object):
23+
"""An example Predictor for an AI Platform custom prediction routine."""
24+
25+
def __init__(self, model, preprocessor):
26+
"""Stores artifacts for prediction. Only initialized via `from_path`.
27+
"""
28+
self._model = model
29+
self._preprocessor = preprocessor
30+
31+
def predict(self, instances, **kwargs):
32+
"""Performs custom prediction.
33+
34+
Preprocesses inputs, then performs prediction using the trained Keras
35+
model.
36+
37+
Args:
38+
instances: A list of prediction input instances.
39+
**kwargs: A dictionary of keyword args provided as additional
40+
fields on the predict request body.
41+
42+
Returns:
43+
A list of outputs containing the prediction results.
44+
"""
45+
inputs = np.asarray(instances)
46+
preprocessed_inputs = self._preprocessor.preprocess(inputs)
47+
outputs = self._model.predict(preprocessed_inputs)
48+
return outputs.tolist()
49+
50+
@classmethod
51+
def from_path(cls, model_dir):
52+
"""Creates an instance of MyPredictor using the given path.
53+
54+
This loads artifacts that have been copied from your model directory in
55+
Cloud Storage. MyPredictor uses them during prediction.
56+
57+
Args:
58+
model_dir: The local directory that contains the trained Keras
59+
model and the pickled preprocessor instance. These are copied
60+
from the Cloud Storage model directory you provide when you
61+
deploy a version resource.
62+
63+
Returns:
64+
An instance of `MyPredictor`.
65+
"""
66+
model_path = os.path.join(model_dir, 'model.h5')
67+
model = keras.models.load_model(model_path)
68+
69+
preprocessor_path = os.path.join(model_dir, 'preprocessor.pkl')
70+
with open(preprocessor_path, 'rb') as f:
71+
preprocessor = pickle.load(f)
72+
73+
return cls(model, preprocessor)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.