|
| 1 | +# Copyright 2020 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# [START aiplatform_create_training_pipeline_custom_training_managed_dataset_sample] |
| 16 | +from google.cloud import aiplatform |
| 17 | +from google.protobuf import json_format |
| 18 | +from google.protobuf.struct_pb2 import Value |
| 19 | + |
| 20 | + |
| 21 | +def create_training_pipeline_custom_training_managed_dataset_sample( |
| 22 | + project: str, |
| 23 | + display_name: str, |
| 24 | + model_display_name: str, |
| 25 | + dataset_id: str, |
| 26 | + annotation_schema_uri: str, |
| 27 | + training_container_spec_image_uri: str, |
| 28 | + model_container_spec_image_uri: str, |
| 29 | + base_output_uri_prefix: str, |
| 30 | + location: str = "us-central1", |
| 31 | + api_endpoint: str = "us-central1-aiplatform.googleapis.com", |
| 32 | +): |
| 33 | + client_options = {"api_endpoint": api_endpoint} |
| 34 | + # Initialize client that will be used to create and send requests. |
| 35 | + # This client only needs to be created once, and can be reused for |
| 36 | + # multiple requests. |
| 37 | + client = aiplatform.gapic.PipelineServiceClient( |
| 38 | + client_options=client_options) |
| 39 | + |
| 40 | + # input_data_config |
| 41 | + input_data_config = { |
| 42 | + "dataset_id": dataset_id, |
| 43 | + "annotation_schema_uri": annotation_schema_uri, |
| 44 | + "gcs_destination": {"output_uri_prefix": base_output_uri_prefix}, |
| 45 | + } |
| 46 | + |
| 47 | + # training_task_definition |
| 48 | + custom_task_definition = "gs://google-cloud-aiplatform/schema/" \ |
| 49 | + "trainingjob/definition/custom_task_1.0.0.yaml" |
| 50 | + |
| 51 | + # training_task_inputs |
| 52 | + training_container_spec = { |
| 53 | + "imageUri": training_container_spec_image_uri, |
| 54 | + # AIP_MODEL_DIR is set by the service according to baseOutputDirectory. |
| 55 | + "args": ["--model-dir=$(AIP_MODEL_DIR)",], |
| 56 | + } |
| 57 | + |
| 58 | + training_worker_pool_spec = { |
| 59 | + "replicaCount": 1, |
| 60 | + "machineSpec": {"machineType": "n1-standard-8"}, |
| 61 | + "containerSpec": training_container_spec, |
| 62 | + } |
| 63 | + |
| 64 | + training_task_inputs_dict = { |
| 65 | + "workerPoolSpecs": [training_worker_pool_spec], |
| 66 | + "baseOutputDirectory": {"outputUriPrefix": base_output_uri_prefix}, |
| 67 | + } |
| 68 | + |
| 69 | + training_task_inputs = json_format.ParseDict( |
| 70 | + training_task_inputs_dict, Value()) |
| 71 | + |
| 72 | + # model_to_upload |
| 73 | + model_container_spec = { |
| 74 | + "image_uri": model_container_spec_image_uri, |
| 75 | + "command": ["/bin/tensorflow_model_server"], |
| 76 | + "args": [ |
| 77 | + "--model_name=$(AIP_MODEL)", |
| 78 | + "--model_base_path=$(AIP_STORAGE_URI)", |
| 79 | + "--rest_api_port=8080", |
| 80 | + "--port=8500", |
| 81 | + "--file_system_poll_wait_seconds=31540000" |
| 82 | + ], |
| 83 | + } |
| 84 | + |
| 85 | + model = { |
| 86 | + "display_name": model_display_name, |
| 87 | + "container_spec": model_container_spec} |
| 88 | + |
| 89 | + training_pipeline = { |
| 90 | + "display_name": display_name, |
| 91 | + "input_data_config": input_data_config, |
| 92 | + "training_task_definition": custom_task_definition, |
| 93 | + "training_task_inputs": training_task_inputs, |
| 94 | + "model_to_upload": model, |
| 95 | + } |
| 96 | + parent = f"projects/{project}/locations/{location}" |
| 97 | + response = client.create_training_pipeline( |
| 98 | + parent=parent, training_pipeline=training_pipeline |
| 99 | + ) |
| 100 | + print("response:", response) |
| 101 | + |
| 102 | + |
| 103 | +# [END aiplatform_create_training_pipeline_custom_training_managed_dataset_sample] |
0 commit comments