Skip to content

Commit b012283

Browse files
morgandudizcology
andauthored
feat: add create_training_pipeline_custom_training_managed_dataset sample (#75)
Co-authored-by: Yu-Han Liu <[email protected]>
1 parent 4c60ad6 commit b012283

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_create_training_pipeline_custom_training_managed_dataset_sample]
16+
from google.cloud import aiplatform
17+
from google.protobuf import json_format
18+
from google.protobuf.struct_pb2 import Value
19+
20+
21+
def create_training_pipeline_custom_training_managed_dataset_sample(
22+
project: str,
23+
display_name: str,
24+
model_display_name: str,
25+
dataset_id: str,
26+
annotation_schema_uri: str,
27+
training_container_spec_image_uri: str,
28+
model_container_spec_image_uri: str,
29+
base_output_uri_prefix: str,
30+
location: str = "us-central1",
31+
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
32+
):
33+
client_options = {"api_endpoint": api_endpoint}
34+
# Initialize client that will be used to create and send requests.
35+
# This client only needs to be created once, and can be reused for
36+
# multiple requests.
37+
client = aiplatform.gapic.PipelineServiceClient(
38+
client_options=client_options)
39+
40+
# input_data_config
41+
input_data_config = {
42+
"dataset_id": dataset_id,
43+
"annotation_schema_uri": annotation_schema_uri,
44+
"gcs_destination": {"output_uri_prefix": base_output_uri_prefix},
45+
}
46+
47+
# training_task_definition
48+
custom_task_definition = "gs://google-cloud-aiplatform/schema/" \
49+
"trainingjob/definition/custom_task_1.0.0.yaml"
50+
51+
# training_task_inputs
52+
training_container_spec = {
53+
"imageUri": training_container_spec_image_uri,
54+
# AIP_MODEL_DIR is set by the service according to baseOutputDirectory.
55+
"args": ["--model-dir=$(AIP_MODEL_DIR)",],
56+
}
57+
58+
training_worker_pool_spec = {
59+
"replicaCount": 1,
60+
"machineSpec": {"machineType": "n1-standard-8"},
61+
"containerSpec": training_container_spec,
62+
}
63+
64+
training_task_inputs_dict = {
65+
"workerPoolSpecs": [training_worker_pool_spec],
66+
"baseOutputDirectory": {"outputUriPrefix": base_output_uri_prefix},
67+
}
68+
69+
training_task_inputs = json_format.ParseDict(
70+
training_task_inputs_dict, Value())
71+
72+
# model_to_upload
73+
model_container_spec = {
74+
"image_uri": model_container_spec_image_uri,
75+
"command": ["/bin/tensorflow_model_server"],
76+
"args": [
77+
"--model_name=$(AIP_MODEL)",
78+
"--model_base_path=$(AIP_STORAGE_URI)",
79+
"--rest_api_port=8080",
80+
"--port=8500",
81+
"--file_system_poll_wait_seconds=31540000"
82+
],
83+
}
84+
85+
model = {
86+
"display_name": model_display_name,
87+
"container_spec": model_container_spec}
88+
89+
training_pipeline = {
90+
"display_name": display_name,
91+
"input_data_config": input_data_config,
92+
"training_task_definition": custom_task_definition,
93+
"training_task_inputs": training_task_inputs,
94+
"model_to_upload": model,
95+
}
96+
parent = f"projects/{project}/locations/{location}"
97+
response = client.create_training_pipeline(
98+
parent=parent, training_pipeline=training_pipeline
99+
)
100+
print("response:", response)
101+
102+
103+
# [END aiplatform_create_training_pipeline_custom_training_managed_dataset_sample]
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import uuid
16+
import pytest
17+
import os
18+
19+
from google.cloud import aiplatform
20+
21+
import helpers
22+
23+
import create_training_pipeline_custom_training_managed_dataset_sample
24+
25+
API_ENDPOINT = "us-central1-aiplatform.googleapis.com"
26+
27+
PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
28+
DISPLAY_NAME = f"temp_create_training_pipeline_custom_training_managed_dataset_test_{uuid.uuid4()}"
29+
MODEL_DISPLAY_NAME = f"Temp Model for {DISPLAY_NAME}"
30+
31+
DATASET_ID = "1084241610289446912" # permanent_50_flowers_dataset
32+
ANNOTATION_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml"
33+
34+
TRAINING_CONTAINER_SPEC_IMAGE_URI = "gcr.io/ucaip-test/custom-container-managed-dataset:latest"
35+
MODEL_CONTAINER_SPEC_IMAGE_URI = "gcr.io/cloud-aiplatform/prediction/tf-gpu.1-15:latest"
36+
37+
BASE_OUTPUT_URI_PREFIX = "gs://ucaip-samples-us-central1/training_pipeline_output/custom_training_managed_dataset"
38+
39+
40+
@pytest.fixture
41+
def shared_state():
42+
state = {}
43+
yield state
44+
45+
46+
@pytest.fixture
47+
def pipeline_client():
48+
client_options = {"api_endpoint": API_ENDPOINT}
49+
pipeline_client = aiplatform.gapic.PipelineServiceClient(
50+
client_options=client_options
51+
)
52+
yield pipeline_client
53+
54+
55+
@pytest.fixture
56+
def model_client():
57+
client_options = {"api_endpoint": API_ENDPOINT}
58+
model_client = aiplatform.gapic.ModelServiceClient(
59+
client_options=client_options)
60+
yield model_client
61+
62+
63+
@pytest.fixture(scope="function", autouse=True)
64+
def teardown(shared_state, model_client, pipeline_client):
65+
yield
66+
model_client.delete_model(name=shared_state["model_name"])
67+
pipeline_client.delete_training_pipeline(
68+
name=shared_state["training_pipeline_name"]
69+
)
70+
71+
72+
def test_create_training_pipeline_custom_training_managed_dataset_sample(
73+
capsys, shared_state, pipeline_client
74+
):
75+
create_training_pipeline_custom_training_managed_dataset_sample.create_training_pipeline_custom_training_managed_dataset_sample(
76+
project=PROJECT_ID,
77+
display_name=DISPLAY_NAME,
78+
model_display_name=MODEL_DISPLAY_NAME,
79+
dataset_id=DATASET_ID,
80+
annotation_schema_uri=ANNOTATION_SCHEMA_URI,
81+
training_container_spec_image_uri=TRAINING_CONTAINER_SPEC_IMAGE_URI,
82+
model_container_spec_image_uri=MODEL_CONTAINER_SPEC_IMAGE_URI,
83+
base_output_uri_prefix=BASE_OUTPUT_URI_PREFIX,
84+
)
85+
86+
out, _ = capsys.readouterr()
87+
88+
# Save resource name of the newly created training pipeline
89+
shared_state["training_pipeline_name"] = helpers.get_name(out)
90+
91+
# Poll until the pipeline succeeds because we want to test the model_upload step as well.
92+
helpers.wait_for_job_state(
93+
get_job_method=pipeline_client.get_training_pipeline,
94+
name=shared_state["training_pipeline_name"],
95+
expected_state="SUCCEEDED",
96+
timeout=1800,
97+
freq=20,
98+
)
99+
100+
training_pipeline = pipeline_client.get_training_pipeline(
101+
name=shared_state["training_pipeline_name"]
102+
)
103+
104+
# Check that the model indeed has been uploaded.
105+
assert training_pipeline.model_to_upload.name != ""
106+
107+
shared_state["model_name"] = training_pipeline.model_to_upload.name

0 commit comments

Comments
 (0)