Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 7182374

Browse filesBrowse files
nnegreyleahecole
andauthored
automl: video beta - move beta samples out of branch and into master (GoogleCloudPlatform#2750)
* automl: video beta - move beta samples out of branch and into master * lint * update error message on batch predict Co-authored-by: Leah E. Cole <6719667+leahecole@users.noreply.github.com>
1 parent 18dc311 commit 7182374
Copy full SHA for 7182374
Expand file treeCollapse file tree

6 files changed

+283
-0
lines changed

‎automl/beta/batch_predict.py

Copy file name to clipboard
+52Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def batch_predict(project_id, model_id, input_uri, output_uri):
17+
"""Batch predict"""
18+
# [START automl_batch_predict_beta]
19+
from google.cloud import automl_v1beta1 as automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# model_id = "YOUR_MODEL_ID"
24+
# input_uri = "gs://YOUR_BUCKET_ID/path/to/your/input/csv_or_jsonl"
25+
# output_uri = "gs://YOUR_BUCKET_ID/path/to/save/results/"
26+
27+
prediction_client = automl.PredictionServiceClient()
28+
29+
# Get the full path of the model.
30+
model_full_id = prediction_client.model_path(
31+
project_id, "us-central1", model_id
32+
)
33+
34+
gcs_source = automl.types.GcsSource(input_uris=[input_uri])
35+
36+
input_config = automl.types.BatchPredictInputConfig(gcs_source=gcs_source)
37+
gcs_destination = automl.types.GcsDestination(output_uri_prefix=output_uri)
38+
output_config = automl.types.BatchPredictOutputConfig(
39+
gcs_destination=gcs_destination
40+
)
41+
42+
response = prediction_client.batch_predict(
43+
model_full_id, input_config, output_config
44+
)
45+
46+
print("Waiting for operation to complete...")
47+
print(
48+
"Batch Prediction results saved to Cloud Storage bucket. {}".format(
49+
response.result()
50+
)
51+
)
52+
# [END automl_batch_predict_beta]

‎automl/beta/batch_predict_test.py

Copy file name to clipboard
+47Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific ladnguage governing permissions and
13+
# limitations under the License.
14+
15+
import datetime
16+
import os
17+
18+
import batch_predict
19+
20+
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
21+
BUCKET_ID = "{}-lcm".format(PROJECT_ID)
22+
MODEL_ID = "TEN0000000000000000000"
23+
PREFIX = "TEST_EXPORT_OUTPUT_" + datetime.datetime.now().strftime(
24+
"%Y%m%d%H%M%S"
25+
)
26+
27+
28+
def test_batch_predict(capsys):
29+
# As batch prediction can take a long time. Try to batch predict on a model
30+
# and confirm that the model was not found, but other elements of the
31+
# request were valid.
32+
try:
33+
input_uri = "gs://{}/entity-extraction/input.jsonl".format(BUCKET_ID)
34+
output_uri = "gs://{}/{}/".format(BUCKET_ID, PREFIX)
35+
batch_predict.batch_predict(
36+
PROJECT_ID, MODEL_ID, input_uri, output_uri
37+
)
38+
out, _ = capsys.readouterr()
39+
assert (
40+
"does not exist"
41+
in out
42+
)
43+
except Exception as e:
44+
assert (
45+
"does not exist"
46+
in e.message
47+
)
+45Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def create_dataset(project_id, display_name):
17+
"""Create a dataset."""
18+
# [START automl_video_classification_create_dataset_beta]
19+
from google.cloud import automl_v1beta1 as automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# display_name = "your_datasets_display_name"
24+
25+
client = automl.AutoMlClient()
26+
27+
# A resource that represents Google Cloud Platform location.
28+
project_location = client.location_path(project_id, "us-central1")
29+
metadata = automl.types.VideoClassificationDatasetMetadata()
30+
dataset = automl.types.Dataset(
31+
display_name=display_name,
32+
video_classification_dataset_metadata=metadata,
33+
)
34+
35+
# Create a dataset with the dataset metadata in the region.
36+
created_dataset = client.create_dataset(project_location, dataset)
37+
38+
# Display the dataset information
39+
print("Dataset name: {}".format(created_dataset.name))
40+
# To get the dataset id, you have to parse it out of the `name` field.
41+
# As dataset Ids are required for other methods.
42+
# Name Form:
43+
# `projects/{project_id}/locations/{location_id}/datasets/{dataset_id}`
44+
print("Dataset id: {}".format(created_dataset.name.split("/")[-1]))
45+
# [END automl_video_classification_create_dataset_beta]
+51Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import datetime
16+
import os
17+
18+
from google.cloud import automl_v1beta1 as automl
19+
import pytest
20+
21+
import video_classification_create_dataset
22+
23+
24+
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
25+
pytest.DATASET_ID = None
26+
27+
28+
@pytest.fixture(scope="function", autouse=True)
29+
def teardown():
30+
yield
31+
32+
# Delete the created dataset
33+
client = automl.AutoMlClient()
34+
dataset_full_id = client.dataset_path(
35+
PROJECT_ID, "us-central1", pytest.DATASET_ID
36+
)
37+
response = client.delete_dataset(dataset_full_id)
38+
response.result()
39+
40+
41+
def test_video_classification_create_dataset(capsys):
42+
# create dataset
43+
dataset_name = "test_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")
44+
video_classification_create_dataset.create_dataset(
45+
PROJECT_ID, dataset_name
46+
)
47+
out, _ = capsys.readouterr()
48+
assert "Dataset id: " in out
49+
50+
# Get the the created dataset id for deletion
51+
pytest.DATASET_ID = out.splitlines()[1].split()[2]
+42Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def create_model(project_id, dataset_id, display_name):
17+
"""Create a model."""
18+
# [START automl_video_classification_create_model_beta]
19+
from google.cloud import automl_v1beta1 as automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# dataset_id = "YOUR_DATASET_ID"
24+
# display_name = "your_models_display_name"
25+
26+
client = automl.AutoMlClient()
27+
28+
# A resource that represents Google Cloud Platform location.
29+
project_location = client.location_path(project_id, "us-central1")
30+
metadata = automl.types.VideoClassificationModelMetadata()
31+
model = automl.types.Model(
32+
display_name=display_name,
33+
dataset_id=dataset_id,
34+
video_classification_model_metadata=metadata,
35+
)
36+
37+
# Create a model with the model metadata in the region.
38+
response = client.create_model(project_location, model)
39+
40+
print("Training operation name: {}".format(response.operation.name))
41+
print("Training started...")
42+
# [END automl_video_classification_create_model_beta]
+46Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from google.cloud import automl_v1beta1 as automl
18+
import pytest
19+
20+
import video_classification_create_model
21+
22+
PROJECT_ID = os.environ["GCLOUD_PROJECT"]
23+
DATASET_ID = "VCN510437278078730240"
24+
pytest.OPERATION_ID = None
25+
26+
27+
@pytest.fixture(scope="function", autouse=True)
28+
def teardown():
29+
yield
30+
31+
# Cancel the operation
32+
client = automl.AutoMlClient()
33+
client.transport._operations_client.cancel_operation(pytest.OPERATION_ID)
34+
35+
36+
def test_video_classification_create_model(capsys):
37+
video_classification_create_model.create_model(
38+
PROJECT_ID, DATASET_ID, "classification_test_create_model"
39+
)
40+
out, _ = capsys.readouterr()
41+
assert "Training started" in out
42+
43+
# Get the the operation id for cancellation
44+
pytest.OPERATION_ID = out.split("Training operation name: ")[1].split(
45+
"\n"
46+
)[0]

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.