From 51df86ee1390a51b82ffc015514ad1e145821a34 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Thu, 3 Aug 2023 11:46:10 -0700 Subject: [PATCH 1/8] feat: add model.evaluate() method to Model class PiperOrigin-RevId: 553544432 --- .../aiplatform/model_evaluation/__init__.py | 11 +- .../model_evaluation/model_evaluation.py | 41 +- .../model_evaluation/model_evaluation_job.py | 410 ++++++ google/cloud/aiplatform/models.py | 231 ++++ .../aiplatform/test_model_evaluation.py | 121 ++ tests/unit/aiplatform/conftest.py | 52 + tests/unit/aiplatform/constants.py | 26 + .../unit/aiplatform/test_model_evaluation.py | 1199 ++++++++++++++++- tests/unit/aiplatform/test_models.py | 511 ++++++- 9 files changed, 2587 insertions(+), 15 deletions(-) create mode 100644 google/cloud/aiplatform/model_evaluation/model_evaluation_job.py create mode 100644 tests/system/aiplatform/test_model_evaluation.py diff --git a/google/cloud/aiplatform/model_evaluation/__init__.py b/google/cloud/aiplatform/model_evaluation/__init__.py index 7dcbee2db5..9907cd92de 100644 --- a/google/cloud/aiplatform/model_evaluation/__init__.py +++ b/google/cloud/aiplatform/model_evaluation/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,11 @@ # limitations under the License. # -from google.cloud.aiplatform.model_evaluation.model_evaluation import ModelEvaluation +from google.cloud.aiplatform.model_evaluation.model_evaluation import ( + ModelEvaluation, +) +from google.cloud.aiplatform.model_evaluation.model_evaluation_job import ( + _ModelEvaluationJob, +) -__all__ = ("ModelEvaluation",) +__all__ = ("ModelEvaluation", "_ModelEvaluationJob") diff --git a/google/cloud/aiplatform/model_evaluation/model_evaluation.py b/google/cloud/aiplatform/model_evaluation/model_evaluation.py index 2c90e830ab..e574ecad78 100644 --- a/google/cloud/aiplatform/model_evaluation/model_evaluation.py +++ b/google/cloud/aiplatform/model_evaluation/model_evaluation.py @@ -15,14 +15,17 @@ # limitations under the License. # +from typing import List, Optional + +from google.protobuf import struct_pb2 + from google.auth import credentials as auth_credentials +from google.cloud import aiplatform from google.cloud.aiplatform import base -from google.cloud.aiplatform import utils from google.cloud.aiplatform import models -from google.protobuf import struct_pb2 - -from typing import List, Optional +from google.cloud.aiplatform import pipeline_jobs +from google.cloud.aiplatform import utils class ModelEvaluation(base.VertexAiResourceNounWithFutureManager): @@ -36,13 +39,35 @@ class ModelEvaluation(base.VertexAiResourceNounWithFutureManager): _format_resource_name_method = "model_evaluation_path" @property - def metrics(self) -> Optional[struct_pb2.Value]: + def metrics(self) -> struct_pb2.Value: """Gets the evaluation metrics from the Model Evaluation. + + Returns: + A struct_pb2.Value with model metrics created from the Model Evaluation + Raises: + ValueError: If the Model Evaluation doesn't have metrics. + """ + if self._gca_resource.metrics: + return self._gca_resource.metrics + + raise ValueError( + "This ModelEvaluation does not have any metrics, this could be because the Evaluation job failed. Check the logs for details." + ) + + @property + def _backing_pipeline_job(self) -> Optional["pipeline_jobs.PipelineJob"]: + """The managed pipeline for this model evaluation job. Returns: - A dict with model metrics created from the Model Evaluation or - None if the metrics for this evaluation are empty. + The PipelineJob resource if this evaluation ran from a managed pipeline or None. """ - return self._gca_resource.metrics + if ( + "metadata" in self._gca_resource + and "pipeline_job_resource_name" in self._gca_resource.metadata + ): + return aiplatform.PipelineJob.get( + resource_name=self._gca_resource.metadata["pipeline_job_resource_name"], + credentials=self.credentials, + ) def __init__( self, diff --git a/google/cloud/aiplatform/model_evaluation/model_evaluation_job.py b/google/cloud/aiplatform/model_evaluation/model_evaluation_job.py new file mode 100644 index 0000000000..f8ee2b9e21 --- /dev/null +++ b/google/cloud/aiplatform/model_evaluation/model_evaluation_job.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, List, Union + +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform._pipeline_based_service import ( + pipeline_based_service, +) +from google.cloud.aiplatform import model_evaluation +from google.cloud.aiplatform import pipeline_jobs + +from google.cloud.aiplatform.compat.types import ( + pipeline_state_v1 as gca_pipeline_state_v1, + pipeline_job_v1 as gca_pipeline_job_v1, + execution_v1 as gca_execution_v1, +) + +_LOGGER = base.Logger(__name__) + +_PIPELINE_TEMPLATE_ARTIFACT_REGISTRY_TAG = "1.0.0" +_BASE_URI = ( + "base_uri", + "https://us-kfp.pkg.dev/vertex-evaluation/pipeline-templates/evaluation", +) +_TAG = ("tag", _PIPELINE_TEMPLATE_ARTIFACT_REGISTRY_TAG) +_MODEL_EVAL_TEMPLATE_REF = frozenset((_BASE_URI, _TAG)) + + +class _ModelEvaluationJob(pipeline_based_service._VertexAiPipelineBasedService): + """Creates a Model Evaluation PipelineJob using _VertexAiPipelineBasedService.""" + + _template_ref = _MODEL_EVAL_TEMPLATE_REF + + _creation_log_message = "Created PipelineJob for your Model Evaluation." + + _component_identifier = "fpc-model-evaluation" + + _template_name_identifier = None + + @property + def _metadata_output_artifact(self) -> Optional[str]: + """The resource uri for the ML Metadata output artifact from the evaluation component of the Model Evaluation pipeline""" + if self.state != gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED: + return + for task in self.backing_pipeline_job._gca_resource.job_detail.task_details: + if task.task_name == self.backing_pipeline_job.name: + return task.outputs["evaluation_metrics"].artifacts[0].name + + def __init__( + self, + evaluation_pipeline_run_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves a ModelEvaluationJob and instantiates its representation. + Example Usage: + my_evaluation = aiplatform.ModelEvaluationJob( + pipeline_job_name = "projects/123/locations/us-central1/pipelineJobs/456" + ) + my_evaluation = aiplatform.ModelEvaluationJob( + pipeline_job_name = "456" + ) + Args: + evaluation_pipeline_run_name (str): + Required. A fully-qualified pipeline job run ID. + Example: "projects/123/locations/us-central1/pipelineJobs/456" or + "456" when project and location are initialized or passed. + project (str): + Optional. Project to retrieve pipeline job from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve pipeline job from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve this pipeline job. Overrides + credentials set in aiplatform.init. + """ + super().__init__( + pipeline_job_name=evaluation_pipeline_run_name, + project=project, + location=location, + credentials=credentials, + ) + + @staticmethod + def _get_template_url( + model_type: str, + feature_attributions: bool, + prediction_type: str, + ) -> str: + """Gets the pipeline template URL for this model evaluation job given the type of data + used to train the model and whether feature attributions should be generated. + + Args: + model_type (str): + Required. Whether the model is an AutoML Tabular model or not. Used to determine which pipeline template should be used. + feature_attributions (bool): + Required. Whether this evaluation job should generate feature attributions. + prediction_type (str): + Required. The type of prediction performed by the Model. One of "classification" or "regression". + + Returns: + (str): The pipeline template URL to use for this model evaluation job. + """ + + # Examples of formatted template URIs: + # model_type="automl_tabular", feature_attrubtions=True, prediction_type="classification" + # https://us-kfp.pkg.dev/vertex-evaluation/pipeline-templates/evaluation-automl-tabular-feature-attribution-classification-pipeline/1.0.0 + # model_type="other", feature_attributions=False, prediction_type="regression" + # https://us-kfp.pkg.dev/vertex-evaluation/pipeline-templates/evaluation-regression-pipeline/1.0.0 + model_type_uri_str = "automl-tabular" if model_type == "automl_tabular" else "" + feature_attributions_uri_str = ( + "feature-attribution" if feature_attributions else "" + ) + + template_ref_dict = dict(_ModelEvaluationJob._template_ref) + + uri_parts = [ + template_ref_dict["base_uri"], + model_type_uri_str, + feature_attributions_uri_str, + prediction_type, + "pipeline/" + template_ref_dict["tag"], + ] + template_url = "-".join(filter(None, uri_parts)) + + return template_url + + @classmethod + def submit( + cls, + model_name: Union[str, "aiplatform.Model"], + prediction_type: str, + target_field_name: str, + pipeline_root: str, + model_type: str, + gcs_source_uris: Optional[List[str]] = None, + bigquery_source_uri: Optional[str] = None, + batch_predict_bigquery_destination_output_uri: Optional[str] = None, + class_labels: Optional[List[str]] = None, + prediction_label_column: Optional[str] = None, + prediction_score_column: Optional[str] = None, + generate_feature_attributions: Optional[bool] = False, + instances_format: Optional[str] = "jsonl", + evaluation_pipeline_display_name: Optional[str] = None, + evaluation_metrics_display_name: Optional[str] = None, + job_id: Optional[str] = None, + service_account: Optional[str] = None, + network: Optional[str] = None, + encryption_spec_key_name: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + experiment: Optional[Union[str, "aiplatform.Experiment"]] = None, + ) -> "_ModelEvaluationJob": + """Submits a Model Evaluation Job using aiplatform.PipelineJob and returns + the ModelEvaluationJob resource. + + Example usage: + + ``` + my_evaluation = _ModelEvaluationJob.submit( + model="projects/123/locations/us-central1/models/456", + prediction_type="classification", + pipeline_root="gs://my-pipeline-bucket/runpath", + gcs_source_uris=["gs://test-prediction-data"], + target_field_name=["prediction_class"], + instances_format="jsonl", + ) + + my_evaluation = _ModelEvaluationJob.submit( + model="projects/123/locations/us-central1/models/456", + prediction_type="regression", + pipeline_root="gs://my-pipeline-bucket/runpath", + gcs_source_uris=["gs://test-prediction-data"], + target_field_name=["price"], + instances_format="jsonl", + ) + ``` + + Args: + model_name (Union[str, "aiplatform.Model"]): + Required. An instance of aiplatform.Model or a fully-qualified model resource name or model ID to run the evaluation + job on. Example: "projects/123/locations/us-central1/models/456" or + "456" when project and location are initialized or passed. + prediction_type (str): + Required. The type of prediction performed by the Model. One of "classification" or "regression". + target_field_name (str): + Required. The name of your prediction column. + pipeline_root (str): + Required. The GCS directory to store output from the model evaluation PipelineJob. + model_type (str): + Required. One of "automl_tabular" or "other". This determines the Model Evaluation template used by this PipelineJob. + gcs_source_uris (List[str]): + Optional. A list of Cloud Storage data files containing the ground truth data to use for this + evaluation job, for example: ["gs://path/to/your/data.csv"]. These files should contain your + model's prediction column. The provided data files must be either CSV or JSONL. One of `gcs_source_uris` + or `bigquery_source_uri` is required. + bigquery_source_uri (str): + Optional. A bigquery table URI containing the ground truth data to use for this evaluation job. This uri should + be in the format 'bq://my-project-id.dataset.table'. One of `gcs_source_uris` or `bigquery_source_uri` is + required. + bigquery_destination_output_uri (str): + Optional. A bigquery table URI where the Batch Prediction job associated with your Model Evaluation will write + prediction output. This can be a BigQuery URI to a project ('bq://my-project'), a dataset + ('bq://my-project.my-dataset'), or a table ('bq://my-project.my-dataset.my-table'). Required if `bigquery_source_uri` + is provided. + class_labels (List[str]): + Optional. For custom (non-AutoML) classification models, a list of possible class names, in the + same order that predictions are generated. This argument is required when prediction_type is 'classification'. + For example, in a classification model with 3 possible classes that are outputted in the format: [0.97, 0.02, 0.01] + with the class names "cat", "dog", and "fish", the value of `class_labels` should be `["cat", "dog", "fish"]` where + the class "cat" corresponds with 0.97 in the example above. + prediction_label_column (str): + Optional. The column name of the field containing classes the model is scoring. Formatted to be able to find nested + columns, delimeted by `.`. If not set, defaulted to `prediction.classes` for classification. + prediction_score_column (str): + Optional. The column name of the field containing batch prediction scores. Formatted to be able to find nested columns, + delimeted by `.`. If not set, defaulted to `prediction.scores` for a `classification` problem_type, `prediction.value` + for a `regression` problem_type. + generate_feature_attributions (boolean): + Optional. Whether the model evaluation job should generate feature attributions. Defaults to False if not specified. + instances_format (str): + The format in which instances are given, must be one of the Model's supportedInputStorageFormats. If not set, defaults to "jsonl". + evaluation_pipeline_display_name (str) + Optional. The user-defined name of the PipelineJob created by this Pipeline Based Service. + evaluation_metrics_display_name (str) + Optional. The user-defined name of the evaluation metrics resource uploaded to Vertex in the evaluation pipeline job. + job_id (str): + Optional. The unique ID of the job run. If not specified, pipeline name + timestamp will be used. + service_account (str): + Specifies the service account for workload run-as account for this Model Evaluation PipelineJob. + Users submitting jobs must have act-as permission on this run-as account. The service account running + this Model Evaluation job needs the following permissions: Dataflow Worker, Storage Admin, Vertex AI User. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + encryption_spec_key_name (str): + Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the job. Has the + form: ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. The key needs to be in the same + region as where the compute resource is created. If this is set, then all + resources created by the PipelineJob for this Model Evaluation will be encrypted with the provided encryption key. + If not specified, encryption_spec of original PipelineJob will be used. + project (str): + Optional. The project to run this PipelineJob in. If not set, + the project set in aiplatform.init will be used. + location (str): + Optional. Location to create PipelineJob. If not set, + location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to create the PipelineJob. + Overrides credentials set in aiplatform.init. + experiment (Union[str, experiments_resource.Experiment]): + Optional. The Vertex AI experiment name or instance to associate to the PipelineJob executing + this model evaluation job. + Returns: + (ModelEvaluationJob): Instantiated represnetation of the model evaluation job. + """ + + if isinstance(model_name, aiplatform.Model): + model_resource_name = model_name.versioned_resource_name + else: + model_resource_name = aiplatform.Model( + model_name=model_name, + project=project, + location=location, + credentials=credentials, + ).versioned_resource_name + + if not evaluation_pipeline_display_name: + evaluation_pipeline_display_name = cls._generate_display_name() + + template_params = { + "batch_predict_instances_format": instances_format, + "model_name": model_resource_name, + "evaluation_display_name": evaluation_metrics_display_name, + "project": project or initializer.global_config.project, + "location": location or initializer.global_config.location, + "batch_predict_gcs_destination_output_uri": pipeline_root, + "target_field_name": target_field_name, + "encryption_spec_key_name": encryption_spec_key_name, + } + + if bigquery_source_uri: + template_params["batch_predict_predictions_format"] = "bigquery" + template_params["batch_predict_bigquery_source_uri"] = bigquery_source_uri + template_params[ + "batch_predict_bigquery_destination_output_uri" + ] = batch_predict_bigquery_destination_output_uri + elif gcs_source_uris: + template_params["batch_predict_gcs_source_uris"] = gcs_source_uris + + if prediction_type == "classification" and model_type == "other": + template_params["evaluation_class_labels"] = class_labels + + if prediction_label_column: + template_params[ + "evaluation_prediction_label_column" + ] = prediction_label_column + + if prediction_score_column: + template_params[ + "evaluation_prediction_score_column" + ] = prediction_score_column + + # If the user provides a SA, use it for the Dataflow job as well + if service_account is not None: + template_params["dataflow_service_account"] = service_account + + template_url = cls._get_template_url( + model_type, + generate_feature_attributions, + prediction_type, + ) + + eval_pipeline_run = cls._create_and_submit_pipeline_job( + template_params=template_params, + template_path=template_url, + pipeline_root=pipeline_root, + display_name=evaluation_pipeline_display_name, + job_id=job_id, + service_account=service_account, + network=network, + encryption_spec_key_name=encryption_spec_key_name, + project=project, + location=location, + credentials=credentials, + experiment=experiment, + ) + + _LOGGER.info( + f"{_ModelEvaluationJob._creation_log_message} View it in the console: {eval_pipeline_run.pipeline_console_uri}" + ) + + return eval_pipeline_run + + def get_model_evaluation( + self, + ) -> Optional["model_evaluation.ModelEvaluation"]: + """Gets the ModelEvaluation created by this ModelEvlauationJob. + + Returns: + aiplatform.ModelEvaluation: Instantiated representation of the ModelEvaluation resource. + Raises: + RuntimeError: If the ModelEvaluationJob pipeline failed. + """ + eval_job_state = self.backing_pipeline_job.state + + if eval_job_state in pipeline_jobs._PIPELINE_ERROR_STATES: + raise RuntimeError( + f"Evaluation job failed. For more details see the logs: {self.pipeline_console_uri}" + ) + if eval_job_state not in pipeline_jobs._PIPELINE_COMPLETE_STATES: + _LOGGER.info( + f"Your evaluation job is still in progress. For more details see the logs {self.pipeline_console_uri}" + ) + return + + for component in self.backing_pipeline_job.task_details: + + # This assumes that task_details has a task with a task_name == backing_pipeline_job.name + if not component.task_name == self.backing_pipeline_job.name: + continue + + # If component execution didn't succeed or the execution wasn't cached, don't return an evaluation + if ( + component.state + not in ( + gca_pipeline_job_v1.PipelineTaskDetail.State.SUCCEEDED, + gca_pipeline_job_v1.PipelineTaskDetail.State.SKIPPED, + ) + and component.execution.state != gca_execution_v1.Execution.State.CACHED + ): + continue + + if "output:evaluation_resource_name" not in component.execution.metadata: + continue + + eval_resource_name = component.execution.metadata[ + "output:evaluation_resource_name" + ] + + eval_resource = model_evaluation.ModelEvaluation( + evaluation_name=eval_resource_name, + credentials=self.credentials, + ) + + return eval_resource diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index f678142781..4694bf2250 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -93,6 +93,11 @@ "saved_model.pbtxt", ] +_SUPPORTED_EVAL_PREDICTION_TYPES = [ + "classification", + "regression", +] + class VersionInfo(NamedTuple): """VersionInfo class envelopes returned Model version information. @@ -4895,6 +4900,232 @@ def get_model_evaluation( credentials=self.credentials, ) + def evaluate( + self, + prediction_type: str, + target_field_name: str, + gcs_source_uris: Optional[List[str]] = None, + bigquery_source_uri: Optional[str] = None, + bigquery_destination_output_uri: Optional[str] = None, + class_labels: Optional[List[str]] = None, + prediction_label_column: Optional[str] = None, + prediction_score_column: Optional[str] = None, + staging_bucket: Optional[str] = None, + service_account: Optional[str] = None, + generate_feature_attributions: bool = False, + evaluation_pipeline_display_name: Optional[str] = None, + evaluation_metrics_display_name: Optional[str] = None, + network: Optional[str] = None, + encryption_spec_key_name: Optional[str] = None, + experiment: Optional[Union[str, "aiplatform.Experiment"]] = None, + ) -> "model_evaluation._ModelEvaluationJob": + """Creates a model evaluation job running on Vertex Pipelines and returns the resulting + ModelEvaluationJob resource. + + Example usage: + + ``` + my_model = Model( + model_name="projects/123/locations/us-central1/models/456" + ) + my_evaluation_job = my_model.evaluate( + prediction_type="classification", + target_field_name="type", + data_source_uris=["gs://sdk-model-eval/my-prediction-data.csv"], + staging_bucket="gs://my-staging-bucket/eval_pipeline_root", + ) + my_evaluation_job.wait() + my_evaluation = my_evaluation_job.get_model_evaluation() + my_evaluation.metrics + ``` + + Args: + prediction_type (str): + Required. The problem type being addressed by this evaluation run. 'classification' and 'regression' + are the currently supported problem types. + target_field_name (str): + Required. The column name of the field containing the label for this prediction task. + gcs_source_uris (List[str]): + Optional. A list of Cloud Storage data files containing the ground truth data to use for this + evaluation job. These files should contain your model's prediction column. Currently only Google Cloud Storage + urls are supported, for example: "gs://path/to/your/data.csv". The provided data files must be + either CSV or JSONL. One of `gcs_source_uris` or `bigquery_source_uri` is required. + bigquery_source_uri (str): + Optional. A bigquery table URI containing the ground truth data to use for this evaluation job. This uri should + be in the format 'bq://my-project-id.dataset.table'. One of `gcs_source_uris` or `bigquery_source_uri` is + required. + bigquery_destination_output_uri (str): + Optional. A bigquery table URI where the Batch Prediction job associated with your Model Evaluation will write + prediction output. This can be a BigQuery URI to a project ('bq://my-project'), a dataset + ('bq://my-project.my-dataset'), or a table ('bq://my-project.my-dataset.my-table'). Required if `bigquery_source_uri` + is provided. + class_labels (List[str]): + Optional. For custom (non-AutoML) classification models, a list of possible class names, in the + same order that predictions are generated. This argument is required when prediction_type is 'classification'. + For example, in a classification model with 3 possible classes that are outputted in the format: [0.97, 0.02, 0.01] + with the class names "cat", "dog", and "fish", the value of `class_labels` should be `["cat", "dog", "fish"]` where + the class "cat" corresponds with 0.97 in the example above. + prediction_label_column (str): + Optional. The column name of the field containing classes the model is scoring. Formatted to be able to find nested + columns, delimeted by `.`. If not set, defaulted to `prediction.classes` for classification. + prediction_score_column (str): + Optional. The column name of the field containing batch prediction scores. Formatted to be able to find nested columns, + delimeted by `.`. If not set, defaulted to `prediction.scores` for a `classification` problem_type, `prediction.value` + for a `regression` problem_type. + staging_bucket (str): + Optional. The GCS directory to use for staging files from this evaluation job. Defaults to the value set in + aiplatform.init(staging_bucket=...) if not provided. Required if staging_bucket is not set in aiplatform.init(). + service_account (str): + Specifies the service account for workload run-as account for this Model Evaluation PipelineJob. + Users submitting jobs must have act-as permission on this run-as account. The service account running + this Model Evaluation job needs the following permissions: Dataflow Worker, Storage Admin, + Vertex AI Administrator, and Vertex AI Service Agent. + generate_feature_attributions (boolean): + Optional. Whether the model evaluation job should generate feature attributions. Defaults to False if not specified. + evaluation_pipeline_display_name (str): + Optional. The display name of your model evaluation job. This is the display name that will be applied to the + Vertex Pipeline run for your evaluation job. If not set, a display name will be generated automatically. + evaluation_metrics_display_name (str): + Optional. The display name of the model evaluation resource uploaded to Vertex from your Model Evaluation pipeline. + network (str): + The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + encryption_spec_key_name (str): + Optional. The Cloud KMS resource identifier of the customer managed encryption key used to protect the job. Has the + form: ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. The key needs to be in the same + region as where the compute resource is created. If this is set, then all + resources created by the PipelineJob for this Model Evaluation will be encrypted with the provided encryption key. + If not specified, encryption_spec of original PipelineJob will be used. + experiment (Union[str, experiments_resource.Experiment]): + Optional. The Vertex AI experiment name or instance to associate to the PipelineJob executing + this model evaluation job. Metrics produced by the PipelineJob as system.Metric Artifacts + will be associated as metrics to the provided experiment, and parameters from this PipelineJob + will be associated as parameters to the provided experiment. + Returns: + model_evaluation.ModelEvaluationJob: Instantiated representation of the + _ModelEvaluationJob. + Raises: + ValueError: + If staging_bucket was not set in aiplatform.init() and staging_bucket was not provided. + If the provided `prediction_type` is not valid. + If the provided `data_source_uris` don't start with 'gs://'. + """ + + if (gcs_source_uris is None) == (bigquery_source_uri is None): + raise ValueError( + "Exactly one of `gcs_source_uris` or `bigquery_source_uri` must be provided." + ) + + if isinstance(gcs_source_uris, str): + gcs_source_uris = [gcs_source_uris] + + if bigquery_source_uri and not isinstance(bigquery_source_uri, str): + raise ValueError("The provided `bigquery_source_uri` must be a string.") + + if bigquery_source_uri and not bigquery_destination_output_uri: + raise ValueError( + "`bigquery_destination_output_uri` must be provided if `bigquery_source_uri` is used as the data source." + ) + + if gcs_source_uris is not None and not all( + uri.startswith("gs://") for uri in gcs_source_uris + ): + raise ValueError("`gcs_source_uris` must start with 'gs://'.") + + if bigquery_source_uri is not None and not bigquery_source_uri.startswith( + "bq://" + ): + raise ValueError( + "`bigquery_source_uri` and `bigquery_destination_output_uri` must start with 'bq://'" + ) + + if ( + bigquery_destination_output_uri is not None + and not bigquery_destination_output_uri.startswith("bq://") + ): + raise ValueError( + "`bigquery_source_uri` and `bigquery_destination_output_uri` must start with 'bq://'" + ) + + SUPPORTED_INSTANCES_FORMAT_FILE_EXTENSIONS = [".jsonl", ".csv"] + + if not staging_bucket and initializer.global_config.staging_bucket: + staging_bucket = initializer.global_config.staging_bucket + elif not staging_bucket and not initializer.global_config.staging_bucket: + raise ValueError( + "Please provide `evaluation_staging_bucket` when calling evaluate or set one using aiplatform.init(staging_bucket=...)" + ) + + if prediction_type not in _SUPPORTED_EVAL_PREDICTION_TYPES: + raise ValueError( + f"Please provide a supported model prediction type, one of: {_SUPPORTED_EVAL_PREDICTION_TYPES}." + ) + + if generate_feature_attributions: + if not self._gca_resource.explanation_spec: + raise ValueError( + "To generate feature attributions with your evaluation, call evaluate on a model with an explanation spec. To run evaluation on the current model, call evaluate with `generate_feature_attributions=False`." + ) + + instances_format = None + + if gcs_source_uris: + + data_file_path_obj = pathlib.Path(gcs_source_uris[0]) + + data_file_extension = data_file_path_obj.suffix + if data_file_extension not in SUPPORTED_INSTANCES_FORMAT_FILE_EXTENSIONS: + _LOGGER.warning( + f"Only the following data file extensions are currently supported: '{SUPPORTED_INSTANCES_FORMAT_FILE_EXTENSIONS}'" + ) + else: + instances_format = data_file_extension[1:] + + elif bigquery_source_uri: + instances_format = "bigquery" + + if ( + self._gca_resource.metadata_schema_uri + == "https://storage.googleapis.com/google-cloud-aiplatform/schema/model/metadata/automl_tabular_1.0.0.yaml" + ): + model_type = "automl_tabular" + else: + model_type = "other" + + if ( + model_type == "other" + and prediction_type == "classification" + and not class_labels + ): + raise ValueError( + "Please provide `class_labels` when running evaluation on a custom classification model." + ) + + return model_evaluation._ModelEvaluationJob.submit( + model_name=self.versioned_resource_name, + prediction_type=prediction_type, + target_field_name=target_field_name, + gcs_source_uris=gcs_source_uris, + bigquery_source_uri=bigquery_source_uri, + batch_predict_bigquery_destination_output_uri=bigquery_destination_output_uri, + class_labels=class_labels, + prediction_label_column=prediction_label_column, + prediction_score_column=prediction_score_column, + service_account=service_account, + pipeline_root=staging_bucket, + instances_format=instances_format, + model_type=model_type, + generate_feature_attributions=generate_feature_attributions, + evaluation_pipeline_display_name=evaluation_pipeline_display_name, + evaluation_metrics_display_name=evaluation_metrics_display_name, + network=network, + encryption_spec_key_name=encryption_spec_key_name, + credentials=self.credentials, + experiment=experiment, + ) + # TODO (b/232546878): Async support class ModelRegistry: diff --git a/tests/system/aiplatform/test_model_evaluation.py b/tests/system/aiplatform/test_model_evaluation.py new file mode 100644 index 0000000000..733ea0dc1d --- /dev/null +++ b/tests/system/aiplatform/test_model_evaluation.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +import uuid + +import pytest + +from google.cloud import storage + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from tests.system.aiplatform import e2e_base + +from google.cloud.aiplatform.compat.types import ( + pipeline_state as gca_pipeline_state, +) + +_TEST_MODEL_EVAL_CLASS_LABELS = ["0", "1", "2"] +_TEST_TARGET_FIELD_NAME = "species" + +_TEST_PROJECT = e2e_base._PROJECT +_TEST_LOCATION = e2e_base._LOCATION +_EVAL_METRICS_KEYS_CLASSIFICATION = [ + "auPrc", + "auRoc", + "logLoss", + "confidenceMetrics", + "confusionMatrix", +] + + +_TEST_XGB_CLASSIFICATION_MODEL_ID = "6336857145803276288" +_TEST_EVAL_DATA_URI = ( + "gs://cloud-samples-data-us-central1/vertex-ai/model-evaluation/iris_training.csv" +) +_TEST_PERMANENT_CUSTOM_MODEL_CLASSIFICATION_RESOURCE_NAME = f"projects/{_TEST_PROJECT}/locations/us-central1/models/{_TEST_XGB_CLASSIFICATION_MODEL_ID}" + +_LOGGER = base.Logger(__name__) + + +@pytest.mark.usefixtures( + "prepare_staging_bucket", + "delete_staging_bucket", + "tear_down_resources", +) +class TestModelEvaluationJob(e2e_base.TestEndToEnd): + + _temp_prefix = "temp_vertex_sdk_model_evaluation_test" + + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + @pytest.fixture() + def storage_client(self): + yield storage.Client(project=_TEST_PROJECT) + + @pytest.fixture() + def staging_bucket(self, storage_client): + new_staging_bucket = f"temp-sdk-integration-{uuid.uuid4()}" + bucket = storage_client.create_bucket( + new_staging_bucket, location="us-central1" + ) + + yield bucket + + def test_model_evaluate_custom_tabular_model(self, staging_bucket, shared_state): + + custom_model = aiplatform.Model( + model_name=_TEST_PERMANENT_CUSTOM_MODEL_CLASSIFICATION_RESOURCE_NAME + ) + + eval_job = custom_model.evaluate( + gcs_source_uris=[_TEST_EVAL_DATA_URI], + prediction_type="classification", + class_labels=_TEST_MODEL_EVAL_CLASS_LABELS, + target_field_name=_TEST_TARGET_FIELD_NAME, + staging_bucket=f"gs://{staging_bucket.name}", + ) + + shared_state["resources"] = [eval_job.backing_pipeline_job] + + _LOGGER.info("%s, state before completion", eval_job.backing_pipeline_job.state) + + eval_job.wait() + + _LOGGER.info("%s, state after completion", eval_job.backing_pipeline_job.state) + + assert ( + eval_job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + assert eval_job.state == eval_job.backing_pipeline_job.state + + assert eval_job.resource_name == eval_job.backing_pipeline_job.resource_name + + model_eval = eval_job.get_model_evaluation() + + shared_state["resources"].append(model_eval) + + eval_metrics_dict = dict(model_eval.metrics) + + for metric_name in _EVAL_METRICS_KEYS_CLASSIFICATION: + assert metric_name in eval_metrics_dict diff --git a/tests/unit/aiplatform/conftest.py b/tests/unit/aiplatform/conftest.py index d01a1a4061..7976049498 100644 --- a/tests/unit/aiplatform/conftest.py +++ b/tests/unit/aiplatform/conftest.py @@ -26,7 +26,9 @@ from google.cloud import aiplatform from google.cloud.aiplatform.utils import source_utils import constants as test_constants +from google.cloud.aiplatform.metadata import constants as metadata_constants from google.cloud.aiplatform.compat.services import ( + metadata_service_client_v1, model_service_client, tensorboard_service_client, pipeline_service_client, @@ -35,6 +37,7 @@ from google.cloud.aiplatform.compat.types import ( context, endpoint, + metadata_store, endpoint_service, model, model_service, @@ -461,3 +464,52 @@ def mock_pipeline_service_create_and_get_with_fail(): ) yield mock_create_training_pipeline, mock_get_training_pipeline + + +# Experiment fixtures +@pytest.fixture +def get_experiment_mock(): + with mock.patch.object( + metadata_service_client_v1.MetadataServiceClient, "get_context" + ) as get_context_mock: + get_context_mock.return_value = ( + test_constants.ExperimentConstants._EXPERIMENT_MOCK + ) + yield get_context_mock + + +@pytest.fixture +def get_metadata_store_mock(): + with mock.patch.object( + metadata_service_client_v1.MetadataServiceClient, "get_metadata_store" + ) as get_metadata_store_mock: + get_metadata_store_mock.return_value = metadata_store.MetadataStore( + name=test_constants.ExperimentConstants._TEST_METADATASTORE, + ) + yield get_metadata_store_mock + + +@pytest.fixture +def get_context_mock(): + with mock.patch.object( + metadata_service_client_v1.MetadataServiceClient, "get_context" + ) as get_context_mock: + get_context_mock.return_value = context.Context( + name=test_constants.ExperimentConstants._TEST_CONTEXT_NAME, + display_name=test_constants.ExperimentConstants._TEST_EXPERIMENT, + description=test_constants.ExperimentConstants._TEST_EXPERIMENT_DESCRIPTION, + schema_title=metadata_constants.SYSTEM_EXPERIMENT, + schema_version=metadata_constants.SCHEMA_VERSIONS[ + metadata_constants.SYSTEM_EXPERIMENT + ], + metadata=metadata_constants.EXPERIMENT_METADATA, + ) + yield get_context_mock + + +@pytest.fixture +def add_context_children_mock(): + with mock.patch.object( + metadata_service_client_v1.MetadataServiceClient, "add_context_children" + ) as add_context_children_mock: + yield add_context_children_mock diff --git a/tests/unit/aiplatform/constants.py b/tests/unit/aiplatform/constants.py index 3db30f7d57..4ca82c7746 100644 --- a/tests/unit/aiplatform/constants.py +++ b/tests/unit/aiplatform/constants.py @@ -26,12 +26,14 @@ from google.cloud.aiplatform import explain from google.cloud.aiplatform import utils from google.cloud.aiplatform import schema +from google.cloud.aiplatform.metadata import constants as metadata_constants from google.cloud.aiplatform.compat.services import ( model_service_client, ) from google.cloud.aiplatform.compat.types import ( + context, custom_job, encryption_spec, endpoint, @@ -328,6 +330,30 @@ class DatasetConstants: _TEST_SOURCE_URI_GCS = "gs://my-bucket/my_index_file.jsonl" +@dataclasses.dataclass(frozen=True) +class ExperimentConstants: + """Defines constants used by Experiments and Metadata tests.""" + + _TEST_EXPERIMENT = "test-experiment" + _TEST_CONTEXT_ID = _TEST_EXPERIMENT + _TEST_METADATA_PARENT = f"projects/{ProjectConstants._TEST_PROJECT}/locations/{ProjectConstants._TEST_LOCATION}/metadataStores/default" + _TEST_CONTEXT_NAME = f"{_TEST_METADATA_PARENT}/contexts/{_TEST_CONTEXT_ID}" + _TEST_EXPERIMENT_DESCRIPTION = "test-experiment-description" + + _EXPERIMENT_MOCK = context.Context( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_EXPERIMENT, + description=_TEST_EXPERIMENT_DESCRIPTION, + schema_title=metadata_constants.SYSTEM_EXPERIMENT, + schema_version=metadata_constants.SCHEMA_VERSIONS[ + metadata_constants.SYSTEM_EXPERIMENT + ], + metadata={**metadata_constants.EXPERIMENT_METADATA}, + ) + + _TEST_METADATASTORE = f"projects/{ProjectConstants._TEST_PROJECT}/locations/{ProjectConstants._TEST_LOCATION}/metadataStores/default" + + @dataclasses.dataclass(frozen=True) class MatchingEngineConstants: """Defines constants used by tests that create MatchingEngine resources.""" diff --git a/tests/unit/aiplatform/test_model_evaluation.py b/tests/unit/aiplatform/test_model_evaluation.py index 0a06012e42..58f3e07ae1 100644 --- a/tests/unit/aiplatform/test_model_evaluation.py +++ b/tests/unit/aiplatform/test_model_evaluation.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,24 +17,47 @@ import datetime import pytest +import yaml +import json +from google.protobuf import json_format +from google.protobuf import struct_pb2 from unittest import mock - +from urllib import request from google.api_core import datetime_helpers +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform.metadata import constants + +from google.cloud import storage from google.cloud import aiplatform from google.cloud.aiplatform import base from google.cloud.aiplatform import models +from google.cloud.aiplatform.utils import gcs_utils from google.cloud.aiplatform.compat.services import ( model_service_client, + metadata_service_client_v1 as metadata_service_client, + job_service_client_v1 as job_service_client, ) +from google.cloud.aiplatform.model_evaluation import model_evaluation_job +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client_v1, +) from google.cloud.aiplatform.compat.types import model as gca_model +from google.cloud.aiplatform_v1 import Execution as GapicExecution +from google.cloud.aiplatform_v1 import MetadataServiceClient + from google.cloud.aiplatform.compat.types import ( + pipeline_job as gca_pipeline_job, + pipeline_state as gca_pipeline_state, model_evaluation as gca_model_evaluation, + context as gca_context, + artifact as gca_artifact, + batch_prediction_job as gca_batch_prediction_job, ) import constants as test_constants @@ -44,6 +67,9 @@ _TEST_MODEL_NAME = "test-model" _TEST_MODEL_ID = test_constants.ModelConstants._TEST_ID _TEST_EVAL_ID = "1028944691210842622" +_TEST_EXPERIMENT = "test-experiment" +_TEST_BATCH_PREDICTION_JOB_ID = "614161631630327111" +_TEST_COMPONENT_IDENTIFIER = "fpc-model-evaluation" _TEST_MODEL_RESOURCE_NAME = test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME @@ -56,9 +82,259 @@ ) ) +_TEST_BATCH_PREDICTION_RESOURCE_NAME = ( + job_service_client.JobServiceClient.batch_prediction_job_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_BATCH_PREDICTION_JOB_ID + ) +) + _TEST_MODEL_EVAL_METRICS = test_constants.ModelConstants._TEST_MODEL_EVAL_METRICS +_TEST_INVALID_MODEL_RESOURCE_NAME = ( + f"prj/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_ID}" +) + +# pipeline job +_TEST_ID = "1028944691210842416" +_TEST_PIPELINE_JOB_DISPLAY_NAME = "sample-pipeline-job-display-name" +_TEST_PIPELINE_JOB_ID = "sample-test-pipeline-202111111" +_TEST_GCS_BUCKET_NAME = "my-bucket" +_TEST_CREDENTIALS = auth_credentials.AnonymousCredentials() +_TEST_SERVICE_ACCOUNT = "abcde@my-project.iam.gserviceaccount.com" +_TEST_PIPELINE_ROOT = f"gs://{_TEST_GCS_BUCKET_NAME}/pipeline_root" +_TEST_PIPELINE_CREATE_TIME = datetime.datetime.now() + +_TEST_KFP_TEMPLATE_URI = "https://us-kfp.pkg.dev/vertex-evaluation/pipeline-templates/evaluation-automl-tabular-classification-pipeline/1.0.0" + +_TEST_TEMPLATE_REF = { + "base_uri": "https://us-kfp.pkg.dev/vertex-evaluation/pipeline-templates/evaluation", + "tag": "20230713_1737", +} +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}" + +_TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}" +_TEST_INVALID_PIPELINE_JOB_NAME = ( + f"prj/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/{_TEST_PIPELINE_JOB_ID}" +) +_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME = "test-eval-job" +_TEST_EVAL_RESOURCE_DISPLAY_NAME = "my-eval-resource-display-name" + +_TEST_MODEL_EVAL_METADATA = {"pipeline_job_resource_name": _TEST_PIPELINE_JOB_NAME} + +_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES = { + "batch_predict_gcs_source_uris": ["gs://my-bucket/my-prediction-data.csv"], + "dataflow_service_account": _TEST_SERVICE_ACCOUNT, + "batch_predict_instances_format": "csv", + "model_name": _TEST_MODEL_RESOURCE_NAME, + "evaluation_display_name": _TEST_EVAL_RESOURCE_DISPLAY_NAME, + "project": _TEST_PROJECT, + "location": _TEST_LOCATION, + "batch_predict_gcs_destination_output_uri": _TEST_GCS_BUCKET_NAME, + "target_field_name": "predict_class", +} + +_TEST_MODEL_EVAL_PREDICTION_TYPE = "classification" + +_TEST_JSON_FORMATTED_MODEL_EVAL_PIPELINE_PARAMETER_VALUES = { + "batch_predict_gcs_source_uris": '["gs://sdk-model-eval/batch-pred-heart.csv"]', + "dataflow_service_account": _TEST_SERVICE_ACCOUNT, + "batch_predict_instances_format": "csv", + "model_name": _TEST_MODEL_RESOURCE_NAME, + "project": _TEST_PROJECT, + "location": _TEST_LOCATION, + "batch_predict_gcs_destination_output_uri": _TEST_GCS_BUCKET_NAME, + "target_field_name": "predict_class", +} + +_TEST_MODEL_EVAL_PIPELINE_SPEC = { + "pipelineInfo": {"name": "evaluation-default-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": { + "parameters": { + "batch_predict_gcs_source_uris": {"type": "STRING"}, + "dataflow_service_account": _TEST_SERVICE_ACCOUNT, + "batch_predict_instances_format": {"type": "STRING"}, + "batch_predict_machine_type": {"type": "STRING"}, + "location": {"type": "STRING"}, + "model_name": {"type": "STRING"}, + "project": {"type": "STRING"}, + "batch_predict_gcs_destination_output_uri": {"type": "STRING"}, + "target_field_name": {"type": "STRING"}, + } + }, + }, + "schemaVersion": "2.0.0", + "sdkVersion": "kfp-1.8.12", + "components": {}, +} + +_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON = json.dumps( + { + "pipelineInfo": {"name": "evaluation-default-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": { + "parameters": { + "batch_predict_gcs_source_uris": {"type": "STRING"}, + "dataflow_service_account": {"type": "STRING"}, + "batch_predict_instances_format": {"type": "STRING"}, + "batch_predict_machine_type": {"type": "STRING"}, + "evaluation_class_labels": {"type": "STRING"}, + "location": {"type": "STRING"}, + "model_name": {"type": "STRING"}, + "project": {"type": "STRING"}, + "batch_predict_gcs_destination_output_uri": {"type": "STRING"}, + "target_field_name": {"type": "STRING"}, + } + }, + }, + "schemaVersion": "2.0.0", + "sdkVersion": "kfp-1.8.12", + "components": {}, + } +) + +_TEST_INVALID_MODEL_EVAL_PIPELINE_SPEC = json.dumps( + { + "pipelineInfo": {"name": "my-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": { + "parameters": { + "batch_predict_gcs_source_uris": {"type": "STRING"}, + "dataflow_service_account": {"type": "STRING"}, + "batch_predict_instances_format": {"type": "STRING"}, + "model_name": {"type": "STRING"}, + "project": {"type": "STRING"}, + "location": {"type": "STRING"}, + "batch_predict_gcs_destination_output_uri": {"type": "STRING"}, + "target_field_name": {"type": "STRING"}, + } + }, + }, + "schemaVersion": "2.0.0", + "sdkVersion": "kfp-1.8.12", + "components": {}, + } +) + +_TEST_MODEL_EVAL_PIPELINE_JOB = json.dumps( + { + "runtimeConfig": {"parameters": _TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES}, + "pipelineInfo": {"name": "evaluation-default-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": { + "parameters": { + "batch_predict_gcs_source_uris": {"type": "STRING"}, + "dataflow_service_account": {"type": "STRING"}, + "batch_predict_instances_format": {"type": "STRING"}, + "batch_predict_machine_type": {"type": "STRING"}, + "evaluation_class_labels": {"type": "STRING"}, + "location": {"type": "STRING"}, + "model_name": {"type": "STRING"}, + "project": {"type": "STRING"}, + "batch_predict_gcs_destination_output_uri": {"type": "STRING"}, + "target_field_name": {"type": "STRING"}, + } + }, + }, + "schemaVersion": "2.0.0", + "sdkVersion": "kfp-1.8.12", + "components": {}, + } +) + +_TEST_INVALID_MODEL_EVAL_PIPELINE_JOB = json.dumps( + { + "runtimeConfig": { + "parameterValues": _TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES + }, + "pipelineInfo": {"name": "my-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": { + "parameters": { + "batch_predict_gcs_source_uris": {"type": "STRING"}, + "batch_predict_instances_format": {"type": "STRING"}, + "model_name": {"type": "STRING"}, + "project": {"type": "STRING"}, + "location": {"type": "STRING"}, + "batch_predict_gcs_destination_output_uri": {"type": "STRING"}, + "target_field_name": {"type": "STRING"}, + } + }, + }, + "schemaVersion": "2.0.0", + "sdkVersion": "kfp-1.8.12", + "components": {"test_component": {}}, + } +) + +_EVAL_GCP_RESOURCES_STR = ( + '{\n "resources": [\n {\n "resourceType": "ModelEvaluation",\n "resourceUri": "https://us-central1-aiplatform.googleapis.com/v1/' + + _TEST_MODEL_EVAL_RESOURCE_NAME + + '"\n }\n ]\n}' +) + +_BP_JOB_GCP_RESOURCES_STR = ( + '{\n "resources": [\n {\n "resourceType": "BatchPredictionJob",\n "resourceUri": "https://us-central1-aiplatform.googleapis.com/v1/' + + _TEST_BATCH_PREDICTION_RESOURCE_NAME + + '"\n }\n ]\n}' +) + +_TEST_PIPELINE_JOB_DETAIL_EVAL = { + "output:evaluation_resource_name": _TEST_MODEL_EVAL_RESOURCE_NAME +} + +_TEST_PIPELINE_JOB_DETAIL_BP = { + "output:gcp_resources": _BP_JOB_GCP_RESOURCES_STR, +} + +_TEST_EVAL_METRICS_ARTIFACT_NAME = ( + "projects/123/locations/us-central1/metadataStores/default/artifacts/456" +) +_TEST_EVAL_METRICS_ARTIFACT_URI = "gs://test-bucket/eval_pipeline_root/123/evaluation-default-pipeline-20220615135923/model-evaluation-2_-789/evaluation_metrics" + +_TEST_EXPERIMENT = "test-experiment" +_TEST_METADATASTORE = f"{_TEST_PARENT}/metadataStores/default" +_TEST_CONTEXT_NAME = f"{_TEST_METADATASTORE}/contexts/{_TEST_EXPERIMENT}" + +# executions: this is used in test_list_pipeline_based_service +_TEST_EXECUTION_PARENT = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" +) + +_TEST_RUN = "run-1" +_TEST_OTHER_RUN = "run-2" +_TEST_EXPERIMENT = "test-experiment" +_TEST_EXECUTION_ID = f"{_TEST_EXPERIMENT}-{_TEST_RUN}" +_TEST_EXECUTION_NAME = f"{_TEST_EXECUTION_PARENT}/executions/{_TEST_EXECUTION_ID}" + +_TEST_OTHER_EXECUTION_ID = f"{_TEST_EXPERIMENT}-{_TEST_OTHER_RUN}" +_TEST_OTHER_EXECUTION_NAME = ( + f"{_TEST_EXECUTION_PARENT}/executions/{_TEST_OTHER_EXECUTION_ID}" +) + +# execution metadata parameters: used in test_list_pipeline_based_service +_TEST_PARAM_KEY_1 = "learning_rate" +_TEST_PARAM_KEY_2 = "dropout" +_TEST_PIPELINE_PARAM_KEY = "pipeline_job_resource_name" +_TEST_PARAMS = { + _TEST_PARAM_KEY_1: 0.01, + _TEST_PARAM_KEY_2: 0.2, + _TEST_PIPELINE_PARAM_KEY: _TEST_PIPELINE_JOB_NAME, +} +_TEST_OTHER_PARAMS = {_TEST_PARAM_KEY_1: 0.02, _TEST_PARAM_KEY_2: 0.3} + +_TEST_MODEL_EVAL_CLASS_LABELS = ["0", "1", "2"] +_TEST_TARGET_FIELD_NAME = "species" + + +# model eval mocks @pytest.fixture def get_model_mock(): with mock.patch.object( @@ -86,7 +362,21 @@ def mock_model(): yield model -# ModelEvaluation mocks +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client_v1.PipelineServiceClient, "create_pipeline_job" + ) as mock_create_pipeline_job: + mock_create_pipeline_job.return_value = gca_pipeline_job.PipelineJob( + name=_TEST_PIPELINE_JOB_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + create_time=_TEST_PIPELINE_CREATE_TIME, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + yield mock_create_pipeline_job + + @pytest.fixture def mock_model_eval_get(): with mock.patch.object( @@ -95,6 +385,7 @@ def mock_model_eval_get(): mock_get_model_eval.return_value = gca_model_evaluation.ModelEvaluation( name=_TEST_MODEL_EVAL_RESOURCE_NAME, metrics=_TEST_MODEL_EVAL_METRICS, + metadata=_TEST_MODEL_EVAL_METADATA, ) yield mock_get_model_eval @@ -130,6 +421,329 @@ def list_model_evaluations_mock(): yield list_model_evaluations_mock +@pytest.fixture +def mock_pipeline_bucket_exists(): + def mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist( + output_artifacts_gcs_dir=None, + service_account=None, + project=None, + location=None, + credentials=None, + ): + output_artifacts_gcs_dir = ( + output_artifacts_gcs_dir + or gcs_utils.generate_gcs_directory_for_pipeline_artifacts( + project=project, + location=location, + ) + ) + return output_artifacts_gcs_dir + + with mock.patch( + "google.cloud.aiplatform.utils.gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist", + wraps=mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist, + ) as mock_context: + yield mock_context + + +@pytest.fixture +def mock_artifact(): + artifact = mock.MagicMock(aiplatform.Artifact) + artifact._gca_resource = gca_artifact.Artifact( + display_name="evaluation_metrics", + name=_TEST_EVAL_METRICS_ARTIFACT_NAME, + uri=_TEST_EVAL_METRICS_ARTIFACT_URI, + ) + yield artifact + + +@pytest.fixture +def get_artifact_mock(): + with mock.patch.object( + metadata_service_client.MetadataServiceClient, "get_artifact" + ) as get_artifact_mock: + get_artifact_mock.return_value = gca_artifact.Artifact( + display_name="evaluation_metrics", + name=_TEST_EVAL_METRICS_ARTIFACT_NAME, + uri=_TEST_EVAL_METRICS_ARTIFACT_URI, + ) + + yield get_artifact_mock + + +@pytest.fixture +def get_batch_prediction_job_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_bp_job_mock: + get_bp_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_RESOURCE_NAME, + ) + yield get_bp_job_mock + + +def make_pipeline_job(state): + return gca_pipeline_job.PipelineJob( + name=_TEST_PIPELINE_JOB_NAME, + state=state, + create_time=_TEST_PIPELINE_CREATE_TIME, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + job_detail=gca_pipeline_job.PipelineJobDetail( + pipeline_run_context=gca_context.Context( + name=_TEST_PIPELINE_JOB_NAME, + ), + task_details=[ + gca_pipeline_job.PipelineTaskDetail( + task_id=123, + task_name=_TEST_PIPELINE_JOB_ID, + state=gca_pipeline_job.PipelineTaskDetail.State.SUCCEEDED, + execution={ + "metadata": struct_pb2.Struct( + fields={ + key: struct_pb2.Value(string_value=value) + for key, value in _TEST_PIPELINE_JOB_DETAIL_EVAL.items() + }, + ), + }, + ), + gca_pipeline_job.PipelineTaskDetail( + task_id=123, + execution=GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata={"component_type": _TEST_COMPONENT_IDENTIFIER}, + ), + ), + ], + ), + ) + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client_v1.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_pipeline_job: + mock_get_pipeline_job.side_effect = [ + make_pipeline_job(gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + ] + + yield mock_get_pipeline_job + + +@pytest.fixture +def mock_pipeline_service_get_with_fail(): + with mock.patch.object( + pipeline_service_client_v1.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_pipeline_job: + mock_get_pipeline_job.side_effect = [ + make_pipeline_job(gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING), + make_pipeline_job(gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING), + make_pipeline_job(gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED), + ] + + yield mock_get_pipeline_job + + +@pytest.fixture +def mock_pipeline_service_get_pending(): + with mock.patch.object( + pipeline_service_client_v1.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_pipeline_job: + mock_get_pipeline_job.side_effect = [ + make_pipeline_job(gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING), + make_pipeline_job(gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING), + ] + + yield mock_get_pipeline_job + + +@pytest.fixture +def mock_load_json(job_spec_json): + with mock.patch.object(storage.Blob, "download_as_bytes") as mock_load_json: + mock_load_json.return_value = json.dumps(job_spec_json).encode() + yield mock_load_json + + +@pytest.fixture +def mock_load_yaml_and_json(job_spec): + with mock.patch.object( + storage.Blob, "download_as_bytes" + ) as mock_load_yaml_and_json: + mock_load_yaml_and_json.return_value = job_spec.encode() + yield mock_load_yaml_and_json + + +@pytest.fixture +def mock_invalid_model_eval_job_get(): + with mock.patch.object( + pipeline_service_client_v1.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_model_eval_job: + mock_get_model_eval_job.return_value = gca_pipeline_job.PipelineJob( + name=_TEST_PIPELINE_JOB_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + create_time=_TEST_PIPELINE_CREATE_TIME, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + # pipeline_spec=_TEST_INVALID_MODEL_EVAL_PIPELINE_SPEC, + ) + yield mock_get_model_eval_job + + +@pytest.fixture +def mock_model_eval_job_create(): + with mock.patch.object( + pipeline_service_client_v1.PipelineServiceClient, "create_pipeline_job" + ) as mock_create_model_eval_job: + mock_create_model_eval_job.return_value = gca_pipeline_job.PipelineJob( + name=_TEST_PIPELINE_JOB_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + create_time=_TEST_PIPELINE_CREATE_TIME, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + pipeline_spec=_TEST_MODEL_EVAL_PIPELINE_SPEC, + ) + yield mock_create_model_eval_job + + +@pytest.fixture +def mock_model_eval_job_get(): + with mock.patch.object( + pipeline_service_client_v1.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_model_eval_job: + mock_get_model_eval_job.return_value = make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + yield mock_get_model_eval_job + + +@pytest.fixture +def mock_successfully_completed_eval_job(): + with mock.patch.object( + pipeline_service_client_v1.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_model_eval_job: + mock_get_model_eval_job.return_value = make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + yield mock_get_model_eval_job + + +@pytest.fixture +def mock_failed_completed_eval_job(): + with mock.patch.object( + pipeline_service_client_v1.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_model_eval_job: + mock_get_model_eval_job.return_value = make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED + ) + yield mock_get_model_eval_job + + +@pytest.fixture +def mock_pending_eval_job(): + with mock.patch.object( + pipeline_service_client_v1.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_model_eval_job: + mock_get_model_eval_job.return_value = make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING + ) + yield mock_get_model_eval_job + + +def make_failed_eval_job(): + model_evaluation_job._ModelEvaluationJob._template_ref = _TEST_TEMPLATE_REF + + eval_job_resource = model_evaluation_job._ModelEvaluationJob( + evaluation_pipeline_run_name=_TEST_PIPELINE_JOB_NAME + ) + eval_job_resource.backing_pipeline_job = gca_pipeline_job.PipelineJob( + name=_TEST_PIPELINE_JOB_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + create_time=_TEST_PIPELINE_CREATE_TIME, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + pipeline_spec=_TEST_MODEL_EVAL_PIPELINE_SPEC, + ) + return eval_job_resource + + +@pytest.fixture +def get_execution_mock(): + with mock.patch.object( + MetadataServiceClient, "get_execution" + ) as get_execution_mock: + get_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata={"component_type": _TEST_COMPONENT_IDENTIFIER}, + ) + yield get_execution_mock + + +@pytest.fixture +def list_executions_mock(): + with mock.patch.object( + MetadataServiceClient, "list_executions" + ) as list_executions_mock: + list_executions_mock.return_value = [ + GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata=_TEST_PARAMS, + ), + GapicExecution( + name=_TEST_OTHER_EXECUTION_NAME, + display_name=_TEST_OTHER_RUN, + schema_title=constants.SYSTEM_RUN, + schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN], + metadata=_TEST_OTHER_PARAMS, + ), + ] + yield list_executions_mock + + +@pytest.fixture +def mock_request_urlopen(job_spec): + with mock.patch.object(request, "urlopen") as mock_urlopen: + mock_read_response = mock.MagicMock() + mock_decode_response = mock.MagicMock() + mock_decode_response.return_value = job_spec.encode() + mock_read_response.return_value.decode = mock_decode_response + mock_urlopen.return_value.read = mock_read_response + yield mock_urlopen + + @pytest.mark.usefixtures("google_auth_mock") class TestModelEvaluation: def test_init_model_evaluation_with_only_resource_name(self, mock_model_eval_get): @@ -216,3 +830,582 @@ def test_list_model_evaluations_with_order_by( ) assert metrics_list[0].create_time > metrics_list[1].create_time + + def test_get_model_evaluation_pipeline_job( + self, mock_model_eval_get, mock_pipeline_service_get + ): + aiplatform.init(project=_TEST_PROJECT) + + eval_pipeline_job = aiplatform.ModelEvaluation( + evaluation_name=_TEST_MODEL_EVAL_RESOURCE_NAME, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + )._backing_pipeline_job + + assert eval_pipeline_job.resource_name == _TEST_PIPELINE_JOB_NAME + + @pytest.mark.parametrize( + "job_spec", + [_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON], + ) + def test_get_model_evaluation_bp_job( + self, + mock_pipeline_service_create, + job_spec, + mock_load_yaml_and_json, + mock_model, + mock_artifact, + get_model_mock, + mock_model_eval_get, + mock_model_eval_job_get, + mock_pipeline_service_get, + mock_model_eval_job_create, + mock_successfully_completed_eval_job, + mock_pipeline_bucket_exists, + get_artifact_mock, + get_batch_prediction_job_mock, + mock_request_urlopen, + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + staging_bucket=_TEST_GCS_BUCKET_NAME, + ) + + test_model_eval_job = model_evaluation_job._ModelEvaluationJob.submit( + model_name=_TEST_MODEL_RESOURCE_NAME, + prediction_type=_TEST_MODEL_EVAL_PREDICTION_TYPE, + instances_format=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_instances_format" + ], + model_type="automl_tabular", + pipeline_root=_TEST_GCS_BUCKET_NAME, + target_field_name=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "target_field_name" + ], + evaluation_pipeline_display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME, + gcs_source_uris=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_gcs_source_uris" + ], + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + + test_model_eval_job.wait() + + eval_resource = test_model_eval_job.get_model_evaluation() + + assert isinstance(eval_resource, aiplatform.ModelEvaluation) + + @pytest.mark.parametrize( + "job_spec", + [_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON], + ) + def test_get_model_evaluation_mlmd_resource( + self, + mock_pipeline_service_create, + job_spec, + mock_load_yaml_and_json, + mock_model, + mock_artifact, + get_model_mock, + mock_model_eval_get, + mock_model_eval_job_get, + mock_pipeline_service_get, + mock_model_eval_job_create, + mock_successfully_completed_eval_job, + mock_pipeline_bucket_exists, + get_artifact_mock, + mock_request_urlopen, + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + staging_bucket=_TEST_GCS_BUCKET_NAME, + ) + + test_model_eval_job = model_evaluation_job._ModelEvaluationJob.submit( + model_name=_TEST_MODEL_RESOURCE_NAME, + prediction_type=_TEST_MODEL_EVAL_PREDICTION_TYPE, + instances_format=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_instances_format" + ], + model_type="automl_tabular", + pipeline_root=_TEST_GCS_BUCKET_NAME, + target_field_name=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "target_field_name" + ], + evaluation_pipeline_display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME, + gcs_source_uris=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_gcs_source_uris" + ], + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + + test_model_eval_job.wait() + + eval_resource = test_model_eval_job.get_model_evaluation() + + assert isinstance(eval_resource, aiplatform.ModelEvaluation) + + +@pytest.mark.usefixtures("google_auth_mock") +class TestModelEvaluationJob: + @pytest.mark.parametrize( + "job_spec", + [_TEST_MODEL_EVAL_PIPELINE_JOB], + ) + def test_init_model_evaluation_job( + self, + job_spec, + mock_load_yaml_and_json, + mock_model_eval_job_get, + get_execution_mock, + ): + aiplatform.init(project=_TEST_PROJECT) + + model_evaluation_job._ModelEvaluationJob( + evaluation_pipeline_run_name=_TEST_PIPELINE_JOB_NAME + ) + + mock_model_eval_job_get.assert_called_with( + name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY + ) + + assert mock_model_eval_job_get.call_count == 2 + + get_execution_mock.assert_called_once + + @pytest.mark.parametrize( + "job_spec", + [_TEST_INVALID_MODEL_EVAL_PIPELINE_JOB], + ) + def test_init_model_evaluation_job_with_non_eval_pipeline_raises( + self, + job_spec, + mock_load_yaml_and_json, + mock_invalid_model_eval_job_get, + ): + """This should fail because we're passing in `mock_invalid_model_eval_job_get`. + + That mock uses a pipeline template that doesn't have the _component_identifier + defined in the ModelEvaluationJob class. + """ + + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(ValueError): + model_evaluation_job._ModelEvaluationJob( + evaluation_pipeline_run_name=_TEST_PIPELINE_JOB_NAME + ) + + def test_init_model_evaluation_job_with_invalid_pipeline_job_name_raises( + self, + mock_pipeline_service_get, + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + with pytest.raises(ValueError): + model_evaluation_job._ModelEvaluationJob( + evaluation_pipeline_run_name=_TEST_INVALID_PIPELINE_JOB_NAME, + ) + + @pytest.mark.parametrize( + "job_spec", + [_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON], + ) + @pytest.mark.usefixtures("mock_pipeline_service_create") + def test_model_evaluation_job_submit( + self, + job_spec, + mock_load_yaml_and_json, + mock_model, + get_model_mock, + mock_model_eval_job_get, + mock_pipeline_service_get, + mock_model_eval_job_create, + mock_pipeline_bucket_exists, + mock_request_urlopen, + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + staging_bucket=_TEST_GCS_BUCKET_NAME, + ) + + test_model_eval_job = model_evaluation_job._ModelEvaluationJob.submit( + model_name=_TEST_MODEL_RESOURCE_NAME, + prediction_type=_TEST_MODEL_EVAL_PREDICTION_TYPE, + instances_format=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_instances_format" + ], + model_type="automl_tabular", + pipeline_root=_TEST_GCS_BUCKET_NAME, + target_field_name=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "target_field_name" + ], + evaluation_pipeline_display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME, + gcs_source_uris=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_gcs_source_uris" + ], + job_id=_TEST_PIPELINE_JOB_ID, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + + test_model_eval_job.wait() + + expected_runtime_config_dict = { + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, + "parameters": { + "batch_predict_gcs_source_uris": { + "stringValue": '["gs://my-bucket/my-prediction-data.csv"]' + }, + "dataflow_service_account": {"stringValue": _TEST_SERVICE_ACCOUNT}, + "batch_predict_instances_format": {"stringValue": "csv"}, + "model_name": {"stringValue": _TEST_MODEL_RESOURCE_NAME}, + "project": {"stringValue": _TEST_PROJECT}, + "location": {"stringValue": _TEST_LOCATION}, + "batch_predict_gcs_destination_output_uri": { + "stringValue": _TEST_GCS_BUCKET_NAME + }, + "target_field_name": {"stringValue": "predict_class"}, + }, + } + + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(expected_runtime_config_dict, runtime_config) + + job_spec = yaml.safe_load(job_spec) + pipeline_spec = job_spec.get("pipelineSpec") or job_spec + + # Construct expected request + expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob( + display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME, + pipeline_spec={ + "components": {}, + "pipelineInfo": pipeline_spec["pipelineInfo"], + "root": pipeline_spec["root"], + "schemaVersion": "2.0.0", + "sdkVersion": "kfp-1.8.12", + }, + runtime_config=runtime_config, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + template_uri=_TEST_KFP_TEMPLATE_URI, + ) + + mock_model_eval_job_create.assert_called_with( + parent=_TEST_PARENT, + pipeline_job=expected_gapic_pipeline_job, + pipeline_job_id=_TEST_PIPELINE_JOB_ID, + timeout=None, + ) + + assert mock_model_eval_job_get.called_once + + assert mock_pipeline_service_get.called_once + + assert mock_model_eval_job_get.called_once + + @pytest.mark.parametrize( + "job_spec", + [_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON], + ) + def test_model_evaluation_job_submit_with_experiment( + self, + mock_pipeline_service_create, + job_spec, + mock_load_yaml_and_json, + mock_model, + get_model_mock, + get_experiment_mock, + mock_model_eval_job_get, + mock_pipeline_service_get, + mock_model_eval_job_create, + add_context_children_mock, + get_metadata_store_mock, + get_context_mock, + mock_pipeline_bucket_exists, + mock_request_urlopen, + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + staging_bucket=_TEST_GCS_BUCKET_NAME, + ) + + test_experiment = aiplatform.Experiment(_TEST_EXPERIMENT) + + test_model_eval_job = model_evaluation_job._ModelEvaluationJob.submit( + model_name=_TEST_MODEL_RESOURCE_NAME, + prediction_type=_TEST_MODEL_EVAL_PREDICTION_TYPE, + pipeline_root=_TEST_GCS_BUCKET_NAME, + target_field_name=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "target_field_name" + ], + model_type="automl_tabular", + evaluation_pipeline_display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME, + gcs_source_uris=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_gcs_source_uris" + ], + job_id=_TEST_PIPELINE_JOB_ID, + instances_format=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_instances_format" + ], + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + experiment=test_experiment, + ) + + test_model_eval_job.wait() + + expected_runtime_config_dict = { + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, + "parameters": { + "batch_predict_gcs_source_uris": { + "stringValue": '["gs://my-bucket/my-prediction-data.csv"]' + }, + "dataflow_service_account": {"stringValue": _TEST_SERVICE_ACCOUNT}, + "batch_predict_instances_format": {"stringValue": "csv"}, + "model_name": {"stringValue": _TEST_MODEL_RESOURCE_NAME}, + "project": {"stringValue": _TEST_PROJECT}, + "location": {"stringValue": _TEST_LOCATION}, + "batch_predict_gcs_destination_output_uri": { + "stringValue": _TEST_GCS_BUCKET_NAME + }, + "target_field_name": {"stringValue": "predict_class"}, + }, + } + + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(expected_runtime_config_dict, runtime_config) + + job_spec = yaml.safe_load(job_spec) + pipeline_spec = job_spec.get("pipelineSpec") or job_spec + + # Construct expected request + expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob( + display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME, + pipeline_spec={ + "components": {}, + "pipelineInfo": pipeline_spec["pipelineInfo"], + "root": pipeline_spec["root"], + "schemaVersion": "2.0.0", + "sdkVersion": "kfp-1.8.12", + }, + runtime_config=runtime_config, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + template_uri=_TEST_KFP_TEMPLATE_URI, + ) + + mock_model_eval_job_create.assert_called_with( + parent=_TEST_PARENT, + pipeline_job=expected_gapic_pipeline_job, + pipeline_job_id=_TEST_PIPELINE_JOB_ID, + timeout=None, + ) + + get_context_mock.assert_called_with( + name=_TEST_CONTEXT_NAME, + retry=base._DEFAULT_RETRY, + ) + + @pytest.mark.parametrize( + "job_spec", + [_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON], + ) + def test_get_model_evaluation_with_successful_pipeline_run_returns_resource( + self, + mock_pipeline_service_create, + job_spec, + mock_load_yaml_and_json, + mock_model, + get_model_mock, + mock_model_eval_get, + mock_model_eval_job_get, + mock_pipeline_service_get, + mock_model_eval_job_create, + mock_successfully_completed_eval_job, + mock_pipeline_bucket_exists, + mock_request_urlopen, + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + staging_bucket=_TEST_GCS_BUCKET_NAME, + ) + + test_model_eval_job = model_evaluation_job._ModelEvaluationJob.submit( + model_name=_TEST_MODEL_RESOURCE_NAME, + prediction_type=_TEST_MODEL_EVAL_PREDICTION_TYPE, + pipeline_root=_TEST_GCS_BUCKET_NAME, + target_field_name=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "target_field_name" + ], + model_type="automl_tabular", + evaluation_pipeline_display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME, + gcs_source_uris=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_gcs_source_uris" + ], + instances_format=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_instances_format" + ], + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + + test_model_eval_job.wait() + + assert ( + test_model_eval_job.backing_pipeline_job.resource_name + == _TEST_PIPELINE_JOB_NAME + ) + + assert isinstance( + test_model_eval_job.backing_pipeline_job, aiplatform.PipelineJob + ) + + test_eval = test_model_eval_job.get_model_evaluation() + + assert isinstance(test_eval, aiplatform.ModelEvaluation) + + assert test_eval.metrics == _TEST_MODEL_EVAL_METRICS + + mock_model_eval_get.assert_called_with( + name=_TEST_MODEL_EVAL_RESOURCE_NAME, retry=base._DEFAULT_RETRY + ) + + assert isinstance(test_eval._backing_pipeline_job, aiplatform.PipelineJob) + + @pytest.mark.parametrize( + "job_spec", + [_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON], + ) + def test_model_evaluation_job_get_model_evaluation_with_failed_pipeline_run_raises( + self, + mock_pipeline_service_create, + job_spec, + mock_load_yaml_and_json, + mock_model, + get_model_mock, + mock_model_eval_get, + mock_model_eval_job_get, + mock_pipeline_service_get, + mock_model_eval_job_create, + mock_failed_completed_eval_job, + mock_pipeline_bucket_exists, + mock_request_urlopen, + ): + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + staging_bucket=_TEST_GCS_BUCKET_NAME, + ) + + test_model_eval_job = model_evaluation_job._ModelEvaluationJob.submit( + model_name=_TEST_MODEL_RESOURCE_NAME, + prediction_type=_TEST_MODEL_EVAL_PREDICTION_TYPE, + pipeline_root=_TEST_GCS_BUCKET_NAME, + target_field_name=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "target_field_name" + ], + model_type="automl_tabular", + evaluation_pipeline_display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME, + gcs_source_uris=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_gcs_source_uris" + ], + instances_format=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_instances_format" + ], + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + + with pytest.raises(RuntimeError): + test_model_eval_job.get_model_evaluation() + + @pytest.mark.parametrize( + "job_spec", + [_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON], + ) + def test_model_evaluation_job_get_model_evaluation_with_pending_pipeline_run_returns_none( + self, + mock_pipeline_service_create, + job_spec, + mock_load_yaml_and_json, + mock_model, + get_model_mock, + mock_model_eval_get, + mock_model_eval_job_get, + mock_pipeline_service_get, + mock_model_eval_job_create, + mock_pending_eval_job, + mock_pipeline_bucket_exists, + mock_request_urlopen, + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + staging_bucket=_TEST_GCS_BUCKET_NAME, + ) + + test_model_eval_job = model_evaluation_job._ModelEvaluationJob.submit( + model_name=_TEST_MODEL_RESOURCE_NAME, + prediction_type=_TEST_MODEL_EVAL_PREDICTION_TYPE, + pipeline_root=_TEST_GCS_BUCKET_NAME, + target_field_name=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "target_field_name" + ], + model_type="automl_tabular", + evaluation_pipeline_display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME, + gcs_source_uris=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_gcs_source_uris" + ], + instances_format=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ + "batch_predict_instances_format" + ], + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + + assert test_model_eval_job.get_model_evaluation() is None + + def test_get_template_url( + self, + ): + + template_url = model_evaluation_job._ModelEvaluationJob._get_template_url( + model_type="automl_tabular", + feature_attributions=False, + prediction_type=_TEST_MODEL_EVAL_PREDICTION_TYPE, + ) + + assert template_url == _TEST_KFP_TEMPLATE_URI + + regression_template_url = ( + model_evaluation_job._ModelEvaluationJob._get_template_url( + model_type="other", + feature_attributions=True, + prediction_type="regression", + ) + ) + + assert ( + regression_template_url + == "https://us-kfp.pkg.dev/vertex-evaluation/pipeline-templates/evaluation-feature-attribution-regression-pipeline/1.0.0" + ) diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 842ab70b5c..32a294b7a9 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -17,21 +17,27 @@ import importlib from concurrent import futures +import json import pathlib import pytest import requests +from datetime import datetime from unittest import mock from unittest.mock import patch +from urllib import request from google.api_core import operation as ga_operation from google.api_core import exceptions as api_exceptions from google.auth import credentials as auth_credentials +from google.cloud import storage from google.cloud import aiplatform from google.cloud.aiplatform import base, explain from google.cloud.aiplatform import initializer from google.cloud.aiplatform import models from google.cloud.aiplatform import utils +from google.cloud.aiplatform.utils import gcs_utils +from google.cloud.aiplatform.metadata import constants as metadata_constants from google.cloud.aiplatform import constants from google.cloud.aiplatform.preview import models as preview_models @@ -63,11 +69,16 @@ model as gca_model, model_evaluation as gca_model_evaluation, model_service as gca_model_service, + pipeline_job as gca_pipeline_job, + pipeline_state as gca_pipeline_state, + context as gca_context, ) from google.cloud.aiplatform.prediction import LocalModel +from google.cloud.aiplatform_v1 import Execution as GapicExecution +from google.cloud.aiplatform.model_evaluation import model_evaluation_job -from google.protobuf import field_mask_pb2, timestamp_pb2 +from google.protobuf import field_mask_pb2, struct_pb2, timestamp_pb2 import constants as test_constants @@ -274,6 +285,116 @@ ), ] +# model.evaluate +_TEST_PIPELINE_JOB_ID = "sample-test-pipeline-202111111" +_TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}" +_TEST_PIPELINE_CREATE_TIME = datetime.now() +_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}" +_TEST_MODEL_EVAL_CLASS_LABELS = ["dog", "cat", "rabbit"] +_TEST_BIGQUERY_EVAL_INPUT_URI = "bq://my-project.my-dataset.my-table" +_TEST_BIGQUERY_EVAL_DESTINATION_URI = "bq://my-project.my-dataset" +_TEST_EVAL_RESOURCE_DISPLAY_NAME = "my-eval-resource-display-name" +_TEST_GCS_BUCKET_NAME = "my-bucket" + +_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES = { + "batch_predict_gcs_source_uris": ["gs://my-bucket/my-prediction-data.csv"], + "dataflow_service_account": _TEST_SERVICE_ACCOUNT, + "batch_predict_instances_format": "csv", + "model_name": _TEST_MODEL_RESOURCE_NAME, + "evaluation_display_name": _TEST_EVAL_RESOURCE_DISPLAY_NAME, + "prediction_type": "classification", + "project": _TEST_PROJECT, + "location": _TEST_LOCATION, + "batch_predict_gcs_destination_output_uri": _TEST_GCS_BUCKET_NAME, + "target_field_name": "predict_class", +} + + +_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON = json.dumps( + { + "pipelineInfo": {"name": "evaluation-default-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": { + "parameters": { + "batch_predict_gcs_source_uris": {"type": "STRING"}, + "dataflow_service_account": {"type": "STRING"}, + "batch_predict_instances_format": {"type": "STRING"}, + "batch_predict_machine_type": {"type": "STRING"}, + "location": {"type": "STRING"}, + "model_name": {"type": "STRING"}, + "prediction_type": {"type": "STRING"}, + "project": {"type": "STRING"}, + "batch_predict_gcs_destination_output_uri": {"type": "STRING"}, + "target_field_name": {"type": "STRING"}, + } + }, + }, + "schemaVersion": "2.0.0", + "sdkVersion": "kfp-1.8.12", + "components": {}, + } +) + +_TEST_MODEL_EVAL_PIPELINE_JOB = json.dumps( + { + "runtimeConfig": {"parameters": _TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES}, + "pipelineInfo": {"name": "evaluation-default-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": { + "parameters": { + "batch_predict_gcs_source_uris": {"type": "STRING"}, + "dataflow_service_account": {"type": "STRING"}, + "evaluation_class_labels": {"type": "STRING"}, + "batch_predict_instances_format": {"type": "STRING"}, + "batch_predict_machine_type": {"type": "STRING"}, + "location": {"type": "STRING"}, + "model_name": {"type": "STRING"}, + "prediction_type": {"type": "STRING"}, + "project": {"type": "STRING"}, + "batch_predict_gcs_destination_output_uri": {"type": "STRING"}, + "target_field_name": {"type": "STRING"}, + } + }, + }, + "schemaVersion": "2.0.0", + "sdkVersion": "kfp-1.8.12", + "components": {}, + } +) + +_TEST_MODEL_EVAL_PIPELINE_JOB_WITH_BQ_INPUT = json.dumps( + { + "runtimeConfig": {"parameters": _TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES}, + "pipelineInfo": {"name": "evaluation-default-pipeline"}, + "root": { + "dag": {"tasks": {}}, + "inputDefinitions": { + "parameters": { + "batch_predict_gcs_source_uris": {"type": "STRING"}, + "dataflow_service_account": {"type": "STRING"}, + "evaluation_class_labels": {"type": "STRING"}, + "batch_predict_instances_format": {"type": "STRING"}, + "batch_predict_predictions_format": {"type": "STRING"}, + "batch_predict_bigquery_source_uri": {"type": "STRING"}, + "batch_predict_bigquery_destination_output_uri": {"type": "STRING"}, + "batch_predict_machine_type": {"type": "STRING"}, + "location": {"type": "STRING"}, + "model_name": {"type": "STRING"}, + "prediction_type": {"type": "STRING"}, + "project": {"type": "STRING"}, + "batch_predict_gcs_destination_output_uri": {"type": "STRING"}, + "target_field_name": {"type": "STRING"}, + } + }, + }, + "schemaVersion": "2.0.0", + "sdkVersion": "kfp-1.8.12", + "components": {}, + } +) + _TEST_LOCAL_MODEL = LocalModel( serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, @@ -711,6 +832,18 @@ def mock_model_eval_get(): mock_get_model_eval.return_value = gca_model_evaluation.ModelEvaluation( name=_TEST_MODEL_EVAL_RESOURCE_NAME, metrics=_TEST_MODEL_EVAL_METRICS, + metadata={"pipeline_job_resource_name": _TEST_PIPELINE_JOB_NAME}, + ) + yield mock_get_model_eval + + +@pytest.fixture +def mock_get_model_evaluation(): + with mock.patch.object( + aiplatform.model_evaluation._ModelEvaluationJob, "get_model_evaluation" + ) as mock_get_model_eval: + mock_get_model_eval.return_value = aiplatform.ModelEvaluation( + evaluation_name=_TEST_MODEL_EVAL_RESOURCE_NAME ) yield mock_get_model_eval @@ -770,6 +903,201 @@ def merge_version_aliases_mock(): yield merge_version_aliases_mock +# model.evaluate fixtures +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_pipeline_job" + ) as mock_create_pipeline_job: + mock_create_pipeline_job.return_value = gca_pipeline_job.PipelineJob( + name=_TEST_PIPELINE_JOB_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + create_time=_TEST_PIPELINE_CREATE_TIME, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + ) + yield mock_create_pipeline_job + + +_TEST_COMPONENT_IDENTIFIER = "fpc-model-evaluation" +_TEST_BATCH_PREDICTION_JOB_ID = "614161631630327111" + +_TEST_BATCH_PREDICTION_RESOURCE_NAME = ( + job_service_client.JobServiceClient.batch_prediction_job_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_BATCH_PREDICTION_JOB_ID + ) +) + +_EVAL_GCP_RESOURCES_STR = ( + '{\n "resources": [\n {\n "resourceType": "ModelEvaluation",\n "resourceUri": "https://us-central1-aiplatform.googleapis.com/v1/' + + _TEST_MODEL_EVAL_RESOURCE_NAME + + '"\n }\n ]\n}' +) + +_BP_JOB_GCP_RESOURCES_STR = ( + '{\n "resources": [\n {\n "resourceType": "BatchPredictionJob",\n "resourceUri": "https://us-central1-aiplatform.googleapis.com/v1/' + + _TEST_BATCH_PREDICTION_RESOURCE_NAME + + '"\n }\n ]\n}' +) + +_TEST_PIPELINE_JOB_DETAIL_EVAL = { + "output:gcp_resources": _EVAL_GCP_RESOURCES_STR, + "component_type": _TEST_COMPONENT_IDENTIFIER, +} + +_TEST_PIPELINE_JOB_DETAIL_BP = { + "output:gcp_resources": _BP_JOB_GCP_RESOURCES_STR, +} + +_TEST_EVAL_METRICS_ARTIFACT_NAME = ( + "projects/123/locations/us-central1/metadataStores/default/artifacts/456" +) +_TEST_EVAL_METRICS_ARTIFACT_URI = "gs://test-bucket/eval_pipeline_root/123/evaluation-default-pipeline-20220615135923/model-evaluation-2_-789/evaluation_metrics" + + +# executions: this is used in test_list_pipeline_based_service +_TEST_EXECUTION_PARENT = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/metadataStores/default" +) + +_TEST_RUN = "run-1" +_TEST_EXPERIMENT = "test-experiment" +_TEST_EXECUTION_ID = f"{_TEST_EXPERIMENT}-{_TEST_RUN}" +_TEST_EXECUTION_NAME = f"{_TEST_EXECUTION_PARENT}/executions/{_TEST_EXECUTION_ID}" + + +def make_pipeline_job(state): + return gca_pipeline_job.PipelineJob( + name=_TEST_PIPELINE_JOB_NAME, + state=state, + create_time=_TEST_PIPELINE_CREATE_TIME, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + job_detail=gca_pipeline_job.PipelineJobDetail( + pipeline_run_context=gca_context.Context( + name=_TEST_PIPELINE_JOB_NAME, + ), + task_details=[ + gca_pipeline_job.PipelineTaskDetail( + task_id=123, + task_name=_TEST_PIPELINE_JOB_ID, + state=gca_pipeline_job.PipelineTaskDetail.State.SUCCEEDED, + execution={ + "metadata": struct_pb2.Struct( + fields={ + key: struct_pb2.Value(string_value=value) + for key, value in _TEST_PIPELINE_JOB_DETAIL_EVAL.items() + }, + ), + }, + ), + gca_pipeline_job.PipelineTaskDetail( + task_id=123, + execution=GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_RUN, + schema_title=metadata_constants.SYSTEM_RUN, + schema_version=metadata_constants.SCHEMA_VERSIONS[ + metadata_constants.SYSTEM_RUN + ], + metadata={"component_type": _TEST_COMPONENT_IDENTIFIER}, + ), + ), + ], + ), + ) + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_pipeline_job: + mock_get_pipeline_job.side_effect = [ + make_pipeline_job(gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + ] + + yield mock_get_pipeline_job + + +@pytest.fixture +def mock_successfully_completed_eval_job(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_pipeline_job" + ) as mock_get_model_eval_job: + mock_get_model_eval_job.return_value = make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + yield mock_get_model_eval_job + + +@pytest.fixture +def mock_pipeline_bucket_exists(): + def mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist( + output_artifacts_gcs_dir=None, + service_account=None, + project=None, + location=None, + credentials=None, + ): + output_artifacts_gcs_dir = ( + output_artifacts_gcs_dir + or gcs_utils.generate_gcs_directory_for_pipeline_artifacts( + project=project, + location=location, + ) + ) + return output_artifacts_gcs_dir + + with mock.patch( + "google.cloud.aiplatform.utils.gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist", + wraps=mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist, + ) as mock_context: + yield mock_context + + +@pytest.fixture +def mock_load_yaml_and_json(job_spec_json): + with patch.object(storage.Blob, "download_as_bytes") as mock_load_yaml_and_json: + mock_load_yaml_and_json.return_value = job_spec_json.encode() + yield mock_load_yaml_and_json + + +@pytest.fixture +def mock_request_urlopen(job_spec_json): + with mock.patch.object(request, "urlopen") as mock_urlopen: + mock_read_response = mock.MagicMock() + mock_decode_response = mock.MagicMock() + mock_decode_response.return_value = job_spec_json.encode() + mock_read_response.return_value.decode = mock_decode_response + mock_urlopen.return_value.read = mock_read_response + yield mock_urlopen + + @pytest.fixture def get_drp_mock(): with mock.patch.object( @@ -3081,3 +3409,184 @@ def test_raw_predict(self, raw_predict_mock): data=_TEST_RAW_PREDICT_DATA, headers=_TEST_RAW_PREDICT_HEADER, ) + + @pytest.mark.parametrize( + "job_spec_json", + [_TEST_MODEL_EVAL_PIPELINE_JOB], + ) + def test_model_evaluate_with_gcs_input_uris( + self, + get_model_mock, + mock_model_eval_get, + mock_get_model_evaluation, + list_model_evaluations_mock, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_successfully_completed_eval_job, + mock_pipeline_bucket_exists, + mock_load_yaml_and_json, + job_spec_json, + mock_request_urlopen, + ): + aiplatform.init(project=_TEST_PROJECT) + + test_model = models.Model(model_name=_TEST_MODEL_RESOURCE_NAME) + + eval_job = test_model.evaluate( + prediction_type="classification", + target_field_name="class", + class_labels=_TEST_MODEL_EVAL_CLASS_LABELS, + staging_bucket="gs://my-eval-staging-path", + gcs_source_uris=["gs://test-bucket/test-file.csv"], + ) + + assert isinstance(eval_job, model_evaluation_job._ModelEvaluationJob) + + assert mock_pipeline_service_create.called_once + + assert mock_pipeline_service_get.called_once + + eval_job.wait() + + eval_resource = eval_job.get_model_evaluation() + + assert isinstance(eval_resource, aiplatform.ModelEvaluation) + + assert eval_resource.metrics == _TEST_MODEL_EVAL_METRICS + + assert isinstance(eval_resource._backing_pipeline_job, aiplatform.PipelineJob) + + @pytest.mark.parametrize( + "job_spec_json", + [_TEST_MODEL_EVAL_PIPELINE_JOB_WITH_BQ_INPUT], + ) + def test_model_evaluate_with_bigquery_input( + self, + get_model_mock, + mock_model_eval_get, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_load_yaml_and_json, + mock_pipeline_bucket_exists, + job_spec_json, + mock_request_urlopen, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket="gs://my-bucket") + + test_model = models.Model(model_name=_TEST_MODEL_RESOURCE_NAME) + + eval_job = test_model.evaluate( + prediction_type="classification", + target_field_name="class", + class_labels=_TEST_MODEL_EVAL_CLASS_LABELS, + bigquery_source_uri=_TEST_BIGQUERY_EVAL_INPUT_URI, + bigquery_destination_output_uri=_TEST_BIGQUERY_EVAL_DESTINATION_URI, + ) + + assert isinstance(eval_job, model_evaluation_job._ModelEvaluationJob) + + assert mock_pipeline_service_create.called_once + + assert mock_pipeline_service_get.called_once + + @pytest.mark.parametrize( + "job_spec_json", + [_TEST_MODEL_EVAL_PIPELINE_JOB], + ) + def test_model_evaluate_using_initialized_staging_bucket( + self, + get_model_mock, + mock_model_eval_get, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_pipeline_bucket_exists, + mock_load_yaml_and_json, + job_spec_json, + mock_request_urlopen, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket="gs://my-bucket") + + test_model = models.Model(model_name=_TEST_MODEL_RESOURCE_NAME) + + eval_job = test_model.evaluate( + prediction_type="classification", + target_field_name="class", + class_labels=_TEST_MODEL_EVAL_CLASS_LABELS, + gcs_source_uris=["gs://test-bucket/test-file.csv"], + ) + + assert isinstance(eval_job, model_evaluation_job._ModelEvaluationJob) + + assert mock_pipeline_service_create.called_once + + assert mock_pipeline_service_get.called_once + + def test_model_evaluate_with_no_staging_path_or_initialized_staging_bucket_raises( + self, + get_model_mock, + mock_model_eval_get, + ): + + aiplatform.init(project=_TEST_PROJECT) + + test_model = models.Model(model_name=_TEST_MODEL_RESOURCE_NAME) + + with pytest.raises(ValueError): + test_model.evaluate( + prediction_type="classification", + target_field_name="class", + class_labels=_TEST_MODEL_EVAL_CLASS_LABELS, + gcs_source_uris=["gs://test-bucket/test-file.csv"], + ) + + def test_model_evaluate_with_invalid_prediction_type_raises( + self, + get_model_mock, + mock_model_eval_get, + ): + + aiplatform.init(project=_TEST_PROJECT) + + test_model = models.Model(model_name=_TEST_MODEL_RESOURCE_NAME) + + with pytest.raises(ValueError): + test_model.evaluate( + prediction_type="invalid_prediction_type", + target_field_name="class", + gcs_source_uris=["gs://test-bucket/test-file.csv"], + ) + + def test_model_evaluate_with_invalid_gcs_uri_raises( + self, + get_model_mock, + mock_model_eval_get, + ): + + aiplatform.init(project=_TEST_PROJECT) + + test_model = models.Model(model_name=_TEST_MODEL_RESOURCE_NAME) + + with pytest.raises(ValueError): + test_model.evaluate( + prediction_type="classification", + target_field_name="class", + gcs_source_uris=["storage.googleapis.com/test-bucket/test-file.csv"], + ) + + def test_model_evaluate_with_invalid_bq_uri_raises( + self, + get_model_mock, + mock_model_eval_get, + ): + + aiplatform.init(project=_TEST_PROJECT) + + test_model = models.Model(model_name=_TEST_MODEL_RESOURCE_NAME) + + with pytest.raises(ValueError): + test_model.evaluate( + prediction_type="classification", + target_field_name="class", + bigquery_source_uri="my-project.my-dataset.my-table", + bigquery_destination_output_uri="bq://my-project.my-dataset.my-table", + ) From be01f31d982fb45aa0baa77bbf680c6f4b7e943b Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Thu, 3 Aug 2023 14:19:46 -0700 Subject: [PATCH 2/8] chore: update to LVM MultiModalEmbedding request PiperOrigin-RevId: 553588294 --- vertexai/vision_models/_vision_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vertexai/vision_models/_vision_models.py b/vertexai/vision_models/_vision_models.py index a9a7150076..63732d3d09 100644 --- a/vertexai/vision_models/_vision_models.py +++ b/vertexai/vision_models/_vision_models.py @@ -255,7 +255,6 @@ def get_embeddings( instance = { "image": {"bytesBase64Encoded": image._as_base64_string()}, - "features": [{"type": "IMAGE_EMBEDDING"}], } if contextual_text: From ff475130d9457640b94a1d834cfa2e03fdd89c5a Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Thu, 3 Aug 2023 17:06:06 -0700 Subject: [PATCH 3/8] chore: LLM - Switched to a more robust way to get the tuned model resource name from the pipeline job PiperOrigin-RevId: 553633219 --- tests/unit/aiplatform/test_language_models.py | 10 ++++++---- vertexai/language_models/_language_models.py | 14 +++++++------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 7da36fbea3..63bfdcf718 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -360,12 +360,14 @@ def make_pipeline_job(state): task_details=[ gca_pipeline_job.PipelineTaskDetail( task_id=456, - task_name="upload-llm-model", + task_name="tune-large-model-20230724214903", execution=GapicExecution( - name="test-execution-name", - display_name="evaluation_metrics", + name="projects/123/locations/europe-west4/metadataStores/default/executions/...", + display_name="tune-large-model-20230724214903", + schema_title="system.Run", metadata={ - "output:model_resource_name": "projects/123/locations/us-central1/models/456" + "output:model_resource_name": "projects/123/locations/us-central1/models/456", + "output:endpoint_resource_name": "projects/123/locations/us-central1/endpoints/456", }, ), ), diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 378006a891..de76f99282 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -1037,19 +1037,19 @@ def result(self) -> "_LanguageModel": if self._model: return self._model self._job.wait() - upload_model_tasks = [ - task_info - for task_info in self._job.gca_resource.job_detail.task_details - if task_info.task_name == "upload-llm-model" + root_pipeline_tasks = [ + task_detail + for task_detail in self._job.gca_resource.job_detail.task_details + if task_detail.execution.schema_title == "system.Run" ] - if len(upload_model_tasks) != 1: + if len(root_pipeline_tasks) != 1: raise RuntimeError( f"Failed to get the model name from the tuning pipeline: {self._job.name}" ) - upload_model_task = upload_model_tasks[0] + root_pipeline_task = root_pipeline_tasks[0] # Trying to get model name from output parameter - vertex_model_name = upload_model_task.execution.metadata[ + vertex_model_name = root_pipeline_task.execution.metadata[ "output:model_resource_name" ].strip() _LOGGER.info(f"Tuning has completed. Created Vertex Model: {vertex_model_name}") From 38ec40a12cf863c9da3de8336dceba10d92f6f56 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Fri, 4 Aug 2023 08:07:16 -0700 Subject: [PATCH 4/8] feat: add support for providing only text to MultiModalEmbeddingModel.get_embeddings() PiperOrigin-RevId: 553809703 --- tests/unit/aiplatform/test_vision_models.py | 30 +++++++++++++++++++++ vertexai/vision_models/_vision_models.py | 20 ++++++++------ 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/tests/unit/aiplatform/test_vision_models.py b/tests/unit/aiplatform/test_vision_models.py index f42228e7e8..4f3c8e74e2 100644 --- a/tests/unit/aiplatform/test_vision_models.py +++ b/tests/unit/aiplatform/test_vision_models.py @@ -264,3 +264,33 @@ def test_image_embedding_model_with_image_and_text(self): assert embedding_response.image_embedding == test_embeddings assert embedding_response.text_embedding == test_embeddings + + def test_image_embedding_model_with_only_text(self): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT + ), + ): + model = vision_models.MultiModalEmbeddingModel.from_pretrained( + "multimodalembedding@001" + ) + + test_embeddings = [0, 0] + gca_predict_response = gca_prediction_service.PredictResponse() + gca_predict_response.predictions.append({"textEmbedding": test_embeddings}) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ): + embedding_response = model.get_embeddings(contextual_text="hello world") + + assert not embedding_response.image_embedding + assert embedding_response.text_embedding == test_embeddings diff --git a/vertexai/vision_models/_vision_models.py b/vertexai/vision_models/_vision_models.py index 63732d3d09..6b6ca8f695 100644 --- a/vertexai/vision_models/_vision_models.py +++ b/vertexai/vision_models/_vision_models.py @@ -234,28 +234,32 @@ class MultiModalEmbeddingModel(_model_garden_models._ModelGardenModel): ) def get_embeddings( - self, image: Image, contextual_text: Optional[str] = None + self, image: Optional[Image] = None, contextual_text: Optional[str] = None ) -> "MultiModalEmbeddingResponse": """Gets embedding vectors from the provided image. Args: image (Image): - The image to generate embeddings for. + Optional. The image to generate embeddings for. One of `image` or `contextual_text` is required. contextual_text (str): Optional. Contextual text for your input image. If provided, the model will also generate an embedding vector for the provided contextual text. The returned image and text embedding vectors are in the same semantic space with the same dimensionality, and the vectors can be used interchangeably for use cases like searching image by text - or searching text by image. + or searching text by image. One of `image` or `contextual_text` is required. Returns: ImageEmbeddingResponse: The image and text embedding vectors. """ - instance = { - "image": {"bytesBase64Encoded": image._as_base64_string()}, - } + if not image and not contextual_text: + raise ValueError("One of `image` or `contextual_text` is required.") + + instance = {} + + if image: + instance["image"] = {"bytesBase64Encoded": image._as_base64_string()} if contextual_text: instance["text"] = contextual_text @@ -280,11 +284,11 @@ class MultiModalEmbeddingResponse: Attributes: image_embedding (List[float]): - The emebedding vector generated from your image. + Optional. The embedding vector generated from your image. text_embedding (List[float]): Optional. The embedding vector generated from the contextual text provided for your image. """ - image_embedding: List[float] _prediction_response: Any + image_embedding: Optional[List[float]] = None text_embedding: Optional[List[float]] = None From dec8ffd80d8804d4a5afaccc1c748d886225b6bf Mon Sep 17 00:00:00 2001 From: Amy Wu Date: Fri, 4 Aug 2023 16:55:23 -0700 Subject: [PATCH 5/8] Copybara import of the project: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit -- 471b38b85ebc8a4682c822c6d28879d4b3386e3b by gcf-owl-bot[bot] <78513119+gcf-owl-bot[bot]@users.noreply.github.com>: feat: add RaySepc to ResourceRuntimeSpec, and add ResourceRuntime to PersistentResource (#2387) * feat: add `PredictionService.ServerStreamingPredict` method feat: add `StreamingPredictRequest` type feat: add `StreamingPredictResponse` type feat: add `Tensor` type PiperOrigin-RevId: 551672526 Source-Link: https://github.com/googleapis/googleapis/commit/1b650d6c6ee9e50fe122562550a47ec258151498 Source-Link: https://github.com/googleapis/googleapis-gen/commit/62fb73702732ce727fbd2e5560d309ece0609850 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiNjJmYjczNzAyNzMyY2U3MjdmYmQyZTU1NjBkMzA5ZWNlMDYwOTg1MCJ9 * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * feat: add RaySepc to ResourceRuntimeSpec, and add ResourceRuntime to PersistentResource PiperOrigin-RevId: 551874408 Source-Link: https://github.com/googleapis/googleapis/commit/4d230ddc6b6d28ad7d2503926a97c6f7fa15483d Source-Link: https://github.com/googleapis/googleapis-gen/commit/9e603a7080be78b1ec8a6cf4ce1d06a5259efc23 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiOWU2MDNhNzA4MGJlNzhiMWVjOGE2Y2Y0Y2UxZDA2YTUyNTllZmMyMyJ9 * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --------- Co-authored-by: Owl Bot Co-authored-by: Amy Wu -- 14efbdff7f4b518a72e2bb88f91c798d81397bdc by yinghsienwu : Resolved merge conflict by incorporating both suggestions. -- 69acc808cb429afc2c2a88a3889f6a0e61c2cce7 by yinghsienwu : Resolved merge conflict by incorporating both suggestions. -- 0df44fd152f7a193885c00e76d92449e5737ffb8 by yinghsienwu : Resolve coverage by reverting files. -- 7e903f94f39aeb9c16cbd94231aac04b05d0d07e by yinghsienwu : Fix migration test. -- be1625b6637f40a3e46d74e117ff6ff27ed96e97 by yinghsienwu : Fix lint -- 03e45ac1398e34013c95313ddb9ee79d2a1f3090 by yinghsienwu : Revert COPYBARA_INTEGRATE_REVIEW=https://github.com/googleapis/python-aiplatform/pull/2407 from googleapis:wuamy-patch 03e45ac1398e34013c95313ddb9ee79d2a1f3090 PiperOrigin-RevId: 553948974 --- docs/aiplatform_v1/schedule_service.rst | 10 + docs/aiplatform_v1/services.rst | 1 + .../definition_v1/types/automl_tables.py | 3 + .../definition_v1beta1/types/automl_tables.py | 3 + google/cloud/aiplatform_v1/__init__.py | 28 + .../cloud/aiplatform_v1/gapic_metadata.json | 94 + .../prediction_service/async_client.py | 92 + .../services/prediction_service/client.py | 94 + .../prediction_service/transports/base.py | 17 + .../prediction_service/transports/grpc.py | 30 + .../transports/grpc_asyncio.py | 30 + .../services/schedule_service/__init__.py | 22 + .../services/schedule_service/async_client.py | 1747 ++++++ .../services/schedule_service/client.py | 2102 +++++++ .../services/schedule_service/pagers.py | 156 + .../schedule_service/transports/__init__.py | 33 + .../schedule_service/transports/base.py | 350 ++ .../schedule_service/transports/grpc.py | 670 +++ .../transports/grpc_asyncio.py | 678 +++ google/cloud/aiplatform_v1/types/__init__.py | 37 + .../cloud/aiplatform_v1/types/explanation.py | 14 +- .../types/model_deployment_monitoring_job.py | 1 + .../aiplatform_v1/types/model_monitoring.py | 7 +- .../aiplatform_v1/types/prediction_service.py | 71 + google/cloud/aiplatform_v1/types/schedule.py | 265 + .../aiplatform_v1/types/schedule_service.py | 295 + google/cloud/aiplatform_v1/types/types.py | 153 + google/cloud/aiplatform_v1beta1/__init__.py | 10 + .../aiplatform_v1beta1/gapic_metadata.json | 10 + .../prediction_service/async_client.py | 93 + .../services/prediction_service/client.py | 93 + .../prediction_service/transports/base.py | 17 + .../prediction_service/transports/grpc.py | 30 + .../transports/grpc_asyncio.py | 30 + .../aiplatform_v1beta1/types/__init__.py | 10 + .../aiplatform_v1beta1/types/explanation.py | 14 +- .../types/model_deployment_monitoring_job.py | 1 + .../types/model_monitoring.py | 7 +- .../types/persistent_resource.py | 62 +- .../types/prediction_service.py | 71 + .../cloud/aiplatform_v1beta1/types/types.py | 153 + ..._service_server_streaming_predict_async.py | 53 + ...n_service_server_streaming_predict_sync.py | 53 + ..._schedule_service_create_schedule_async.py | 59 + ...d_schedule_service_create_schedule_sync.py | 59 + ..._schedule_service_delete_schedule_async.py | 56 + ...d_schedule_service_delete_schedule_sync.py | 56 + ...ted_schedule_service_get_schedule_async.py | 52 + ...ated_schedule_service_get_schedule_sync.py | 52 + ...d_schedule_service_list_schedules_async.py | 53 + ...ed_schedule_service_list_schedules_sync.py | 53 + ...d_schedule_service_pause_schedule_async.py | 50 + ...ed_schedule_service_pause_schedule_sync.py | 50 + ..._schedule_service_resume_schedule_async.py | 50 + ...d_schedule_service_resume_schedule_sync.py | 50 + ..._schedule_service_update_schedule_async.py | 58 + ...d_schedule_service_update_schedule_sync.py | 58 + ..._service_server_streaming_predict_async.py | 53 + ...n_service_server_streaming_predict_sync.py | 53 + ...t_metadata_google.cloud.aiplatform.v1.json | 1292 +++++ ...adata_google.cloud.aiplatform.v1beta1.json | 153 + .../aiplatform_v1/test_prediction_service.py | 161 + .../aiplatform_v1/test_schedule_service.py | 5067 +++++++++++++++++ .../test_prediction_service.py | 161 + 64 files changed, 15393 insertions(+), 13 deletions(-) create mode 100644 docs/aiplatform_v1/schedule_service.rst create mode 100644 google/cloud/aiplatform_v1/services/schedule_service/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/schedule_service/async_client.py create mode 100644 google/cloud/aiplatform_v1/services/schedule_service/client.py create mode 100644 google/cloud/aiplatform_v1/services/schedule_service/pagers.py create mode 100644 google/cloud/aiplatform_v1/services/schedule_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/schedule_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1/services/schedule_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1/services/schedule_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1/types/schedule.py create mode 100644 google/cloud/aiplatform_v1/types/schedule_service.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_prediction_service_server_streaming_predict_async.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_prediction_service_server_streaming_predict_sync.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_create_schedule_async.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_create_schedule_sync.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_delete_schedule_async.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_delete_schedule_sync.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_get_schedule_async.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_get_schedule_sync.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_list_schedules_async.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_list_schedules_sync.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_pause_schedule_async.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_pause_schedule_sync.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_resume_schedule_async.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_resume_schedule_sync.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_update_schedule_async.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_schedule_service_update_schedule_sync.py create mode 100644 samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_async.py create mode 100644 samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_sync.py create mode 100644 tests/unit/gapic/aiplatform_v1/test_schedule_service.py diff --git a/docs/aiplatform_v1/schedule_service.rst b/docs/aiplatform_v1/schedule_service.rst new file mode 100644 index 0000000000..227c5be458 --- /dev/null +++ b/docs/aiplatform_v1/schedule_service.rst @@ -0,0 +1,10 @@ +ScheduleService +--------------------------------- + +.. automodule:: google.cloud.aiplatform_v1.services.schedule_service + :members: + :inherited-members: + +.. automodule:: google.cloud.aiplatform_v1.services.schedule_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1/services.rst b/docs/aiplatform_v1/services.rst index fb40b751fe..93afd80841 100644 --- a/docs/aiplatform_v1/services.rst +++ b/docs/aiplatform_v1/services.rst @@ -17,6 +17,7 @@ Services for Google Cloud Aiplatform v1 API model_service pipeline_service prediction_service + schedule_service specialist_pool_service tensorboard_service vizier_service diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py index 76971673d2..5cbe201d37 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py @@ -110,6 +110,7 @@ class AutoMlTablesInputs(proto.Message): the prediction type. If the field is not set, a default objective function is used. classification (binary): + "maximize-au-roc" (default) - Maximize the area under the receiver operating characteristic (ROC) curve. @@ -122,9 +123,11 @@ class AutoMlTablesInputs(proto.Message): Maximize recall for a specified precision value. classification (multi-class): + "minimize-log-loss" (default) - Minimize log loss. regression: + "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE). "minimize-mae" - Minimize mean-absolute error (MAE). diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py index e69ff10b82..e7cb9d9bfa 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py @@ -110,6 +110,7 @@ class AutoMlTablesInputs(proto.Message): the prediction type. If the field is not set, a default objective function is used. classification (binary): + "maximize-au-roc" (default) - Maximize the area under the receiver operating characteristic (ROC) curve. @@ -122,9 +123,11 @@ class AutoMlTablesInputs(proto.Message): Maximize recall for a specified precision value. classification (multi-class): + "minimize-log-loss" (default) - Minimize log loss. regression: + "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE). "minimize-mae" - Minimize mean-absolute error (MAE). diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py index 172af5272c..8b53d0eabb 100644 --- a/google/cloud/aiplatform_v1/__init__.py +++ b/google/cloud/aiplatform_v1/__init__.py @@ -50,6 +50,8 @@ from .services.pipeline_service import PipelineServiceAsyncClient from .services.prediction_service import PredictionServiceClient from .services.prediction_service import PredictionServiceAsyncClient +from .services.schedule_service import ScheduleServiceClient +from .services.schedule_service import ScheduleServiceAsyncClient from .services.specialist_pool_service import SpecialistPoolServiceClient from .services.specialist_pool_service import SpecialistPoolServiceAsyncClient from .services.tensorboard_service import TensorboardServiceClient @@ -463,8 +465,19 @@ from .types.prediction_service import PredictRequest from .types.prediction_service import PredictResponse from .types.prediction_service import RawPredictRequest +from .types.prediction_service import StreamingPredictRequest +from .types.prediction_service import StreamingPredictResponse from .types.publisher_model import PublisherModel from .types.saved_query import SavedQuery +from .types.schedule import Schedule +from .types.schedule_service import CreateScheduleRequest +from .types.schedule_service import DeleteScheduleRequest +from .types.schedule_service import GetScheduleRequest +from .types.schedule_service import ListSchedulesRequest +from .types.schedule_service import ListSchedulesResponse +from .types.schedule_service import PauseScheduleRequest +from .types.schedule_service import ResumeScheduleRequest +from .types.schedule_service import UpdateScheduleRequest from .types.service_networking import PrivateServiceConnectConfig from .types.specialist_pool import SpecialistPool from .types.specialist_pool_service import CreateSpecialistPoolOperationMetadata @@ -544,6 +557,7 @@ from .types.types import DoubleArray from .types.types import Int64Array from .types.types import StringArray +from .types.types import Tensor from .types.unmanaged_container_model import UnmanagedContainerModel from .types.user_action_reference import UserActionReference from .types.value import Value @@ -585,6 +599,7 @@ "ModelServiceAsyncClient", "PipelineServiceAsyncClient", "PredictionServiceAsyncClient", + "ScheduleServiceAsyncClient", "SpecialistPoolServiceAsyncClient", "TensorboardServiceAsyncClient", "VizierServiceAsyncClient", @@ -674,6 +689,7 @@ "CreateModelDeploymentMonitoringJobRequest", "CreateNasJobRequest", "CreatePipelineJobRequest", + "CreateScheduleRequest", "CreateSpecialistPoolOperationMetadata", "CreateSpecialistPoolRequest", "CreateStudyRequest", @@ -720,6 +736,7 @@ "DeleteOperationMetadata", "DeletePipelineJobRequest", "DeleteSavedQueryRequest", + "DeleteScheduleRequest", "DeleteSpecialistPoolRequest", "DeleteStudyRequest", "DeleteTensorboardExperimentRequest", @@ -820,6 +837,7 @@ "GetNasTrialDetailRequest", "GetPipelineJobRequest", "GetPublisherModelRequest", + "GetScheduleRequest", "GetSpecialistPoolRequest", "GetStudyRequest", "GetTensorboardExperimentRequest", @@ -908,6 +926,8 @@ "ListPipelineJobsResponse", "ListSavedQueriesRequest", "ListSavedQueriesResponse", + "ListSchedulesRequest", + "ListSchedulesResponse", "ListSpecialistPoolsRequest", "ListSpecialistPoolsResponse", "ListStudiesRequest", @@ -968,6 +988,7 @@ "Neighbor", "NfsMount", "PauseModelDeploymentMonitoringJobRequest", + "PauseScheduleRequest", "PipelineFailurePolicy", "PipelineJob", "PipelineJobDetail", @@ -1018,11 +1039,14 @@ "RemoveDatapointsResponse", "ResourcesConsumed", "ResumeModelDeploymentMonitoringJobRequest", + "ResumeScheduleRequest", "SampleConfig", "SampledShapleyAttribution", "SamplingStrategy", "SavedQuery", "Scalar", + "Schedule", + "ScheduleServiceClient", "Scheduling", "SearchDataItemsRequest", "SearchDataItemsResponse", @@ -1037,6 +1061,8 @@ "SpecialistPoolServiceClient", "StopTrialRequest", "StratifiedSplit", + "StreamingPredictRequest", + "StreamingPredictResponse", "StreamingReadFeatureValuesRequest", "StringArray", "Study", @@ -1045,6 +1071,7 @@ "SuggestTrialsRequest", "SuggestTrialsResponse", "TFRecordDestination", + "Tensor", "Tensorboard", "TensorboardBlob", "TensorboardBlobSequence", @@ -1085,6 +1112,7 @@ "UpdateModelDeploymentMonitoringJobOperationMetadata", "UpdateModelDeploymentMonitoringJobRequest", "UpdateModelRequest", + "UpdateScheduleRequest", "UpdateSpecialistPoolOperationMetadata", "UpdateSpecialistPoolRequest", "UpdateTensorboardExperimentRequest", diff --git a/google/cloud/aiplatform_v1/gapic_metadata.json b/google/cloud/aiplatform_v1/gapic_metadata.json index 550f4836a6..c100fe1214 100644 --- a/google/cloud/aiplatform_v1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1/gapic_metadata.json @@ -1806,6 +1806,11 @@ "methods": [ "raw_predict" ] + }, + "ServerStreamingPredict": { + "methods": [ + "server_streaming_predict" + ] } } }, @@ -1826,6 +1831,95 @@ "methods": [ "raw_predict" ] + }, + "ServerStreamingPredict": { + "methods": [ + "server_streaming_predict" + ] + } + } + } + } + }, + "ScheduleService": { + "clients": { + "grpc": { + "libraryClient": "ScheduleServiceClient", + "rpcs": { + "CreateSchedule": { + "methods": [ + "create_schedule" + ] + }, + "DeleteSchedule": { + "methods": [ + "delete_schedule" + ] + }, + "GetSchedule": { + "methods": [ + "get_schedule" + ] + }, + "ListSchedules": { + "methods": [ + "list_schedules" + ] + }, + "PauseSchedule": { + "methods": [ + "pause_schedule" + ] + }, + "ResumeSchedule": { + "methods": [ + "resume_schedule" + ] + }, + "UpdateSchedule": { + "methods": [ + "update_schedule" + ] + } + } + }, + "grpc-async": { + "libraryClient": "ScheduleServiceAsyncClient", + "rpcs": { + "CreateSchedule": { + "methods": [ + "create_schedule" + ] + }, + "DeleteSchedule": { + "methods": [ + "delete_schedule" + ] + }, + "GetSchedule": { + "methods": [ + "get_schedule" + ] + }, + "ListSchedules": { + "methods": [ + "list_schedules" + ] + }, + "PauseSchedule": { + "methods": [ + "pause_schedule" + ] + }, + "ResumeSchedule": { + "methods": [ + "resume_schedule" + ] + }, + "UpdateSchedule": { + "methods": [ + "update_schedule" + ] } } } diff --git a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py index 271505cf2c..4f870e77e5 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py @@ -22,6 +22,8 @@ MutableMapping, MutableSequence, Optional, + AsyncIterable, + Awaitable, Sequence, Tuple, Type, @@ -45,6 +47,7 @@ from google.api import httpbody_pb2 # type: ignore from google.cloud.aiplatform_v1.types import explanation from google.cloud.aiplatform_v1.types import prediction_service +from google.cloud.aiplatform_v1.types import types from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -548,6 +551,95 @@ async def sample_raw_predict(): # Done; return the response. return response + def server_streaming_predict( + self, + request: Optional[ + Union[prediction_service.StreamingPredictRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Awaitable[AsyncIterable[prediction_service.StreamingPredictResponse]]: + r"""Perform a server-side streaming online prediction + request for Vertex LLM streaming. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_server_streaming_predict(): + # Create a client + client = aiplatform_v1.PredictionServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.StreamingPredictRequest( + endpoint="endpoint_value", + ) + + # Make the request + stream = await client.server_streaming_predict(request=request) + + # Handle the response + async for response in stream: + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.StreamingPredictRequest, dict]]): + The request object. Request message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1.PredictionService.StreamingPredict]. + + The first message must contain + [endpoint][google.cloud.aiplatform.v1.StreamingPredictRequest.endpoint] + field and optionally [input][]. The subsequent messages + must contain [input][]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + AsyncIterable[google.cloud.aiplatform_v1.types.StreamingPredictResponse]: + Response message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1.PredictionService.StreamingPredict]. + + """ + # Create or coerce a protobuf request object. + request = prediction_service.StreamingPredictRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.server_streaming_predict, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + async def explain( self, request: Optional[Union[prediction_service.ExplainRequest, dict]] = None, diff --git a/google/cloud/aiplatform_v1/services/prediction_service/client.py b/google/cloud/aiplatform_v1/services/prediction_service/client.py index a61f34785e..505760b293 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/client.py @@ -16,12 +16,15 @@ from collections import OrderedDict import os import re + +import pkg_resources from typing import ( Dict, Mapping, MutableMapping, MutableSequence, Optional, + Iterable, Sequence, Tuple, Type, @@ -49,6 +52,7 @@ from google.api import httpbody_pb2 # type: ignore from google.cloud.aiplatform_v1.types import explanation from google.cloud.aiplatform_v1.types import prediction_service +from google.cloud.aiplatform_v1.types import types from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -795,6 +799,96 @@ def sample_raw_predict(): # Done; return the response. return response + def server_streaming_predict( + self, + request: Optional[ + Union[prediction_service.StreamingPredictRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Iterable[prediction_service.StreamingPredictResponse]: + r"""Perform a server-side streaming online prediction + request for Vertex LLM streaming. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_server_streaming_predict(): + # Create a client + client = aiplatform_v1.PredictionServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.StreamingPredictRequest( + endpoint="endpoint_value", + ) + + # Make the request + stream = client.server_streaming_predict(request=request) + + # Handle the response + for response in stream: + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1.types.StreamingPredictRequest, dict]): + The request object. Request message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1.PredictionService.StreamingPredict]. + + The first message must contain + [endpoint][google.cloud.aiplatform.v1.StreamingPredictRequest.endpoint] + field and optionally [input][]. The subsequent messages + must contain [input][]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + Iterable[google.cloud.aiplatform_v1.types.StreamingPredictResponse]: + Response message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1.PredictionService.StreamingPredict]. + + """ + # Create or coerce a protobuf request object. + # Minor optimization to avoid making a copy if the user passes + # in a prediction_service.StreamingPredictRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, prediction_service.StreamingPredictRequest): + request = prediction_service.StreamingPredictRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.server_streaming_predict] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def explain( self, request: Optional[Union[prediction_service.ExplainRequest, dict]] = None, diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py index 8de37beab6..ef9167bb20 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py @@ -138,6 +138,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.server_streaming_predict: gapic_v1.method.wrap_method( + self.server_streaming_predict, + default_timeout=None, + client_info=client_info, + ), self.explain: gapic_v1.method.wrap_method( self.explain, default_timeout=None, @@ -175,6 +180,18 @@ def raw_predict( ]: raise NotImplementedError() + @property + def server_streaming_predict( + self, + ) -> Callable[ + [prediction_service.StreamingPredictRequest], + Union[ + prediction_service.StreamingPredictResponse, + Awaitable[prediction_service.StreamingPredictResponse], + ], + ]: + raise NotImplementedError() + @property def explain( self, diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py index 31fcac6fad..d109c44abb 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py @@ -298,6 +298,36 @@ def raw_predict( ) return self._stubs["raw_predict"] + @property + def server_streaming_predict( + self, + ) -> Callable[ + [prediction_service.StreamingPredictRequest], + prediction_service.StreamingPredictResponse, + ]: + r"""Return a callable for the server streaming predict method over gRPC. + + Perform a server-side streaming online prediction + request for Vertex LLM streaming. + + Returns: + Callable[[~.StreamingPredictRequest], + ~.StreamingPredictResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "server_streaming_predict" not in self._stubs: + self._stubs["server_streaming_predict"] = self.grpc_channel.unary_stream( + "/google.cloud.aiplatform.v1.PredictionService/ServerStreamingPredict", + request_serializer=prediction_service.StreamingPredictRequest.serialize, + response_deserializer=prediction_service.StreamingPredictResponse.deserialize, + ) + return self._stubs["server_streaming_predict"] + @property def explain( self, diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py index 964652dd98..9affc60c31 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py @@ -304,6 +304,36 @@ def raw_predict( ) return self._stubs["raw_predict"] + @property + def server_streaming_predict( + self, + ) -> Callable[ + [prediction_service.StreamingPredictRequest], + Awaitable[prediction_service.StreamingPredictResponse], + ]: + r"""Return a callable for the server streaming predict method over gRPC. + + Perform a server-side streaming online prediction + request for Vertex LLM streaming. + + Returns: + Callable[[~.StreamingPredictRequest], + Awaitable[~.StreamingPredictResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "server_streaming_predict" not in self._stubs: + self._stubs["server_streaming_predict"] = self.grpc_channel.unary_stream( + "/google.cloud.aiplatform.v1.PredictionService/ServerStreamingPredict", + request_serializer=prediction_service.StreamingPredictRequest.serialize, + response_deserializer=prediction_service.StreamingPredictResponse.deserialize, + ) + return self._stubs["server_streaming_predict"] + @property def explain( self, diff --git a/google/cloud/aiplatform_v1/services/schedule_service/__init__.py b/google/cloud/aiplatform_v1/services/schedule_service/__init__.py new file mode 100644 index 0000000000..40f84efec0 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/schedule_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .client import ScheduleServiceClient +from .async_client import ScheduleServiceAsyncClient + +__all__ = ( + "ScheduleServiceClient", + "ScheduleServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1/services/schedule_service/async_client.py b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py new file mode 100644 index 0000000000..bed5bc7db6 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py @@ -0,0 +1,1747 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import functools +import re +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.cloud.aiplatform_v1 import gapic_version as package_version + +from google.api_core.client_options import ClientOptions +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.api_core import operation as gac_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.schedule_service import pagers +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import pipeline_service +from google.cloud.aiplatform_v1.types import schedule +from google.cloud.aiplatform_v1.types import schedule as gca_schedule +from google.cloud.aiplatform_v1.types import schedule_service +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from .transports.base import ScheduleServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import ScheduleServiceGrpcAsyncIOTransport +from .client import ScheduleServiceClient + + +class ScheduleServiceAsyncClient: + """A service for creating and managing Vertex AI's Schedule + resources to periodically launch shceudled runs to make API + calls. + """ + + _client: ScheduleServiceClient + + DEFAULT_ENDPOINT = ScheduleServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = ScheduleServiceClient.DEFAULT_MTLS_ENDPOINT + + artifact_path = staticmethod(ScheduleServiceClient.artifact_path) + parse_artifact_path = staticmethod(ScheduleServiceClient.parse_artifact_path) + context_path = staticmethod(ScheduleServiceClient.context_path) + parse_context_path = staticmethod(ScheduleServiceClient.parse_context_path) + custom_job_path = staticmethod(ScheduleServiceClient.custom_job_path) + parse_custom_job_path = staticmethod(ScheduleServiceClient.parse_custom_job_path) + execution_path = staticmethod(ScheduleServiceClient.execution_path) + parse_execution_path = staticmethod(ScheduleServiceClient.parse_execution_path) + network_path = staticmethod(ScheduleServiceClient.network_path) + parse_network_path = staticmethod(ScheduleServiceClient.parse_network_path) + pipeline_job_path = staticmethod(ScheduleServiceClient.pipeline_job_path) + parse_pipeline_job_path = staticmethod( + ScheduleServiceClient.parse_pipeline_job_path + ) + schedule_path = staticmethod(ScheduleServiceClient.schedule_path) + parse_schedule_path = staticmethod(ScheduleServiceClient.parse_schedule_path) + common_billing_account_path = staticmethod( + ScheduleServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + ScheduleServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(ScheduleServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + ScheduleServiceClient.parse_common_folder_path + ) + common_organization_path = staticmethod( + ScheduleServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + ScheduleServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(ScheduleServiceClient.common_project_path) + parse_common_project_path = staticmethod( + ScheduleServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(ScheduleServiceClient.common_location_path) + parse_common_location_path = staticmethod( + ScheduleServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ScheduleServiceAsyncClient: The constructed client. + """ + return ScheduleServiceClient.from_service_account_info.__func__(ScheduleServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ScheduleServiceAsyncClient: The constructed client. + """ + return ScheduleServiceClient.from_service_account_file.__func__(ScheduleServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return ScheduleServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> ScheduleServiceTransport: + """Returns the transport used by the client instance. + + Returns: + ScheduleServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(ScheduleServiceClient).get_transport_class, type(ScheduleServiceClient) + ) + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, ScheduleServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the schedule service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.ScheduleServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = ScheduleServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_schedule( + self, + request: Optional[Union[schedule_service.CreateScheduleRequest, dict]] = None, + *, + parent: Optional[str] = None, + schedule: Optional[gca_schedule.Schedule] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_schedule.Schedule: + r"""Creates a Schedule. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_create_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + schedule = aiplatform_v1.Schedule() + schedule.cron = "cron_value" + schedule.create_pipeline_job_request.parent = "parent_value" + schedule.display_name = "display_name_value" + schedule.max_concurrent_run_count = 2596 + + request = aiplatform_v1.CreateScheduleRequest( + parent="parent_value", + schedule=schedule, + ) + + # Make the request + response = await client.create_schedule(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.CreateScheduleRequest, dict]]): + The request object. Request message for + [ScheduleService.CreateSchedule][google.cloud.aiplatform.v1.ScheduleService.CreateSchedule]. + parent (:class:`str`): + Required. The resource name of the Location to create + the Schedule in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + schedule (:class:`google.cloud.aiplatform_v1.types.Schedule`): + Required. The Schedule to create. + This corresponds to the ``schedule`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Schedule: + An instance of a Schedule + periodically schedules runs to make API + calls based on user specified time + specification and API request type. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, schedule]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = schedule_service.CreateScheduleRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if schedule is not None: + request.schedule = schedule + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_schedule, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_schedule( + self, + request: Optional[Union[schedule_service.DeleteScheduleRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a Schedule. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_delete_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.DeleteScheduleRequest( + name="name_value", + ) + + # Make the request + operation = client.delete_schedule(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.DeleteScheduleRequest, dict]]): + The request object. Request message for + [ScheduleService.DeleteSchedule][google.cloud.aiplatform.v1.ScheduleService.DeleteSchedule]. + name (:class:`str`): + Required. The name of the Schedule resource to be + deleted. Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = schedule_service.DeleteScheduleRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_schedule, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty_pb2.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_schedule( + self, + request: Optional[Union[schedule_service.GetScheduleRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> schedule.Schedule: + r"""Gets a Schedule. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_get_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.GetScheduleRequest( + name="name_value", + ) + + # Make the request + response = await client.get_schedule(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.GetScheduleRequest, dict]]): + The request object. Request message for + [ScheduleService.GetSchedule][google.cloud.aiplatform.v1.ScheduleService.GetSchedule]. + name (:class:`str`): + Required. The name of the Schedule resource. Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Schedule: + An instance of a Schedule + periodically schedules runs to make API + calls based on user specified time + specification and API request type. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = schedule_service.GetScheduleRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_schedule, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_schedules( + self, + request: Optional[Union[schedule_service.ListSchedulesRequest, dict]] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSchedulesAsyncPager: + r"""Lists Schedules in a Location. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_list_schedules(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.ListSchedulesRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_schedules(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.ListSchedulesRequest, dict]]): + The request object. Request message for + [ScheduleService.ListSchedules][google.cloud.aiplatform.v1.ScheduleService.ListSchedules]. + parent (:class:`str`): + Required. The resource name of the Location to list the + Schedules from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.schedule_service.pagers.ListSchedulesAsyncPager: + Response message for + [ScheduleService.ListSchedules][google.cloud.aiplatform.v1.ScheduleService.ListSchedules] + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = schedule_service.ListSchedulesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_schedules, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListSchedulesAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def pause_schedule( + self, + request: Optional[Union[schedule_service.PauseScheduleRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Pauses a Schedule. Will mark + [Schedule.state][google.cloud.aiplatform.v1.Schedule.state] to + 'PAUSED'. If the schedule is paused, no new runs will be + created. Already created runs will NOT be paused or canceled. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_pause_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.PauseScheduleRequest( + name="name_value", + ) + + # Make the request + await client.pause_schedule(request=request) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.PauseScheduleRequest, dict]]): + The request object. Request message for + [ScheduleService.PauseSchedule][google.cloud.aiplatform.v1.ScheduleService.PauseSchedule]. + name (:class:`str`): + Required. The name of the Schedule resource to be + paused. Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = schedule_service.PauseScheduleRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.pause_schedule, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def resume_schedule( + self, + request: Optional[Union[schedule_service.ResumeScheduleRequest, dict]] = None, + *, + name: Optional[str] = None, + catch_up: Optional[bool] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Resumes a paused Schedule to start scheduling new runs. Will + mark [Schedule.state][google.cloud.aiplatform.v1.Schedule.state] + to 'ACTIVE'. Only paused Schedule can be resumed. + + When the Schedule is resumed, new runs will be scheduled + starting from the next execution time after the current time + based on the time_specification in the Schedule. If + [Schedule.catchUp][] is set up true, all missed runs will be + scheduled for backfill first. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_resume_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.ResumeScheduleRequest( + name="name_value", + ) + + # Make the request + await client.resume_schedule(request=request) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.ResumeScheduleRequest, dict]]): + The request object. Request message for + [ScheduleService.ResumeSchedule][google.cloud.aiplatform.v1.ScheduleService.ResumeSchedule]. + name (:class:`str`): + Required. The name of the Schedule resource to be + resumed. Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + catch_up (:class:`bool`): + Optional. Whether to backfill missed runs when the + schedule is resumed from PAUSED state. If set to true, + all missed runs will be scheduled. New runs will be + scheduled after the backfill is complete. This will also + update + [Schedule.catch_up][google.cloud.aiplatform.v1.Schedule.catch_up] + field. Default to false. + + This corresponds to the ``catch_up`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name, catch_up]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = schedule_service.ResumeScheduleRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + if catch_up is not None: + request.catch_up = catch_up + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.resume_schedule, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def update_schedule( + self, + request: Optional[Union[schedule_service.UpdateScheduleRequest, dict]] = None, + *, + schedule: Optional[gca_schedule.Schedule] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_schedule.Schedule: + r"""Updates an active or paused Schedule. + + When the Schedule is updated, new runs will be scheduled + starting from the updated next execution time after the update + time based on the time_specification in the updated Schedule. + All unstarted runs before the update time will be skipped while + already created runs will NOT be paused or canceled. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_update_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + schedule = aiplatform_v1.Schedule() + schedule.cron = "cron_value" + schedule.create_pipeline_job_request.parent = "parent_value" + schedule.display_name = "display_name_value" + schedule.max_concurrent_run_count = 2596 + + request = aiplatform_v1.UpdateScheduleRequest( + schedule=schedule, + ) + + # Make the request + response = await client.update_schedule(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.UpdateScheduleRequest, dict]]): + The request object. Request message for + [ScheduleService.UpdateSchedule][google.cloud.aiplatform.v1.ScheduleService.UpdateSchedule]. + schedule (:class:`google.cloud.aiplatform_v1.types.Schedule`): + Required. The Schedule which replaces the resource on + the server. The following restrictions will be applied: + + - The scheduled request type cannot be changed. + - The output_only fields will be ignored if specified. + + This corresponds to the ``schedule`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The update mask applies to the resource. See + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Schedule: + An instance of a Schedule + periodically schedules runs to make API + calls based on user specified time + specification and API request type. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([schedule, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = schedule_service.UpdateScheduleRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if schedule is not None: + request.schedule = schedule + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_schedule, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("schedule.name", request.schedule.name),) + ), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.list_operations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.get_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_operation( + self, + request: Optional[operations_pb2.DeleteOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a long-running operation. + + This method indicates that the client is no longer interested + in the operation result. It does not cancel the operation. + If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.DeleteOperationRequest`): + The request object. Request message for + `DeleteOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.DeleteOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.delete_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def cancel_operation( + self, + request: Optional[operations_pb2.CancelOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Starts asynchronous cancellation on a long-running operation. + + The server makes a best effort to cancel the operation, but success + is not guaranteed. If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.CancelOperationRequest`): + The request object. Request message for + `CancelOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.CancelOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.cancel_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def wait_operation( + self, + request: Optional[operations_pb2.WaitOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Waits until the specified long-running operation is done or reaches at most + a specified timeout, returning the latest state. + + If the operation is already done, the latest state is immediately returned. + If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC + timeout is used. If the server does not support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.WaitOperationRequest`): + The request object. Request message for + `WaitOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.WaitOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.wait_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def set_iam_policy( + self, + request: Optional[iam_policy_pb2.SetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Sets the IAM access control policy on the specified function. + + Replaces any existing policy. + + Args: + request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`): + The request object. Request message for `SetIamPolicy` + method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.SetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.set_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_iam_policy( + self, + request: Optional[iam_policy_pb2.GetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Gets the IAM access control policy for a function. + + Returns an empty policy if the function exists and does not have a + policy set. + + Args: + request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`): + The request object. Request message for `GetIamPolicy` + method. + retry (google.api_core.retry.Retry): Designation of what errors, if + any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.GetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.get_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def test_iam_permissions( + self, + request: Optional[iam_policy_pb2.TestIamPermissionsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> iam_policy_pb2.TestIamPermissionsResponse: + r"""Tests the specified IAM permissions against the IAM access control + policy for a function. + + If the function does not exist, this will return an empty set + of permissions, not a NOT_FOUND error. + + Args: + request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`): + The request object. Request message for + `TestIamPermissions` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.iam_policy_pb2.TestIamPermissionsResponse: + Response message for ``TestIamPermissions`` method. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.TestIamPermissionsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.test_iam_permissions, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_location( + self, + request: Optional[locations_pb2.GetLocationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.Location: + r"""Gets information about a location. + + Args: + request (:class:`~.location_pb2.GetLocationRequest`): + The request object. Request message for + `GetLocation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.Location: + Location object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.GetLocationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.get_location, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_locations( + self, + request: Optional[locations_pb2.ListLocationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.ListLocationsResponse: + r"""Lists information about the supported locations for this service. + + Args: + request (:class:`~.location_pb2.ListLocationsRequest`): + The request object. Request message for + `ListLocations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.ListLocationsResponse: + Response message for ``ListLocations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.ListLocationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._client._transport.list_locations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "ScheduleServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("ScheduleServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/schedule_service/client.py b/google/cloud/aiplatform_v1/services/schedule_service/client.py new file mode 100644 index 0000000000..d6249545b2 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/schedule_service/client.py @@ -0,0 +1,2102 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import os +import re +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +from google.cloud.aiplatform_v1 import gapic_version as package_version + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.api_core import operation as gac_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.schedule_service import pagers +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import pipeline_service +from google.cloud.aiplatform_v1.types import schedule +from google.cloud.aiplatform_v1.types import schedule as gca_schedule +from google.cloud.aiplatform_v1.types import schedule_service +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from .transports.base import ScheduleServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import ScheduleServiceGrpcTransport +from .transports.grpc_asyncio import ScheduleServiceGrpcAsyncIOTransport + + +class ScheduleServiceClientMeta(type): + """Metaclass for the ScheduleService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[ScheduleServiceTransport]] + _transport_registry["grpc"] = ScheduleServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ScheduleServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[ScheduleServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class ScheduleServiceClient(metaclass=ScheduleServiceClientMeta): + """A service for creating and managing Vertex AI's Schedule + resources to periodically launch shceudled runs to make API + calls. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ScheduleServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ScheduleServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> ScheduleServiceTransport: + """Returns the transport used by the client instance. + + Returns: + ScheduleServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def artifact_path( + project: str, + location: str, + metadata_store: str, + artifact: str, + ) -> str: + """Returns a fully-qualified artifact string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format( + project=project, + location=location, + metadata_store=metadata_store, + artifact=artifact, + ) + + @staticmethod + def parse_artifact_path(path: str) -> Dict[str, str]: + """Parses a artifact path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/artifacts/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def context_path( + project: str, + location: str, + metadata_store: str, + context: str, + ) -> str: + """Returns a fully-qualified context string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) + + @staticmethod + def parse_context_path(path: str) -> Dict[str, str]: + """Parses a context path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/contexts/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def custom_job_path( + project: str, + location: str, + custom_job: str, + ) -> str: + """Returns a fully-qualified custom_job string.""" + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, + location=location, + custom_job=custom_job, + ) + + @staticmethod + def parse_custom_job_path(path: str) -> Dict[str, str]: + """Parses a custom_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def execution_path( + project: str, + location: str, + metadata_store: str, + execution: str, + ) -> str: + """Returns a fully-qualified execution string.""" + return "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format( + project=project, + location=location, + metadata_store=metadata_store, + execution=execution, + ) + + @staticmethod + def parse_execution_path(path: str) -> Dict[str, str]: + """Parses a execution path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/metadataStores/(?P.+?)/executions/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def network_path( + project: str, + network: str, + ) -> str: + """Returns a fully-qualified network string.""" + return "projects/{project}/global/networks/{network}".format( + project=project, + network=network, + ) + + @staticmethod + def parse_network_path(path: str) -> Dict[str, str]: + """Parses a network path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/global/networks/(?P.+?)$", path + ) + return m.groupdict() if m else {} + + @staticmethod + def pipeline_job_path( + project: str, + location: str, + pipeline_job: str, + ) -> str: + """Returns a fully-qualified pipeline_job string.""" + return "projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}".format( + project=project, + location=location, + pipeline_job=pipeline_job, + ) + + @staticmethod + def parse_pipeline_job_path(path: str) -> Dict[str, str]: + """Parses a pipeline_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/pipelineJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def schedule_path( + project: str, + location: str, + schedule: str, + ) -> str: + """Returns a fully-qualified schedule string.""" + return "projects/{project}/locations/{location}/schedules/{schedule}".format( + project=project, + location=location, + schedule=schedule, + ) + + @staticmethod + def parse_schedule_path(path: str) -> Dict[str, str]: + """Parses a schedule path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/schedules/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, ScheduleServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the schedule service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ScheduleServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + client_options = cast(client_options_lib.ClientOptions, client_options) + + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options + ) + + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, ScheduleServiceTransport): + # transport is a ScheduleServiceTransport instance. + if credentials or client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = transport + else: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=client_options.api_audience, + ) + + def create_schedule( + self, + request: Optional[Union[schedule_service.CreateScheduleRequest, dict]] = None, + *, + parent: Optional[str] = None, + schedule: Optional[gca_schedule.Schedule] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_schedule.Schedule: + r"""Creates a Schedule. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_create_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + schedule = aiplatform_v1.Schedule() + schedule.cron = "cron_value" + schedule.create_pipeline_job_request.parent = "parent_value" + schedule.display_name = "display_name_value" + schedule.max_concurrent_run_count = 2596 + + request = aiplatform_v1.CreateScheduleRequest( + parent="parent_value", + schedule=schedule, + ) + + # Make the request + response = client.create_schedule(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1.types.CreateScheduleRequest, dict]): + The request object. Request message for + [ScheduleService.CreateSchedule][google.cloud.aiplatform.v1.ScheduleService.CreateSchedule]. + parent (str): + Required. The resource name of the Location to create + the Schedule in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + schedule (google.cloud.aiplatform_v1.types.Schedule): + Required. The Schedule to create. + This corresponds to the ``schedule`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Schedule: + An instance of a Schedule + periodically schedules runs to make API + calls based on user specified time + specification and API request type. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, schedule]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a schedule_service.CreateScheduleRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, schedule_service.CreateScheduleRequest): + request = schedule_service.CreateScheduleRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if schedule is not None: + request.schedule = schedule + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_schedule] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_schedule( + self, + request: Optional[Union[schedule_service.DeleteScheduleRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes a Schedule. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_delete_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.DeleteScheduleRequest( + name="name_value", + ) + + # Make the request + operation = client.delete_schedule(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1.types.DeleteScheduleRequest, dict]): + The request object. Request message for + [ScheduleService.DeleteSchedule][google.cloud.aiplatform.v1.ScheduleService.DeleteSchedule]. + name (str): + Required. The name of the Schedule resource to be + deleted. Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a schedule_service.DeleteScheduleRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, schedule_service.DeleteScheduleRequest): + request = schedule_service.DeleteScheduleRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_schedule] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty_pb2.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def get_schedule( + self, + request: Optional[Union[schedule_service.GetScheduleRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> schedule.Schedule: + r"""Gets a Schedule. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_get_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.GetScheduleRequest( + name="name_value", + ) + + # Make the request + response = client.get_schedule(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1.types.GetScheduleRequest, dict]): + The request object. Request message for + [ScheduleService.GetSchedule][google.cloud.aiplatform.v1.ScheduleService.GetSchedule]. + name (str): + Required. The name of the Schedule resource. Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Schedule: + An instance of a Schedule + periodically schedules runs to make API + calls based on user specified time + specification and API request type. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a schedule_service.GetScheduleRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, schedule_service.GetScheduleRequest): + request = schedule_service.GetScheduleRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_schedule] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_schedules( + self, + request: Optional[Union[schedule_service.ListSchedulesRequest, dict]] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSchedulesPager: + r"""Lists Schedules in a Location. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_list_schedules(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.ListSchedulesRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_schedules(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1.types.ListSchedulesRequest, dict]): + The request object. Request message for + [ScheduleService.ListSchedules][google.cloud.aiplatform.v1.ScheduleService.ListSchedules]. + parent (str): + Required. The resource name of the Location to list the + Schedules from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.schedule_service.pagers.ListSchedulesPager: + Response message for + [ScheduleService.ListSchedules][google.cloud.aiplatform.v1.ScheduleService.ListSchedules] + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a schedule_service.ListSchedulesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, schedule_service.ListSchedulesRequest): + request = schedule_service.ListSchedulesRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_schedules] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListSchedulesPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def pause_schedule( + self, + request: Optional[Union[schedule_service.PauseScheduleRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Pauses a Schedule. Will mark + [Schedule.state][google.cloud.aiplatform.v1.Schedule.state] to + 'PAUSED'. If the schedule is paused, no new runs will be + created. Already created runs will NOT be paused or canceled. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_pause_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.PauseScheduleRequest( + name="name_value", + ) + + # Make the request + client.pause_schedule(request=request) + + Args: + request (Union[google.cloud.aiplatform_v1.types.PauseScheduleRequest, dict]): + The request object. Request message for + [ScheduleService.PauseSchedule][google.cloud.aiplatform.v1.ScheduleService.PauseSchedule]. + name (str): + Required. The name of the Schedule resource to be + paused. Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a schedule_service.PauseScheduleRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, schedule_service.PauseScheduleRequest): + request = schedule_service.PauseScheduleRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.pause_schedule] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def resume_schedule( + self, + request: Optional[Union[schedule_service.ResumeScheduleRequest, dict]] = None, + *, + name: Optional[str] = None, + catch_up: Optional[bool] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Resumes a paused Schedule to start scheduling new runs. Will + mark [Schedule.state][google.cloud.aiplatform.v1.Schedule.state] + to 'ACTIVE'. Only paused Schedule can be resumed. + + When the Schedule is resumed, new runs will be scheduled + starting from the next execution time after the current time + based on the time_specification in the Schedule. If + [Schedule.catchUp][] is set up true, all missed runs will be + scheduled for backfill first. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_resume_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.ResumeScheduleRequest( + name="name_value", + ) + + # Make the request + client.resume_schedule(request=request) + + Args: + request (Union[google.cloud.aiplatform_v1.types.ResumeScheduleRequest, dict]): + The request object. Request message for + [ScheduleService.ResumeSchedule][google.cloud.aiplatform.v1.ScheduleService.ResumeSchedule]. + name (str): + Required. The name of the Schedule resource to be + resumed. Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + catch_up (bool): + Optional. Whether to backfill missed runs when the + schedule is resumed from PAUSED state. If set to true, + all missed runs will be scheduled. New runs will be + scheduled after the backfill is complete. This will also + update + [Schedule.catch_up][google.cloud.aiplatform.v1.Schedule.catch_up] + field. Default to false. + + This corresponds to the ``catch_up`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name, catch_up]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a schedule_service.ResumeScheduleRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, schedule_service.ResumeScheduleRequest): + request = schedule_service.ResumeScheduleRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + if catch_up is not None: + request.catch_up = catch_up + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.resume_schedule] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def update_schedule( + self, + request: Optional[Union[schedule_service.UpdateScheduleRequest, dict]] = None, + *, + schedule: Optional[gca_schedule.Schedule] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_schedule.Schedule: + r"""Updates an active or paused Schedule. + + When the Schedule is updated, new runs will be scheduled + starting from the updated next execution time after the update + time based on the time_specification in the updated Schedule. + All unstarted runs before the update time will be skipped while + already created runs will NOT be paused or canceled. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_update_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + schedule = aiplatform_v1.Schedule() + schedule.cron = "cron_value" + schedule.create_pipeline_job_request.parent = "parent_value" + schedule.display_name = "display_name_value" + schedule.max_concurrent_run_count = 2596 + + request = aiplatform_v1.UpdateScheduleRequest( + schedule=schedule, + ) + + # Make the request + response = client.update_schedule(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1.types.UpdateScheduleRequest, dict]): + The request object. Request message for + [ScheduleService.UpdateSchedule][google.cloud.aiplatform.v1.ScheduleService.UpdateSchedule]. + schedule (google.cloud.aiplatform_v1.types.Schedule): + Required. The Schedule which replaces the resource on + the server. The following restrictions will be applied: + + - The scheduled request type cannot be changed. + - The output_only fields will be ignored if specified. + + This corresponds to the ``schedule`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Schedule: + An instance of a Schedule + periodically schedules runs to make API + calls based on user specified time + specification and API request type. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([schedule, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a schedule_service.UpdateScheduleRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, schedule_service.UpdateScheduleRequest): + request = schedule_service.UpdateScheduleRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if schedule is not None: + request.schedule = schedule + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_schedule] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("schedule.name", request.schedule.name),) + ), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "ScheduleServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.list_operations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.get_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_operation( + self, + request: Optional[operations_pb2.DeleteOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a long-running operation. + + This method indicates that the client is no longer interested + in the operation result. It does not cancel the operation. + If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.DeleteOperationRequest`): + The request object. Request message for + `DeleteOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.DeleteOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.delete_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def cancel_operation( + self, + request: Optional[operations_pb2.CancelOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Starts asynchronous cancellation on a long-running operation. + + The server makes a best effort to cancel the operation, but success + is not guaranteed. If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.CancelOperationRequest`): + The request object. Request message for + `CancelOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.CancelOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.cancel_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def wait_operation( + self, + request: Optional[operations_pb2.WaitOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Waits until the specified long-running operation is done or reaches at most + a specified timeout, returning the latest state. + + If the operation is already done, the latest state is immediately returned. + If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC + timeout is used. If the server does not support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.WaitOperationRequest`): + The request object. Request message for + `WaitOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.WaitOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.wait_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def set_iam_policy( + self, + request: Optional[iam_policy_pb2.SetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Sets the IAM access control policy on the specified function. + + Replaces any existing policy. + + Args: + request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`): + The request object. Request message for `SetIamPolicy` + method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.SetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.set_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_iam_policy( + self, + request: Optional[iam_policy_pb2.GetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Gets the IAM access control policy for a function. + + Returns an empty policy if the function exists and does not have a + policy set. + + Args: + request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`): + The request object. Request message for `GetIamPolicy` + method. + retry (google.api_core.retry.Retry): Designation of what errors, if + any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.GetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.get_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def test_iam_permissions( + self, + request: Optional[iam_policy_pb2.TestIamPermissionsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> iam_policy_pb2.TestIamPermissionsResponse: + r"""Tests the specified IAM permissions against the IAM access control + policy for a function. + + If the function does not exist, this will return an empty set + of permissions, not a NOT_FOUND error. + + Args: + request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`): + The request object. Request message for + `TestIamPermissions` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.iam_policy_pb2.TestIamPermissionsResponse: + Response message for ``TestIamPermissions`` method. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.TestIamPermissionsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.test_iam_permissions, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_location( + self, + request: Optional[locations_pb2.GetLocationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.Location: + r"""Gets information about a location. + + Args: + request (:class:`~.location_pb2.GetLocationRequest`): + The request object. Request message for + `GetLocation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.Location: + Location object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.GetLocationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.get_location, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_locations( + self, + request: Optional[locations_pb2.ListLocationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.ListLocationsResponse: + r"""Lists information about the supported locations for this service. + + Args: + request (:class:`~.location_pb2.ListLocationsRequest`): + The request object. Request message for + `ListLocations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.ListLocationsResponse: + Response message for ``ListLocations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.ListLocationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.list_locations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("ScheduleServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/schedule_service/pagers.py b/google/cloud/aiplatform_v1/services/schedule_service/pagers.py new file mode 100644 index 0000000000..3bd06809c4 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/schedule_service/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Sequence, + Tuple, + Optional, + Iterator, +) + +from google.cloud.aiplatform_v1.types import schedule +from google.cloud.aiplatform_v1.types import schedule_service + + +class ListSchedulesPager: + """A pager for iterating through ``list_schedules`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListSchedulesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``schedules`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListSchedules`` requests and continue to iterate + through the ``schedules`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListSchedulesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., schedule_service.ListSchedulesResponse], + request: schedule_service.ListSchedulesRequest, + response: schedule_service.ListSchedulesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListSchedulesRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListSchedulesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = schedule_service.ListSchedulesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[schedule_service.ListSchedulesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterator[schedule.Schedule]: + for page in self.pages: + yield from page.schedules + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListSchedulesAsyncPager: + """A pager for iterating through ``list_schedules`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListSchedulesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``schedules`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListSchedules`` requests and continue to iterate + through the ``schedules`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListSchedulesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[schedule_service.ListSchedulesResponse]], + request: schedule_service.ListSchedulesRequest, + response: schedule_service.ListSchedulesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListSchedulesRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListSchedulesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = schedule_service.ListSchedulesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[schedule_service.ListSchedulesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterator[schedule.Schedule]: + async def async_generator(): + async for page in self.pages: + for response in page.schedules: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/schedule_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/schedule_service/transports/__init__.py new file mode 100644 index 0000000000..bc0830480b --- /dev/null +++ b/google/cloud/aiplatform_v1/services/schedule_service/transports/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import ScheduleServiceTransport +from .grpc import ScheduleServiceGrpcTransport +from .grpc_asyncio import ScheduleServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[ScheduleServiceTransport]] +_transport_registry["grpc"] = ScheduleServiceGrpcTransport +_transport_registry["grpc_asyncio"] = ScheduleServiceGrpcAsyncIOTransport + +__all__ = ( + "ScheduleServiceTransport", + "ScheduleServiceGrpcTransport", + "ScheduleServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1/services/schedule_service/transports/base.py b/google/cloud/aiplatform_v1/services/schedule_service/transports/base.py new file mode 100644 index 0000000000..10c18b402f --- /dev/null +++ b/google/cloud/aiplatform_v1/services/schedule_service/transports/base.py @@ -0,0 +1,350 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +from google.cloud.aiplatform_v1 import gapic_version as package_version + +import google.auth # type: ignore +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import operations_v1 +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.cloud.aiplatform_v1.types import schedule +from google.cloud.aiplatform_v1.types import schedule as gca_schedule +from google.cloud.aiplatform_v1.types import schedule_service +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class ScheduleServiceTransport(abc.ABC): + """Abstract transport class for ScheduleService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + DEFAULT_HOST: str = "aiplatform.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_schedule: gapic_v1.method.wrap_method( + self.create_schedule, + default_timeout=None, + client_info=client_info, + ), + self.delete_schedule: gapic_v1.method.wrap_method( + self.delete_schedule, + default_timeout=None, + client_info=client_info, + ), + self.get_schedule: gapic_v1.method.wrap_method( + self.get_schedule, + default_timeout=None, + client_info=client_info, + ), + self.list_schedules: gapic_v1.method.wrap_method( + self.list_schedules, + default_timeout=None, + client_info=client_info, + ), + self.pause_schedule: gapic_v1.method.wrap_method( + self.pause_schedule, + default_timeout=None, + client_info=client_info, + ), + self.resume_schedule: gapic_v1.method.wrap_method( + self.resume_schedule, + default_timeout=None, + client_info=client_info, + ), + self.update_schedule: gapic_v1.method.wrap_method( + self.update_schedule, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def operations_client(self): + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_schedule( + self, + ) -> Callable[ + [schedule_service.CreateScheduleRequest], + Union[gca_schedule.Schedule, Awaitable[gca_schedule.Schedule]], + ]: + raise NotImplementedError() + + @property + def delete_schedule( + self, + ) -> Callable[ + [schedule_service.DeleteScheduleRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def get_schedule( + self, + ) -> Callable[ + [schedule_service.GetScheduleRequest], + Union[schedule.Schedule, Awaitable[schedule.Schedule]], + ]: + raise NotImplementedError() + + @property + def list_schedules( + self, + ) -> Callable[ + [schedule_service.ListSchedulesRequest], + Union[ + schedule_service.ListSchedulesResponse, + Awaitable[schedule_service.ListSchedulesResponse], + ], + ]: + raise NotImplementedError() + + @property + def pause_schedule( + self, + ) -> Callable[ + [schedule_service.PauseScheduleRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + @property + def resume_schedule( + self, + ) -> Callable[ + [schedule_service.ResumeScheduleRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + @property + def update_schedule( + self, + ) -> Callable[ + [schedule_service.UpdateScheduleRequest], + Union[gca_schedule.Schedule, Awaitable[gca_schedule.Schedule]], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def cancel_operation( + self, + ) -> Callable[[operations_pb2.CancelOperationRequest], None,]: + raise NotImplementedError() + + @property + def delete_operation( + self, + ) -> Callable[[operations_pb2.DeleteOperationRequest], None,]: + raise NotImplementedError() + + @property + def wait_operation( + self, + ) -> Callable[ + [operations_pb2.WaitOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def set_iam_policy( + self, + ) -> Callable[ + [iam_policy_pb2.SetIamPolicyRequest], + Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]], + ]: + raise NotImplementedError() + + @property + def get_iam_policy( + self, + ) -> Callable[ + [iam_policy_pb2.GetIamPolicyRequest], + Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]], + ]: + raise NotImplementedError() + + @property + def test_iam_permissions( + self, + ) -> Callable[ + [iam_policy_pb2.TestIamPermissionsRequest], + Union[ + iam_policy_pb2.TestIamPermissionsResponse, + Awaitable[iam_policy_pb2.TestIamPermissionsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_location( + self, + ) -> Callable[ + [locations_pb2.GetLocationRequest], + Union[locations_pb2.Location, Awaitable[locations_pb2.Location]], + ]: + raise NotImplementedError() + + @property + def list_locations( + self, + ) -> Callable[ + [locations_pb2.ListLocationsRequest], + Union[ + locations_pb2.ListLocationsResponse, + Awaitable[locations_pb2.ListLocationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("ScheduleServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc.py new file mode 100644 index 0000000000..a9b15b92f2 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc.py @@ -0,0 +1,670 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import grpc_helpers +from google.api_core import operations_v1 +from google.api_core import gapic_v1 +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1.types import schedule +from google.cloud.aiplatform_v1.types import schedule as gca_schedule +from google.cloud.aiplatform_v1.types import schedule_service +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from .base import ScheduleServiceTransport, DEFAULT_CLIENT_INFO + + +class ScheduleServiceGrpcTransport(ScheduleServiceTransport): + """gRPC backend transport for ScheduleService. + + A service for creating and managing Vertex AI's Schedule + resources to periodically launch shceudled runs to make API + calls. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client: Optional[operations_v1.OperationsClient] = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Quick check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + + # Return the client from cache. + return self._operations_client + + @property + def create_schedule( + self, + ) -> Callable[[schedule_service.CreateScheduleRequest], gca_schedule.Schedule]: + r"""Return a callable for the create schedule method over gRPC. + + Creates a Schedule. + + Returns: + Callable[[~.CreateScheduleRequest], + ~.Schedule]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_schedule" not in self._stubs: + self._stubs["create_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/CreateSchedule", + request_serializer=schedule_service.CreateScheduleRequest.serialize, + response_deserializer=gca_schedule.Schedule.deserialize, + ) + return self._stubs["create_schedule"] + + @property + def delete_schedule( + self, + ) -> Callable[[schedule_service.DeleteScheduleRequest], operations_pb2.Operation]: + r"""Return a callable for the delete schedule method over gRPC. + + Deletes a Schedule. + + Returns: + Callable[[~.DeleteScheduleRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_schedule" not in self._stubs: + self._stubs["delete_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/DeleteSchedule", + request_serializer=schedule_service.DeleteScheduleRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_schedule"] + + @property + def get_schedule( + self, + ) -> Callable[[schedule_service.GetScheduleRequest], schedule.Schedule]: + r"""Return a callable for the get schedule method over gRPC. + + Gets a Schedule. + + Returns: + Callable[[~.GetScheduleRequest], + ~.Schedule]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_schedule" not in self._stubs: + self._stubs["get_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/GetSchedule", + request_serializer=schedule_service.GetScheduleRequest.serialize, + response_deserializer=schedule.Schedule.deserialize, + ) + return self._stubs["get_schedule"] + + @property + def list_schedules( + self, + ) -> Callable[ + [schedule_service.ListSchedulesRequest], schedule_service.ListSchedulesResponse + ]: + r"""Return a callable for the list schedules method over gRPC. + + Lists Schedules in a Location. + + Returns: + Callable[[~.ListSchedulesRequest], + ~.ListSchedulesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_schedules" not in self._stubs: + self._stubs["list_schedules"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/ListSchedules", + request_serializer=schedule_service.ListSchedulesRequest.serialize, + response_deserializer=schedule_service.ListSchedulesResponse.deserialize, + ) + return self._stubs["list_schedules"] + + @property + def pause_schedule( + self, + ) -> Callable[[schedule_service.PauseScheduleRequest], empty_pb2.Empty]: + r"""Return a callable for the pause schedule method over gRPC. + + Pauses a Schedule. Will mark + [Schedule.state][google.cloud.aiplatform.v1.Schedule.state] to + 'PAUSED'. If the schedule is paused, no new runs will be + created. Already created runs will NOT be paused or canceled. + + Returns: + Callable[[~.PauseScheduleRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "pause_schedule" not in self._stubs: + self._stubs["pause_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/PauseSchedule", + request_serializer=schedule_service.PauseScheduleRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["pause_schedule"] + + @property + def resume_schedule( + self, + ) -> Callable[[schedule_service.ResumeScheduleRequest], empty_pb2.Empty]: + r"""Return a callable for the resume schedule method over gRPC. + + Resumes a paused Schedule to start scheduling new runs. Will + mark [Schedule.state][google.cloud.aiplatform.v1.Schedule.state] + to 'ACTIVE'. Only paused Schedule can be resumed. + + When the Schedule is resumed, new runs will be scheduled + starting from the next execution time after the current time + based on the time_specification in the Schedule. If + [Schedule.catchUp][] is set up true, all missed runs will be + scheduled for backfill first. + + Returns: + Callable[[~.ResumeScheduleRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "resume_schedule" not in self._stubs: + self._stubs["resume_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/ResumeSchedule", + request_serializer=schedule_service.ResumeScheduleRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["resume_schedule"] + + @property + def update_schedule( + self, + ) -> Callable[[schedule_service.UpdateScheduleRequest], gca_schedule.Schedule]: + r"""Return a callable for the update schedule method over gRPC. + + Updates an active or paused Schedule. + + When the Schedule is updated, new runs will be scheduled + starting from the updated next execution time after the update + time based on the time_specification in the updated Schedule. + All unstarted runs before the update time will be skipped while + already created runs will NOT be paused or canceled. + + Returns: + Callable[[~.UpdateScheduleRequest], + ~.Schedule]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_schedule" not in self._stubs: + self._stubs["update_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/UpdateSchedule", + request_serializer=schedule_service.UpdateScheduleRequest.serialize, + response_deserializer=gca_schedule.Schedule.deserialize, + ) + return self._stubs["update_schedule"] + + def close(self): + self.grpc_channel.close() + + @property + def delete_operation( + self, + ) -> Callable[[operations_pb2.DeleteOperationRequest], None]: + r"""Return a callable for the delete_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["delete_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/DeleteOperation", + request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["delete_operation"] + + @property + def cancel_operation( + self, + ) -> Callable[[operations_pb2.CancelOperationRequest], None]: + r"""Return a callable for the cancel_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_operation" not in self._stubs: + self._stubs["cancel_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/CancelOperation", + request_serializer=operations_pb2.CancelOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["cancel_operation"] + + @property + def wait_operation( + self, + ) -> Callable[[operations_pb2.WaitOperationRequest], None]: + r"""Return a callable for the wait_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["wait_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/WaitOperation", + request_serializer=operations_pb2.WaitOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["wait_operation"] + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def list_locations( + self, + ) -> Callable[ + [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse + ]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_locations" not in self._stubs: + self._stubs["list_locations"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/ListLocations", + request_serializer=locations_pb2.ListLocationsRequest.SerializeToString, + response_deserializer=locations_pb2.ListLocationsResponse.FromString, + ) + return self._stubs["list_locations"] + + @property + def get_location( + self, + ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_location" not in self._stubs: + self._stubs["get_location"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/GetLocation", + request_serializer=locations_pb2.GetLocationRequest.SerializeToString, + response_deserializer=locations_pb2.Location.FromString, + ) + return self._stubs["get_location"] + + @property + def set_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the set iam policy method over gRPC. + Sets the IAM access control policy on the specified + function. Replaces any existing policy. + Returns: + Callable[[~.SetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "set_iam_policy" not in self._stubs: + self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/SetIamPolicy", + request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["set_iam_policy"] + + @property + def get_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the get iam policy method over gRPC. + Gets the IAM access control policy for a function. + Returns an empty policy if the function exists and does + not have a policy set. + Returns: + Callable[[~.GetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_iam_policy" not in self._stubs: + self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/GetIamPolicy", + request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["get_iam_policy"] + + @property + def test_iam_permissions( + self, + ) -> Callable[ + [iam_policy_pb2.TestIamPermissionsRequest], + iam_policy_pb2.TestIamPermissionsResponse, + ]: + r"""Return a callable for the test iam permissions method over gRPC. + Tests the specified permissions against the IAM access control + policy for a function. If the function does not exist, this will + return an empty set of permissions, not a NOT_FOUND error. + Returns: + Callable[[~.TestIamPermissionsRequest], + ~.TestIamPermissionsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "test_iam_permissions" not in self._stubs: + self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/TestIamPermissions", + request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString, + response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString, + ) + return self._stubs["test_iam_permissions"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("ScheduleServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..2eb1b913d3 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/schedule_service/transports/grpc_asyncio.py @@ -0,0 +1,678 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.api_core import operations_v1 +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1.types import schedule +from google.cloud.aiplatform_v1.types import schedule as gca_schedule +from google.cloud.aiplatform_v1.types import schedule_service +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from .base import ScheduleServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import ScheduleServiceGrpcTransport + + +class ScheduleServiceGrpcAsyncIOTransport(ScheduleServiceTransport): + """gRPC AsyncIO backend transport for ScheduleService. + + A service for creating and managing Vertex AI's Schedule + resources to periodically launch shceudled runs to make API + calls. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client: Optional[operations_v1.OperationsAsyncClient] = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Quick check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_schedule( + self, + ) -> Callable[ + [schedule_service.CreateScheduleRequest], Awaitable[gca_schedule.Schedule] + ]: + r"""Return a callable for the create schedule method over gRPC. + + Creates a Schedule. + + Returns: + Callable[[~.CreateScheduleRequest], + Awaitable[~.Schedule]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_schedule" not in self._stubs: + self._stubs["create_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/CreateSchedule", + request_serializer=schedule_service.CreateScheduleRequest.serialize, + response_deserializer=gca_schedule.Schedule.deserialize, + ) + return self._stubs["create_schedule"] + + @property + def delete_schedule( + self, + ) -> Callable[ + [schedule_service.DeleteScheduleRequest], Awaitable[operations_pb2.Operation] + ]: + r"""Return a callable for the delete schedule method over gRPC. + + Deletes a Schedule. + + Returns: + Callable[[~.DeleteScheduleRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_schedule" not in self._stubs: + self._stubs["delete_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/DeleteSchedule", + request_serializer=schedule_service.DeleteScheduleRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_schedule"] + + @property + def get_schedule( + self, + ) -> Callable[[schedule_service.GetScheduleRequest], Awaitable[schedule.Schedule]]: + r"""Return a callable for the get schedule method over gRPC. + + Gets a Schedule. + + Returns: + Callable[[~.GetScheduleRequest], + Awaitable[~.Schedule]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_schedule" not in self._stubs: + self._stubs["get_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/GetSchedule", + request_serializer=schedule_service.GetScheduleRequest.serialize, + response_deserializer=schedule.Schedule.deserialize, + ) + return self._stubs["get_schedule"] + + @property + def list_schedules( + self, + ) -> Callable[ + [schedule_service.ListSchedulesRequest], + Awaitable[schedule_service.ListSchedulesResponse], + ]: + r"""Return a callable for the list schedules method over gRPC. + + Lists Schedules in a Location. + + Returns: + Callable[[~.ListSchedulesRequest], + Awaitable[~.ListSchedulesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_schedules" not in self._stubs: + self._stubs["list_schedules"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/ListSchedules", + request_serializer=schedule_service.ListSchedulesRequest.serialize, + response_deserializer=schedule_service.ListSchedulesResponse.deserialize, + ) + return self._stubs["list_schedules"] + + @property + def pause_schedule( + self, + ) -> Callable[[schedule_service.PauseScheduleRequest], Awaitable[empty_pb2.Empty]]: + r"""Return a callable for the pause schedule method over gRPC. + + Pauses a Schedule. Will mark + [Schedule.state][google.cloud.aiplatform.v1.Schedule.state] to + 'PAUSED'. If the schedule is paused, no new runs will be + created. Already created runs will NOT be paused or canceled. + + Returns: + Callable[[~.PauseScheduleRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "pause_schedule" not in self._stubs: + self._stubs["pause_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/PauseSchedule", + request_serializer=schedule_service.PauseScheduleRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["pause_schedule"] + + @property + def resume_schedule( + self, + ) -> Callable[[schedule_service.ResumeScheduleRequest], Awaitable[empty_pb2.Empty]]: + r"""Return a callable for the resume schedule method over gRPC. + + Resumes a paused Schedule to start scheduling new runs. Will + mark [Schedule.state][google.cloud.aiplatform.v1.Schedule.state] + to 'ACTIVE'. Only paused Schedule can be resumed. + + When the Schedule is resumed, new runs will be scheduled + starting from the next execution time after the current time + based on the time_specification in the Schedule. If + [Schedule.catchUp][] is set up true, all missed runs will be + scheduled for backfill first. + + Returns: + Callable[[~.ResumeScheduleRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "resume_schedule" not in self._stubs: + self._stubs["resume_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/ResumeSchedule", + request_serializer=schedule_service.ResumeScheduleRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["resume_schedule"] + + @property + def update_schedule( + self, + ) -> Callable[ + [schedule_service.UpdateScheduleRequest], Awaitable[gca_schedule.Schedule] + ]: + r"""Return a callable for the update schedule method over gRPC. + + Updates an active or paused Schedule. + + When the Schedule is updated, new runs will be scheduled + starting from the updated next execution time after the update + time based on the time_specification in the updated Schedule. + All unstarted runs before the update time will be skipped while + already created runs will NOT be paused or canceled. + + Returns: + Callable[[~.UpdateScheduleRequest], + Awaitable[~.Schedule]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_schedule" not in self._stubs: + self._stubs["update_schedule"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ScheduleService/UpdateSchedule", + request_serializer=schedule_service.UpdateScheduleRequest.serialize, + response_deserializer=gca_schedule.Schedule.deserialize, + ) + return self._stubs["update_schedule"] + + def close(self): + return self.grpc_channel.close() + + @property + def delete_operation( + self, + ) -> Callable[[operations_pb2.DeleteOperationRequest], None]: + r"""Return a callable for the delete_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["delete_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/DeleteOperation", + request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["delete_operation"] + + @property + def cancel_operation( + self, + ) -> Callable[[operations_pb2.CancelOperationRequest], None]: + r"""Return a callable for the cancel_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_operation" not in self._stubs: + self._stubs["cancel_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/CancelOperation", + request_serializer=operations_pb2.CancelOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["cancel_operation"] + + @property + def wait_operation( + self, + ) -> Callable[[operations_pb2.WaitOperationRequest], None]: + r"""Return a callable for the wait_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["wait_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/WaitOperation", + request_serializer=operations_pb2.WaitOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["wait_operation"] + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def list_locations( + self, + ) -> Callable[ + [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse + ]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_locations" not in self._stubs: + self._stubs["list_locations"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/ListLocations", + request_serializer=locations_pb2.ListLocationsRequest.SerializeToString, + response_deserializer=locations_pb2.ListLocationsResponse.FromString, + ) + return self._stubs["list_locations"] + + @property + def get_location( + self, + ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_location" not in self._stubs: + self._stubs["get_location"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/GetLocation", + request_serializer=locations_pb2.GetLocationRequest.SerializeToString, + response_deserializer=locations_pb2.Location.FromString, + ) + return self._stubs["get_location"] + + @property + def set_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the set iam policy method over gRPC. + Sets the IAM access control policy on the specified + function. Replaces any existing policy. + Returns: + Callable[[~.SetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "set_iam_policy" not in self._stubs: + self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/SetIamPolicy", + request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["set_iam_policy"] + + @property + def get_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the get iam policy method over gRPC. + Gets the IAM access control policy for a function. + Returns an empty policy if the function exists and does + not have a policy set. + Returns: + Callable[[~.GetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_iam_policy" not in self._stubs: + self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/GetIamPolicy", + request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["get_iam_policy"] + + @property + def test_iam_permissions( + self, + ) -> Callable[ + [iam_policy_pb2.TestIamPermissionsRequest], + iam_policy_pb2.TestIamPermissionsResponse, + ]: + r"""Return a callable for the test iam permissions method over gRPC. + Tests the specified permissions against the IAM access control + policy for a function. If the function does not exist, this will + return an empty set of permissions, not a NOT_FOUND error. + Returns: + Callable[[~.TestIamPermissionsRequest], + ~.TestIamPermissionsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "test_iam_permissions" not in self._stubs: + self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/TestIamPermissions", + request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString, + response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString, + ) + return self._stubs["test_iam_permissions"] + + +__all__ = ("ScheduleServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py index 5826e69920..5e00d82b4d 100644 --- a/google/cloud/aiplatform_v1/types/__init__.py +++ b/google/cloud/aiplatform_v1/types/__init__.py @@ -522,6 +522,30 @@ PredictRequest, PredictResponse, RawPredictRequest, + StreamingPredictRequest, + StreamingPredictResponse, +) +from .publisher_model import ( + PublisherModel, +) +from .saved_query import ( + SavedQuery, +) +from .schedule import ( + Schedule, +) +from .schedule_service import ( + CreateScheduleRequest, + DeleteScheduleRequest, + GetScheduleRequest, + ListSchedulesRequest, + ListSchedulesResponse, + PauseScheduleRequest, + ResumeScheduleRequest, + UpdateScheduleRequest, +) +from .service_networking import ( + PrivateServiceConnectConfig, ) from .publisher_model import ( PublisherModel, @@ -631,6 +655,7 @@ DoubleArray, Int64Array, StringArray, + Tensor, ) from .unmanaged_container_model import ( UnmanagedContainerModel, @@ -1065,8 +1090,19 @@ "PredictRequest", "PredictResponse", "RawPredictRequest", + "StreamingPredictRequest", + "StreamingPredictResponse", "PublisherModel", "SavedQuery", + "Schedule", + "CreateScheduleRequest", + "DeleteScheduleRequest", + "GetScheduleRequest", + "ListSchedulesRequest", + "ListSchedulesResponse", + "PauseScheduleRequest", + "ResumeScheduleRequest", + "UpdateScheduleRequest", "PrivateServiceConnectConfig", "SpecialistPool", "CreateSpecialistPoolOperationMetadata", @@ -1146,6 +1182,7 @@ "DoubleArray", "Int64Array", "StringArray", + "Tensor", "UnmanagedContainerModel", "UserActionReference", "Value", diff --git a/google/cloud/aiplatform_v1/types/explanation.py b/google/cloud/aiplatform_v1/types/explanation.py index fdfd276f32..0203fafccb 100644 --- a/google/cloud/aiplatform_v1/types/explanation.py +++ b/google/cloud/aiplatform_v1/types/explanation.py @@ -487,7 +487,9 @@ class IntegratedGradientsAttribution(proto.Message): blurred image to the input image is created. Using a blurred baseline instead of zero (black image) is motivated by the BlurIG approach - explained here: https://arxiv.org/abs/2004.03383 + explained here: + + https://arxiv.org/abs/2004.03383 """ step_count: int = proto.Field( @@ -510,7 +512,9 @@ class XraiAttribution(proto.Message): r"""An explanation method that redistributes Integrated Gradients attributions to segmented regions, taking advantage of the model's fully differentiable structure. Refer to this paper for - more details: https://arxiv.org/abs/1906.02825 + more details: + + https://arxiv.org/abs/1906.02825 Supported only by image Models. @@ -537,7 +541,9 @@ class XraiAttribution(proto.Message): blurred image to the input image is created. Using a blurred baseline instead of zero (black image) is motivated by the BlurIG approach - explained here: https://arxiv.org/abs/2004.03383 + explained here: + + https://arxiv.org/abs/2004.03383 """ step_count: int = proto.Field( @@ -562,6 +568,7 @@ class SmoothGradConfig(proto.Message): gradients from noisy samples in the vicinity of the inputs. Adding noise can help improve the computed gradients. Refer to this paper for more details: + https://arxiv.org/pdf/1706.03825.pdf This message has `oneof`_ fields (mutually exclusive fields). @@ -675,6 +682,7 @@ class BlurBaselineConfig(proto.Message): the input image is created. Using a blurred baseline instead of zero (black image) is motivated by the BlurIG approach explained here: + https://arxiv.org/abs/2004.03383 Attributes: diff --git a/google/cloud/aiplatform_v1/types/model_deployment_monitoring_job.py b/google/cloud/aiplatform_v1/types/model_deployment_monitoring_job.py index f1819b0857..88b6d5d7a0 100644 --- a/google/cloud/aiplatform_v1/types/model_deployment_monitoring_job.py +++ b/google/cloud/aiplatform_v1/types/model_deployment_monitoring_job.py @@ -147,6 +147,7 @@ class ModelDeploymentMonitoringJob(proto.Message): the job under customer project. Customer could do their own query & analysis. There could be 4 log tables in maximum: + 1. Training data logging predict request/response 2. Serving data logging predict request/response diff --git a/google/cloud/aiplatform_v1/types/model_monitoring.py b/google/cloud/aiplatform_v1/types/model_monitoring.py index 89e4196a1f..aed3044bb5 100644 --- a/google/cloud/aiplatform_v1/types/model_monitoring.py +++ b/google/cloud/aiplatform_v1/types/model_monitoring.py @@ -385,9 +385,10 @@ class ThresholdConfig(proto.Message): value (float): Specify a threshold value that can trigger the alert. If this threshold config is for - feature distribution distance: 1. For - categorical feature, the distribution distance - is calculated by L-inifinity norm. + feature distribution distance: + + 1. For categorical feature, the distribution + distance is calculated by L-inifinity norm. 2. For numerical feature, the distribution distance is calculated by Jensen–Shannon divergence. diff --git a/google/cloud/aiplatform_v1/types/prediction_service.py b/google/cloud/aiplatform_v1/types/prediction_service.py index c30669cdc0..dc8692905c 100644 --- a/google/cloud/aiplatform_v1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1/types/prediction_service.py @@ -21,6 +21,7 @@ from google.api import httpbody_pb2 # type: ignore from google.cloud.aiplatform_v1.types import explanation +from google.cloud.aiplatform_v1.types import types from google.protobuf import struct_pb2 # type: ignore @@ -30,6 +31,8 @@ "PredictRequest", "PredictResponse", "RawPredictRequest", + "StreamingPredictRequest", + "StreamingPredictResponse", "ExplainRequest", "ExplainResponse", }, @@ -109,6 +112,10 @@ class PredictResponse(proto.Message): name][google.cloud.aiplatform.v1.Model.display_name] of the Model which is deployed as the DeployedModel that this prediction hits. + metadata (google.protobuf.struct_pb2.Value): + Output only. Request-level metadata returned + by the model. The metadata type will be + dependent upon the model implementation. """ predictions: MutableSequence[struct_pb2.Value] = proto.RepeatedField( @@ -132,6 +139,11 @@ class PredictResponse(proto.Message): proto.STRING, number=4, ) + metadata: struct_pb2.Value = proto.Field( + proto.MESSAGE, + number=6, + message=struct_pb2.Value, + ) class RawPredictRequest(proto.Message): @@ -176,6 +188,65 @@ class RawPredictRequest(proto.Message): ) +class StreamingPredictRequest(proto.Message): + r"""Request message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1.PredictionService.StreamingPredict]. + + The first message must contain + [endpoint][google.cloud.aiplatform.v1.StreamingPredictRequest.endpoint] + field and optionally [input][]. The subsequent messages must contain + [input][]. + + Attributes: + endpoint (str): + Required. The name of the Endpoint requested to serve the + prediction. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + inputs (MutableSequence[google.cloud.aiplatform_v1.types.Tensor]): + The prediction input. + parameters (google.cloud.aiplatform_v1.types.Tensor): + The parameters that govern the prediction. + """ + + endpoint: str = proto.Field( + proto.STRING, + number=1, + ) + inputs: MutableSequence[types.Tensor] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=types.Tensor, + ) + parameters: types.Tensor = proto.Field( + proto.MESSAGE, + number=3, + message=types.Tensor, + ) + + +class StreamingPredictResponse(proto.Message): + r"""Response message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1.PredictionService.StreamingPredict]. + + Attributes: + outputs (MutableSequence[google.cloud.aiplatform_v1.types.Tensor]): + The prediction output. + parameters (google.cloud.aiplatform_v1.types.Tensor): + The parameters that govern the prediction. + """ + + outputs: MutableSequence[types.Tensor] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=types.Tensor, + ) + parameters: types.Tensor = proto.Field( + proto.MESSAGE, + number=2, + message=types.Tensor, + ) + + class ExplainRequest(proto.Message): r"""Request message for [PredictionService.Explain][google.cloud.aiplatform.v1.PredictionService.Explain]. diff --git a/google/cloud/aiplatform_v1/types/schedule.py b/google/cloud/aiplatform_v1/types/schedule.py new file mode 100644 index 0000000000..7bceaf70e1 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/schedule.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.cloud.aiplatform_v1.types import pipeline_service +from google.protobuf import timestamp_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "Schedule", + }, +) + + +class Schedule(proto.Message): + r"""An instance of a Schedule periodically schedules runs to make + API calls based on user specified time specification and API + request type. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + cron (str): + Cron schedule (https://en.wikipedia.org/wiki/Cron) to launch + scheduled runs. To explicitly set a timezone to the cron + tab, apply a prefix in the cron tab: + "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}". The + ${IANA_TIME_ZONE} may only be a valid string from IANA time + zone database. For example, "CRON_TZ=America/New_York 1 \* + \* \* \*", or "TZ=America/New_York 1 \* \* \* \*". + + This field is a member of `oneof`_ ``time_specification``. + create_pipeline_job_request (google.cloud.aiplatform_v1.types.CreatePipelineJobRequest): + Request for + [PipelineService.CreatePipelineJob][google.cloud.aiplatform.v1.PipelineService.CreatePipelineJob]. + CreatePipelineJobRequest.parent field is required (format: + projects/{project}/locations/{location}). + + This field is a member of `oneof`_ ``request``. + name (str): + Output only. The resource name of the + Schedule. + display_name (str): + Required. User provided name of the Schedule. + The name can be up to 128 characters long and + can consist of any UTF-8 characters. + start_time (google.protobuf.timestamp_pb2.Timestamp): + Optional. Timestamp after which the first run + can be scheduled. Default to Schedule create + time if not specified. + end_time (google.protobuf.timestamp_pb2.Timestamp): + Optional. Timestamp after which no new runs can be + scheduled. If specified, The schedule will be completed when + either end_time is reached or when scheduled_run_count >= + max_run_count. If not specified, new runs will keep getting + scheduled until this Schedule is paused or deleted. Already + scheduled runs will be allowed to complete. Unset if not + specified. + max_run_count (int): + Optional. Maximum run count of the schedule. If specified, + The schedule will be completed when either started_run_count + >= max_run_count or when end_time is reached. If not + specified, new runs will keep getting scheduled until this + Schedule is paused or deleted. Already scheduled runs will + be allowed to complete. Unset if not specified. + started_run_count (int): + Output only. The number of runs started by + this schedule. + state (google.cloud.aiplatform_v1.types.Schedule.State): + Output only. The state of this Schedule. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Schedule was + created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Schedule was + updated. + next_run_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Schedule should schedule + the next run. Having a next_run_time in the past means the + runs are being started behind schedule. + last_pause_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Schedule was + last paused. Unset if never paused. + last_resume_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Schedule was + last resumed. Unset if never resumed from pause. + max_concurrent_run_count (int): + Required. Maximum number of runs that can be + started concurrently for this Schedule. This is + the limit for starting the scheduled requests + and not the execution of the operations/jobs + created by the requests (if applicable). + allow_queueing (bool): + Optional. Whether new scheduled runs can be queued when + max_concurrent_runs limit is reached. If set to true, new + runs will be queued instead of skipped. Default to false. + catch_up (bool): + Output only. Whether to backfill missed runs + when the schedule is resumed from PAUSED state. + If set to true, all missed runs will be + scheduled. New runs will be scheduled after the + backfill is complete. Default to false. + last_scheduled_run_response (google.cloud.aiplatform_v1.types.Schedule.RunResponse): + Output only. Response of the last scheduled + run. This is the response for starting the + scheduled requests and not the execution of the + operations/jobs created by the requests (if + applicable). Unset if no run has been scheduled + yet. + """ + + class State(proto.Enum): + r"""Possible state of the schedule. + + Values: + STATE_UNSPECIFIED (0): + Unspecified. + ACTIVE (1): + The Schedule is active. Runs are being + scheduled on the user-specified timespec. + PAUSED (2): + The schedule is paused. No new runs will be + created until the schedule is resumed. Already + started runs will be allowed to complete. + COMPLETED (3): + The Schedule is completed. No new runs will + be scheduled. Already started runs will be + allowed to complete. Schedules in completed + state cannot be paused or resumed. + """ + STATE_UNSPECIFIED = 0 + ACTIVE = 1 + PAUSED = 2 + COMPLETED = 3 + + class RunResponse(proto.Message): + r"""Status of a scheduled run. + + Attributes: + scheduled_run_time (google.protobuf.timestamp_pb2.Timestamp): + The scheduled run time based on the + user-specified schedule. + run_response (str): + The response of the scheduled run. + """ + + scheduled_run_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=1, + message=timestamp_pb2.Timestamp, + ) + run_response: str = proto.Field( + proto.STRING, + number=2, + ) + + cron: str = proto.Field( + proto.STRING, + number=10, + oneof="time_specification", + ) + create_pipeline_job_request: pipeline_service.CreatePipelineJobRequest = ( + proto.Field( + proto.MESSAGE, + number=14, + oneof="request", + message=pipeline_service.CreatePipelineJobRequest, + ) + ) + name: str = proto.Field( + proto.STRING, + number=1, + ) + display_name: str = proto.Field( + proto.STRING, + number=2, + ) + start_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + end_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=4, + message=timestamp_pb2.Timestamp, + ) + max_run_count: int = proto.Field( + proto.INT64, + number=16, + ) + started_run_count: int = proto.Field( + proto.INT64, + number=17, + ) + state: State = proto.Field( + proto.ENUM, + number=5, + enum=State, + ) + create_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=6, + message=timestamp_pb2.Timestamp, + ) + update_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=19, + message=timestamp_pb2.Timestamp, + ) + next_run_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=7, + message=timestamp_pb2.Timestamp, + ) + last_pause_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=8, + message=timestamp_pb2.Timestamp, + ) + last_resume_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=9, + message=timestamp_pb2.Timestamp, + ) + max_concurrent_run_count: int = proto.Field( + proto.INT64, + number=11, + ) + allow_queueing: bool = proto.Field( + proto.BOOL, + number=12, + ) + catch_up: bool = proto.Field( + proto.BOOL, + number=13, + ) + last_scheduled_run_response: RunResponse = proto.Field( + proto.MESSAGE, + number=18, + message=RunResponse, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/schedule_service.py b/google/cloud/aiplatform_v1/types/schedule_service.py new file mode 100644 index 0000000000..77ac5bc58a --- /dev/null +++ b/google/cloud/aiplatform_v1/types/schedule_service.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.cloud.aiplatform_v1.types import schedule as gca_schedule +from google.protobuf import field_mask_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "CreateScheduleRequest", + "GetScheduleRequest", + "ListSchedulesRequest", + "ListSchedulesResponse", + "DeleteScheduleRequest", + "PauseScheduleRequest", + "ResumeScheduleRequest", + "UpdateScheduleRequest", + }, +) + + +class CreateScheduleRequest(proto.Message): + r"""Request message for + [ScheduleService.CreateSchedule][google.cloud.aiplatform.v1.ScheduleService.CreateSchedule]. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + Schedule in. Format: + ``projects/{project}/locations/{location}`` + schedule (google.cloud.aiplatform_v1.types.Schedule): + Required. The Schedule to create. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + schedule: gca_schedule.Schedule = proto.Field( + proto.MESSAGE, + number=2, + message=gca_schedule.Schedule, + ) + + +class GetScheduleRequest(proto.Message): + r"""Request message for + [ScheduleService.GetSchedule][google.cloud.aiplatform.v1.ScheduleService.GetSchedule]. + + Attributes: + name (str): + Required. The name of the Schedule resource. Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class ListSchedulesRequest(proto.Message): + r"""Request message for + [ScheduleService.ListSchedules][google.cloud.aiplatform.v1.ScheduleService.ListSchedules]. + + Attributes: + parent (str): + Required. The resource name of the Location to list the + Schedules from. Format: + ``projects/{project}/locations/{location}`` + filter (str): + Lists the Schedules that match the filter expression. The + following fields are supported: + + - ``display_name``: Supports ``=``, ``!=`` comparisons, and + ``:`` wildcard. + - ``state``: Supports ``=`` and ``!=`` comparisons. + - ``request``: Supports existence of the + check. (e.g. ``create_pipeline_job_request:*`` --> + Schedule has create_pipeline_job_request). + - ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``start_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, ``>=`` comparisons and ``:*`` existence check. + Values must be in RFC 3339 format. + - ``next_run_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + + Filter expressions can be combined together using logical + operators (``NOT``, ``AND`` & ``OR``). The syntax to define + filter expression is based on https://google.aip.dev/160. + + Examples: + + - ``state="ACTIVE" AND display_name:"my_schedule_*"`` + - ``NOT display_name="my_schedule"`` + - ``create_time>"2021-05-18T00:00:00Z"`` + - ``end_time>"2021-05-18T00:00:00Z" OR NOT end_time:*`` + - ``create_pipeline_job_request:*`` + page_size (int): + The standard list page size. + Default to 100 if not specified. + page_token (str): + The standard list page token. Typically obtained via + [ListSchedulesResponse.next_page_token][google.cloud.aiplatform.v1.ListSchedulesResponse.next_page_token] + of the previous + [ScheduleService.ListSchedules][google.cloud.aiplatform.v1.ScheduleService.ListSchedules] + call. + order_by (str): + A comma-separated list of fields to order by. The default + sort order is in ascending order. Use "desc" after a field + name for descending. You can have multiple order_by fields + provided. + + For example, using "create_time desc, end_time" will order + results by create time in descending order, and if there are + multiple schedules having the same create time, order them + by the end time in ascending order. + + If order_by is not specified, it will order by default with + create_time in descending order. + + Supported fields: + + - ``create_time`` + - ``start_time`` + - ``end_time`` + - ``next_run_time`` + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + filter: str = proto.Field( + proto.STRING, + number=2, + ) + page_size: int = proto.Field( + proto.INT32, + number=3, + ) + page_token: str = proto.Field( + proto.STRING, + number=4, + ) + order_by: str = proto.Field( + proto.STRING, + number=5, + ) + + +class ListSchedulesResponse(proto.Message): + r"""Response message for + [ScheduleService.ListSchedules][google.cloud.aiplatform.v1.ScheduleService.ListSchedules] + + Attributes: + schedules (MutableSequence[google.cloud.aiplatform_v1.types.Schedule]): + List of Schedules in the requested page. + next_page_token (str): + A token to retrieve the next page of results. Pass to + [ListSchedulesRequest.page_token][google.cloud.aiplatform.v1.ListSchedulesRequest.page_token] + to obtain that page. + """ + + @property + def raw_page(self): + return self + + schedules: MutableSequence[gca_schedule.Schedule] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gca_schedule.Schedule, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class DeleteScheduleRequest(proto.Message): + r"""Request message for + [ScheduleService.DeleteSchedule][google.cloud.aiplatform.v1.ScheduleService.DeleteSchedule]. + + Attributes: + name (str): + Required. The name of the Schedule resource to be deleted. + Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class PauseScheduleRequest(proto.Message): + r"""Request message for + [ScheduleService.PauseSchedule][google.cloud.aiplatform.v1.ScheduleService.PauseSchedule]. + + Attributes: + name (str): + Required. The name of the Schedule resource to be paused. + Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class ResumeScheduleRequest(proto.Message): + r"""Request message for + [ScheduleService.ResumeSchedule][google.cloud.aiplatform.v1.ScheduleService.ResumeSchedule]. + + Attributes: + name (str): + Required. The name of the Schedule resource to be resumed. + Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + catch_up (bool): + Optional. Whether to backfill missed runs when the schedule + is resumed from PAUSED state. If set to true, all missed + runs will be scheduled. New runs will be scheduled after the + backfill is complete. This will also update + [Schedule.catch_up][google.cloud.aiplatform.v1.Schedule.catch_up] + field. Default to false. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + catch_up: bool = proto.Field( + proto.BOOL, + number=2, + ) + + +class UpdateScheduleRequest(proto.Message): + r"""Request message for + [ScheduleService.UpdateSchedule][google.cloud.aiplatform.v1.ScheduleService.UpdateSchedule]. + + Attributes: + schedule (google.cloud.aiplatform_v1.types.Schedule): + Required. The Schedule which replaces the resource on the + server. The following restrictions will be applied: + + - The scheduled request type cannot be changed. + - The output_only fields will be ignored if specified. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + [google.protobuf.FieldMask][google.protobuf.FieldMask]. + """ + + schedule: gca_schedule.Schedule = proto.Field( + proto.MESSAGE, + number=1, + message=gca_schedule.Schedule, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=2, + message=field_mask_pb2.FieldMask, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/types.py b/google/cloud/aiplatform_v1/types/types.py index 43fe4225de..3814552c6c 100644 --- a/google/cloud/aiplatform_v1/types/types.py +++ b/google/cloud/aiplatform_v1/types/types.py @@ -27,6 +27,7 @@ "DoubleArray", "Int64Array", "StringArray", + "Tensor", }, ) @@ -87,4 +88,156 @@ class StringArray(proto.Message): ) +class Tensor(proto.Message): + r"""A tensor value type. + + Attributes: + dtype (google.cloud.aiplatform_v1.types.Tensor.DataType): + The data type of tensor. + shape (MutableSequence[int]): + Shape of the tensor. + bool_val (MutableSequence[bool]): + Type specific representations that make it easy to create + tensor protos in all languages. Only the representation + corresponding to "dtype" can be set. The values hold the + flattened representation of the tensor in row major order. + + [BOOL][google.aiplatform.master.Tensor.DataType.BOOL] + string_val (MutableSequence[str]): + [STRING][google.aiplatform.master.Tensor.DataType.STRING] + bytes_val (MutableSequence[bytes]): + [STRING][google.aiplatform.master.Tensor.DataType.STRING] + float_val (MutableSequence[float]): + [FLOAT][google.aiplatform.master.Tensor.DataType.FLOAT] + double_val (MutableSequence[float]): + [DOUBLE][google.aiplatform.master.Tensor.DataType.DOUBLE] + int_val (MutableSequence[int]): + [INT_8][google.aiplatform.master.Tensor.DataType.INT8] + [INT_16][google.aiplatform.master.Tensor.DataType.INT16] + [INT_32][google.aiplatform.master.Tensor.DataType.INT32] + int64_val (MutableSequence[int]): + [INT64][google.aiplatform.master.Tensor.DataType.INT64] + uint_val (MutableSequence[int]): + [UINT8][google.aiplatform.master.Tensor.DataType.UINT8] + [UINT16][google.aiplatform.master.Tensor.DataType.UINT16] + [UINT32][google.aiplatform.master.Tensor.DataType.UINT32] + uint64_val (MutableSequence[int]): + [UINT64][google.aiplatform.master.Tensor.DataType.UINT64] + list_val (MutableSequence[google.cloud.aiplatform_v1.types.Tensor]): + A list of tensor values. + struct_val (MutableMapping[str, google.cloud.aiplatform_v1.types.Tensor]): + A map of string to tensor. + tensor_val (bytes): + Serialized raw tensor content. + """ + + class DataType(proto.Enum): + r"""Data type of the tensor. + + Values: + DATA_TYPE_UNSPECIFIED (0): + Not a legal value for DataType. Used to + indicate a DataType field has not been set. + BOOL (1): + Data types that all computation devices are + expected to be capable to support. + STRING (2): + No description available. + FLOAT (3): + No description available. + DOUBLE (4): + No description available. + INT8 (5): + No description available. + INT16 (6): + No description available. + INT32 (7): + No description available. + INT64 (8): + No description available. + UINT8 (9): + No description available. + UINT16 (10): + No description available. + UINT32 (11): + No description available. + UINT64 (12): + No description available. + """ + DATA_TYPE_UNSPECIFIED = 0 + BOOL = 1 + STRING = 2 + FLOAT = 3 + DOUBLE = 4 + INT8 = 5 + INT16 = 6 + INT32 = 7 + INT64 = 8 + UINT8 = 9 + UINT16 = 10 + UINT32 = 11 + UINT64 = 12 + + dtype: DataType = proto.Field( + proto.ENUM, + number=1, + enum=DataType, + ) + shape: MutableSequence[int] = proto.RepeatedField( + proto.INT64, + number=2, + ) + bool_val: MutableSequence[bool] = proto.RepeatedField( + proto.BOOL, + number=3, + ) + string_val: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=14, + ) + bytes_val: MutableSequence[bytes] = proto.RepeatedField( + proto.BYTES, + number=15, + ) + float_val: MutableSequence[float] = proto.RepeatedField( + proto.FLOAT, + number=5, + ) + double_val: MutableSequence[float] = proto.RepeatedField( + proto.DOUBLE, + number=6, + ) + int_val: MutableSequence[int] = proto.RepeatedField( + proto.INT32, + number=7, + ) + int64_val: MutableSequence[int] = proto.RepeatedField( + proto.INT64, + number=8, + ) + uint_val: MutableSequence[int] = proto.RepeatedField( + proto.UINT32, + number=9, + ) + uint64_val: MutableSequence[int] = proto.RepeatedField( + proto.UINT64, + number=10, + ) + list_val: MutableSequence["Tensor"] = proto.RepeatedField( + proto.MESSAGE, + number=11, + message="Tensor", + ) + struct_val: MutableMapping[str, "Tensor"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=12, + message="Tensor", + ) + tensor_val: bytes = proto.Field( + proto.BYTES, + number=13, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index ba9958bb7d..7575620d9c 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -465,7 +465,9 @@ from .types.operation import DeleteOperationMetadata from .types.operation import GenericOperationMetadata from .types.persistent_resource import PersistentResource +from .types.persistent_resource import RaySpec from .types.persistent_resource import ResourcePool +from .types.persistent_resource import ResourceRuntime from .types.persistent_resource import ResourceRuntimeSpec from .types.persistent_resource import ServiceAccountSpec from .types.persistent_resource_service import CreatePersistentResourceOperationMetadata @@ -498,6 +500,8 @@ from .types.prediction_service import PredictRequest from .types.prediction_service import PredictResponse from .types.prediction_service import RawPredictRequest +from .types.prediction_service import StreamingPredictRequest +from .types.prediction_service import StreamingPredictResponse from .types.publisher_model import PublisherModel from .types.saved_query import SavedQuery from .types.schedule import Schedule @@ -590,6 +594,7 @@ from .types.types import DoubleArray from .types.types import Int64Array from .types.types import StringArray +from .types.types import Tensor from .types.unmanaged_container_model import UnmanagedContainerModel from .types.user_action_reference import UserActionReference from .types.value import Value @@ -1076,6 +1081,7 @@ "QueryDeployedModelsResponse", "QueryExecutionInputsAndOutputsRequest", "RawPredictRequest", + "RaySpec", "ReadFeatureValuesRequest", "ReadFeatureValuesResponse", "ReadIndexDatapointsRequest", @@ -1093,6 +1099,7 @@ "RemoveDatapointsRequest", "RemoveDatapointsResponse", "ResourcePool", + "ResourceRuntime", "ResourceRuntimeSpec", "ResourcesConsumed", "ResumeModelDeploymentMonitoringJobRequest", @@ -1119,6 +1126,8 @@ "SpecialistPoolServiceClient", "StopTrialRequest", "StratifiedSplit", + "StreamingPredictRequest", + "StreamingPredictResponse", "StreamingReadFeatureValuesRequest", "StringArray", "Study", @@ -1127,6 +1136,7 @@ "SuggestTrialsRequest", "SuggestTrialsResponse", "TFRecordDestination", + "Tensor", "Tensorboard", "TensorboardBlob", "TensorboardBlobSequence", diff --git a/google/cloud/aiplatform_v1beta1/gapic_metadata.json b/google/cloud/aiplatform_v1beta1/gapic_metadata.json index 9eb752c8a2..47b9306b67 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1beta1/gapic_metadata.json @@ -1924,6 +1924,11 @@ "methods": [ "raw_predict" ] + }, + "ServerStreamingPredict": { + "methods": [ + "server_streaming_predict" + ] } } }, @@ -1944,6 +1949,11 @@ "methods": [ "raw_predict" ] + }, + "ServerStreamingPredict": { + "methods": [ + "server_streaming_predict" + ] } } } diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py index ba91baf7b0..79fa9943e4 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py @@ -16,12 +16,15 @@ from collections import OrderedDict import functools import re +import pkg_resources from typing import ( Dict, Mapping, MutableMapping, MutableSequence, Optional, + AsyncIterable, + Awaitable, Sequence, Tuple, Type, @@ -45,6 +48,7 @@ from google.api import httpbody_pb2 # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service +from google.cloud.aiplatform_v1beta1.types import types from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -548,6 +552,95 @@ async def sample_raw_predict(): # Done; return the response. return response + def server_streaming_predict( + self, + request: Optional[ + Union[prediction_service.StreamingPredictRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Awaitable[AsyncIterable[prediction_service.StreamingPredictResponse]]: + r"""Perform a server-side streaming online prediction + request for Vertex LLM streaming. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + async def sample_server_streaming_predict(): + # Create a client + client = aiplatform_v1beta1.PredictionServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.StreamingPredictRequest( + endpoint="endpoint_value", + ) + + # Make the request + stream = await client.server_streaming_predict(request=request) + + # Handle the response + async for response in stream: + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1beta1.types.StreamingPredictRequest, dict]]): + The request object. Request message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1beta1.PredictionService.StreamingPredict]. + + The first message must contain + [endpoint][google.cloud.aiplatform.v1beta1.StreamingPredictRequest.endpoint] + field and optionally [input][]. The subsequent messages + must contain [input][]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + AsyncIterable[google.cloud.aiplatform_v1beta1.types.StreamingPredictResponse]: + Response message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1beta1.PredictionService.StreamingPredict]. + + """ + # Create or coerce a protobuf request object. + request = prediction_service.StreamingPredictRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.server_streaming_predict, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + async def explain( self, request: Optional[Union[prediction_service.ExplainRequest, dict]] = None, diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index 1416fa5730..ba928c663c 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -16,12 +16,14 @@ from collections import OrderedDict import os import re +import pkg_resources from typing import ( Dict, Mapping, MutableMapping, MutableSequence, Optional, + Iterable, Sequence, Tuple, Type, @@ -49,6 +51,7 @@ from google.api import httpbody_pb2 # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service +from google.cloud.aiplatform_v1beta1.types import types from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -795,6 +798,96 @@ def sample_raw_predict(): # Done; return the response. return response + def server_streaming_predict( + self, + request: Optional[ + Union[prediction_service.StreamingPredictRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Iterable[prediction_service.StreamingPredictResponse]: + r"""Perform a server-side streaming online prediction + request for Vertex LLM streaming. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + def sample_server_streaming_predict(): + # Create a client + client = aiplatform_v1beta1.PredictionServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.StreamingPredictRequest( + endpoint="endpoint_value", + ) + + # Make the request + stream = client.server_streaming_predict(request=request) + + # Handle the response + for response in stream: + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1beta1.types.StreamingPredictRequest, dict]): + The request object. Request message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1beta1.PredictionService.StreamingPredict]. + + The first message must contain + [endpoint][google.cloud.aiplatform.v1beta1.StreamingPredictRequest.endpoint] + field and optionally [input][]. The subsequent messages + must contain [input][]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + Iterable[google.cloud.aiplatform_v1beta1.types.StreamingPredictResponse]: + Response message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1beta1.PredictionService.StreamingPredict]. + + """ + # Create or coerce a protobuf request object. + # Minor optimization to avoid making a copy if the user passes + # in a prediction_service.StreamingPredictRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, prediction_service.StreamingPredictRequest): + request = prediction_service.StreamingPredictRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.server_streaming_predict] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def explain( self, request: Optional[Union[prediction_service.ExplainRequest, dict]] = None, diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py index ea093fdce7..44ab49ffb0 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py @@ -138,6 +138,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.server_streaming_predict: gapic_v1.method.wrap_method( + self.server_streaming_predict, + default_timeout=None, + client_info=client_info, + ), self.explain: gapic_v1.method.wrap_method( self.explain, default_timeout=5.0, @@ -175,6 +180,18 @@ def raw_predict( ]: raise NotImplementedError() + @property + def server_streaming_predict( + self, + ) -> Callable[ + [prediction_service.StreamingPredictRequest], + Union[ + prediction_service.StreamingPredictResponse, + Awaitable[prediction_service.StreamingPredictResponse], + ], + ]: + raise NotImplementedError() + @property def explain( self, diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index 677304c25d..a6017621e7 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -298,6 +298,36 @@ def raw_predict( ) return self._stubs["raw_predict"] + @property + def server_streaming_predict( + self, + ) -> Callable[ + [prediction_service.StreamingPredictRequest], + prediction_service.StreamingPredictResponse, + ]: + r"""Return a callable for the server streaming predict method over gRPC. + + Perform a server-side streaming online prediction + request for Vertex LLM streaming. + + Returns: + Callable[[~.StreamingPredictRequest], + ~.StreamingPredictResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "server_streaming_predict" not in self._stubs: + self._stubs["server_streaming_predict"] = self.grpc_channel.unary_stream( + "/google.cloud.aiplatform.v1beta1.PredictionService/ServerStreamingPredict", + request_serializer=prediction_service.StreamingPredictRequest.serialize, + response_deserializer=prediction_service.StreamingPredictResponse.deserialize, + ) + return self._stubs["server_streaming_predict"] + @property def explain( self, diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py index 4bb8fdefc9..841c3c7898 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py @@ -304,6 +304,36 @@ def raw_predict( ) return self._stubs["raw_predict"] + @property + def server_streaming_predict( + self, + ) -> Callable[ + [prediction_service.StreamingPredictRequest], + Awaitable[prediction_service.StreamingPredictResponse], + ]: + r"""Return a callable for the server streaming predict method over gRPC. + + Perform a server-side streaming online prediction + request for Vertex LLM streaming. + + Returns: + Callable[[~.StreamingPredictRequest], + Awaitable[~.StreamingPredictResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "server_streaming_predict" not in self._stubs: + self._stubs["server_streaming_predict"] = self.grpc_channel.unary_stream( + "/google.cloud.aiplatform.v1beta1.PredictionService/ServerStreamingPredict", + request_serializer=prediction_service.StreamingPredictRequest.serialize, + response_deserializer=prediction_service.StreamingPredictResponse.deserialize, + ) + return self._stubs["server_streaming_predict"] + @property def explain( self, diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index f464f1e2bc..cd77aed5bc 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -512,7 +512,9 @@ ) from .persistent_resource import ( PersistentResource, + RaySpec, ResourcePool, + ResourceRuntime, ResourceRuntimeSpec, ServiceAccountSpec, ) @@ -551,6 +553,8 @@ PredictRequest, PredictResponse, RawPredictRequest, + StreamingPredictRequest, + StreamingPredictResponse, ) from .publisher_model import ( PublisherModel, @@ -675,6 +679,7 @@ DoubleArray, Int64Array, StringArray, + Tensor, ) from .unmanaged_container_model import ( UnmanagedContainerModel, @@ -1097,7 +1102,9 @@ "DeleteOperationMetadata", "GenericOperationMetadata", "PersistentResource", + "RaySpec", "ResourcePool", + "ResourceRuntime", "ResourceRuntimeSpec", "ServiceAccountSpec", "CreatePersistentResourceOperationMetadata", @@ -1130,6 +1137,8 @@ "PredictRequest", "PredictResponse", "RawPredictRequest", + "StreamingPredictRequest", + "StreamingPredictResponse", "PublisherModel", "SavedQuery", "Schedule", @@ -1222,6 +1231,7 @@ "DoubleArray", "Int64Array", "StringArray", + "Tensor", "UnmanagedContainerModel", "UserActionReference", "Value", diff --git a/google/cloud/aiplatform_v1beta1/types/explanation.py b/google/cloud/aiplatform_v1beta1/types/explanation.py index 89c796ba37..cc6821ad12 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation.py @@ -487,7 +487,9 @@ class IntegratedGradientsAttribution(proto.Message): blurred image to the input image is created. Using a blurred baseline instead of zero (black image) is motivated by the BlurIG approach - explained here: https://arxiv.org/abs/2004.03383 + explained here: + + https://arxiv.org/abs/2004.03383 """ step_count: int = proto.Field( @@ -510,7 +512,9 @@ class XraiAttribution(proto.Message): r"""An explanation method that redistributes Integrated Gradients attributions to segmented regions, taking advantage of the model's fully differentiable structure. Refer to this paper for - more details: https://arxiv.org/abs/1906.02825 + more details: + + https://arxiv.org/abs/1906.02825 Supported only by image Models. @@ -537,7 +541,9 @@ class XraiAttribution(proto.Message): blurred image to the input image is created. Using a blurred baseline instead of zero (black image) is motivated by the BlurIG approach - explained here: https://arxiv.org/abs/2004.03383 + explained here: + + https://arxiv.org/abs/2004.03383 """ step_count: int = proto.Field( @@ -562,6 +568,7 @@ class SmoothGradConfig(proto.Message): gradients from noisy samples in the vicinity of the inputs. Adding noise can help improve the computed gradients. Refer to this paper for more details: + https://arxiv.org/pdf/1706.03825.pdf This message has `oneof`_ fields (mutually exclusive fields). @@ -675,6 +682,7 @@ class BlurBaselineConfig(proto.Message): the input image is created. Using a blurred baseline instead of zero (black image) is motivated by the BlurIG approach explained here: + https://arxiv.org/abs/2004.03383 Attributes: diff --git a/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py b/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py index 7ef0ffc263..d2b23355c0 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py +++ b/google/cloud/aiplatform_v1beta1/types/model_deployment_monitoring_job.py @@ -147,6 +147,7 @@ class ModelDeploymentMonitoringJob(proto.Message): the job under customer project. Customer could do their own query & analysis. There could be 4 log tables in maximum: + 1. Training data logging predict request/response 2. Serving data logging predict request/response diff --git a/google/cloud/aiplatform_v1beta1/types/model_monitoring.py b/google/cloud/aiplatform_v1beta1/types/model_monitoring.py index 323cfb1235..48a85eb95b 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_monitoring.py +++ b/google/cloud/aiplatform_v1beta1/types/model_monitoring.py @@ -445,9 +445,10 @@ class ThresholdConfig(proto.Message): value (float): Specify a threshold value that can trigger the alert. If this threshold config is for - feature distribution distance: 1. For - categorical feature, the distribution distance - is calculated by L-inifinity norm. + feature distribution distance: + + 1. For categorical feature, the distribution + distance is calculated by L-inifinity norm. 2. For numerical feature, the distribution distance is calculated by Jensen–Shannon divergence. diff --git a/google/cloud/aiplatform_v1beta1/types/persistent_resource.py b/google/cloud/aiplatform_v1beta1/types/persistent_resource.py index a0b5c0e085..7f7ba49e2d 100644 --- a/google/cloud/aiplatform_v1beta1/types/persistent_resource.py +++ b/google/cloud/aiplatform_v1beta1/types/persistent_resource.py @@ -31,6 +31,8 @@ "PersistentResource", "ResourcePool", "ResourceRuntimeSpec", + "RaySpec", + "ResourceRuntime", "ServiceAccountSpec", }, ) @@ -44,7 +46,7 @@ class PersistentResource(proto.Message): Attributes: name (str): - Output only. Resource name of a + Immutable. Resource name of a PersistentResource. display_name (str): Optional. The display name of the @@ -104,6 +106,9 @@ class PersistentResource(proto.Message): resource_runtime_spec (google.cloud.aiplatform_v1beta1.types.ResourceRuntimeSpec): Optional. Persistent Resource runtime spec. Used for e.g. Ray cluster configuration. + resource_runtime (google.cloud.aiplatform_v1beta1.types.ResourceRuntime): + Output only. Runtime information of the + Persistent Resource. reserved_ip_ranges (MutableSequence[str]): Optional. A list of names for the reserved ip ranges under the VPC network that can be used for this persistent @@ -198,6 +203,11 @@ class State(proto.Enum): number=13, message="ResourceRuntimeSpec", ) + resource_runtime: "ResourceRuntime" = proto.Field( + proto.MESSAGE, + number=14, + message="ResourceRuntime", + ) reserved_ip_ranges: MutableSequence[str] = proto.RepeatedField( proto.STRING, number=15, @@ -317,6 +327,10 @@ class ResourceRuntimeSpec(proto.Message): service_account_spec (google.cloud.aiplatform_v1beta1.types.ServiceAccountSpec): Optional. Configure the use of workload identity on the PersistentResource + ray_spec (google.cloud.aiplatform_v1beta1.types.RaySpec): + Ray cluster configuration. + Required when creating a dedicated RayCluster on + the PersistentResource. """ service_account_spec: "ServiceAccountSpec" = proto.Field( @@ -324,6 +338,52 @@ class ResourceRuntimeSpec(proto.Message): number=2, message="ServiceAccountSpec", ) + ray_spec: "RaySpec" = proto.Field( + proto.MESSAGE, + number=1, + message="RaySpec", + ) + + +class RaySpec(proto.Message): + r"""Configuration information for the Ray cluster. + For experimental launch, Ray cluster creation and Persistent + cluster creation are 1:1 mapping: We will provision all the + nodes within the Persistent cluster as Ray nodes. + + Attributes: + image_uri (str): + Optional. Default image for user to choose a preferred ML + framework(e.g. tensorflow or Pytorch) by choosing from + Vertex prebuild + images(https://cloud.google.com/vertex-ai/docs/training/pre-built-containers). + Either this or the resource_pool_images is required. Use + this field if you need all the resource pools to have the + same Ray image, Otherwise, use the {@code + resource_pool_images} field. + """ + + image_uri: str = proto.Field( + proto.STRING, + number=1, + ) + + +class ResourceRuntime(proto.Message): + r"""Persistent Cluster runtime information as output + + Attributes: + access_uris (MutableMapping[str, str]): + Output only. URIs for user to connect to the Cluster. + Example: { "RAY_HEAD_NODE_INTERNAL_IP": "head-node-IP:10001" + "RAY_DASHBOARD_URI": "ray-dashboard-address:8888" } + """ + + access_uris: MutableMapping[str, str] = proto.MapField( + proto.STRING, + proto.STRING, + number=1, + ) class ServiceAccountSpec(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index 1365ccc32a..ee907296da 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -21,6 +21,7 @@ from google.api import httpbody_pb2 # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import types from google.protobuf import struct_pb2 # type: ignore @@ -30,6 +31,8 @@ "PredictRequest", "PredictResponse", "RawPredictRequest", + "StreamingPredictRequest", + "StreamingPredictResponse", "ExplainRequest", "ExplainResponse", }, @@ -110,6 +113,10 @@ class PredictResponse(proto.Message): name][google.cloud.aiplatform.v1beta1.Model.display_name] of the Model which is deployed as the DeployedModel that this prediction hits. + metadata (google.protobuf.struct_pb2.Value): + Output only. Request-level metadata returned + by the model. The metadata type will be + dependent upon the model implementation. """ predictions: MutableSequence[struct_pb2.Value] = proto.RepeatedField( @@ -133,6 +140,11 @@ class PredictResponse(proto.Message): proto.STRING, number=4, ) + metadata: struct_pb2.Value = proto.Field( + proto.MESSAGE, + number=6, + message=struct_pb2.Value, + ) class RawPredictRequest(proto.Message): @@ -178,6 +190,65 @@ class RawPredictRequest(proto.Message): ) +class StreamingPredictRequest(proto.Message): + r"""Request message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1beta1.PredictionService.StreamingPredict]. + + The first message must contain + [endpoint][google.cloud.aiplatform.v1beta1.StreamingPredictRequest.endpoint] + field and optionally [input][]. The subsequent messages must contain + [input][]. + + Attributes: + endpoint (str): + Required. The name of the Endpoint requested to serve the + prediction. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + inputs (MutableSequence[google.cloud.aiplatform_v1beta1.types.Tensor]): + The prediction input. + parameters (google.cloud.aiplatform_v1beta1.types.Tensor): + The parameters that govern the prediction. + """ + + endpoint: str = proto.Field( + proto.STRING, + number=1, + ) + inputs: MutableSequence[types.Tensor] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=types.Tensor, + ) + parameters: types.Tensor = proto.Field( + proto.MESSAGE, + number=3, + message=types.Tensor, + ) + + +class StreamingPredictResponse(proto.Message): + r"""Response message for + [PredictionService.StreamingPredict][google.cloud.aiplatform.v1beta1.PredictionService.StreamingPredict]. + + Attributes: + outputs (MutableSequence[google.cloud.aiplatform_v1beta1.types.Tensor]): + The prediction output. + parameters (google.cloud.aiplatform_v1beta1.types.Tensor): + The parameters that govern the prediction. + """ + + outputs: MutableSequence[types.Tensor] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=types.Tensor, + ) + parameters: types.Tensor = proto.Field( + proto.MESSAGE, + number=2, + message=types.Tensor, + ) + + class ExplainRequest(proto.Message): r"""Request message for [PredictionService.Explain][google.cloud.aiplatform.v1beta1.PredictionService.Explain]. diff --git a/google/cloud/aiplatform_v1beta1/types/types.py b/google/cloud/aiplatform_v1beta1/types/types.py index 95fa72137b..5f64363797 100644 --- a/google/cloud/aiplatform_v1beta1/types/types.py +++ b/google/cloud/aiplatform_v1beta1/types/types.py @@ -27,6 +27,7 @@ "DoubleArray", "Int64Array", "StringArray", + "Tensor", }, ) @@ -87,4 +88,156 @@ class StringArray(proto.Message): ) +class Tensor(proto.Message): + r"""A tensor value type. + + Attributes: + dtype (google.cloud.aiplatform_v1beta1.types.Tensor.DataType): + The data type of tensor. + shape (MutableSequence[int]): + Shape of the tensor. + bool_val (MutableSequence[bool]): + Type specific representations that make it easy to create + tensor protos in all languages. Only the representation + corresponding to "dtype" can be set. The values hold the + flattened representation of the tensor in row major order. + + [BOOL][google.aiplatform.master.Tensor.DataType.BOOL] + string_val (MutableSequence[str]): + [STRING][google.aiplatform.master.Tensor.DataType.STRING] + bytes_val (MutableSequence[bytes]): + [STRING][google.aiplatform.master.Tensor.DataType.STRING] + float_val (MutableSequence[float]): + [FLOAT][google.aiplatform.master.Tensor.DataType.FLOAT] + double_val (MutableSequence[float]): + [DOUBLE][google.aiplatform.master.Tensor.DataType.DOUBLE] + int_val (MutableSequence[int]): + [INT_8][google.aiplatform.master.Tensor.DataType.INT8] + [INT_16][google.aiplatform.master.Tensor.DataType.INT16] + [INT_32][google.aiplatform.master.Tensor.DataType.INT32] + int64_val (MutableSequence[int]): + [INT64][google.aiplatform.master.Tensor.DataType.INT64] + uint_val (MutableSequence[int]): + [UINT8][google.aiplatform.master.Tensor.DataType.UINT8] + [UINT16][google.aiplatform.master.Tensor.DataType.UINT16] + [UINT32][google.aiplatform.master.Tensor.DataType.UINT32] + uint64_val (MutableSequence[int]): + [UINT64][google.aiplatform.master.Tensor.DataType.UINT64] + list_val (MutableSequence[google.cloud.aiplatform_v1beta1.types.Tensor]): + A list of tensor values. + struct_val (MutableMapping[str, google.cloud.aiplatform_v1beta1.types.Tensor]): + A map of string to tensor. + tensor_val (bytes): + Serialized raw tensor content. + """ + + class DataType(proto.Enum): + r"""Data type of the tensor. + + Values: + DATA_TYPE_UNSPECIFIED (0): + Not a legal value for DataType. Used to + indicate a DataType field has not been set. + BOOL (1): + Data types that all computation devices are + expected to be capable to support. + STRING (2): + No description available. + FLOAT (3): + No description available. + DOUBLE (4): + No description available. + INT8 (5): + No description available. + INT16 (6): + No description available. + INT32 (7): + No description available. + INT64 (8): + No description available. + UINT8 (9): + No description available. + UINT16 (10): + No description available. + UINT32 (11): + No description available. + UINT64 (12): + No description available. + """ + DATA_TYPE_UNSPECIFIED = 0 + BOOL = 1 + STRING = 2 + FLOAT = 3 + DOUBLE = 4 + INT8 = 5 + INT16 = 6 + INT32 = 7 + INT64 = 8 + UINT8 = 9 + UINT16 = 10 + UINT32 = 11 + UINT64 = 12 + + dtype: DataType = proto.Field( + proto.ENUM, + number=1, + enum=DataType, + ) + shape: MutableSequence[int] = proto.RepeatedField( + proto.INT64, + number=2, + ) + bool_val: MutableSequence[bool] = proto.RepeatedField( + proto.BOOL, + number=3, + ) + string_val: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=14, + ) + bytes_val: MutableSequence[bytes] = proto.RepeatedField( + proto.BYTES, + number=15, + ) + float_val: MutableSequence[float] = proto.RepeatedField( + proto.FLOAT, + number=5, + ) + double_val: MutableSequence[float] = proto.RepeatedField( + proto.DOUBLE, + number=6, + ) + int_val: MutableSequence[int] = proto.RepeatedField( + proto.INT32, + number=7, + ) + int64_val: MutableSequence[int] = proto.RepeatedField( + proto.INT64, + number=8, + ) + uint_val: MutableSequence[int] = proto.RepeatedField( + proto.UINT32, + number=9, + ) + uint64_val: MutableSequence[int] = proto.RepeatedField( + proto.UINT64, + number=10, + ) + list_val: MutableSequence["Tensor"] = proto.RepeatedField( + proto.MESSAGE, + number=11, + message="Tensor", + ) + struct_val: MutableMapping[str, "Tensor"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=12, + message="Tensor", + ) + tensor_val: bytes = proto.Field( + proto.BYTES, + number=13, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/samples/generated_samples/aiplatform_v1_generated_prediction_service_server_streaming_predict_async.py b/samples/generated_samples/aiplatform_v1_generated_prediction_service_server_streaming_predict_async.py new file mode 100644 index 0000000000..b67602b865 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_prediction_service_server_streaming_predict_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ServerStreamingPredict +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_PredictionService_ServerStreamingPredict_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_server_streaming_predict(): + # Create a client + client = aiplatform_v1.PredictionServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.StreamingPredictRequest( + endpoint="endpoint_value", + ) + + # Make the request + stream = await client.server_streaming_predict(request=request) + + # Handle the response + async for response in stream: + print(response) + +# [END aiplatform_v1_generated_PredictionService_ServerStreamingPredict_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_prediction_service_server_streaming_predict_sync.py b/samples/generated_samples/aiplatform_v1_generated_prediction_service_server_streaming_predict_sync.py new file mode 100644 index 0000000000..51d6510487 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_prediction_service_server_streaming_predict_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ServerStreamingPredict +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_PredictionService_ServerStreamingPredict_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_server_streaming_predict(): + # Create a client + client = aiplatform_v1.PredictionServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.StreamingPredictRequest( + endpoint="endpoint_value", + ) + + # Make the request + stream = client.server_streaming_predict(request=request) + + # Handle the response + for response in stream: + print(response) + +# [END aiplatform_v1_generated_PredictionService_ServerStreamingPredict_sync] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_create_schedule_async.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_create_schedule_async.py new file mode 100644 index 0000000000..302dc65065 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_create_schedule_async.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_CreateSchedule_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_create_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + schedule = aiplatform_v1.Schedule() + schedule.cron = "cron_value" + schedule.create_pipeline_job_request.parent = "parent_value" + schedule.display_name = "display_name_value" + schedule.max_concurrent_run_count = 2596 + + request = aiplatform_v1.CreateScheduleRequest( + parent="parent_value", + schedule=schedule, + ) + + # Make the request + response = await client.create_schedule(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_ScheduleService_CreateSchedule_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_create_schedule_sync.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_create_schedule_sync.py new file mode 100644 index 0000000000..c3a247e58a --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_create_schedule_sync.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_CreateSchedule_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_create_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + schedule = aiplatform_v1.Schedule() + schedule.cron = "cron_value" + schedule.create_pipeline_job_request.parent = "parent_value" + schedule.display_name = "display_name_value" + schedule.max_concurrent_run_count = 2596 + + request = aiplatform_v1.CreateScheduleRequest( + parent="parent_value", + schedule=schedule, + ) + + # Make the request + response = client.create_schedule(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_ScheduleService_CreateSchedule_sync] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_delete_schedule_async.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_delete_schedule_async.py new file mode 100644 index 0000000000..c25a6b3c5c --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_delete_schedule_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_DeleteSchedule_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_delete_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.DeleteScheduleRequest( + name="name_value", + ) + + # Make the request + operation = client.delete_schedule(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_ScheduleService_DeleteSchedule_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_delete_schedule_sync.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_delete_schedule_sync.py new file mode 100644 index 0000000000..04b4725181 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_delete_schedule_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_DeleteSchedule_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_delete_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.DeleteScheduleRequest( + name="name_value", + ) + + # Make the request + operation = client.delete_schedule(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_ScheduleService_DeleteSchedule_sync] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_get_schedule_async.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_get_schedule_async.py new file mode 100644 index 0000000000..52af15e3f1 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_get_schedule_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_GetSchedule_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_get_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.GetScheduleRequest( + name="name_value", + ) + + # Make the request + response = await client.get_schedule(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_ScheduleService_GetSchedule_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_get_schedule_sync.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_get_schedule_sync.py new file mode 100644 index 0000000000..7befbd5bcd --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_get_schedule_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_GetSchedule_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_get_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.GetScheduleRequest( + name="name_value", + ) + + # Make the request + response = client.get_schedule(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_ScheduleService_GetSchedule_sync] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_list_schedules_async.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_list_schedules_async.py new file mode 100644 index 0000000000..907158ddc2 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_list_schedules_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListSchedules +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_ListSchedules_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_list_schedules(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.ListSchedulesRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_schedules(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END aiplatform_v1_generated_ScheduleService_ListSchedules_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_list_schedules_sync.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_list_schedules_sync.py new file mode 100644 index 0000000000..ff397d54d2 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_list_schedules_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListSchedules +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_ListSchedules_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_list_schedules(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.ListSchedulesRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_schedules(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END aiplatform_v1_generated_ScheduleService_ListSchedules_sync] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_pause_schedule_async.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_pause_schedule_async.py new file mode 100644 index 0000000000..68fcf5b8c4 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_pause_schedule_async.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for PauseSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_PauseSchedule_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_pause_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.PauseScheduleRequest( + name="name_value", + ) + + # Make the request + await client.pause_schedule(request=request) + + +# [END aiplatform_v1_generated_ScheduleService_PauseSchedule_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_pause_schedule_sync.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_pause_schedule_sync.py new file mode 100644 index 0000000000..7ac0981a59 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_pause_schedule_sync.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for PauseSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_PauseSchedule_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_pause_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.PauseScheduleRequest( + name="name_value", + ) + + # Make the request + client.pause_schedule(request=request) + + +# [END aiplatform_v1_generated_ScheduleService_PauseSchedule_sync] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_resume_schedule_async.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_resume_schedule_async.py new file mode 100644 index 0000000000..9df03c71f9 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_resume_schedule_async.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ResumeSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_ResumeSchedule_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_resume_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.ResumeScheduleRequest( + name="name_value", + ) + + # Make the request + await client.resume_schedule(request=request) + + +# [END aiplatform_v1_generated_ScheduleService_ResumeSchedule_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_resume_schedule_sync.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_resume_schedule_sync.py new file mode 100644 index 0000000000..894959fcb1 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_resume_schedule_sync.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ResumeSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_ResumeSchedule_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_resume_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.ResumeScheduleRequest( + name="name_value", + ) + + # Make the request + client.resume_schedule(request=request) + + +# [END aiplatform_v1_generated_ScheduleService_ResumeSchedule_sync] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_update_schedule_async.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_update_schedule_async.py new file mode 100644 index 0000000000..113435f7e1 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_update_schedule_async.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_UpdateSchedule_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_update_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceAsyncClient() + + # Initialize request argument(s) + schedule = aiplatform_v1.Schedule() + schedule.cron = "cron_value" + schedule.create_pipeline_job_request.parent = "parent_value" + schedule.display_name = "display_name_value" + schedule.max_concurrent_run_count = 2596 + + request = aiplatform_v1.UpdateScheduleRequest( + schedule=schedule, + ) + + # Make the request + response = await client.update_schedule(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_ScheduleService_UpdateSchedule_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_schedule_service_update_schedule_sync.py b/samples/generated_samples/aiplatform_v1_generated_schedule_service_update_schedule_sync.py new file mode 100644 index 0000000000..d61e95102b --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_schedule_service_update_schedule_sync.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateSchedule +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_ScheduleService_UpdateSchedule_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_update_schedule(): + # Create a client + client = aiplatform_v1.ScheduleServiceClient() + + # Initialize request argument(s) + schedule = aiplatform_v1.Schedule() + schedule.cron = "cron_value" + schedule.create_pipeline_job_request.parent = "parent_value" + schedule.display_name = "display_name_value" + schedule.max_concurrent_run_count = 2596 + + request = aiplatform_v1.UpdateScheduleRequest( + schedule=schedule, + ) + + # Make the request + response = client.update_schedule(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_ScheduleService_UpdateSchedule_sync] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_async.py b/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_async.py new file mode 100644 index 0000000000..648386531c --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ServerStreamingPredict +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_PredictionService_ServerStreamingPredict_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +async def sample_server_streaming_predict(): + # Create a client + client = aiplatform_v1beta1.PredictionServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.StreamingPredictRequest( + endpoint="endpoint_value", + ) + + # Make the request + stream = await client.server_streaming_predict(request=request) + + # Handle the response + async for response in stream: + print(response) + +# [END aiplatform_v1beta1_generated_PredictionService_ServerStreamingPredict_async] diff --git a/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_sync.py b/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_sync.py new file mode 100644 index 0000000000..f712d4d0c9 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ServerStreamingPredict +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1beta1_generated_PredictionService_ServerStreamingPredict_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1beta1 + + +def sample_server_streaming_predict(): + # Create a client + client = aiplatform_v1beta1.PredictionServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.StreamingPredictRequest( + endpoint="endpoint_value", + ) + + # Make the request + stream = client.server_streaming_predict(request=request) + + # Handle the response + for response in stream: + print(response) + +# [END aiplatform_v1beta1_generated_PredictionService_ServerStreamingPredict_sync] diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index 78aaf3e5fd..89ca575886 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -26784,6 +26784,1298 @@ ], "title": "aiplatform_v1_generated_prediction_service_raw_predict_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.PredictionServiceAsyncClient", + "shortName": "PredictionServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.PredictionServiceAsyncClient.server_streaming_predict", + "method": { + "fullName": "google.cloud.aiplatform.v1.PredictionService.ServerStreamingPredict", + "service": { + "fullName": "google.cloud.aiplatform.v1.PredictionService", + "shortName": "PredictionService" + }, + "shortName": "ServerStreamingPredict" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.StreamingPredictRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "Iterable[google.cloud.aiplatform_v1.types.StreamingPredictResponse]", + "shortName": "server_streaming_predict" + }, + "description": "Sample for ServerStreamingPredict", + "file": "aiplatform_v1_generated_prediction_service_server_streaming_predict_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_PredictionService_ServerStreamingPredict_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_prediction_service_server_streaming_predict_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.PredictionServiceClient", + "shortName": "PredictionServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.PredictionServiceClient.server_streaming_predict", + "method": { + "fullName": "google.cloud.aiplatform.v1.PredictionService.ServerStreamingPredict", + "service": { + "fullName": "google.cloud.aiplatform.v1.PredictionService", + "shortName": "PredictionService" + }, + "shortName": "ServerStreamingPredict" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.StreamingPredictRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "Iterable[google.cloud.aiplatform_v1.types.StreamingPredictResponse]", + "shortName": "server_streaming_predict" + }, + "description": "Sample for ServerStreamingPredict", + "file": "aiplatform_v1_generated_prediction_service_server_streaming_predict_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_PredictionService_ServerStreamingPredict_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_prediction_service_server_streaming_predict_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient", + "shortName": "ScheduleServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient.create_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.CreateSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "CreateSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.CreateScheduleRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "schedule", + "type": "google.cloud.aiplatform_v1.types.Schedule" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1.types.Schedule", + "shortName": "create_schedule" + }, + "description": "Sample for CreateSchedule", + "file": "aiplatform_v1_generated_schedule_service_create_schedule_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_CreateSchedule_async", + "segments": [ + { + "end": 58, + "start": 27, + "type": "FULL" + }, + { + "end": 58, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 52, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 55, + "start": 53, + "type": "REQUEST_EXECUTION" + }, + { + "end": 59, + "start": 56, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_create_schedule_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient", + "shortName": "ScheduleServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient.create_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.CreateSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "CreateSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.CreateScheduleRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "schedule", + "type": "google.cloud.aiplatform_v1.types.Schedule" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1.types.Schedule", + "shortName": "create_schedule" + }, + "description": "Sample for CreateSchedule", + "file": "aiplatform_v1_generated_schedule_service_create_schedule_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_CreateSchedule_sync", + "segments": [ + { + "end": 58, + "start": 27, + "type": "FULL" + }, + { + "end": 58, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 52, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 55, + "start": 53, + "type": "REQUEST_EXECUTION" + }, + { + "end": 59, + "start": 56, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_create_schedule_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient", + "shortName": "ScheduleServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient.delete_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.DeleteSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "DeleteSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.DeleteScheduleRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation_async.AsyncOperation", + "shortName": "delete_schedule" + }, + "description": "Sample for DeleteSchedule", + "file": "aiplatform_v1_generated_schedule_service_delete_schedule_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_DeleteSchedule_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_delete_schedule_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient", + "shortName": "ScheduleServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient.delete_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.DeleteSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "DeleteSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.DeleteScheduleRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation.Operation", + "shortName": "delete_schedule" + }, + "description": "Sample for DeleteSchedule", + "file": "aiplatform_v1_generated_schedule_service_delete_schedule_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_DeleteSchedule_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_delete_schedule_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient", + "shortName": "ScheduleServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient.get_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.GetSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "GetSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.GetScheduleRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1.types.Schedule", + "shortName": "get_schedule" + }, + "description": "Sample for GetSchedule", + "file": "aiplatform_v1_generated_schedule_service_get_schedule_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_GetSchedule_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_get_schedule_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient", + "shortName": "ScheduleServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient.get_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.GetSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "GetSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.GetScheduleRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1.types.Schedule", + "shortName": "get_schedule" + }, + "description": "Sample for GetSchedule", + "file": "aiplatform_v1_generated_schedule_service_get_schedule_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_GetSchedule_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_get_schedule_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient", + "shortName": "ScheduleServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient.list_schedules", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.ListSchedules", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "ListSchedules" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.ListSchedulesRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1.services.schedule_service.pagers.ListSchedulesAsyncPager", + "shortName": "list_schedules" + }, + "description": "Sample for ListSchedules", + "file": "aiplatform_v1_generated_schedule_service_list_schedules_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_ListSchedules_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_list_schedules_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient", + "shortName": "ScheduleServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient.list_schedules", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.ListSchedules", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "ListSchedules" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.ListSchedulesRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1.services.schedule_service.pagers.ListSchedulesPager", + "shortName": "list_schedules" + }, + "description": "Sample for ListSchedules", + "file": "aiplatform_v1_generated_schedule_service_list_schedules_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_ListSchedules_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_list_schedules_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient", + "shortName": "ScheduleServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient.pause_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.PauseSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "PauseSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.PauseScheduleRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "shortName": "pause_schedule" + }, + "description": "Sample for PauseSchedule", + "file": "aiplatform_v1_generated_schedule_service_pause_schedule_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_PauseSchedule_async", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_pause_schedule_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient", + "shortName": "ScheduleServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient.pause_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.PauseSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "PauseSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.PauseScheduleRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "shortName": "pause_schedule" + }, + "description": "Sample for PauseSchedule", + "file": "aiplatform_v1_generated_schedule_service_pause_schedule_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_PauseSchedule_sync", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_pause_schedule_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient", + "shortName": "ScheduleServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient.resume_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.ResumeSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "ResumeSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.ResumeScheduleRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "catch_up", + "type": "bool" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "shortName": "resume_schedule" + }, + "description": "Sample for ResumeSchedule", + "file": "aiplatform_v1_generated_schedule_service_resume_schedule_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_ResumeSchedule_async", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_resume_schedule_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient", + "shortName": "ScheduleServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient.resume_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.ResumeSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "ResumeSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.ResumeScheduleRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "catch_up", + "type": "bool" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "shortName": "resume_schedule" + }, + "description": "Sample for ResumeSchedule", + "file": "aiplatform_v1_generated_schedule_service_resume_schedule_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_ResumeSchedule_sync", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_resume_schedule_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient", + "shortName": "ScheduleServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceAsyncClient.update_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.UpdateSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "UpdateSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.UpdateScheduleRequest" + }, + { + "name": "schedule", + "type": "google.cloud.aiplatform_v1.types.Schedule" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1.types.Schedule", + "shortName": "update_schedule" + }, + "description": "Sample for UpdateSchedule", + "file": "aiplatform_v1_generated_schedule_service_update_schedule_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_UpdateSchedule_async", + "segments": [ + { + "end": 57, + "start": 27, + "type": "FULL" + }, + { + "end": 57, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 51, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 54, + "start": 52, + "type": "REQUEST_EXECUTION" + }, + { + "end": 58, + "start": 55, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_update_schedule_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient", + "shortName": "ScheduleServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.ScheduleServiceClient.update_schedule", + "method": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService.UpdateSchedule", + "service": { + "fullName": "google.cloud.aiplatform.v1.ScheduleService", + "shortName": "ScheduleService" + }, + "shortName": "UpdateSchedule" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.UpdateScheduleRequest" + }, + { + "name": "schedule", + "type": "google.cloud.aiplatform_v1.types.Schedule" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1.types.Schedule", + "shortName": "update_schedule" + }, + "description": "Sample for UpdateSchedule", + "file": "aiplatform_v1_generated_schedule_service_update_schedule_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_ScheduleService_UpdateSchedule_sync", + "segments": [ + { + "end": 57, + "start": 27, + "type": "FULL" + }, + { + "end": 57, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 51, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 54, + "start": 52, + "type": "REQUEST_EXECUTION" + }, + { + "end": 58, + "start": 55, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_schedule_service_update_schedule_sync.py" + }, { "canonical": true, "clientMethod": { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index 05a52b9c3b..fa25d9eb2c 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -28265,6 +28265,159 @@ ], "title": "aiplatform_v1beta1_generated_prediction_service_raw_predict_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceAsyncClient", + "shortName": "PredictionServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceAsyncClient.server_streaming_predict", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.PredictionService.ServerStreamingPredict", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.PredictionService", + "shortName": "PredictionService" + }, + "shortName": "ServerStreamingPredict" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.StreamingPredictRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "Iterable[google.cloud.aiplatform_v1beta1.types.StreamingPredictResponse]", + "shortName": "server_streaming_predict" + }, + "description": "Sample for ServerStreamingPredict", + "file": "aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_PredictionService_ServerStreamingPredict_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceClient", + "shortName": "PredictionServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1beta1.PredictionServiceClient.server_streaming_predict", + "method": { + "fullName": "google.cloud.aiplatform.v1beta1.PredictionService.ServerStreamingPredict", + "service": { + "fullName": "google.cloud.aiplatform.v1beta1.PredictionService", + "shortName": "PredictionService" + }, + "shortName": "ServerStreamingPredict" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1beta1.types.StreamingPredictRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "Iterable[google.cloud.aiplatform_v1beta1.types.StreamingPredictResponse]", + "shortName": "server_streaming_predict" + }, + "description": "Sample for ServerStreamingPredict", + "file": "aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1beta1_generated_PredictionService_ServerStreamingPredict_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1beta1_generated_prediction_service_server_streaming_predict_sync.py" + }, { "canonical": true, "clientMethod": { diff --git a/tests/unit/gapic/aiplatform_v1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1/test_prediction_service.py index cc6f7dabec..a0f9d5944e 100644 --- a/tests/unit/gapic/aiplatform_v1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_prediction_service.py @@ -48,6 +48,7 @@ from google.cloud.aiplatform_v1.types import explanation from google.cloud.aiplatform_v1.types import io from google.cloud.aiplatform_v1.types import prediction_service +from google.cloud.aiplatform_v1.types import types from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore @@ -1149,6 +1150,165 @@ async def test_raw_predict_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [ + prediction_service.StreamingPredictRequest, + dict, + ], +) +def test_server_streaming_predict(request_type, transport: str = "grpc"): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.server_streaming_predict), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iter([prediction_service.StreamingPredictResponse()]) + response = client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == prediction_service.StreamingPredictRequest() + + # Establish that the response is the type that we expect. + for message in response: + assert isinstance(message, prediction_service.StreamingPredictResponse) + + +def test_server_streaming_predict_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.server_streaming_predict), "__call__" + ) as call: + client.server_streaming_predict() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == prediction_service.StreamingPredictRequest() + + +@pytest.mark.asyncio +async def test_server_streaming_predict_async( + transport: str = "grpc_asyncio", + request_type=prediction_service.StreamingPredictRequest, +): + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.server_streaming_predict), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[prediction_service.StreamingPredictResponse()] + ) + response = await client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == prediction_service.StreamingPredictRequest() + + # Establish that the response is the type that we expect. + message = await response.read() + assert isinstance(message, prediction_service.StreamingPredictResponse) + + +@pytest.mark.asyncio +async def test_server_streaming_predict_async_from_dict(): + await test_server_streaming_predict_async(request_type=dict) + + +def test_server_streaming_predict_field_headers(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.StreamingPredictRequest() + + request.endpoint = "endpoint_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.server_streaming_predict), "__call__" + ) as call: + call.return_value = iter([prediction_service.StreamingPredictResponse()]) + client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "endpoint=endpoint_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_server_streaming_predict_field_headers_async(): + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.StreamingPredictRequest() + + request.endpoint = "endpoint_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.server_streaming_predict), "__call__" + ) as call: + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[prediction_service.StreamingPredictResponse()] + ) + await client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "endpoint=endpoint_value", + ) in kw["metadata"] + + @pytest.mark.parametrize( "request_type", [ @@ -1473,6 +1633,7 @@ def test_prediction_service_base_transport(): methods = ( "predict", "raw_predict", + "server_streaming_predict", "explain", "set_iam_policy", "get_iam_policy", diff --git a/tests/unit/gapic/aiplatform_v1/test_schedule_service.py b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py new file mode 100644 index 0000000000..0af4dcff66 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py @@ -0,0 +1,5067 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule +from proto.marshal.rules import wrappers + +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.api_core import path_template +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1.services.schedule_service import ( + ScheduleServiceAsyncClient, +) +from google.cloud.aiplatform_v1.services.schedule_service import ScheduleServiceClient +from google.cloud.aiplatform_v1.services.schedule_service import pagers +from google.cloud.aiplatform_v1.services.schedule_service import transports +from google.cloud.aiplatform_v1.types import artifact +from google.cloud.aiplatform_v1.types import context +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import execution +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import pipeline_failure_policy +from google.cloud.aiplatform_v1.types import pipeline_job +from google.cloud.aiplatform_v1.types import pipeline_service +from google.cloud.aiplatform_v1.types import pipeline_state +from google.cloud.aiplatform_v1.types import schedule +from google.cloud.aiplatform_v1.types import schedule as gca_schedule +from google.cloud.aiplatform_v1.types import schedule_service +from google.cloud.aiplatform_v1.types import value +from google.cloud.location import locations_pb2 +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import options_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import any_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import struct_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from google.rpc import status_pb2 # type: ignore +import google.auth + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert ScheduleServiceClient._get_default_mtls_endpoint(None) is None + assert ( + ScheduleServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + ScheduleServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + ScheduleServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ScheduleServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ScheduleServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (ScheduleServiceClient, "grpc"), + (ScheduleServiceAsyncClient, "grpc_asyncio"), + ], +) +def test_schedule_service_client_from_service_account_info( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ("aiplatform.googleapis.com:443") + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.ScheduleServiceGrpcTransport, "grpc"), + (transports.ScheduleServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_schedule_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (ScheduleServiceClient, "grpc"), + (ScheduleServiceAsyncClient, "grpc_asyncio"), + ], +) +def test_schedule_service_client_from_service_account_file( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ("aiplatform.googleapis.com:443") + + +def test_schedule_service_client_get_transport_class(): + transport = ScheduleServiceClient.get_transport_class() + available_transports = [ + transports.ScheduleServiceGrpcTransport, + ] + assert transport in available_transports + + transport = ScheduleServiceClient.get_transport_class("grpc") + assert transport == transports.ScheduleServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ScheduleServiceClient, transports.ScheduleServiceGrpcTransport, "grpc"), + ( + ScheduleServiceAsyncClient, + transports.ScheduleServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + ScheduleServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ScheduleServiceClient), +) +@mock.patch.object( + ScheduleServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ScheduleServiceAsyncClient), +) +def test_schedule_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(ScheduleServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(ScheduleServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class(transport=transport_name) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class(transport=transport_name) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + ScheduleServiceClient, + transports.ScheduleServiceGrpcTransport, + "grpc", + "true", + ), + ( + ScheduleServiceAsyncClient, + transports.ScheduleServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + ScheduleServiceClient, + transports.ScheduleServiceGrpcTransport, + "grpc", + "false", + ), + ( + ScheduleServiceAsyncClient, + transports.ScheduleServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + ScheduleServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ScheduleServiceClient), +) +@mock.patch.object( + ScheduleServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ScheduleServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_schedule_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class", [ScheduleServiceClient, ScheduleServiceAsyncClient] +) +@mock.patch.object( + ScheduleServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ScheduleServiceClient), +) +@mock.patch.object( + ScheduleServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ScheduleServiceAsyncClient), +) +def test_schedule_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ScheduleServiceClient, transports.ScheduleServiceGrpcTransport, "grpc"), + ( + ScheduleServiceAsyncClient, + transports.ScheduleServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_schedule_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + ScheduleServiceClient, + transports.ScheduleServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + ScheduleServiceAsyncClient, + transports.ScheduleServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_schedule_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_schedule_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1.services.schedule_service.transports.ScheduleServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = ScheduleServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + ScheduleServiceClient, + transports.ScheduleServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + ScheduleServiceAsyncClient, + transports.ScheduleServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_schedule_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + schedule_service.CreateScheduleRequest, + dict, + ], +) +def test_create_schedule(request_type, transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_schedule.Schedule( + name="name_value", + display_name="display_name_value", + max_run_count=1410, + started_run_count=1843, + state=gca_schedule.Schedule.State.ACTIVE, + max_concurrent_run_count=2596, + allow_queueing=True, + catch_up=True, + cron="cron_value", + ) + response = client.create_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.CreateScheduleRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_schedule.Schedule) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.max_run_count == 1410 + assert response.started_run_count == 1843 + assert response.state == gca_schedule.Schedule.State.ACTIVE + assert response.max_concurrent_run_count == 2596 + assert response.allow_queueing is True + assert response.catch_up is True + + +def test_create_schedule_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_schedule), "__call__") as call: + client.create_schedule() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.CreateScheduleRequest() + + +@pytest.mark.asyncio +async def test_create_schedule_async( + transport: str = "grpc_asyncio", request_type=schedule_service.CreateScheduleRequest +): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_schedule.Schedule( + name="name_value", + display_name="display_name_value", + max_run_count=1410, + started_run_count=1843, + state=gca_schedule.Schedule.State.ACTIVE, + max_concurrent_run_count=2596, + allow_queueing=True, + catch_up=True, + ) + ) + response = await client.create_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.CreateScheduleRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_schedule.Schedule) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.max_run_count == 1410 + assert response.started_run_count == 1843 + assert response.state == gca_schedule.Schedule.State.ACTIVE + assert response.max_concurrent_run_count == 2596 + assert response.allow_queueing is True + assert response.catch_up is True + + +@pytest.mark.asyncio +async def test_create_schedule_async_from_dict(): + await test_create_schedule_async(request_type=dict) + + +def test_create_schedule_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.CreateScheduleRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_schedule), "__call__") as call: + call.return_value = gca_schedule.Schedule() + client.create_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_schedule_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.CreateScheduleRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_schedule), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_schedule.Schedule() + ) + await client.create_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +def test_create_schedule_flattened(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_schedule.Schedule() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_schedule( + parent="parent_value", + schedule=gca_schedule.Schedule(cron="cron_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + arg = args[0].schedule + mock_val = gca_schedule.Schedule(cron="cron_value") + assert arg == mock_val + + +def test_create_schedule_flattened_error(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_schedule( + schedule_service.CreateScheduleRequest(), + parent="parent_value", + schedule=gca_schedule.Schedule(cron="cron_value"), + ) + + +@pytest.mark.asyncio +async def test_create_schedule_flattened_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_schedule.Schedule() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_schedule.Schedule() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_schedule( + parent="parent_value", + schedule=gca_schedule.Schedule(cron="cron_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + arg = args[0].schedule + mock_val = gca_schedule.Schedule(cron="cron_value") + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_create_schedule_flattened_error_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_schedule( + schedule_service.CreateScheduleRequest(), + parent="parent_value", + schedule=gca_schedule.Schedule(cron="cron_value"), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + schedule_service.DeleteScheduleRequest, + dict, + ], +) +def test_delete_schedule(request_type, transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.delete_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.DeleteScheduleRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_schedule_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_schedule), "__call__") as call: + client.delete_schedule() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.DeleteScheduleRequest() + + +@pytest.mark.asyncio +async def test_delete_schedule_async( + transport: str = "grpc_asyncio", request_type=schedule_service.DeleteScheduleRequest +): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.delete_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.DeleteScheduleRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_schedule_async_from_dict(): + await test_delete_schedule_async(request_type=dict) + + +def test_delete_schedule_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.DeleteScheduleRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_schedule), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.delete_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_schedule_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.DeleteScheduleRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_schedule), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.delete_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_delete_schedule_flattened(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_schedule( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_delete_schedule_flattened_error(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_schedule( + schedule_service.DeleteScheduleRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_schedule_flattened_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_schedule( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_delete_schedule_flattened_error_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_schedule( + schedule_service.DeleteScheduleRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + schedule_service.GetScheduleRequest, + dict, + ], +) +def test_get_schedule(request_type, transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = schedule.Schedule( + name="name_value", + display_name="display_name_value", + max_run_count=1410, + started_run_count=1843, + state=schedule.Schedule.State.ACTIVE, + max_concurrent_run_count=2596, + allow_queueing=True, + catch_up=True, + cron="cron_value", + ) + response = client.get_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.GetScheduleRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, schedule.Schedule) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.max_run_count == 1410 + assert response.started_run_count == 1843 + assert response.state == schedule.Schedule.State.ACTIVE + assert response.max_concurrent_run_count == 2596 + assert response.allow_queueing is True + assert response.catch_up is True + + +def test_get_schedule_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_schedule), "__call__") as call: + client.get_schedule() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.GetScheduleRequest() + + +@pytest.mark.asyncio +async def test_get_schedule_async( + transport: str = "grpc_asyncio", request_type=schedule_service.GetScheduleRequest +): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + schedule.Schedule( + name="name_value", + display_name="display_name_value", + max_run_count=1410, + started_run_count=1843, + state=schedule.Schedule.State.ACTIVE, + max_concurrent_run_count=2596, + allow_queueing=True, + catch_up=True, + ) + ) + response = await client.get_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.GetScheduleRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, schedule.Schedule) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.max_run_count == 1410 + assert response.started_run_count == 1843 + assert response.state == schedule.Schedule.State.ACTIVE + assert response.max_concurrent_run_count == 2596 + assert response.allow_queueing is True + assert response.catch_up is True + + +@pytest.mark.asyncio +async def test_get_schedule_async_from_dict(): + await test_get_schedule_async(request_type=dict) + + +def test_get_schedule_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.GetScheduleRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_schedule), "__call__") as call: + call.return_value = schedule.Schedule() + client.get_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_schedule_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.GetScheduleRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_schedule), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(schedule.Schedule()) + await client.get_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_schedule_flattened(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = schedule.Schedule() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_schedule( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_schedule_flattened_error(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_schedule( + schedule_service.GetScheduleRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_schedule_flattened_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = schedule.Schedule() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(schedule.Schedule()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_schedule( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_schedule_flattened_error_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_schedule( + schedule_service.GetScheduleRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + schedule_service.ListSchedulesRequest, + dict, + ], +) +def test_list_schedules(request_type, transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = schedule_service.ListSchedulesResponse( + next_page_token="next_page_token_value", + ) + response = client.list_schedules(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.ListSchedulesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListSchedulesPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_schedules_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + client.list_schedules() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.ListSchedulesRequest() + + +@pytest.mark.asyncio +async def test_list_schedules_async( + transport: str = "grpc_asyncio", request_type=schedule_service.ListSchedulesRequest +): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + schedule_service.ListSchedulesResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_schedules(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.ListSchedulesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListSchedulesAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_schedules_async_from_dict(): + await test_list_schedules_async(request_type=dict) + + +def test_list_schedules_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.ListSchedulesRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + call.return_value = schedule_service.ListSchedulesResponse() + client.list_schedules(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_schedules_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.ListSchedulesRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + schedule_service.ListSchedulesResponse() + ) + await client.list_schedules(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +def test_list_schedules_flattened(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = schedule_service.ListSchedulesResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_schedules( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +def test_list_schedules_flattened_error(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_schedules( + schedule_service.ListSchedulesRequest(), + parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_schedules_flattened_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = schedule_service.ListSchedulesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + schedule_service.ListSchedulesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_schedules( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_list_schedules_flattened_error_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_schedules( + schedule_service.ListSchedulesRequest(), + parent="parent_value", + ) + + +def test_list_schedules_pager(transport_name: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials, + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + schedule.Schedule(), + schedule.Schedule(), + ], + next_page_token="abc", + ), + schedule_service.ListSchedulesResponse( + schedules=[], + next_page_token="def", + ), + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + ], + next_page_token="ghi", + ), + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + schedule.Schedule(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_schedules(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, schedule.Schedule) for i in results) + + +def test_list_schedules_pages(transport_name: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials, + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_schedules), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + schedule.Schedule(), + schedule.Schedule(), + ], + next_page_token="abc", + ), + schedule_service.ListSchedulesResponse( + schedules=[], + next_page_token="def", + ), + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + ], + next_page_token="ghi", + ), + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + schedule.Schedule(), + ], + ), + RuntimeError, + ) + pages = list(client.list_schedules(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_schedules_async_pager(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_schedules), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + schedule.Schedule(), + schedule.Schedule(), + ], + next_page_token="abc", + ), + schedule_service.ListSchedulesResponse( + schedules=[], + next_page_token="def", + ), + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + ], + next_page_token="ghi", + ), + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + schedule.Schedule(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_schedules( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, schedule.Schedule) for i in responses) + + +@pytest.mark.asyncio +async def test_list_schedules_async_pages(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_schedules), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + schedule.Schedule(), + schedule.Schedule(), + ], + next_page_token="abc", + ), + schedule_service.ListSchedulesResponse( + schedules=[], + next_page_token="def", + ), + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + ], + next_page_token="ghi", + ), + schedule_service.ListSchedulesResponse( + schedules=[ + schedule.Schedule(), + schedule.Schedule(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_schedules(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + schedule_service.PauseScheduleRequest, + dict, + ], +) +def test_pause_schedule(request_type, transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.pause_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.pause_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.PauseScheduleRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_pause_schedule_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.pause_schedule), "__call__") as call: + client.pause_schedule() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.PauseScheduleRequest() + + +@pytest.mark.asyncio +async def test_pause_schedule_async( + transport: str = "grpc_asyncio", request_type=schedule_service.PauseScheduleRequest +): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.pause_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.pause_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.PauseScheduleRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_pause_schedule_async_from_dict(): + await test_pause_schedule_async(request_type=dict) + + +def test_pause_schedule_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.PauseScheduleRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.pause_schedule), "__call__") as call: + call.return_value = None + client.pause_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_pause_schedule_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.PauseScheduleRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.pause_schedule), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.pause_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_pause_schedule_flattened(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.pause_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.pause_schedule( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_pause_schedule_flattened_error(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.pause_schedule( + schedule_service.PauseScheduleRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_pause_schedule_flattened_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.pause_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.pause_schedule( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_pause_schedule_flattened_error_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.pause_schedule( + schedule_service.PauseScheduleRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + schedule_service.ResumeScheduleRequest, + dict, + ], +) +def test_resume_schedule(request_type, transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.resume_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.resume_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.ResumeScheduleRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_resume_schedule_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.resume_schedule), "__call__") as call: + client.resume_schedule() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.ResumeScheduleRequest() + + +@pytest.mark.asyncio +async def test_resume_schedule_async( + transport: str = "grpc_asyncio", request_type=schedule_service.ResumeScheduleRequest +): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.resume_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.resume_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.ResumeScheduleRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_resume_schedule_async_from_dict(): + await test_resume_schedule_async(request_type=dict) + + +def test_resume_schedule_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.ResumeScheduleRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.resume_schedule), "__call__") as call: + call.return_value = None + client.resume_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_resume_schedule_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.ResumeScheduleRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.resume_schedule), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.resume_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_resume_schedule_flattened(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.resume_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.resume_schedule( + name="name_value", + catch_up=True, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + arg = args[0].catch_up + mock_val = True + assert arg == mock_val + + +def test_resume_schedule_flattened_error(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.resume_schedule( + schedule_service.ResumeScheduleRequest(), + name="name_value", + catch_up=True, + ) + + +@pytest.mark.asyncio +async def test_resume_schedule_flattened_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.resume_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.resume_schedule( + name="name_value", + catch_up=True, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + arg = args[0].catch_up + mock_val = True + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_resume_schedule_flattened_error_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.resume_schedule( + schedule_service.ResumeScheduleRequest(), + name="name_value", + catch_up=True, + ) + + +@pytest.mark.parametrize( + "request_type", + [ + schedule_service.UpdateScheduleRequest, + dict, + ], +) +def test_update_schedule(request_type, transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_schedule.Schedule( + name="name_value", + display_name="display_name_value", + max_run_count=1410, + started_run_count=1843, + state=gca_schedule.Schedule.State.ACTIVE, + max_concurrent_run_count=2596, + allow_queueing=True, + catch_up=True, + cron="cron_value", + ) + response = client.update_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.UpdateScheduleRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_schedule.Schedule) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.max_run_count == 1410 + assert response.started_run_count == 1843 + assert response.state == gca_schedule.Schedule.State.ACTIVE + assert response.max_concurrent_run_count == 2596 + assert response.allow_queueing is True + assert response.catch_up is True + + +def test_update_schedule_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_schedule), "__call__") as call: + client.update_schedule() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.UpdateScheduleRequest() + + +@pytest.mark.asyncio +async def test_update_schedule_async( + transport: str = "grpc_asyncio", request_type=schedule_service.UpdateScheduleRequest +): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_schedule.Schedule( + name="name_value", + display_name="display_name_value", + max_run_count=1410, + started_run_count=1843, + state=gca_schedule.Schedule.State.ACTIVE, + max_concurrent_run_count=2596, + allow_queueing=True, + catch_up=True, + ) + ) + response = await client.update_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == schedule_service.UpdateScheduleRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_schedule.Schedule) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.max_run_count == 1410 + assert response.started_run_count == 1843 + assert response.state == gca_schedule.Schedule.State.ACTIVE + assert response.max_concurrent_run_count == 2596 + assert response.allow_queueing is True + assert response.catch_up is True + + +@pytest.mark.asyncio +async def test_update_schedule_async_from_dict(): + await test_update_schedule_async(request_type=dict) + + +def test_update_schedule_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.UpdateScheduleRequest() + + request.schedule.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_schedule), "__call__") as call: + call.return_value = gca_schedule.Schedule() + client.update_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "schedule.name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_schedule_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = schedule_service.UpdateScheduleRequest() + + request.schedule.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_schedule), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_schedule.Schedule() + ) + await client.update_schedule(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "schedule.name=name_value", + ) in kw["metadata"] + + +def test_update_schedule_flattened(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_schedule.Schedule() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_schedule( + schedule=gca_schedule.Schedule(cron="cron_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].schedule + mock_val = gca_schedule.Schedule(cron="cron_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +def test_update_schedule_flattened_error(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_schedule( + schedule_service.UpdateScheduleRequest(), + schedule=gca_schedule.Schedule(cron="cron_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_schedule_flattened_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_schedule), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_schedule.Schedule() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_schedule.Schedule() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_schedule( + schedule=gca_schedule.Schedule(cron="cron_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].schedule + mock_val = gca_schedule.Schedule(cron="cron_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_update_schedule_flattened_error_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_schedule( + schedule_service.UpdateScheduleRequest(), + schedule=gca_schedule.Schedule(cron="cron_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.ScheduleServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.ScheduleServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ScheduleServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.ScheduleServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ScheduleServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ScheduleServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.ScheduleServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ScheduleServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.ScheduleServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = ScheduleServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.ScheduleServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.ScheduleServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ScheduleServiceGrpcTransport, + transports.ScheduleServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + ], +) +def test_transport_kind(transport_name): + transport = ScheduleServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.ScheduleServiceGrpcTransport, + ) + + +def test_schedule_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.ScheduleServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_schedule_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1.services.schedule_service.transports.ScheduleServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.ScheduleServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_schedule", + "delete_schedule", + "get_schedule", + "list_schedules", + "pause_schedule", + "resume_schedule", + "update_schedule", + "set_iam_policy", + "get_iam_policy", + "test_iam_permissions", + "get_location", + "list_locations", + "get_operation", + "wait_operation", + "cancel_operation", + "delete_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_schedule_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.schedule_service.transports.ScheduleServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.ScheduleServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_schedule_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.cloud.aiplatform_v1.services.schedule_service.transports.ScheduleServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.ScheduleServiceTransport() + adc.assert_called_once() + + +def test_schedule_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + ScheduleServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ScheduleServiceGrpcTransport, + transports.ScheduleServiceGrpcAsyncIOTransport, + ], +) +def test_schedule_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ScheduleServiceGrpcTransport, + transports.ScheduleServiceGrpcAsyncIOTransport, + ], +) +def test_schedule_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.ScheduleServiceGrpcTransport, grpc_helpers), + (transports.ScheduleServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_schedule_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=["1", "2"], + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ScheduleServiceGrpcTransport, + transports.ScheduleServiceGrpcAsyncIOTransport, + ], +) +def test_schedule_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + ], +) +def test_schedule_service_host_no_port(transport_name): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ("aiplatform.googleapis.com:443") + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + ], +) +def test_schedule_service_host_with_port(transport_name): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ("aiplatform.googleapis.com:8000") + + +def test_schedule_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ScheduleServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_schedule_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ScheduleServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.ScheduleServiceGrpcTransport, + transports.ScheduleServiceGrpcAsyncIOTransport, + ], +) +def test_schedule_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.ScheduleServiceGrpcTransport, + transports.ScheduleServiceGrpcAsyncIOTransport, + ], +) +def test_schedule_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_schedule_service_grpc_lro_client(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_schedule_service_grpc_lro_async_client(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_artifact_path(): + project = "squid" + location = "clam" + metadata_store = "whelk" + artifact = "octopus" + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}".format( + project=project, + location=location, + metadata_store=metadata_store, + artifact=artifact, + ) + actual = ScheduleServiceClient.artifact_path( + project, location, metadata_store, artifact + ) + assert expected == actual + + +def test_parse_artifact_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "metadata_store": "cuttlefish", + "artifact": "mussel", + } + path = ScheduleServiceClient.artifact_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_artifact_path(path) + assert expected == actual + + +def test_context_path(): + project = "winkle" + location = "nautilus" + metadata_store = "scallop" + context = "abalone" + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/contexts/{context}".format( + project=project, + location=location, + metadata_store=metadata_store, + context=context, + ) + actual = ScheduleServiceClient.context_path( + project, location, metadata_store, context + ) + assert expected == actual + + +def test_parse_context_path(): + expected = { + "project": "squid", + "location": "clam", + "metadata_store": "whelk", + "context": "octopus", + } + path = ScheduleServiceClient.context_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_context_path(path) + assert expected == actual + + +def test_custom_job_path(): + project = "oyster" + location = "nudibranch" + custom_job = "cuttlefish" + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, + location=location, + custom_job=custom_job, + ) + actual = ScheduleServiceClient.custom_job_path(project, location, custom_job) + assert expected == actual + + +def test_parse_custom_job_path(): + expected = { + "project": "mussel", + "location": "winkle", + "custom_job": "nautilus", + } + path = ScheduleServiceClient.custom_job_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_custom_job_path(path) + assert expected == actual + + +def test_execution_path(): + project = "scallop" + location = "abalone" + metadata_store = "squid" + execution = "clam" + expected = "projects/{project}/locations/{location}/metadataStores/{metadata_store}/executions/{execution}".format( + project=project, + location=location, + metadata_store=metadata_store, + execution=execution, + ) + actual = ScheduleServiceClient.execution_path( + project, location, metadata_store, execution + ) + assert expected == actual + + +def test_parse_execution_path(): + expected = { + "project": "whelk", + "location": "octopus", + "metadata_store": "oyster", + "execution": "nudibranch", + } + path = ScheduleServiceClient.execution_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_execution_path(path) + assert expected == actual + + +def test_network_path(): + project = "cuttlefish" + network = "mussel" + expected = "projects/{project}/global/networks/{network}".format( + project=project, + network=network, + ) + actual = ScheduleServiceClient.network_path(project, network) + assert expected == actual + + +def test_parse_network_path(): + expected = { + "project": "winkle", + "network": "nautilus", + } + path = ScheduleServiceClient.network_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_network_path(path) + assert expected == actual + + +def test_pipeline_job_path(): + project = "scallop" + location = "abalone" + pipeline_job = "squid" + expected = ( + "projects/{project}/locations/{location}/pipelineJobs/{pipeline_job}".format( + project=project, + location=location, + pipeline_job=pipeline_job, + ) + ) + actual = ScheduleServiceClient.pipeline_job_path(project, location, pipeline_job) + assert expected == actual + + +def test_parse_pipeline_job_path(): + expected = { + "project": "clam", + "location": "whelk", + "pipeline_job": "octopus", + } + path = ScheduleServiceClient.pipeline_job_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_pipeline_job_path(path) + assert expected == actual + + +def test_schedule_path(): + project = "oyster" + location = "nudibranch" + schedule = "cuttlefish" + expected = "projects/{project}/locations/{location}/schedules/{schedule}".format( + project=project, + location=location, + schedule=schedule, + ) + actual = ScheduleServiceClient.schedule_path(project, location, schedule) + assert expected == actual + + +def test_parse_schedule_path(): + expected = { + "project": "mussel", + "location": "winkle", + "schedule": "nautilus", + } + path = ScheduleServiceClient.schedule_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_schedule_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "scallop" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = ScheduleServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "abalone", + } + path = ScheduleServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "squid" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = ScheduleServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "clam", + } + path = ScheduleServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "whelk" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = ScheduleServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "octopus", + } + path = ScheduleServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "oyster" + expected = "projects/{project}".format( + project=project, + ) + actual = ScheduleServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "nudibranch", + } + path = ScheduleServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "cuttlefish" + location = "mussel" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = ScheduleServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "winkle", + "location": "nautilus", + } + path = ScheduleServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.ScheduleServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.ScheduleServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = ScheduleServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_delete_operation(transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.DeleteOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_operation_async(transport: str = "grpc"): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.DeleteOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_operation_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.DeleteOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + call.return_value = None + + client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_operation_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.DeleteOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_delete_operation_from_dict(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.delete_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_delete_operation_from_dict_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_cancel_operation(transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.CancelOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_operation_async(transport: str = "grpc"): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.CancelOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_operation_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.CancelOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + call.return_value = None + + client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_operation_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.CancelOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_cancel_operation_from_dict(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_cancel_operation_from_dict_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.cancel_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_wait_operation(transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.WaitOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_wait_operation(transport: str = "grpc"): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.WaitOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_wait_operation_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.WaitOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_wait_operation_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.WaitOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_wait_operation_from_dict(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.wait_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_wait_operation_from_dict_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.wait_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_get_operation(transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc"): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc"): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_locations(transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.ListLocationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.ListLocationsResponse() + response = client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.ListLocationsResponse) + + +@pytest.mark.asyncio +async def test_list_locations_async(transport: str = "grpc"): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.ListLocationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.ListLocationsResponse() + ) + response = await client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.ListLocationsResponse) + + +def test_list_locations_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.ListLocationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + call.return_value = locations_pb2.ListLocationsResponse() + + client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_locations_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.ListLocationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.ListLocationsResponse() + ) + await client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_locations_from_dict(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.ListLocationsResponse() + + response = client.list_locations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_locations_from_dict_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.ListLocationsResponse() + ) + response = await client.list_locations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_get_location(transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.GetLocationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.Location() + response = client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.Location) + + +@pytest.mark.asyncio +async def test_get_location_async(transport: str = "grpc_asyncio"): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.GetLocationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.Location() + ) + response = await client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.Location) + + +def test_get_location_field_headers(): + client = ScheduleServiceClient(credentials=ga_credentials.AnonymousCredentials()) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.GetLocationRequest() + request.name = "locations/abc" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + call.return_value = locations_pb2.Location() + + client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations/abc", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_location_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials() + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.GetLocationRequest() + request.name = "locations/abc" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.Location() + ) + await client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations/abc", + ) in kw["metadata"] + + +def test_get_location_from_dict(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.Location() + + response = client.get_location( + request={ + "name": "locations/abc", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_location_from_dict_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.Location() + ) + response = await client.get_location( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_set_iam_policy(transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.SetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + response = client.set_iam_policy(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +@pytest.mark.asyncio +async def test_set_iam_policy_async(transport: str = "grpc_asyncio"): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.SetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + ) + response = await client.set_iam_policy(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +def test_set_iam_policy_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.SetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + call.return_value = policy_pb2.Policy() + + client.set_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_set_iam_policy_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.SetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + await client.set_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +def test_set_iam_policy_from_dict(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy() + + response = client.set_iam_policy( + request={ + "resource": "resource_value", + "policy": policy_pb2.Policy(version=774), + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_set_iam_policy_from_dict_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + response = await client.set_iam_policy( + request={ + "resource": "resource_value", + "policy": policy_pb2.Policy(version=774), + } + ) + call.assert_called() + + +def test_get_iam_policy(transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.GetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + + response = client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +@pytest.mark.asyncio +async def test_get_iam_policy_async(transport: str = "grpc_asyncio"): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.GetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + ) + + response = await client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +def test_get_iam_policy_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.GetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + call.return_value = policy_pb2.Policy() + + client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_iam_policy_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.GetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + await client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +def test_get_iam_policy_from_dict(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy() + + response = client.get_iam_policy( + request={ + "resource": "resource_value", + "options": options_pb2.GetPolicyOptions(requested_policy_version=2598), + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_iam_policy_from_dict_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + response = await client.get_iam_policy( + request={ + "resource": "resource_value", + "options": options_pb2.GetPolicyOptions(requested_policy_version=2598), + } + ) + call.assert_called() + + +def test_test_iam_permissions(transport: str = "grpc"): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.TestIamPermissionsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iam_policy_pb2.TestIamPermissionsResponse( + permissions=["permissions_value"], + ) + + response = client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, iam_policy_pb2.TestIamPermissionsResponse) + + assert response.permissions == ["permissions_value"] + + +@pytest.mark.asyncio +async def test_test_iam_permissions_async(transport: str = "grpc_asyncio"): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.TestIamPermissionsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + iam_policy_pb2.TestIamPermissionsResponse( + permissions=["permissions_value"], + ) + ) + + response = await client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, iam_policy_pb2.TestIamPermissionsResponse) + + assert response.permissions == ["permissions_value"] + + +def test_test_iam_permissions_field_headers(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.TestIamPermissionsRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + call.return_value = iam_policy_pb2.TestIamPermissionsResponse() + + client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_test_iam_permissions_field_headers_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.TestIamPermissionsRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + iam_policy_pb2.TestIamPermissionsResponse() + ) + + await client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +def test_test_iam_permissions_from_dict(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iam_policy_pb2.TestIamPermissionsResponse() + + response = client.test_iam_permissions( + request={ + "resource": "resource_value", + "permissions": ["permissions_value"], + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_test_iam_permissions_from_dict_async(): + client = ScheduleServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + iam_policy_pb2.TestIamPermissionsResponse() + ) + + response = await client.test_iam_permissions( + request={ + "resource": "resource_value", + "permissions": ["permissions_value"], + } + ) + call.assert_called() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = ScheduleServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (ScheduleServiceClient, transports.ScheduleServiceGrpcTransport), + (ScheduleServiceAsyncClient, transports.ScheduleServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py index ba6d14d736..c40139c0e3 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -48,6 +48,7 @@ from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import prediction_service +from google.cloud.aiplatform_v1beta1.types import types from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore @@ -1149,6 +1150,165 @@ async def test_raw_predict_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [ + prediction_service.StreamingPredictRequest, + dict, + ], +) +def test_server_streaming_predict(request_type, transport: str = "grpc"): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.server_streaming_predict), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iter([prediction_service.StreamingPredictResponse()]) + response = client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == prediction_service.StreamingPredictRequest() + + # Establish that the response is the type that we expect. + for message in response: + assert isinstance(message, prediction_service.StreamingPredictResponse) + + +def test_server_streaming_predict_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.server_streaming_predict), "__call__" + ) as call: + client.server_streaming_predict() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == prediction_service.StreamingPredictRequest() + + +@pytest.mark.asyncio +async def test_server_streaming_predict_async( + transport: str = "grpc_asyncio", + request_type=prediction_service.StreamingPredictRequest, +): + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.server_streaming_predict), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[prediction_service.StreamingPredictResponse()] + ) + response = await client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == prediction_service.StreamingPredictRequest() + + # Establish that the response is the type that we expect. + message = await response.read() + assert isinstance(message, prediction_service.StreamingPredictResponse) + + +@pytest.mark.asyncio +async def test_server_streaming_predict_async_from_dict(): + await test_server_streaming_predict_async(request_type=dict) + + +def test_server_streaming_predict_field_headers(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.StreamingPredictRequest() + + request.endpoint = "endpoint_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.server_streaming_predict), "__call__" + ) as call: + call.return_value = iter([prediction_service.StreamingPredictResponse()]) + client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "endpoint=endpoint_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_server_streaming_predict_field_headers_async(): + client = PredictionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.StreamingPredictRequest() + + request.endpoint = "endpoint_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.server_streaming_predict), "__call__" + ) as call: + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[prediction_service.StreamingPredictResponse()] + ) + await client.server_streaming_predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "endpoint=endpoint_value", + ) in kw["metadata"] + + @pytest.mark.parametrize( "request_type", [ @@ -1473,6 +1633,7 @@ def test_prediction_service_base_transport(): methods = ( "predict", "raw_predict", + "server_streaming_predict", "explain", "set_iam_policy", "get_iam_policy", From 40f3e411fe18a686defe2b099ddd20957a8518e3 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Wed, 9 Aug 2023 05:10:36 -0700 Subject: [PATCH 6/8] chore: De-hardcoded model parameter defaults Model interface classes support different models that might have different defaults for their parameters. SDK should not hardcode these parameters by default, letting the user to either use the model's defaults or explicitly override them. There was a recent similar case where the tuning parameter defaults were different for different tuning methods. PiperOrigin-RevId: 555129237 --- tests/unit/aiplatform/test_language_models.py | 6 +- vertexai/language_models/_language_models.py | 174 +++++++++--------- 2 files changed, 91 insertions(+), 89 deletions(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 63bfdcf718..e246f2cd19 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -1126,7 +1126,6 @@ def test_code_generation(self): # Validating the parameters predict_temperature = 0.1 predict_max_output_tokens = 100 - default_temperature = language_models.CodeGenerationModel._DEFAULT_TEMPERATURE default_max_output_tokens = ( language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS ) @@ -1149,7 +1148,7 @@ def test_code_generation(self): prefix="Write a function that checks if a year is a leap year.", ) prediction_parameters = mock_predict.call_args[1]["parameters"] - assert prediction_parameters["temperature"] == default_temperature + assert "temperature" not in prediction_parameters assert prediction_parameters["maxOutputTokens"] == default_max_output_tokens def test_code_completion(self): @@ -1192,7 +1191,6 @@ def test_code_completion(self): # Validating the parameters predict_temperature = 0.1 predict_max_output_tokens = 100 - default_temperature = language_models.CodeGenerationModel._DEFAULT_TEMPERATURE default_max_output_tokens = ( language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS ) @@ -1215,7 +1213,7 @@ def test_code_completion(self): prefix="def reverse_string(s):", ) prediction_parameters = mock_predict.call_args[1]["parameters"] - assert prediction_parameters["temperature"] == default_temperature + assert "temperature" not in prediction_parameters assert prediction_parameters["maxOutputTokens"] == default_max_output_tokens def test_text_embedding(self): diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index de76f99282..d1f7cdd345 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -237,28 +237,25 @@ class _TextGenerationModel(_LanguageModel): _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml" - _DEFAULT_TEMPERATURE = 0.0 _DEFAULT_MAX_OUTPUT_TOKENS = 128 - _DEFAULT_TOP_P = 0.95 - _DEFAULT_TOP_K = 40 def predict( self, prompt: str, *, - max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS, - temperature: float = _DEFAULT_TEMPERATURE, - top_k: int = _DEFAULT_TOP_K, - top_p: float = _DEFAULT_TOP_P, + max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, ) -> "TextGenerationResponse": """Gets model response for a single prompt. Args: prompt: Question to ask the model. - max_output_tokens: Max length of the output text in tokens. - temperature: Controls the randomness of predictions. Range: [0, 1]. - top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. - top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. Returns: A `TextGenerationResponse` object that contains the text produced by the model. @@ -275,19 +272,19 @@ def predict( def _batch_predict( self, prompts: List[str], - max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS, - temperature: float = _DEFAULT_TEMPERATURE, - top_k: int = _DEFAULT_TOP_K, - top_p: float = _DEFAULT_TOP_P, + max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, ) -> List["TextGenerationResponse"]: """Gets model response for a single prompt. Args: prompts: Questions to ask the model. - max_output_tokens: Max length of the output text in tokens. - temperature: Controls the randomness of predictions. Range: [0, 1]. - top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. - top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. Returns: A list of `TextGenerationResponse` objects that contain the texts produced by the model. @@ -458,17 +455,17 @@ class _ChatModel(_TextGenerationModel): def start_chat( self, max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, - temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE, - top_k: int = _TextGenerationModel._DEFAULT_TOP_K, - top_p: float = _TextGenerationModel._DEFAULT_TOP_P, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, ) -> "_ChatSession": """Starts a chat session with the model. Args: - max_output_tokens: Max length of the output text in tokens. - temperature: Controls the randomness of predictions. Range: [0, 1]. - top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. - top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. Returns: A `ChatSession` object. @@ -492,9 +489,9 @@ def __init__( self, model: _ChatModel, max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, - temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE, - top_k: int = _TextGenerationModel._DEFAULT_TOP_K, - top_p: float = _TextGenerationModel._DEFAULT_TOP_P, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, ): self._model = model self._history = [] @@ -517,13 +514,13 @@ def send_message( Args: message: Message to send to the model - max_output_tokens: Max length of the output text in tokens. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. Uses the value specified when calling `ChatModel.start_chat` by default. - temperature: Controls the randomness of predictions. Range: [0, 1]. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. Uses the value specified when calling `ChatModel.start_chat` by default. - top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. Uses the value specified when calling `ChatModel.start_chat` by default. - top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. Uses the value specified when calling `ChatModel.start_chat` by default. Returns: @@ -633,10 +630,10 @@ def start_chat( *, context: Optional[str] = None, examples: Optional[List[InputOutputTextPair]] = None, - max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, - temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE, - top_k: int = _TextGenerationModel._DEFAULT_TOP_K, - top_p: float = _TextGenerationModel._DEFAULT_TOP_P, + max_output_tokens: Optional[int] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, message_history: Optional[List[ChatMessage]] = None, ) -> "ChatSession": """Starts a chat session with the model. @@ -646,10 +643,10 @@ def start_chat( For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style examples: List of structured messages to the model to learn how to respond to the conversation. A list of `InputOutputTextPair` objects. - max_output_tokens: Max length of the output text in tokens. - temperature: Controls the randomness of predictions. Range: [0, 1]. - top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40] - top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. message_history: A list of previously sent and received messages. Returns: @@ -717,19 +714,18 @@ class CodeChatModel(_ChatModelBase): _LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE _DEFAULT_MAX_OUTPUT_TOKENS = 128 - _DEFAULT_TEMPERATURE = 0.5 def start_chat( self, *, - max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS, - temperature: float = _DEFAULT_TEMPERATURE, + max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS, + temperature: Optional[float] = None, message_history: Optional[List[ChatMessage]] = None, ) -> "CodeChatSession": """Starts a chat session with the code chat model. Args: - max_output_tokens: Max length of the output text in tokens. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1000]. temperature: Controls the randomness of predictions. Range: [0, 1]. Returns: @@ -754,11 +750,10 @@ def __init__( model: _ChatModelBase, context: Optional[str] = None, examples: Optional[List[InputOutputTextPair]] = None, - max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, - temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE, - top_k: int = _TextGenerationModel._DEFAULT_TOP_K, - top_p: float = _TextGenerationModel._DEFAULT_TOP_P, - is_code_chat_session: bool = False, + max_output_tokens: Optional[int] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, message_history: Optional[List[ChatMessage]] = None, ): self._model = model @@ -768,7 +763,6 @@ def __init__( self._temperature = temperature self._top_k = top_k self._top_p = top_p - self._is_code_chat_session = is_code_chat_session self._message_history: List[ChatMessage] = message_history or [] @property @@ -789,30 +783,36 @@ def send_message( Args: message: Message to send to the model - max_output_tokens: Max length of the output text in tokens. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. Uses the value specified when calling `ChatModel.start_chat` by default. - temperature: Controls the randomness of predictions. Range: [0, 1]. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. Uses the value specified when calling `ChatModel.start_chat` by default. - top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. Uses the value specified when calling `ChatModel.start_chat` by default. - top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. Uses the value specified when calling `ChatModel.start_chat` by default. Returns: A `TextGenerationResponse` object that contains the text produced by the model. """ - prediction_parameters = { - "temperature": temperature - if temperature is not None - else self._temperature, - "maxDecodeSteps": max_output_tokens - if max_output_tokens is not None - else self._max_output_tokens, - } + prediction_parameters = {} + + max_output_tokens = max_output_tokens or self._max_output_tokens + if max_output_tokens: + prediction_parameters["maxDecodeSteps"] = max_output_tokens - if not self._is_code_chat_session: - prediction_parameters["topP"] = top_p if top_p is not None else self._top_p - prediction_parameters["topK"] = top_k if top_k is not None else self._top_k + if temperature is None: + temperature = self._temperature + if temperature is not None: + prediction_parameters["temperature"] = temperature + + top_p = top_p or self._top_p + if top_p: + prediction_parameters["topP"] = top_p + + top_k = top_k or self._top_k + if top_k: + prediction_parameters["topK"] = top_k message_structs = [] for past_message in self._message_history: @@ -830,9 +830,9 @@ def send_message( ) prediction_instance = {"messages": message_structs} - if not self._is_code_chat_session and self._context: + if self._context: prediction_instance["context"] = self._context - if not self._is_code_chat_session and self._examples: + if self._examples: prediction_instance["examples"] = [ { "input": {"content": example.input_text}, @@ -885,10 +885,10 @@ def __init__( model: ChatModel, context: Optional[str] = None, examples: Optional[List[InputOutputTextPair]] = None, - max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, - temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE, - top_k: int = _TextGenerationModel._DEFAULT_TOP_K, - top_p: float = _TextGenerationModel._DEFAULT_TOP_P, + max_output_tokens: Optional[int] = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, message_history: Optional[List[ChatMessage]] = None, ): super().__init__( @@ -913,14 +913,13 @@ def __init__( self, model: CodeChatModel, max_output_tokens: int = CodeChatModel._DEFAULT_MAX_OUTPUT_TOKENS, - temperature: float = CodeChatModel._DEFAULT_TEMPERATURE, + temperature: Optional[float] = None, message_history: Optional[List[ChatMessage]] = None, ): super().__init__( model=model, max_output_tokens=max_output_tokens, temperature=temperature, - is_code_chat_session=True, message_history=message_history, ) @@ -935,7 +934,7 @@ def send_message( Args: message: Message to send to the model - max_output_tokens: Max length of the output text in tokens. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1000]. Uses the value specified when calling `CodeChatModel.start_chat` by default. temperature: Controls the randomness of predictions. Range: [0, 1]. Uses the value specified when calling `CodeChatModel.start_chat` by default. @@ -970,33 +969,38 @@ class CodeGenerationModel(_LanguageModel): _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml" _LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE - _DEFAULT_TEMPERATURE = 0.0 _DEFAULT_MAX_OUTPUT_TOKENS = 128 def predict( self, prefix: str, - suffix: Optional[str] = "", + suffix: Optional[str] = None, *, - max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS, - temperature: float = _DEFAULT_TEMPERATURE, + max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS, + temperature: Optional[float] = None, ) -> "TextGenerationResponse": """Gets model response for a single prompt. Args: prefix: Code before the current point. suffix: Code after the current point. - max_output_tokens: Max length of the output text in tokens. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1000]. temperature: Controls the randomness of predictions. Range: [0, 1]. Returns: A `TextGenerationResponse` object that contains the text produced by the model. """ - instance = {"prefix": prefix, "suffix": suffix} - prediction_parameters = { - "temperature": temperature, - "maxOutputTokens": max_output_tokens, - } + instance = {"prefix": prefix} + if suffix: + instance["suffix"] = suffix + + prediction_parameters = {} + + if temperature is not None: + prediction_parameters["temperature"] = temperature + + if max_output_tokens: + prediction_parameters["maxOutputTokens"] = max_output_tokens prediction_response = self._endpoint.predict( instances=[instance], From 57806fb947e5b692cd8d4701e572eaf54585d383 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Wed, 9 Aug 2023 20:51:39 -0700 Subject: [PATCH 7/8] fix: LLM - Fixed filter in `list_tuned_model_names` PiperOrigin-RevId: 555359431 --- vertexai/language_models/_language_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index d1f7cdd345..2795d94bd8 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -1083,7 +1083,7 @@ def _get_tuned_models_dir_uri(model_id: str) -> str: def _list_tuned_model_names(model_id: str) -> List[str]: tuned_models = aiplatform.Model.list( - filter=f'labels.{_TUNING_BASE_MODEL_ID_LABEL_KEY}="{model_id}"', + filter=f'labels.{_TUNING_BASE_MODEL_ID_LABEL_KEY}="{model_id.replace("@", "-")}"', # TODO(b/275444096): Remove the explicit location once models are deployed to the user's selected location location=_TUNED_MODEL_LOCATION, ) From 06c9d1859279d09adb3a4f5351bcc9ec408df318 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Thu, 10 Aug 2023 16:09:21 -0400 Subject: [PATCH 8/8] chore(main): release 1.30.0 (#2404) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- .release-please-manifest.json | 2 +- CHANGELOG.md | 13 +++++++++++++ google/cloud/aiplatform/gapic_version.py | 2 +- .../v1/schema/predict/instance/gapic_version.py | 2 +- .../v1/schema/predict/instance_v1/gapic_version.py | 2 +- .../v1/schema/predict/params/gapic_version.py | 2 +- .../v1/schema/predict/params_v1/gapic_version.py | 2 +- .../v1/schema/predict/prediction/gapic_version.py | 2 +- .../schema/predict/prediction_v1/gapic_version.py | 2 +- .../schema/trainingjob/definition/gapic_version.py | 2 +- .../trainingjob/definition_v1/gapic_version.py | 2 +- .../schema/predict/instance/gapic_version.py | 2 +- .../predict/instance_v1beta1/gapic_version.py | 2 +- .../v1beta1/schema/predict/params/gapic_version.py | 2 +- .../schema/predict/params_v1beta1/gapic_version.py | 2 +- .../schema/predict/prediction/gapic_version.py | 2 +- .../predict/prediction_v1beta1/gapic_version.py | 2 +- .../schema/trainingjob/definition/gapic_version.py | 2 +- .../trainingjob/definition_v1beta1/gapic_version.py | 2 +- google/cloud/aiplatform/version.py | 2 +- google/cloud/aiplatform_v1/gapic_version.py | 2 +- google/cloud/aiplatform_v1beta1/gapic_version.py | 2 +- ...snippet_metadata_google.cloud.aiplatform.v1.json | 2 +- ...et_metadata_google.cloud.aiplatform.v1beta1.json | 2 +- 24 files changed, 36 insertions(+), 23 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 57a1ffd7da..bc40583868 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.29.0" + ".": "1.30.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 96186f54e8..0921998b6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,19 @@ # Changelog +## [1.30.0](https://github.com/googleapis/python-aiplatform/compare/v1.29.0...v1.30.0) (2023-08-10) + + +### Features + +* Add model.evaluate() method to Model class ([51df86e](https://github.com/googleapis/python-aiplatform/commit/51df86ee1390a51b82ffc015514ad1e145821a34)) +* Add support for providing only text to MultiModalEmbeddingModel.get_embeddings() ([38ec40a](https://github.com/googleapis/python-aiplatform/commit/38ec40a12cf863c9da3de8336dceba10d92f6f56)) + + +### Bug Fixes + +* LLM - Fixed filter in `list_tuned_model_names` ([57806fb](https://github.com/googleapis/python-aiplatform/commit/57806fb947e5b692cd8d4701e572eaf54585d383)) + ## [1.29.0](https://github.com/googleapis/python-aiplatform/compare/v1.28.1...v1.29.0) (2023-08-02) diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/gapic_version.py +++ b/google/cloud/aiplatform/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py index 507af32a89..7751a594bb 100644 --- a/google/cloud/aiplatform/version.py +++ b/google/cloud/aiplatform/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.29.0" +__version__ = "1.30.0" diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform_v1/gapic_version.py +++ b/google/cloud/aiplatform_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py index ed6cb05766..ba61327cef 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.29.0" # {x-release-please-version} +__version__ = "1.30.0" # {x-release-please-version} diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index 89ca575886..19e83ef7bb 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.29.0" + "version": "1.30.0" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index fa25d9eb2c..036f353740 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.29.0" + "version": "1.30.0" }, "snippets": [ {