diff --git a/.kokoro/build.sh b/.kokoro/build.sh index 031fa2200a..64fe343878 100755 --- a/.kokoro/build.sh +++ b/.kokoro/build.sh @@ -33,6 +33,9 @@ export GOOGLE_APPLICATION_CREDENTIALS=${KOKORO_GFILE_DIR}/service-account.json # Setup project id. export PROJECT_ID=$(cat "${KOKORO_GFILE_DIR}/project-id.json") +# Setup staging endpoint. +export STAGING_ENDPOINT=$(cat "${KOKORO_KEYSTORE_DIR}/73713_vertexai-staging-endpoint") + # Remove old nox python3 -m pip uninstall --yes --quiet nox-automation diff --git a/.kokoro/continuous/common.cfg b/.kokoro/continuous/common.cfg index c8f353660a..d8ac15f5d4 100644 --- a/.kokoro/continuous/common.cfg +++ b/.kokoro/continuous/common.cfg @@ -16,6 +16,16 @@ gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/google-cloud-python" # Use the trampoline script to run in docker. build_file: "python-aiplatform/.kokoro/trampoline.sh" +# Fetch vertexai staging endpoint +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73713 + keyname: "vertexai-staging-endpoint" + } + } +} + # Configure the docker image for kokoro-trampoline. env_vars: { key: "TRAMPOLINE_IMAGE" diff --git a/.kokoro/release.sh b/.kokoro/release.sh index 41cedde8ef..5ddfe573df 100755 --- a/.kokoro/release.sh +++ b/.kokoro/release.sh @@ -16,7 +16,7 @@ set -eo pipefail # Start the releasetool reporter -python3 -m pip install --require-hashes -r github/python-aiplatform/.kokoro/requirements.txt +python3 -m pip install --require-hashes --no-deps -r github/python-aiplatform/.kokoro/requirements.txt python3 -m releasetool publish-reporter-script > /tmp/publisher-script; source /tmp/publisher-script # Disable buffering, so that the logs stream through. @@ -30,7 +30,7 @@ twine upload --username __token__ --password "${GCA_TWINE_PASSWORD}" dist/* # Move into the `vertexai` package, build the distribution and upload. VERTEXAI_TWINE_PASSWORD=$(cat "${KOKORO_KEYSTORE_DIR}/73713_vertexai-pypi-token-1") -cd github/python-aiplatform/pypi/_vertex_ai_placeholder +cd pypi/_vertex_ai_placeholder python3 -m build twine upload --username __token__ --password "${VERTEXAI_TWINE_PASSWORD}" dist/* diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 9172d96d2d..4023e05246 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.61.0" + ".": "1.62.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 25e69f3dbf..8af68118f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,33 @@ # Changelog +## [1.62.0](https://github.com/googleapis/python-aiplatform/compare/v1.61.0...v1.62.0) (2024-08-13) + + +### Features + +* Add metadata to evaluation result. ([375095e](https://github.com/googleapis/python-aiplatform/commit/375095e72cc4f43611710372a1e36753a891a710)) +* Add Prompt class for multimodal prompt templating ([1bdc235](https://github.com/googleapis/python-aiplatform/commit/1bdc235ea64f8d63ce9d60d88cb873ee341d3ff9)) +* Add support for query method in Vertex AI Extension SDK ([0008735](https://github.com/googleapis/python-aiplatform/commit/0008735968606a716add88072cff76f2fc552d7b)) +* Add support for reservation affinity in custom training jobs. ([802609b](https://github.com/googleapis/python-aiplatform/commit/802609b1f5e5d8d41a77dafb5b1a2dbf01f2bd30)) +* Add support for strategy in custom training jobs. ([a076191](https://github.com/googleapis/python-aiplatform/commit/a076191b8726363e1f7c47ef8343eb86cebf9918)) +* Adding spot, reservation_affinity to Vertex SDK ([3e785bd](https://github.com/googleapis/python-aiplatform/commit/3e785bd9c9d3d11197ef930f563ee96231a67d84)) +* Support api keys in initializer and create_client ([7404f67](https://github.com/googleapis/python-aiplatform/commit/7404f679246e41e0009ec2d49f05d669eb357f71)) +* Support creating optimized online store with private service connect ([659ba3f](https://github.com/googleapis/python-aiplatform/commit/659ba3f287f9aa78840d4b9b9ca216002d5f1e6a)) +* Support disable Cloud logging in Ray on Vertex ([accaa97](https://github.com/googleapis/python-aiplatform/commit/accaa9750d98b7a37b08da3bd2058d9cdd03bd5c)) +* Support PSC-Interface in Ray on Vertex ([accaa97](https://github.com/googleapis/python-aiplatform/commit/accaa9750d98b7a37b08da3bd2058d9cdd03bd5c)) + + +### Bug Fixes + +* Added credentials, project, and location on PipelineJobSchedule init ([281c171](https://github.com/googleapis/python-aiplatform/commit/281c1710afc6cac49c02d926bee7a6c43b6ef851)) +* Avoid breakage of langchain from orjson 3.10.7 ([c990f73](https://github.com/googleapis/python-aiplatform/commit/c990f73845f38e58ba2dddb372ad2f84d4a05479)) +* Deprecate disable_attribution in GoogleSearchRetrieval. ([c68d559](https://github.com/googleapis/python-aiplatform/commit/c68d559b9d0fd7288b6775f57d05f474f5f7920a)) + + +### Documentation + +* Update the docstring for compute_tokens method. ([849e8d4](https://github.com/googleapis/python-aiplatform/commit/849e8d409e4838cad0a020231b806b0c9ef587ce)) + ## [1.61.0](https://github.com/googleapis/python-aiplatform/compare/v1.60.0...v1.61.0) (2024-08-05) diff --git a/docs/aiplatform_v1/evaluation_service.rst b/docs/aiplatform_v1/evaluation_service.rst new file mode 100644 index 0000000000..2ecf29d75a --- /dev/null +++ b/docs/aiplatform_v1/evaluation_service.rst @@ -0,0 +1,6 @@ +EvaluationService +----------------------------------- + +.. automodule:: google.cloud.aiplatform_v1.services.evaluation_service + :members: + :inherited-members: diff --git a/docs/aiplatform_v1/services_.rst b/docs/aiplatform_v1/services_.rst index a66d73dac9..3587aa5833 100644 --- a/docs/aiplatform_v1/services_.rst +++ b/docs/aiplatform_v1/services_.rst @@ -6,6 +6,7 @@ Services for Google Cloud Aiplatform v1 API dataset_service deployment_resource_pool_service endpoint_service + evaluation_service feature_online_store_admin_service feature_online_store_service feature_registry_service diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index 5fa73056bc..ee80d10d22 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -93,6 +93,7 @@ pipeline_state as pipeline_state_v1beta1, prediction_service as prediction_service_v1beta1, publisher_model as publisher_model_v1beta1, + reservation_affinity as reservation_affinity_v1beta1, service_networking as service_networking_v1beta1, schedule as schedule_v1beta1, schedule_service as schedule_service_v1beta1, @@ -176,6 +177,7 @@ pipeline_state as pipeline_state_v1, prediction_service as prediction_service_v1, publisher_model as publisher_model_v1, + reservation_affinity as reservation_affinity_v1, schedule as schedule_v1, schedule_service as schedule_service_v1, service_networking as service_networking_v1, @@ -254,6 +256,7 @@ pipeline_state_v1, prediction_service_v1, publisher_model_v1, + reservation_affinity_v1, schedule_v1, schedule_service_v1, specialist_pool_v1, @@ -337,6 +340,7 @@ pipeline_state_v1beta1, prediction_service_v1beta1, publisher_model_v1beta1, + reservation_affinity_v1beta1, schedule_v1beta1, schedule_service_v1beta1, specialist_pool_v1beta1, diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/gapic_version.py +++ b/google/cloud/aiplatform/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index b437f1f4d1..08c3528c57 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -74,7 +74,7 @@ def _set_project_as_env_var_or_google_auth_default(self): the project and credentials have already been set. """ - if not self._project: + if not self._project and not self._api_key: # Project is not set. Trying to get it from the environment. # See https://github.com/googleapis/python-aiplatform/issues/852 # See https://github.com/googleapis/google-auth-library-python/issues/924 @@ -104,7 +104,7 @@ def _set_project_as_env_var_or_google_auth_default(self): self._credentials = self._credentials or credentials self._project = project - if not self._credentials: + if not self._credentials and not self._api_key: credentials, _ = google.auth.default() self._credentials = credentials @@ -117,6 +117,7 @@ def __init__(self): self._network = None self._service_account = None self._api_endpoint = None + self._api_key = None self._api_transport = None self._request_metadata = None self._resource_type = None @@ -137,6 +138,7 @@ def init( network: Optional[str] = None, service_account: Optional[str] = None, api_endpoint: Optional[str] = None, + api_key: Optional[str] = None, api_transport: Optional[str] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = None, ): @@ -197,6 +199,9 @@ def init( api_endpoint (str): Optional. The desired API endpoint, e.g., us-central1-aiplatform.googleapis.com + api_key (str): + Optional. The API key to use for service calls. + NOTE: Not all services support API keys. api_transport (str): Optional. The transport method which is either 'grpc' or 'rest'. NOTE: "rest" transport functionality is currently in a @@ -252,6 +257,8 @@ def init( self._service_account = service_account if request_metadata is not None: self._request_metadata = request_metadata + if api_key is not None: + self._api_key = api_key self._resource_type = None # Finally, perform secondary state updates @@ -304,6 +311,11 @@ def api_endpoint(self) -> Optional[str]: """Default API endpoint, if provided.""" return self._api_endpoint + @property + def api_key(self) -> Optional[str]: + """API Key, if provided.""" + return self._api_key + @property def project(self) -> str: """Default project.""" @@ -325,7 +337,7 @@ def project(self) -> str: except GoogleAuthError as exc: raise GoogleAuthError(project_not_found_exception_str) from exc - if not project_id: + if not project_id and not self.api_key: raise ValueError(project_not_found_exception_str) return project_id @@ -403,6 +415,7 @@ def get_client_options( location_override: Optional[str] = None, prediction_client: bool = False, api_base_path_override: Optional[str] = None, + api_key: Optional[str] = None, api_path_override: Optional[str] = None, ) -> client_options.ClientOptions: """Creates GAPIC client_options using location and type. @@ -414,6 +427,7 @@ def get_client_options( Vertex AI. prediction_client (str): Optional. flag to use a prediction endpoint. api_base_path_override (str): Optional. Override default API base path. + api_key (str): Optional. API key to use for the client. api_path_override (str): Optional. Override default api path. Returns: clients_options (google.api_core.client_options.ClientOptions): @@ -447,6 +461,11 @@ def get_client_options( else api_path_override ) + # Project/location take precedence over api_key + if api_key and not self._project: + return client_options.ClientOptions( + api_endpoint=api_endpoint, api_key=api_key + ) return client_options.ClientOptions(api_endpoint=api_endpoint) def common_location_path( @@ -479,6 +498,7 @@ def create_client( location_override: Optional[str] = None, prediction_client: bool = False, api_base_path_override: Optional[str] = None, + api_key: Optional[str] = None, api_path_override: Optional[str] = None, appended_user_agent: Optional[List[str]] = None, appended_gapic_version: Optional[str] = None, @@ -493,6 +513,7 @@ def create_client( Optional. Custom auth credentials. If not provided will use the current config. location_override (str): Optional. location override. prediction_client (str): Optional. flag to use a prediction endpoint. + api_key (str): Optional. API key to use for the client. api_base_path_override (str): Optional. Override default api base path. api_path_override (str): Optional. Override default api path. appended_user_agent (List[str]): @@ -539,6 +560,7 @@ def create_client( "client_options": self.get_client_options( location_override=location_override, prediction_client=prediction_client, + api_key=api_key, api_base_path_override=api_base_path_override, api_path_override=api_path_override, ), diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 70d5029013..6f6a8380d8 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -2214,6 +2214,7 @@ def run( create_request_timeout: Optional[float] = None, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> None: """Run this configured CustomJob. @@ -2282,6 +2283,8 @@ def run( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. """ network = network or initializer.global_config.network service_account = service_account or initializer.global_config.service_account @@ -2299,6 +2302,7 @@ def run( create_request_timeout=create_request_timeout, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) @base.optional_sync() @@ -2316,6 +2320,7 @@ def _run( create_request_timeout: Optional[float] = None, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> None: """Helper method to ensure network synchronization and to run the configured CustomJob. @@ -2382,6 +2387,8 @@ def _run( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. """ self.submit( service_account=service_account, @@ -2395,6 +2402,7 @@ def _run( create_request_timeout=create_request_timeout, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) self._block_until_complete() @@ -2413,6 +2421,7 @@ def submit( create_request_timeout: Optional[float] = None, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> None: """Submit the configured CustomJob. @@ -2476,6 +2485,8 @@ def submit( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Raises: ValueError: @@ -2498,12 +2509,18 @@ def submit( if network: self._gca_resource.job_spec.network = network - if timeout or restart_job_on_worker_restart or disable_retries: + if ( + timeout + or restart_job_on_worker_restart + or disable_retries + or scheduling_strategy + ): timeout = duration_pb2.Duration(seconds=timeout) if timeout else None self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling( timeout=timeout, restart_job_on_worker_restart=restart_job_on_worker_restart, disable_retries=disable_retries, + strategy=scheduling_strategy, ) if enable_web_access: @@ -2868,6 +2885,7 @@ def run( sync: bool = True, create_request_timeout: Optional[float] = None, disable_retries: bool = False, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> None: """Run this configured CustomJob. @@ -2916,6 +2934,8 @@ def run( Indicates if the job should retry for internal errors after the job starts running. If True, overrides `restart_job_on_worker_restart` to False. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. """ network = network or initializer.global_config.network service_account = service_account or initializer.global_config.service_account @@ -2930,6 +2950,7 @@ def run( sync=sync, create_request_timeout=create_request_timeout, disable_retries=disable_retries, + scheduling_strategy=scheduling_strategy, ) @base.optional_sync() @@ -2944,6 +2965,7 @@ def _run( sync: bool = True, create_request_timeout: Optional[float] = None, disable_retries: bool = False, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> None: """Helper method to ensure network synchronization and to run the configured CustomJob. @@ -2990,6 +3012,8 @@ def _run( Indicates if the job should retry for internal errors after the job starts running. If True, overrides `restart_job_on_worker_restart` to False. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. """ if service_account: self._gca_resource.trial_job_spec.service_account = service_account @@ -2997,13 +3021,19 @@ def _run( if network: self._gca_resource.trial_job_spec.network = network - if timeout or restart_job_on_worker_restart or disable_retries: + if ( + timeout + or restart_job_on_worker_restart + or disable_retries + or scheduling_strategy + ): duration = duration_pb2.Duration(seconds=timeout) if timeout else None self._gca_resource.trial_job_spec.scheduling = ( gca_custom_job_compat.Scheduling( timeout=duration, restart_job_on_worker_restart=restart_job_on_worker_restart, disable_retries=disable_retries, + strategy=scheduling_strategy, ) ) diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index fd2ff225cf..05cfea1add 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -247,6 +247,10 @@ def create( autoscaling_target_accelerator_duty_cycle: Optional[int] = None, sync=True, create_request_timeout: Optional[float] = None, + reservation_affinity_type: Optional[str] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, + spot: bool = False, ) -> "DeploymentResourcePool": """Creates a new DeploymentResourcePool. @@ -305,6 +309,20 @@ def create( when the Future has completed. create_request_timeout (float): Optional. The create request timeout in seconds. + reservation_affinity_type (str): + Optional. The type of reservation affinity. + One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, + SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' + spot (bool): + Optional. Whether to schedule the deployment workload on spot VMs. Returns: DeploymentResourcePool @@ -327,8 +345,12 @@ def create( max_replica_count=max_replica_count, accelerator_type=accelerator_type, accelerator_count=accelerator_count, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization, autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle, + spot=spot, sync=sync, create_request_timeout=create_request_timeout, ) @@ -348,8 +370,12 @@ def _create( max_replica_count: int = 1, accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, + reservation_affinity_type: Optional[str] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, autoscaling_target_cpu_utilization: Optional[int] = None, autoscaling_target_accelerator_duty_cycle: Optional[int] = None, + spot: bool = False, sync=True, create_request_timeout: Optional[float] = None, ) -> "DeploymentResourcePool": @@ -398,6 +424,18 @@ def _create( NVIDIA_TESLA_A100. accelerator_count (int): Optional. The number of accelerators attached to each replica. + reservation_affinity_type (str): + Optional. The type of reservation affinity. + One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, + SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' autoscaling_target_cpu_utilization (int): Optional. Target CPU utilization value for autoscaling. A default value of 60 will be used if not specified. @@ -406,6 +444,8 @@ def _create( autoscaling. Must also set accelerator_type and accelerator count if specified. A default value of 60 will be used if accelerators are requested and this is not specified. + spot (bool): + Optional. Whether to schedule the deployment workload on spot VMs. sync (bool): Optional. Whether to execute this method synchronously. If False, this method will be executed in a concurrent Future and @@ -425,6 +465,7 @@ def _create( dedicated_resources = gca_machine_resources_compat.DedicatedResources( min_replica_count=min_replica_count, max_replica_count=max_replica_count, + spot=spot, ) machine_spec = gca_machine_resources_compat.MachineSpec( @@ -458,6 +499,13 @@ def _create( [autoscaling_metric_spec] ) + if reservation_affinity_type: + machine_spec.reservation_affinity = utils.get_reservation_affinity( + reservation_affinity_type, + reservation_affinity_key, + reservation_affinity_values, + ) + dedicated_resources.machine_spec = machine_spec gapic_drp = gca_deployment_resource_pool_compat.DeploymentResourcePool( @@ -1226,6 +1274,10 @@ def deploy( enable_access_logging=False, disable_container_logging: bool = False, deployment_resource_pool: Optional[DeploymentResourcePool] = None, + reservation_affinity_type: Optional[str] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, + spot: bool = False, ) -> None: """Deploys a Model to the Endpoint. @@ -1319,6 +1371,20 @@ def deploy( are deployed to the same DeploymentResourcePool will be hosted in a shared model server. If provided, will override replica count arguments. + reservation_affinity_type (str): + Optional. The type of reservation affinity. + One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, + SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' + spot (bool): + Optional. Whether to schedule the deployment workload on spot VMs. """ self._sync_gca_resource_if_skipped() @@ -1348,6 +1414,9 @@ def deploy( accelerator_type=accelerator_type, accelerator_count=accelerator_count, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, service_account=service_account, explanation_spec=explanation_spec, metadata=metadata, @@ -1355,6 +1424,7 @@ def deploy( deploy_request_timeout=deploy_request_timeout, autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization, autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle, + spot=spot, enable_access_logging=enable_access_logging, disable_container_logging=disable_container_logging, deployment_resource_pool=deployment_resource_pool, @@ -1373,6 +1443,9 @@ def _deploy( accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, tpu_topology: Optional[str] = None, + reservation_affinity_type: Optional[str] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, service_account: Optional[str] = None, explanation_spec: Optional[aiplatform.explain.ExplanationSpec] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), @@ -1380,6 +1453,7 @@ def _deploy( deploy_request_timeout: Optional[float] = None, autoscaling_target_cpu_utilization: Optional[int] = None, autoscaling_target_accelerator_duty_cycle: Optional[int] = None, + spot: bool = False, enable_access_logging=False, disable_container_logging: bool = False, deployment_resource_pool: Optional[DeploymentResourcePool] = None, @@ -1435,6 +1509,18 @@ def _deploy( tpu_topology (str): Optional. The TPU topology to use for the DeployedModel. Required for CloudTPU multihost deployments. + reservation_affinity_type (str): + Optional. The type of reservation affinity. + One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, + SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' service_account (str): The service account that the DeployedModel's container runs as. Specify the email address of the service account. If this service account is not @@ -1460,6 +1546,8 @@ def _deploy( Target Accelerator Duty Cycle. Must also set accelerator_type and accelerator_count if specified. A default value of 60 will be used if not specified. + spot (bool): + Optional. Whether to schedule the deployment workload on spot VMs. enable_access_logging (bool): Whether to enable endpoint access logging. Defaults to False. disable_container_logging (bool): @@ -1490,12 +1578,16 @@ def _deploy( accelerator_type=accelerator_type, accelerator_count=accelerator_count, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, service_account=service_account, explanation_spec=explanation_spec, metadata=metadata, deploy_request_timeout=deploy_request_timeout, autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization, autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle, + spot=spot, enable_access_logging=enable_access_logging, disable_container_logging=disable_container_logging, deployment_resource_pool=deployment_resource_pool, @@ -1522,12 +1614,16 @@ def _deploy_call( accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, tpu_topology: Optional[str] = None, + reservation_affinity_type: Optional[str] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, service_account: Optional[str] = None, explanation_spec: Optional[aiplatform.explain.ExplanationSpec] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), deploy_request_timeout: Optional[float] = None, autoscaling_target_cpu_utilization: Optional[int] = None, autoscaling_target_accelerator_duty_cycle: Optional[int] = None, + spot: bool = False, enable_access_logging=False, disable_container_logging: bool = False, deployment_resource_pool: Optional[DeploymentResourcePool] = None, @@ -1593,6 +1689,18 @@ def _deploy_call( tpu_topology (str): Optional. The TPU topology to use for the DeployedModel. Required for CloudTPU multihost deployments. + reservation_affinity_type (str): + Optional. The type of reservation affinity. + One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, + SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' service_account (str): The service account that the DeployedModel's container runs as. Specify the email address of the service account. If this service account is not @@ -1615,6 +1723,8 @@ def _deploy_call( Optional. Target Accelerator Duty Cycle. Must also set accelerator_type and accelerator_count if specified. A default value of 60 will be used if not specified. + spot (bool): + Optional. Whether to schedule the deployment workload on spot VMs. enable_access_logging (bool): Whether to enable endpoint access logging. Defaults to False. disable_container_logging (bool): @@ -1743,6 +1853,7 @@ def _deploy_call( dedicated_resources = gca_machine_resources_compat.DedicatedResources( min_replica_count=min_replica_count, max_replica_count=max_replica_count, + spot=spot, ) machine_spec = gca_machine_resources_compat.MachineSpec( @@ -1772,6 +1883,13 @@ def _deploy_call( [autoscaling_metric_spec] ) + if reservation_affinity_type: + machine_spec.reservation_affinity = utils.get_reservation_affinity( + reservation_affinity_type, + reservation_affinity_key, + reservation_affinity_values, + ) + if tpu_topology is not None: machine_spec.tpu_topology = tpu_topology @@ -3536,6 +3654,10 @@ def deploy( disable_container_logging: bool = False, traffic_percentage: Optional[int] = 0, traffic_split: Optional[Dict[str, int]] = None, + reservation_affinity_type: Optional[str] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, + spot: bool = False, ) -> None: """Deploys a Model to the PrivateEndpoint. @@ -3637,6 +3759,20 @@ def deploy( map must be empty if the Endpoint is to not accept any traffic at the moment. Key for model being deployed is "0". Should not be provided if traffic_percentage is provided. + reservation_affinity_type (str): + Optional. The type of reservation affinity. + One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, + SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' + spot (bool): + Optional. Whether to schedule the deployment workload on spot VMs. """ if self.network: @@ -3672,10 +3808,14 @@ def deploy( accelerator_type=accelerator_type, accelerator_count=accelerator_count, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, service_account=service_account, explanation_spec=explanation_spec, metadata=metadata, sync=sync, + spot=spot, disable_container_logging=disable_container_logging, ) @@ -4719,6 +4859,10 @@ def deploy( PrivateEndpoint.PrivateServiceConnectConfig ] = None, deployment_resource_pool: Optional[DeploymentResourcePool] = None, + reservation_affinity_type: Optional[str] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, + spot: bool = False, ) -> Union[Endpoint, PrivateEndpoint]: """Deploys model to endpoint. Endpoint will be created if unspecified. @@ -4834,6 +4978,20 @@ def deploy( are deployed to the same DeploymentResourcePool will be hosted in a shared model server. If provided, will override replica count arguments. + reservation_affinity_type (str): + Optional. The type of reservation affinity. + One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, + SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' + spot (bool): + Optional. Whether to schedule the deployment workload on spot VMs. Returns: endpoint (Union[Endpoint, PrivateEndpoint]): @@ -4884,6 +5042,9 @@ def deploy( accelerator_type=accelerator_type, accelerator_count=accelerator_count, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, service_account=service_account, explanation_spec=explanation_spec, metadata=metadata, @@ -4894,6 +5055,7 @@ def deploy( deploy_request_timeout=deploy_request_timeout, autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization, autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle, + spot=spot, enable_access_logging=enable_access_logging, disable_container_logging=disable_container_logging, private_service_connect_config=private_service_connect_config, @@ -4913,6 +5075,9 @@ def _deploy( accelerator_type: Optional[str] = None, accelerator_count: Optional[int] = None, tpu_topology: Optional[str] = None, + reservation_affinity_type: Optional[str] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, service_account: Optional[str] = None, explanation_spec: Optional[aiplatform.explain.ExplanationSpec] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), @@ -4922,6 +5087,7 @@ def _deploy( deploy_request_timeout: Optional[float] = None, autoscaling_target_cpu_utilization: Optional[int] = None, autoscaling_target_accelerator_duty_cycle: Optional[int] = None, + spot: bool = False, enable_access_logging=False, disable_container_logging: bool = False, private_service_connect_config: Optional[ @@ -4980,6 +5146,18 @@ def _deploy( tpu_topology (str): Optional. The TPU topology to use for the DeployedModel. Requireid for CloudTPU multihost deployments. + reservation_affinity_type (str): + Optional. The type of reservation affinity. + One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, + SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' service_account (str): The service account that the DeployedModel's container runs as. Specify the email address of the service account. If this service account is not @@ -5023,6 +5201,8 @@ def _deploy( Optional. Target Accelerator Duty Cycle. Must also set accelerator_type and accelerator_count if specified. A default value of 60 will be used if not specified. + spot (bool): + Optional. Whether to schedule the deployment workload on spot VMs. enable_access_logging (bool): Whether to enable endpoint access logging. Defaults to False. disable_container_logging (bool): @@ -5081,12 +5261,16 @@ def _deploy( accelerator_type=accelerator_type, accelerator_count=accelerator_count, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, service_account=service_account, explanation_spec=explanation_spec, metadata=metadata, deploy_request_timeout=deploy_request_timeout, autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization, autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle, + spot=spot, enable_access_logging=enable_access_logging, disable_container_logging=disable_container_logging, deployment_resource_pool=deployment_resource_pool, diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index a0a39454ea..5fdd8fcac0 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -604,6 +604,9 @@ def create_schedule( pipeline_job_schedule = aiplatform.PipelineJobSchedule( pipeline_job=self, display_name=display_name, + credentials=self.credentials, + project=self.project, + location=self.location, ) pipeline_job_schedule.create( diff --git a/google/cloud/aiplatform/tensorboard/tensorboard_resource.py b/google/cloud/aiplatform/tensorboard/tensorboard_resource.py index 302df9614d..c8da8e6786 100644 --- a/google/cloud/aiplatform/tensorboard/tensorboard_resource.py +++ b/google/cloud/aiplatform/tensorboard/tensorboard_resource.py @@ -858,6 +858,14 @@ def list( return tensorboard_runs + def get_tensorboard_time_series_id(self, display_name: str) -> str: + """Returns the TensorboardTimeSeries with the given display name.""" + if display_name not in self._time_series_display_name_to_id_mapping: + self._sync_time_series_display_name_to_id_mapping() + + time_series_id = self._time_series_display_name_to_id_mapping.get(display_name) + return time_series_id + def write_tensorboard_scalar_data( self, time_series_data: Dict[str, float], diff --git a/google/cloud/aiplatform/tensorboard/uploader_utils.py b/google/cloud/aiplatform/tensorboard/uploader_utils.py index 9b08f2f8b8..f8b3fef88e 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_utils.py +++ b/google/cloud/aiplatform/tensorboard/uploader_utils.py @@ -18,7 +18,6 @@ """Shared utils for tensorboard log uploader.""" import abc import contextlib -import json import logging import re import time @@ -116,7 +115,7 @@ def batch_create_runs( """ created_runs = [] for run_name in run_names: - tb_run = self._create_or_get_run_resource(run_name) + tb_run = self._get_or_create_run_resource(run_name) created_runs.append(tb_run) if run_name not in self._run_name_to_run_resource_name: self._run_name_to_run_resource_name[run_name] = tb_run.resource_name @@ -191,11 +190,11 @@ def get_run_resource_name(self, run_name: str) -> str: Resource name of the run. """ if run_name not in self._run_name_to_run_resource_name: - tb_run = self._create_or_get_run_resource(run_name) + tb_run = self._get_or_create_run_resource(run_name) self._run_name_to_run_resource_name[run_name] = tb_run.resource_name return self._run_name_to_run_resource_name[run_name] - def _create_or_get_run_resource( + def _get_or_create_run_resource( self, run_name: str ) -> tensorboard_run.TensorboardRun: """Creates new experiment run and tensorboard run resources. @@ -266,7 +265,7 @@ def get_time_series_resource_name( Resource name of the time series """ if (run_name, tag_name) not in self._run_tag_name_to_time_series_name: - time_series = self._create_or_get_time_series( + time_series = self._get_or_create_time_series( self.get_run_resource_name(run_name), tag_name, time_series_resource_creator, @@ -276,7 +275,7 @@ def get_time_series_resource_name( ] = time_series.name return self._run_tag_name_to_time_series_name[(run_name, tag_name)] - def _create_or_get_time_series( + def _get_or_create_time_series( self, run_resource_name: str, tag_name: str, @@ -306,45 +305,27 @@ def _create_or_get_time_series( ValueError: More than one time series with the resource name was found. """ - time_series = time_series_resource_creator() - time_series.display_name = tag_name - try: - time_series = self._api.create_tensorboard_time_series( - parent=run_resource_name, tensorboard_time_series=time_series + run_name = run_resource_name.split("/")[-1] + run = self._get_or_create_run_resource(run_name) + time_series_id = run.get_tensorboard_time_series_id(tag_name) + time_series = self._api.get_tensorboard_time_series( + request=tensorboard_service.GetTensorboardTimeSeriesRequest( + name=run_resource_name + "/timeSeries/" + time_series_id ) - except exceptions.InvalidArgument as e: - # If the time series display name already exists then retrieve it - if "already exist" in e.message: - list_of_time_series = self._api.list_tensorboard_time_series( - request=tensorboard_service.ListTensorboardTimeSeriesRequest( - parent=run_resource_name, - filter="display_name = {}".format(json.dumps(str(tag_name))), - ) + ) + if not time_series: + time_series = time_series_resource_creator() + time_series.display_name = tag_name + try: + time_series = self._api.create_tensorboard_time_series( + parent=run_resource_name, tensorboard_time_series=time_series ) - num = 0 - time_series = None - - for ts in list_of_time_series: - num += 1 - if num > 1: - break - time_series = ts - - if not time_series: - raise ExistingResourceNotFoundError( - "Could not find time series resource with display name: {}".format( - tag_name - ) - ) - - if num != 1: - raise ValueError( - "More than one time series resource found with display_name: {}".format( - tag_name - ) + except exceptions.InvalidArgument as e: + raise ValueError( + "Could not find time series resource with display name: {}".format( + tag_name ) - else: - raise + ) from e return time_series @@ -367,6 +348,45 @@ def __init__(self, run_resource_id: str, api: TensorboardServiceClient): str, tensorboard_time_series.TensorboardTimeSeries ] = {} + def _get_run_resource(self) -> tensorboard_run.TensorboardRun: + """Gets or creates new experiment run and tensorboard run resources. + + The experiment run will be associated with the tensorboard run resource. + This will link all tensorboard run data to the associated experiment. + + Returns: + tb_run (tensorboard_run.TensorboardRun): + The TensorboardRun given the run_name. + + Raises: + ValueError: + run_resource_id is invalid. + """ + m = re.match( + "projects/(.*)/locations/(.*)/tensorboards/(.*)/experiments/(.*)/runs/(.*)", + self._run_resource_id, + ) + project = m[1] + location = m[2] + tensorboard = m[3] + experiment = m[4] + run_name = m[5] + experiment_run = experiment_run_resource.ExperimentRun.get( + project=project, location=location, run_name=run_name + ) + if not experiment_run: + experiment_run = experiment_run_resource.ExperimentRun.create( + project=project, + location=location, + run_name=run_name, + experiment=experiment, + tensorboard=tensorboard, + state=gca_execution.Execution.State.RUNNING, + ) + tb_run_artifact = experiment_run._backing_tensorboard_run + tb_run = tb_run_artifact.resource + return tb_run + def get_or_create( self, tag_name: str, @@ -389,56 +409,34 @@ def get_or_create( A new or existing tensorboard_time_series.TensorboardTimeSeries. Raises: - exceptions.InvalidArgument: + ValueError: The tag_name or time_series_resource_creator is an invalid argument to create_tensorboard_time_series api call. - ExistingResourceNotFoundError: - Could not find the resource given the tag name. - ValueError: - More than one time series with the resource name was found. """ if tag_name in self._tag_to_time_series_proto: return self._tag_to_time_series_proto[tag_name] - time_series = time_series_resource_creator() - time_series.display_name = tag_name - try: - time_series = self._api.create_tensorboard_time_series( - parent=self._run_resource_id, tensorboard_time_series=time_series + tb_run = self._get_run_resource() + time_series_id = tb_run.get_tensorboard_time_series_id(tag_name) + time_series = self._api.get_tensorboard_time_series( + request=tensorboard_service.GetTensorboardTimeSeriesRequest( + name=self._run_resource_id + "/timeSeries/" + time_series_id ) - except exceptions.InvalidArgument as e: - # If the time series display name already exists then retrieve it - if "already exist" in e.message: - list_of_time_series = self._api.list_tensorboard_time_series( - request=tensorboard_service.ListTensorboardTimeSeriesRequest( - parent=self._run_resource_id, - filter="display_name = {}".format(json.dumps(str(tag_name))), - ) - ) - num = 0 - time_series = None - - for ts in list_of_time_series: - num += 1 - if num > 1: - break - time_series = ts - - if not time_series: - raise ExistingResourceNotFoundError( - "Could not find time series resource with display name: {}".format( - tag_name - ) - ) + ) + if not time_series: + time_series = time_series_resource_creator() + time_series.display_name = tag_name - if num != 1: - raise ValueError( - "More than one time series resource found with display_name: {}".format( - tag_name - ) + try: + time_series = self._api.create_tensorboard_time_series( + parent=self._run_resource_id, tensorboard_time_series=time_series + ) + except exceptions.InvalidArgument as e: + raise ValueError( + "Could not find time series resource with display name: {}".format( + tag_name ) - else: - raise + ) from e self._tag_to_time_series_proto[tag_name] = time_series return time_series diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index bc415d0ee7..0202ba37ae 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -17,7 +17,7 @@ import datetime import time -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union from google.protobuf import json_format import abc @@ -44,6 +44,7 @@ from google.cloud.aiplatform.compat.types import ( training_pipeline as gca_training_pipeline, study as gca_study_compat, + custom_job as gca_custom_job_compat, ) from google.cloud.aiplatform.utils import _timestamped_gcs_dir @@ -1403,6 +1404,11 @@ def _prepare_and_validate_run( reduction_server_replica_count: int = 0, reduction_server_machine_type: Optional[str] = None, tpu_topology: Optional[str] = None, + reservation_affinity_type: Optional[ + Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"] + ] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, ) -> Tuple[worker_spec_utils._DistributedTrainingSpec, Optional[gca_model.Model]]: """Create worker pool specs and managed model as well validating the run. @@ -1450,6 +1456,23 @@ def _prepare_and_validate_run( tpu_topology (str): Optional. Only required if the machine type is a TPU v5 version. + reservation_affinity_type (str): + Optional. The type of reservation affinity. One of: + * "NO_RESERVATION" : No reservation is used. + * "ANY_RESERVATION" : Any reservation that matches machine spec + can be used. + * "SPECIFIC_RESERVATION" : A specific reservation must be use + used. See reservation_affinity_key and + reservation_affinity_values for how to specify the reservation. + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use + `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' Returns: Worker pools specs and managed model for run. @@ -1489,6 +1512,9 @@ def _prepare_and_validate_run( reduction_server_replica_count=reduction_server_replica_count, reduction_server_machine_type=reduction_server_machine_type, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, ).pool_specs ) @@ -1525,6 +1551,7 @@ def _prepare_training_task_inputs_and_output_dir( tensorboard: Optional[str] = None, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Tuple[Dict, str]: """Prepares training task inputs and output directory for custom job. @@ -1582,6 +1609,8 @@ def _prepare_training_task_inputs_and_output_dir( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: Training task inputs and Output directory for custom job. @@ -1612,12 +1641,18 @@ def _prepare_training_task_inputs_and_output_dir( if persistent_resource_id: training_task_inputs["persistent_resource_id"] = persistent_resource_id - if timeout or restart_job_on_worker_restart or disable_retries: + if ( + timeout + or restart_job_on_worker_restart + or disable_retries + or scheduling_strategy + ): timeout = f"{timeout}s" if timeout else None scheduling = { "timeout": timeout, "restart_job_on_worker_restart": restart_job_on_worker_restart, "disable_retries": disable_retries, + "strategy": scheduling_strategy, } training_task_inputs["scheduling"] = scheduling @@ -3005,6 +3040,12 @@ def run( disable_retries: bool = False, persistent_resource_id: Optional[str] = None, tpu_topology: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, + reservation_affinity_type: Optional[ + Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"] + ] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, ) -> Optional[models.Model]: """Runs the custom training job. @@ -3360,6 +3401,25 @@ def run( details on the TPU topology, refer to https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must be a supported value for the TPU machine type. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. + reservation_affinity_type (str): + Optional. The type of reservation affinity. One of: + * "NO_RESERVATION" : No reservation is used. + * "ANY_RESERVATION" : Any reservation that matches machine spec + can be used. + * "SPECIFIC_RESERVATION" : A specific reservation must be use + used. See reservation_affinity_key and + reservation_affinity_values for how to specify the reservation. + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use + `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' Returns: The trained Vertex AI model resource or None if the training @@ -3380,6 +3440,9 @@ def run( reduction_server_replica_count=reduction_server_replica_count, reduction_server_machine_type=reduction_server_machine_type, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, ) # make and copy package @@ -3417,13 +3480,16 @@ def run( enable_web_access=enable_web_access, enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, - reduction_server_container_uri=reduction_server_container_uri - if reduction_server_replica_count > 0 - else None, + reduction_server_container_uri=( + reduction_server_container_uri + if reduction_server_replica_count > 0 + else None + ), sync=sync, create_request_timeout=create_request_timeout, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) def submit( @@ -3477,6 +3543,12 @@ def submit( disable_retries: bool = False, persistent_resource_id: Optional[str] = None, tpu_topology: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, + reservation_affinity_type: Optional[ + Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"] + ] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, ) -> Optional[models.Model]: """Submits the custom training job without blocking until completion. @@ -3777,6 +3849,25 @@ def submit( details on the TPU topology, refer to https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must be a supported value for the TPU machine type. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. + reservation_affinity_type (str): + Optional. The type of reservation affinity. One of: + * "NO_RESERVATION" : No reservation is used. + * "ANY_RESERVATION" : Any reservation that matches machine spec + can be used. + * "SPECIFIC_RESERVATION" : A specific reservation must be use + used. See reservation_affinity_key and + reservation_affinity_values for how to specify the reservation. + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use + `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' Returns: model: The trained Vertex AI Model resource or None if training did not @@ -3796,6 +3887,9 @@ def submit( reduction_server_replica_count=reduction_server_replica_count, reduction_server_machine_type=reduction_server_machine_type, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, ) # make and copy package @@ -3833,14 +3927,17 @@ def submit( enable_web_access=enable_web_access, enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, - reduction_server_container_uri=reduction_server_container_uri - if reduction_server_replica_count > 0 - else None, + reduction_server_container_uri=( + reduction_server_container_uri + if reduction_server_replica_count > 0 + else None + ), sync=sync, create_request_timeout=create_request_timeout, block=False, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -3888,6 +3985,7 @@ def _run( block: Optional[bool] = True, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Optional[models.Model]: """Packages local script and launches training_job. @@ -4084,6 +4182,8 @@ def _run( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -4138,6 +4238,7 @@ def _run( tensorboard=tensorboard, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) model = self._run_job( @@ -4462,6 +4563,12 @@ def run( disable_retries: bool = False, persistent_resource_id: Optional[str] = None, tpu_topology: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, + reservation_affinity_type: Optional[ + Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"] + ] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, ) -> Optional[models.Model]: """Runs the custom training job. @@ -4755,6 +4862,25 @@ def run( details on the TPU topology, refer to https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must be a supported value for the TPU machine type. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. + reservation_affinity_type (str): + Optional. The type of reservation affinity. One of: + * "NO_RESERVATION" : No reservation is used. + * "ANY_RESERVATION" : Any reservation that matches machine spec + can be used. + * "SPECIFIC_RESERVATION" : A specific reservation must be use + used. See reservation_affinity_key and + reservation_affinity_values for how to specify the reservation. + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use + `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' Returns: model: The trained Vertex AI Model resource or None if training did not @@ -4780,6 +4906,9 @@ def run( reduction_server_replica_count=reduction_server_replica_count, reduction_server_machine_type=reduction_server_machine_type, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, ) return self._run( @@ -4811,13 +4940,16 @@ def run( enable_web_access=enable_web_access, enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, - reduction_server_container_uri=reduction_server_container_uri - if reduction_server_replica_count > 0 - else None, + reduction_server_container_uri=( + reduction_server_container_uri + if reduction_server_replica_count > 0 + else None + ), sync=sync, create_request_timeout=create_request_timeout, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) def submit( @@ -4871,6 +5003,12 @@ def submit( disable_retries: bool = False, persistent_resource_id: Optional[str] = None, tpu_topology: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, + reservation_affinity_type: Optional[ + Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"] + ] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, ) -> Optional[models.Model]: """Submits the custom training job without blocking until completion. @@ -5164,6 +5302,25 @@ def submit( details on the TPU topology, refer to https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must be a supported value for the TPU machine type. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. + reservation_affinity_type (str): + Optional. The type of reservation affinity. One of: + * "NO_RESERVATION" : No reservation is used. + * "ANY_RESERVATION" : Any reservation that matches machine spec + can be used. + * "SPECIFIC_RESERVATION" : A specific reservation must be use + used. See reservation_affinity_key and + reservation_affinity_values for how to specify the reservation. + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use + `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' Returns: model: The trained Vertex AI Model resource or None if training did not @@ -5188,6 +5345,9 @@ def submit( reduction_server_replica_count=reduction_server_replica_count, reduction_server_machine_type=reduction_server_machine_type, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, ) return self._run( @@ -5219,14 +5379,17 @@ def submit( enable_web_access=enable_web_access, enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, - reduction_server_container_uri=reduction_server_container_uri - if reduction_server_replica_count > 0 - else None, + reduction_server_container_uri=( + reduction_server_container_uri + if reduction_server_replica_count > 0 + else None + ), sync=sync, create_request_timeout=create_request_timeout, block=False, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -5273,6 +5436,7 @@ def _run( block: Optional[bool] = True, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Optional[models.Model]: """Packages local script and launches training_job. Args: @@ -5465,6 +5629,8 @@ def _run( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -5513,6 +5679,7 @@ def _run( tensorboard=tensorboard, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) model = self._run_job( @@ -7537,6 +7704,12 @@ def run( disable_retries: bool = False, persistent_resource_id: Optional[str] = None, tpu_topology: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, + reservation_affinity_type: Optional[ + Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"] + ] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, ) -> Optional[models.Model]: """Runs the custom training job. @@ -7831,6 +8004,25 @@ def run( details on the TPU topology, refer to https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must be a supported value for the TPU machine type. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. + reservation_affinity_type (str): + Optional. The type of reservation affinity. One of: + * "NO_RESERVATION" : No reservation is used. + * "ANY_RESERVATION" : Any reservation that matches machine spec + can be used. + * "SPECIFIC_RESERVATION" : A specific reservation must be use + used. See reservation_affinity_key and + reservation_affinity_values for how to specify the reservation. + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use + `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' Returns: model: The trained Vertex AI Model resource or None if training did not @@ -7851,6 +8043,9 @@ def run( reduction_server_replica_count=reduction_server_replica_count, reduction_server_machine_type=reduction_server_machine_type, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, ) return self._run( @@ -7882,13 +8077,16 @@ def run( enable_web_access=enable_web_access, enable_dashboard_access=enable_dashboard_access, tensorboard=tensorboard, - reduction_server_container_uri=reduction_server_container_uri - if reduction_server_replica_count > 0 - else None, + reduction_server_container_uri=( + reduction_server_container_uri + if reduction_server_replica_count > 0 + else None + ), sync=sync, create_request_timeout=create_request_timeout, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -7934,6 +8132,7 @@ def _run( create_request_timeout: Optional[float] = None, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Optional[models.Model]: """Packages local script and launches training_job. @@ -8111,6 +8310,8 @@ def _run( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -8159,6 +8360,7 @@ def _run( tensorboard=tensorboard, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) model = self._run_job( diff --git a/google/cloud/aiplatform/utils/__init__.py b/google/cloud/aiplatform/utils/__init__.py index ea269dea71..1a1ea3168c 100644 --- a/google/cloud/aiplatform/utils/__init__.py +++ b/google/cloud/aiplatform/utils/__init__.py @@ -21,7 +21,7 @@ import pathlib import logging import re -from typing import Any, Callable, Dict, Optional, Type, TypeVar, Tuple +from typing import Any, Callable, Dict, Optional, Type, TypeVar, Tuple, List import uuid from google.protobuf import timestamp_pb2 @@ -94,6 +94,7 @@ from google.cloud.aiplatform.compat.types import ( accelerator_type as gca_accelerator_type, + reservation_affinity_v1 as gca_reservation_affinity_v1, ) VertexAiServiceClient = TypeVar( @@ -393,6 +394,52 @@ def extract_project_and_location_from_parent( return parent_resources.groupdict() if parent_resources else {} +def get_reservation_affinity( + reservation_affinity_type: str, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, +) -> gca_reservation_affinity_v1.ReservationAffinity: + """Given reservation affinity type and/or key, values, return a ReservationAffinity object. + + Args: + reservation_affinity_type (str): + Required. The type of reservation affinity. + One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, + SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' + + Returns: + gca_reservation_affinity_v1.ReservationAffinity + + Raises: + ValueError: + If reservation_affinity_key and reservation_affinity_values are not + specified when reservation_affinity_type is SPECIFIC_RESERVATION. + """ + if reservation_affinity_type == "SPECIFIC_RESERVATION": + if not reservation_affinity_key or not reservation_affinity_values: + raise ValueError( + "reservation_affinity_key and reservation_affinity_values must be " + "specified when reservation_affinity_type is SPECIFIC_RESERVATION." + ) + return gca_reservation_affinity_v1.ReservationAffinity( + reservation_affinity_type=reservation_affinity_type, + key=reservation_affinity_key, + values=reservation_affinity_values, + ) + else: + return gca_reservation_affinity_v1.ReservationAffinity( + reservation_affinity_type=reservation_affinity_type, + ) + + class ClientWithOverride: class WrappedClient: """Wrapper class for client that creates client at API invocation diff --git a/google/cloud/aiplatform/utils/worker_spec_utils.py b/google/cloud/aiplatform/utils/worker_spec_utils.py index 65d509fdf3..fc201340c9 100644 --- a/google/cloud/aiplatform/utils/worker_spec_utils.py +++ b/google/cloud/aiplatform/utils/worker_spec_utils.py @@ -14,7 +14,7 @@ # limitations under the License. # -from typing import NamedTuple, Optional, Dict, Union, List +from typing import NamedTuple, Optional, Dict, Union, List, Literal from google.cloud.aiplatform import utils from google.cloud.aiplatform.compat.types import ( @@ -45,6 +45,9 @@ class _WorkerPoolSpec(NamedTuple): accelerator_type='NVIDIA_TESLA_K80', boot_disk_type='pd-ssd', boot_disk_size_gb=100, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, ) Note that container and python package specs are not stored with this spec. @@ -57,6 +60,11 @@ class _WorkerPoolSpec(NamedTuple): boot_disk_type: str = "pd-ssd" boot_disk_size_gb: int = 100 tpu_topology: Optional[str] = None + reservation_affinity_type: Optional[ + Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"] + ] = None + reservation_affinity_key: Optional[str] = None + reservation_affinity_values: Optional[List[str]] = None def _get_accelerator_type(self) -> Optional[str]: """Validates accelerator_type and returns the name of the accelerator. @@ -101,6 +109,18 @@ def spec_dict(self) -> Dict[str, Union[int, str, Dict[str, Union[int, str]]]]: if self.tpu_topology: spec["machine_spec"]["tpu_topology"] = self.tpu_topology + if self.reservation_affinity_type: + spec["machine_spec"]["reservation_affinity"] = { + "reservation_affinity_type": self.reservation_affinity_type, + } + if self.reservation_affinity_type == "SPECIFIC_RESERVATION": + spec["machine_spec"]["reservation_affinity"][ + "key" + ] = self.reservation_affinity_key + spec["machine_spec"]["reservation_affinity"][ + "values" + ] = self.reservation_affinity_values + return spec @property @@ -190,6 +210,11 @@ def chief_worker_pool( reduction_server_replica_count: int = 0, reduction_server_machine_type: str = None, tpu_topology: str = None, + reservation_affinity_type: Optional[ + Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"] + ] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, ) -> "_DistributedTrainingSpec": """Parametrizes Config to support only chief with worker replicas. @@ -223,6 +248,23 @@ def chief_worker_pool( TPU topology for the TPU type. This field is required for the TPU v5 versions. This field is only passed to the chief replica as TPU jobs only allow 1 replica. + reservation_affinity_type (str): + Optional. The type of reservation affinity. One of: + * "NO_RESERVATION" : No reservation is used. + * "ANY_RESERVATION" : Any reservation that matches machine spec + can be used. + * "SPECIFIC_RESERVATION" : A specific reservation must be use + used. See reservation_affinity_key and + reservation_affinity_values for how to specify the reservation. + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use + `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' Returns: _DistributedTrainingSpec representing one chief and n workers all of @@ -240,6 +282,9 @@ def chief_worker_pool( boot_disk_type=boot_disk_type, boot_disk_size_gb=boot_disk_size_gb, tpu_topology=tpu_topology, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, ) worker_spec = _WorkerPoolSpec( @@ -249,11 +294,17 @@ def chief_worker_pool( accelerator_type=accelerator_type, boot_disk_type=boot_disk_type, boot_disk_size_gb=boot_disk_size_gb, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, ) reduction_server_spec = _WorkerPoolSpec( replica_count=reduction_server_replica_count, machine_type=reduction_server_machine_type, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, ) return cls( diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py index 8372b02f53..2b1ba55395 100644 --- a/google/cloud/aiplatform/version.py +++ b/google/cloud/aiplatform/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.61.0" +__version__ = "1.62.0" diff --git a/google/cloud/aiplatform/vertex_ray/__init__.py b/google/cloud/aiplatform/vertex_ray/__init__.py index 8e58f0e7da..112ac77ab8 100644 --- a/google/cloud/aiplatform/vertex_ray/__init__.py +++ b/google/cloud/aiplatform/vertex_ray/__init__.py @@ -38,6 +38,7 @@ from google.cloud.aiplatform.vertex_ray.util.resources import ( Resources, NodeImages, + PscIConfig, ) from google.cloud.aiplatform.vertex_ray.dashboard_sdk import ( @@ -61,4 +62,5 @@ "update_ray_cluster", "Resources", "NodeImages", + "PscIConfig", ) diff --git a/google/cloud/aiplatform/vertex_ray/cluster_init.py b/google/cloud/aiplatform/vertex_ray/cluster_init.py index 1894847edc..fd278335e4 100644 --- a/google/cloud/aiplatform/vertex_ray/cluster_init.py +++ b/google/cloud/aiplatform/vertex_ray/cluster_init.py @@ -23,17 +23,20 @@ from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils from google.cloud.aiplatform.utils import resource_manager_utils -from google.cloud.aiplatform_v1.types import persistent_resource_service +from google.cloud.aiplatform_v1beta1.types import persistent_resource_service -from google.cloud.aiplatform_v1.types.persistent_resource import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( PersistentResource, + RayLogsSpec, RaySpec, RayMetricSpec, ResourcePool, ResourceRuntimeSpec, ServiceAccountSpec, ) - +from google.cloud.aiplatform_v1beta1.types.service_networking import ( + PscInterfaceConfig, +) from google.cloud.aiplatform.vertex_ray.util import ( _gapic_utils, _validation_utils, @@ -56,6 +59,8 @@ def create_ray_cluster( worker_node_types: Optional[List[resources.Resources]] = [resources.Resources()], custom_images: Optional[resources.NodeImages] = None, enable_metrics_collection: Optional[bool] = True, + enable_logging: Optional[bool] = True, + psc_interface_config: Optional[resources.PscIConfig] = None, labels: Optional[Dict[str, str]] = None, ) -> str: """Create a ray cluster on the Vertex AI. @@ -119,6 +124,8 @@ def create_ray_cluster( head/worker_node_type(s). Note that configuring `Resources.custom_image` will override `custom_images` here. Allowlist only. enable_metrics_collection: Enable Ray metrics collection for visualization. + enable_logging: Enable exporting Ray logs to Cloud Logging. + psc_interface_config: PSC-I config. labels: The labels with user-defined metadata to organize Ray cluster. @@ -258,10 +265,17 @@ def create_ray_cluster( i += 1 resource_pools = [resource_pool_0] + worker_pools - disabled = not enable_metrics_collection - ray_metric_spec = RayMetricSpec(disabled=disabled) + + metrics_collection_disabled = not enable_metrics_collection + ray_metric_spec = RayMetricSpec(disabled=metrics_collection_disabled) + + logging_disabled = not enable_logging + ray_logs_spec = RayLogsSpec(disabled=logging_disabled) + ray_spec = RaySpec( - resource_pool_images=resource_pool_images, ray_metric_spec=ray_metric_spec + resource_pool_images=resource_pool_images, + ray_metric_spec=ray_metric_spec, + ray_logs_spec=ray_logs_spec, ) if service_account: service_account_spec = ServiceAccountSpec( @@ -274,11 +288,18 @@ def create_ray_cluster( ) else: resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec) + if psc_interface_config: + gapic_psc_interface_config = PscInterfaceConfig( + network_attachment=psc_interface_config.network_attachment, + ) + else: + gapic_psc_interface_config = None persistent_resource = PersistentResource( resource_pools=resource_pools, network=network, labels=labels, resource_runtime_spec=resource_runtime_spec, + psc_interface_config=gapic_psc_interface_config, ) location = initializer.global_config.location diff --git a/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py b/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py index bfedef2db3..87f6f824b1 100644 --- a/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py +++ b/google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py @@ -28,12 +28,13 @@ from google.cloud.aiplatform.vertex_ray.util import _validation_utils from google.cloud.aiplatform.vertex_ray.util.resources import ( Cluster, + PscIConfig, Resources, ) -from google.cloud.aiplatform_v1.types.persistent_resource import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( PersistentResource, ) -from google.cloud.aiplatform_v1.types.persistent_resource_service import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import ( GetPersistentResourceRequest, ) @@ -47,7 +48,7 @@ def create_persistent_resource_client(): return initializer.global_config.create_client( client_class=PersistentResourceClientWithOverride, appended_gapic_version="vertex_ray", - ) + ).select_version("v1beta1") def polling_delay(num_attempts: int, time_scale: float) -> datetime.timedelta: @@ -159,6 +160,10 @@ def persistent_resource_to_cluster( % persistent_resource.name, ) return + if persistent_resource.psc_interface_config: + cluster.psc_interface_config = PscIConfig( + network_attachment=persistent_resource.psc_interface_config.network_attachment + ) resource_pools = persistent_resource.resource_pools head_resource_pool = resource_pools[0] @@ -192,6 +197,12 @@ def persistent_resource_to_cluster( ray_version = None cluster.python_version = python_version cluster.ray_version = ray_version + cluster.ray_metric_enabled = not ( + persistent_resource.resource_runtime_spec.ray_spec.ray_metric_spec.disabled + ) + cluster.ray_logs_enabled = not ( + persistent_resource.resource_runtime_spec.ray_spec.ray_logs_spec.disabled + ) accelerator_type = head_resource_pool.machine_spec.accelerator_type if accelerator_type.value != 0: diff --git a/google/cloud/aiplatform/vertex_ray/util/resources.py b/google/cloud/aiplatform/vertex_ray/util/resources.py index 28f28f68fd..3e865f34e3 100644 --- a/google/cloud/aiplatform/vertex_ray/util/resources.py +++ b/google/cloud/aiplatform/vertex_ray/util/resources.py @@ -16,7 +16,7 @@ # import dataclasses from typing import Dict, List, Optional -from google.cloud.aiplatform_v1.types import PersistentResource +from google.cloud.aiplatform_v1beta1.types import PersistentResource @dataclasses.dataclass @@ -68,6 +68,27 @@ class NodeImages: worker: str = None +@dataclasses.dataclass +class PscIConfig: + """PSC-I config. + + Attributes: + network_attachment: Optional. The name or full name of the Compute Engine + `network attachment ` + to attach to the resource. It has a format: + ``projects/{project}/regions/{region}/networkAttachments/{networkAttachment}``. + Where {project} is a project number, as in ``12345``, and + {networkAttachment} is a network attachment name. To specify + this field, you must have already [created a network + attachment] + (https://cloud.google.com/vpc/docs/create-manage-network-attachments#create-network-attachments). + This field is only used for resources using PSC-I. Make sure you do not + specify the network here for VPC peering. + """ + + network_attachment: str = None + + @dataclasses.dataclass class Cluster: """Ray cluster (output only). @@ -111,6 +132,9 @@ class Cluster: head_node_type: Resources = None worker_node_types: List[Resources] = None dashboard_address: str = None + ray_metric_enabled: bool = True + ray_logs_enabled: bool = True + psc_interface_config: PscIConfig = None labels: Dict[str, str] = None diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py index 723a7b5aca..033063ce65 100644 --- a/google/cloud/aiplatform_v1/__init__.py +++ b/google/cloud/aiplatform_v1/__init__.py @@ -28,6 +28,8 @@ ) from .services.endpoint_service import EndpointServiceClient from .services.endpoint_service import EndpointServiceAsyncClient +from .services.evaluation_service import EvaluationServiceClient +from .services.evaluation_service import EvaluationServiceAsyncClient from .services.feature_online_store_admin_service import ( FeatureOnlineStoreAdminServiceClient, ) @@ -200,6 +202,108 @@ from .types.evaluated_annotation import ErrorAnalysisAnnotation from .types.evaluated_annotation import EvaluatedAnnotation from .types.evaluated_annotation import EvaluatedAnnotationExplanation +from .types.evaluation_service import BleuInput +from .types.evaluation_service import BleuInstance +from .types.evaluation_service import BleuMetricValue +from .types.evaluation_service import BleuResults +from .types.evaluation_service import BleuSpec +from .types.evaluation_service import CoherenceInput +from .types.evaluation_service import CoherenceInstance +from .types.evaluation_service import CoherenceResult +from .types.evaluation_service import CoherenceSpec +from .types.evaluation_service import EvaluateInstancesRequest +from .types.evaluation_service import EvaluateInstancesResponse +from .types.evaluation_service import ExactMatchInput +from .types.evaluation_service import ExactMatchInstance +from .types.evaluation_service import ExactMatchMetricValue +from .types.evaluation_service import ExactMatchResults +from .types.evaluation_service import ExactMatchSpec +from .types.evaluation_service import FluencyInput +from .types.evaluation_service import FluencyInstance +from .types.evaluation_service import FluencyResult +from .types.evaluation_service import FluencySpec +from .types.evaluation_service import FulfillmentInput +from .types.evaluation_service import FulfillmentInstance +from .types.evaluation_service import FulfillmentResult +from .types.evaluation_service import FulfillmentSpec +from .types.evaluation_service import GroundednessInput +from .types.evaluation_service import GroundednessInstance +from .types.evaluation_service import GroundednessResult +from .types.evaluation_service import GroundednessSpec +from .types.evaluation_service import PairwiseMetricInput +from .types.evaluation_service import PairwiseMetricInstance +from .types.evaluation_service import PairwiseMetricResult +from .types.evaluation_service import PairwiseMetricSpec +from .types.evaluation_service import PairwiseQuestionAnsweringQualityInput +from .types.evaluation_service import PairwiseQuestionAnsweringQualityInstance +from .types.evaluation_service import PairwiseQuestionAnsweringQualityResult +from .types.evaluation_service import PairwiseQuestionAnsweringQualitySpec +from .types.evaluation_service import PairwiseSummarizationQualityInput +from .types.evaluation_service import PairwiseSummarizationQualityInstance +from .types.evaluation_service import PairwiseSummarizationQualityResult +from .types.evaluation_service import PairwiseSummarizationQualitySpec +from .types.evaluation_service import PointwiseMetricInput +from .types.evaluation_service import PointwiseMetricInstance +from .types.evaluation_service import PointwiseMetricResult +from .types.evaluation_service import PointwiseMetricSpec +from .types.evaluation_service import QuestionAnsweringCorrectnessInput +from .types.evaluation_service import QuestionAnsweringCorrectnessInstance +from .types.evaluation_service import QuestionAnsweringCorrectnessResult +from .types.evaluation_service import QuestionAnsweringCorrectnessSpec +from .types.evaluation_service import QuestionAnsweringHelpfulnessInput +from .types.evaluation_service import QuestionAnsweringHelpfulnessInstance +from .types.evaluation_service import QuestionAnsweringHelpfulnessResult +from .types.evaluation_service import QuestionAnsweringHelpfulnessSpec +from .types.evaluation_service import QuestionAnsweringQualityInput +from .types.evaluation_service import QuestionAnsweringQualityInstance +from .types.evaluation_service import QuestionAnsweringQualityResult +from .types.evaluation_service import QuestionAnsweringQualitySpec +from .types.evaluation_service import QuestionAnsweringRelevanceInput +from .types.evaluation_service import QuestionAnsweringRelevanceInstance +from .types.evaluation_service import QuestionAnsweringRelevanceResult +from .types.evaluation_service import QuestionAnsweringRelevanceSpec +from .types.evaluation_service import RougeInput +from .types.evaluation_service import RougeInstance +from .types.evaluation_service import RougeMetricValue +from .types.evaluation_service import RougeResults +from .types.evaluation_service import RougeSpec +from .types.evaluation_service import SafetyInput +from .types.evaluation_service import SafetyInstance +from .types.evaluation_service import SafetyResult +from .types.evaluation_service import SafetySpec +from .types.evaluation_service import SummarizationHelpfulnessInput +from .types.evaluation_service import SummarizationHelpfulnessInstance +from .types.evaluation_service import SummarizationHelpfulnessResult +from .types.evaluation_service import SummarizationHelpfulnessSpec +from .types.evaluation_service import SummarizationQualityInput +from .types.evaluation_service import SummarizationQualityInstance +from .types.evaluation_service import SummarizationQualityResult +from .types.evaluation_service import SummarizationQualitySpec +from .types.evaluation_service import SummarizationVerbosityInput +from .types.evaluation_service import SummarizationVerbosityInstance +from .types.evaluation_service import SummarizationVerbosityResult +from .types.evaluation_service import SummarizationVerbositySpec +from .types.evaluation_service import ToolCallValidInput +from .types.evaluation_service import ToolCallValidInstance +from .types.evaluation_service import ToolCallValidMetricValue +from .types.evaluation_service import ToolCallValidResults +from .types.evaluation_service import ToolCallValidSpec +from .types.evaluation_service import ToolNameMatchInput +from .types.evaluation_service import ToolNameMatchInstance +from .types.evaluation_service import ToolNameMatchMetricValue +from .types.evaluation_service import ToolNameMatchResults +from .types.evaluation_service import ToolNameMatchSpec +from .types.evaluation_service import ToolParameterKeyMatchInput +from .types.evaluation_service import ToolParameterKeyMatchInstance +from .types.evaluation_service import ToolParameterKeyMatchMetricValue +from .types.evaluation_service import ToolParameterKeyMatchResults +from .types.evaluation_service import ToolParameterKeyMatchSpec +from .types.evaluation_service import ToolParameterKVMatchInput +from .types.evaluation_service import ToolParameterKVMatchInstance +from .types.evaluation_service import ToolParameterKVMatchMetricValue +from .types.evaluation_service import ToolParameterKVMatchResults +from .types.evaluation_service import ToolParameterKVMatchSpec +from .types.evaluation_service import PairwiseChoice from .types.event import Event from .types.execution import Execution from .types.explanation import Attribution @@ -671,6 +775,7 @@ from .types.prediction_service import StreamingRawPredictResponse from .types.prediction_service import StreamRawPredictRequest from .types.publisher_model import PublisherModel +from .types.reservation_affinity import ReservationAffinity from .types.saved_query import SavedQuery from .types.schedule import Schedule from .types.schedule_service import CreateScheduleRequest @@ -812,6 +917,7 @@ "DatasetServiceAsyncClient", "DeploymentResourcePoolServiceAsyncClient", "EndpointServiceAsyncClient", + "EvaluationServiceAsyncClient", "FeatureOnlineStoreAdminServiceAsyncClient", "FeatureOnlineStoreServiceAsyncClient", "FeatureRegistryServiceAsyncClient", @@ -881,6 +987,11 @@ "BatchReadTensorboardTimeSeriesDataResponse", "BigQueryDestination", "BigQuerySource", + "BleuInput", + "BleuInstance", + "BleuMetricValue", + "BleuResults", + "BleuSpec", "Blob", "BlurBaselineConfig", "BoolArray", @@ -898,6 +1009,10 @@ "CheckTrialEarlyStoppingStateResponse", "Citation", "CitationMetadata", + "CoherenceInput", + "CoherenceInstance", + "CoherenceResult", + "CoherenceSpec", "CompleteTrialRequest", "CompletionStats", "ComputeTokensRequest", @@ -1049,9 +1164,17 @@ "EntityType", "EnvVar", "ErrorAnalysisAnnotation", + "EvaluateInstancesRequest", + "EvaluateInstancesResponse", "EvaluatedAnnotation", "EvaluatedAnnotationExplanation", + "EvaluationServiceClient", "Event", + "ExactMatchInput", + "ExactMatchInstance", + "ExactMatchMetricValue", + "ExactMatchResults", + "ExactMatchSpec", "Examples", "ExamplesOverride", "ExamplesRestrictionsNamespace", @@ -1104,7 +1227,15 @@ "FilterSplit", "FindNeighborsRequest", "FindNeighborsResponse", + "FluencyInput", + "FluencyInstance", + "FluencyResult", + "FluencySpec", "FractionSplit", + "FulfillmentInput", + "FulfillmentInstance", + "FulfillmentResult", + "FulfillmentSpec", "FunctionCall", "FunctionCallingConfig", "FunctionDeclaration", @@ -1163,6 +1294,10 @@ "GetTrialRequest", "GetTuningJobRequest", "GoogleSearchRetrieval", + "GroundednessInput", + "GroundednessInstance", + "GroundednessResult", + "GroundednessSpec", "GroundingChunk", "GroundingMetadata", "GroundingSupport", @@ -1344,6 +1479,19 @@ "NotebookRuntimeTemplateRef", "NotebookRuntimeType", "NotebookServiceClient", + "PairwiseChoice", + "PairwiseMetricInput", + "PairwiseMetricInstance", + "PairwiseMetricResult", + "PairwiseMetricSpec", + "PairwiseQuestionAnsweringQualityInput", + "PairwiseQuestionAnsweringQualityInstance", + "PairwiseQuestionAnsweringQualityResult", + "PairwiseQuestionAnsweringQualitySpec", + "PairwiseSummarizationQualityInput", + "PairwiseSummarizationQualityInstance", + "PairwiseSummarizationQualityResult", + "PairwiseSummarizationQualitySpec", "Part", "PauseModelDeploymentMonitoringJobRequest", "PauseScheduleRequest", @@ -1358,6 +1506,10 @@ "PipelineTaskDetail", "PipelineTaskExecutorDetail", "PipelineTemplateMetadata", + "PointwiseMetricInput", + "PointwiseMetricInstance", + "PointwiseMetricResult", + "PointwiseMetricSpec", "Port", "PredefinedSplit", "PredictRequest", @@ -1387,6 +1539,22 @@ "QueryDeployedModelsRequest", "QueryDeployedModelsResponse", "QueryExecutionInputsAndOutputsRequest", + "QuestionAnsweringCorrectnessInput", + "QuestionAnsweringCorrectnessInstance", + "QuestionAnsweringCorrectnessResult", + "QuestionAnsweringCorrectnessSpec", + "QuestionAnsweringHelpfulnessInput", + "QuestionAnsweringHelpfulnessInstance", + "QuestionAnsweringHelpfulnessResult", + "QuestionAnsweringHelpfulnessSpec", + "QuestionAnsweringQualityInput", + "QuestionAnsweringQualityInstance", + "QuestionAnsweringQualityResult", + "QuestionAnsweringQualitySpec", + "QuestionAnsweringRelevanceInput", + "QuestionAnsweringRelevanceInstance", + "QuestionAnsweringRelevanceResult", + "QuestionAnsweringRelevanceSpec", "RawPredictRequest", "RayLogsSpec", "RayMetricSpec", @@ -1409,6 +1577,7 @@ "RemoveContextChildrenResponse", "RemoveDatapointsRequest", "RemoveDatapointsResponse", + "ReservationAffinity", "ResourcePool", "ResourceRuntime", "ResourceRuntimeSpec", @@ -1418,8 +1587,17 @@ "ResumeModelDeploymentMonitoringJobRequest", "ResumeScheduleRequest", "Retrieval", + "RougeInput", + "RougeInstance", + "RougeMetricValue", + "RougeResults", + "RougeSpec", + "SafetyInput", + "SafetyInstance", "SafetyRating", + "SafetyResult", "SafetySetting", + "SafetySpec", "SampleConfig", "SampledShapleyAttribution", "SamplingStrategy", @@ -1470,6 +1648,18 @@ "SuggestTrialsMetadata", "SuggestTrialsRequest", "SuggestTrialsResponse", + "SummarizationHelpfulnessInput", + "SummarizationHelpfulnessInstance", + "SummarizationHelpfulnessResult", + "SummarizationHelpfulnessSpec", + "SummarizationQualityInput", + "SummarizationQualityInstance", + "SummarizationQualityResult", + "SummarizationQualitySpec", + "SummarizationVerbosityInput", + "SummarizationVerbosityInstance", + "SummarizationVerbosityResult", + "SummarizationVerbositySpec", "SupervisedHyperParameters", "SupervisedTuningDataStats", "SupervisedTuningDatasetDistribution", @@ -1492,7 +1682,27 @@ "TimestampSplit", "TokensInfo", "Tool", + "ToolCallValidInput", + "ToolCallValidInstance", + "ToolCallValidMetricValue", + "ToolCallValidResults", + "ToolCallValidSpec", "ToolConfig", + "ToolNameMatchInput", + "ToolNameMatchInstance", + "ToolNameMatchMetricValue", + "ToolNameMatchResults", + "ToolNameMatchSpec", + "ToolParameterKVMatchInput", + "ToolParameterKVMatchInstance", + "ToolParameterKVMatchMetricValue", + "ToolParameterKVMatchResults", + "ToolParameterKVMatchSpec", + "ToolParameterKeyMatchInput", + "ToolParameterKeyMatchInstance", + "ToolParameterKeyMatchMetricValue", + "ToolParameterKeyMatchResults", + "ToolParameterKeyMatchSpec", "TrainingConfig", "TrainingPipeline", "Trial", diff --git a/google/cloud/aiplatform_v1/gapic_metadata.json b/google/cloud/aiplatform_v1/gapic_metadata.json index cb891e9b06..c404c1c7a8 100644 --- a/google/cloud/aiplatform_v1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1/gapic_metadata.json @@ -557,6 +557,40 @@ } } }, + "EvaluationService": { + "clients": { + "grpc": { + "libraryClient": "EvaluationServiceClient", + "rpcs": { + "EvaluateInstances": { + "methods": [ + "evaluate_instances" + ] + } + } + }, + "grpc-async": { + "libraryClient": "EvaluationServiceAsyncClient", + "rpcs": { + "EvaluateInstances": { + "methods": [ + "evaluate_instances" + ] + } + } + }, + "rest": { + "libraryClient": "EvaluationServiceClient", + "rpcs": { + "EvaluateInstances": { + "methods": [ + "evaluate_instances" + ] + } + } + } + } + }, "FeatureOnlineStoreAdminService": { "clients": { "grpc": { diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform_v1/gapic_version.py +++ b/google/cloud/aiplatform_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py index 886029a303..e34527d4c5 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py @@ -777,6 +777,8 @@ async def sample_list_datasets(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1799,6 +1801,8 @@ async def sample_list_dataset_versions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2043,6 +2047,8 @@ async def sample_list_data_items(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2141,6 +2147,8 @@ async def sample_search_data_items(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2261,6 +2269,8 @@ async def sample_list_saved_queries(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2617,6 +2627,8 @@ async def sample_list_annotations(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/client.py b/google/cloud/aiplatform_v1/services/dataset_service/client.py index 236ffda044..80600fe010 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/client.py @@ -1292,6 +1292,8 @@ def sample_list_datasets(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2290,6 +2292,8 @@ def sample_list_dataset_versions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2528,6 +2532,8 @@ def sample_list_data_items(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2624,6 +2630,8 @@ def sample_search_data_items(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2741,6 +2749,8 @@ def sample_list_saved_queries(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3088,6 +3098,8 @@ def sample_list_annotations(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py index 393a97cae9..fd0eab5fa9 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import annotation from google.cloud.aiplatform_v1.types import data_item from google.cloud.aiplatform_v1.types import dataset @@ -56,6 +69,8 @@ def __init__( request: dataset_service.ListDatasetsRequest, response: dataset_service.ListDatasetsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -67,12 +82,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListDatasetsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDatasetsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -83,7 +103,12 @@ def pages(self) -> Iterator[dataset_service.ListDatasetsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[dataset.Dataset]: @@ -118,6 +143,8 @@ def __init__( request: dataset_service.ListDatasetsRequest, response: dataset_service.ListDatasetsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -129,12 +156,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListDatasetsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDatasetsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages(self) -> AsyncIterator[dataset_service.ListDatasetsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[dataset.Dataset]: @@ -184,6 +221,8 @@ def __init__( request: dataset_service.ListDatasetVersionsRequest, response: dataset_service.ListDatasetVersionsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -195,12 +234,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListDatasetVersionsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDatasetVersionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -211,7 +255,12 @@ def pages(self) -> Iterator[dataset_service.ListDatasetVersionsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[dataset_version.DatasetVersion]: @@ -246,6 +295,8 @@ def __init__( request: dataset_service.ListDatasetVersionsRequest, response: dataset_service.ListDatasetVersionsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -257,12 +308,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListDatasetVersionsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDatasetVersionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -273,7 +329,12 @@ async def pages(self) -> AsyncIterator[dataset_service.ListDatasetVersionsRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[dataset_version.DatasetVersion]: @@ -312,6 +373,8 @@ def __init__( request: dataset_service.ListDataItemsRequest, response: dataset_service.ListDataItemsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -323,12 +386,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListDataItemsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDataItemsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -339,7 +407,12 @@ def pages(self) -> Iterator[dataset_service.ListDataItemsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[data_item.DataItem]: @@ -374,6 +447,8 @@ def __init__( request: dataset_service.ListDataItemsRequest, response: dataset_service.ListDataItemsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -385,12 +460,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListDataItemsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDataItemsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -401,7 +481,12 @@ async def pages(self) -> AsyncIterator[dataset_service.ListDataItemsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[data_item.DataItem]: @@ -440,6 +525,8 @@ def __init__( request: dataset_service.SearchDataItemsRequest, response: dataset_service.SearchDataItemsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -451,12 +538,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.SearchDataItemsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.SearchDataItemsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -467,7 +559,12 @@ def pages(self) -> Iterator[dataset_service.SearchDataItemsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[dataset_service.DataItemView]: @@ -502,6 +599,8 @@ def __init__( request: dataset_service.SearchDataItemsRequest, response: dataset_service.SearchDataItemsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -513,12 +612,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.SearchDataItemsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.SearchDataItemsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -529,7 +633,12 @@ async def pages(self) -> AsyncIterator[dataset_service.SearchDataItemsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[dataset_service.DataItemView]: @@ -568,6 +677,8 @@ def __init__( request: dataset_service.ListSavedQueriesRequest, response: dataset_service.ListSavedQueriesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -579,12 +690,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListSavedQueriesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListSavedQueriesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -595,7 +711,12 @@ def pages(self) -> Iterator[dataset_service.ListSavedQueriesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[saved_query.SavedQuery]: @@ -630,6 +751,8 @@ def __init__( request: dataset_service.ListSavedQueriesRequest, response: dataset_service.ListSavedQueriesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -641,12 +764,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListSavedQueriesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListSavedQueriesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -657,7 +785,12 @@ async def pages(self) -> AsyncIterator[dataset_service.ListSavedQueriesResponse] yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[saved_query.SavedQuery]: @@ -696,6 +829,8 @@ def __init__( request: dataset_service.ListAnnotationsRequest, response: dataset_service.ListAnnotationsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -707,12 +842,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListAnnotationsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListAnnotationsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -723,7 +863,12 @@ def pages(self) -> Iterator[dataset_service.ListAnnotationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[annotation.Annotation]: @@ -758,6 +903,8 @@ def __init__( request: dataset_service.ListAnnotationsRequest, response: dataset_service.ListAnnotationsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -769,12 +916,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListAnnotationsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListAnnotationsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -785,7 +937,12 @@ async def pages(self) -> AsyncIterator[dataset_service.ListAnnotationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[annotation.Annotation]: diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/rest.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/rest.py index e2be118dc1..3bfd085692 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/rest.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/rest.py @@ -2877,6 +2877,11 @@ def __call__( "uri": "/v1/{parent=projects/*/locations/*}/datasets", "body": "dataset", }, + { + "method": "post", + "uri": "/v1/datasets", + "body": "dataset", + }, ] request, metadata = self._interceptor.pre_create_dataset(request, metadata) pb_request = dataset_service.CreateDatasetRequest.pb(request) @@ -2971,6 +2976,11 @@ def __call__( "uri": "/v1/{parent=projects/*/locations/*/datasets/*}/datasetVersions", "body": "dataset_version", }, + { + "method": "post", + "uri": "/v1/{parent=datasets/*}/datasetVersions", + "body": "dataset_version", + }, ] request, metadata = self._interceptor.pre_create_dataset_version( request, metadata @@ -3066,6 +3076,10 @@ def __call__( "method": "delete", "uri": "/v1/{name=projects/*/locations/*/datasets/*}", }, + { + "method": "delete", + "uri": "/v1/{name=datasets/*}", + }, ] request, metadata = self._interceptor.pre_delete_dataset(request, metadata) pb_request = dataset_service.DeleteDatasetRequest.pb(request) @@ -3153,6 +3167,10 @@ def __call__( "method": "delete", "uri": "/v1/{name=projects/*/locations/*/datasets/*/datasetVersions/*}", }, + { + "method": "delete", + "uri": "/v1/{name=datasets/*/datasetVersions/*}", + }, ] request, metadata = self._interceptor.pre_delete_dataset_version( request, metadata @@ -3514,6 +3532,10 @@ def __call__( "method": "get", "uri": "/v1/{name=projects/*/locations/*/datasets/*}", }, + { + "method": "get", + "uri": "/v1/{name=datasets/*}", + }, ] request, metadata = self._interceptor.pre_get_dataset(request, metadata) pb_request = dataset_service.GetDatasetRequest.pb(request) @@ -3600,6 +3622,10 @@ def __call__( "method": "get", "uri": "/v1/{name=projects/*/locations/*/datasets/*/datasetVersions/*}", }, + { + "method": "get", + "uri": "/v1/{name=datasets/*/datasetVersions/*}", + }, ] request, metadata = self._interceptor.pre_get_dataset_version( request, metadata @@ -3962,6 +3988,10 @@ def __call__( "method": "get", "uri": "/v1/{parent=projects/*/locations/*}/datasets", }, + { + "method": "get", + "uri": "/v1/datasets", + }, ] request, metadata = self._interceptor.pre_list_datasets(request, metadata) pb_request = dataset_service.ListDatasetsRequest.pb(request) @@ -4050,6 +4080,10 @@ def __call__( "method": "get", "uri": "/v1/{parent=projects/*/locations/*/datasets/*}/datasetVersions", }, + { + "method": "get", + "uri": "/v1/{parent=datasets/*}/datasetVersions", + }, ] request, metadata = self._interceptor.pre_list_dataset_versions( request, metadata @@ -4231,6 +4265,10 @@ def __call__( "method": "get", "uri": "/v1/{name=projects/*/locations/*/datasets/*/datasetVersions/*}:restore", }, + { + "method": "get", + "uri": "/v1/{name=datasets/*/datasetVersions/*}:restore", + }, ] request, metadata = self._interceptor.pre_restore_dataset_version( request, metadata @@ -4412,6 +4450,11 @@ def __call__( "uri": "/v1/{dataset.name=projects/*/locations/*/datasets/*}", "body": "dataset", }, + { + "method": "patch", + "uri": "/v1/{dataset.name=datasets/*}", + "body": "dataset", + }, ] request, metadata = self._interceptor.pre_update_dataset(request, metadata) pb_request = dataset_service.UpdateDatasetRequest.pb(request) @@ -4507,6 +4550,11 @@ def __call__( "uri": "/v1/{dataset_version.name=projects/*/locations/*/datasets/*/datasetVersions/*}", "body": "dataset_version", }, + { + "method": "patch", + "uri": "/v1/{dataset_version.name=datasets/*/datasetVersions/*}", + "body": "dataset_version", + }, ] request, metadata = self._interceptor.pre_update_dataset_version( request, metadata diff --git a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/async_client.py b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/async_client.py index 106231dc60..c7080a0057 100644 --- a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/async_client.py @@ -96,6 +96,12 @@ class DeploymentResourcePoolServiceAsyncClient: parse_model_path = staticmethod( DeploymentResourcePoolServiceClient.parse_model_path ) + reservation_path = staticmethod( + DeploymentResourcePoolServiceClient.reservation_path + ) + parse_reservation_path = staticmethod( + DeploymentResourcePoolServiceClient.parse_reservation_path + ) common_billing_account_path = staticmethod( DeploymentResourcePoolServiceClient.common_billing_account_path ) @@ -708,6 +714,8 @@ async def sample_list_deployment_resource_pools(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1136,6 +1144,8 @@ async def sample_query_deployed_models(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/client.py b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/client.py index ae39a3fa3d..9f08d2a6bd 100644 --- a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/client.py +++ b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/client.py @@ -269,6 +269,28 @@ def parse_model_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path( billing_account: str, @@ -1162,6 +1184,8 @@ def sample_list_deployment_resource_pools(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1585,6 +1609,8 @@ def sample_query_deployed_models(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/pagers.py b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/pagers.py index 4843247d4d..84c391309c 100644 --- a/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/deployment_resource_pool_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import deployment_resource_pool from google.cloud.aiplatform_v1.types import deployment_resource_pool_service from google.cloud.aiplatform_v1.types import endpoint @@ -55,6 +68,8 @@ def __init__( request: deployment_resource_pool_service.ListDeploymentResourcePoolsRequest, response: deployment_resource_pool_service.ListDeploymentResourcePoolsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -66,6 +81,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListDeploymentResourcePoolsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -74,6 +92,8 @@ def __init__( deployment_resource_pool_service.ListDeploymentResourcePoolsRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -86,7 +106,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[deployment_resource_pool.DeploymentResourcePool]: @@ -126,6 +151,8 @@ def __init__( request: deployment_resource_pool_service.ListDeploymentResourcePoolsRequest, response: deployment_resource_pool_service.ListDeploymentResourcePoolsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -137,6 +164,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListDeploymentResourcePoolsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -145,6 +175,8 @@ def __init__( deployment_resource_pool_service.ListDeploymentResourcePoolsRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -159,7 +191,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__( @@ -202,6 +239,8 @@ def __init__( request: deployment_resource_pool_service.QueryDeployedModelsRequest, response: deployment_resource_pool_service.QueryDeployedModelsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -213,6 +252,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.QueryDeployedModelsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -221,6 +263,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -233,7 +277,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[endpoint.DeployedModel]: @@ -270,6 +319,8 @@ def __init__( request: deployment_resource_pool_service.QueryDeployedModelsRequest, response: deployment_resource_pool_service.QueryDeployedModelsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -281,6 +332,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.QueryDeployedModelsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -289,6 +343,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -301,7 +357,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[endpoint.DeployedModel]: diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py index cc6d477249..145b365329 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py @@ -95,6 +95,8 @@ class EndpointServiceAsyncClient: ) network_path = staticmethod(EndpointServiceClient.network_path) parse_network_path = staticmethod(EndpointServiceClient.parse_network_path) + reservation_path = staticmethod(EndpointServiceClient.reservation_path) + parse_reservation_path = staticmethod(EndpointServiceClient.parse_reservation_path) common_billing_account_path = staticmethod( EndpointServiceClient.common_billing_account_path ) @@ -669,6 +671,8 @@ async def sample_list_endpoints(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py index 8b1a3173c7..648563edc4 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py @@ -303,6 +303,28 @@ def parse_network_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path( billing_account: str, @@ -1148,6 +1170,8 @@ def sample_list_endpoints(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py index a93cac9443..1b9c72c5f8 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import endpoint from google.cloud.aiplatform_v1.types import endpoint_service @@ -52,6 +65,8 @@ def __init__( request: endpoint_service.ListEndpointsRequest, response: endpoint_service.ListEndpointsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListEndpointsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = endpoint_service.ListEndpointsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[endpoint_service.ListEndpointsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[endpoint.Endpoint]: @@ -114,6 +139,8 @@ def __init__( request: endpoint_service.ListEndpointsRequest, response: endpoint_service.ListEndpointsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -125,12 +152,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListEndpointsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = endpoint_service.ListEndpointsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -141,7 +173,12 @@ async def pages(self) -> AsyncIterator[endpoint_service.ListEndpointsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[endpoint.Endpoint]: diff --git a/google/cloud/aiplatform_v1/services/evaluation_service/__init__.py b/google/cloud/aiplatform_v1/services/evaluation_service/__init__.py new file mode 100644 index 0000000000..b09b156cce --- /dev/null +++ b/google/cloud/aiplatform_v1/services/evaluation_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .client import EvaluationServiceClient +from .async_client import EvaluationServiceAsyncClient + +__all__ = ( + "EvaluationServiceClient", + "EvaluationServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1/services/evaluation_service/async_client.py b/google/cloud/aiplatform_v1/services/evaluation_service/async_client.py new file mode 100644 index 0000000000..fc8b70f38e --- /dev/null +++ b/google/cloud/aiplatform_v1/services/evaluation_service/async_client.py @@ -0,0 +1,1078 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import functools +import re +from typing import ( + Dict, + Callable, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.cloud.aiplatform_v1 import gapic_version as package_version + +from google.api_core.client_options import ClientOptions +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + + +try: + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore + +from google.cloud.aiplatform_v1.types import evaluation_service +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from .transports.base import EvaluationServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import EvaluationServiceGrpcAsyncIOTransport +from .client import EvaluationServiceClient + + +class EvaluationServiceAsyncClient: + """Vertex AI Online Evaluation Service.""" + + _client: EvaluationServiceClient + + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = EvaluationServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = EvaluationServiceClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = EvaluationServiceClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = EvaluationServiceClient._DEFAULT_UNIVERSE + + common_billing_account_path = staticmethod( + EvaluationServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + EvaluationServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(EvaluationServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + EvaluationServiceClient.parse_common_folder_path + ) + common_organization_path = staticmethod( + EvaluationServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + EvaluationServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(EvaluationServiceClient.common_project_path) + parse_common_project_path = staticmethod( + EvaluationServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(EvaluationServiceClient.common_location_path) + parse_common_location_path = staticmethod( + EvaluationServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + EvaluationServiceAsyncClient: The constructed client. + """ + return EvaluationServiceClient.from_service_account_info.__func__(EvaluationServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + EvaluationServiceAsyncClient: The constructed client. + """ + return EvaluationServiceClient.from_service_account_file.__func__(EvaluationServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return EvaluationServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> EvaluationServiceTransport: + """Returns the transport used by the client instance. + + Returns: + EvaluationServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + + get_transport_class = functools.partial( + type(EvaluationServiceClient).get_transport_class, type(EvaluationServiceClient) + ) + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[ + str, + EvaluationServiceTransport, + Callable[..., EvaluationServiceTransport], + ] + ] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the evaluation service async client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,EvaluationServiceTransport,Callable[..., EvaluationServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the EvaluationServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = EvaluationServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def evaluate_instances( + self, + request: Optional[ + Union[evaluation_service.EvaluateInstancesRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> evaluation_service.EvaluateInstancesResponse: + r"""Evaluates instances based on a given metric. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_evaluate_instances(): + # Create a client + client = aiplatform_v1.EvaluationServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.EvaluateInstancesRequest( + location="location_value", + ) + + # Make the request + response = await client.evaluate_instances(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.EvaluateInstancesRequest, dict]]): + The request object. Request message for + EvaluationService.EvaluateInstances. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.EvaluateInstancesResponse: + Response message for + EvaluationService.EvaluateInstances. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, evaluation_service.EvaluateInstancesRequest): + request = evaluation_service.EvaluateInstancesRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.evaluate_instances + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("location", request.location),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_operations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_operation( + self, + request: Optional[operations_pb2.DeleteOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a long-running operation. + + This method indicates that the client is no longer interested + in the operation result. It does not cancel the operation. + If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.DeleteOperationRequest`): + The request object. Request message for + `DeleteOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.DeleteOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def cancel_operation( + self, + request: Optional[operations_pb2.CancelOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Starts asynchronous cancellation on a long-running operation. + + The server makes a best effort to cancel the operation, but success + is not guaranteed. If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.CancelOperationRequest`): + The request object. Request message for + `CancelOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.CancelOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def wait_operation( + self, + request: Optional[operations_pb2.WaitOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Waits until the specified long-running operation is done or reaches at most + a specified timeout, returning the latest state. + + If the operation is already done, the latest state is immediately returned. + If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC + timeout is used. If the server does not support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.WaitOperationRequest`): + The request object. Request message for + `WaitOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.WaitOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.wait_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def set_iam_policy( + self, + request: Optional[iam_policy_pb2.SetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Sets the IAM access control policy on the specified function. + + Replaces any existing policy. + + Args: + request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`): + The request object. Request message for `SetIamPolicy` + method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.SetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.set_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_iam_policy( + self, + request: Optional[iam_policy_pb2.GetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Gets the IAM access control policy for a function. + + Returns an empty policy if the function exists and does not have a + policy set. + + Args: + request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`): + The request object. Request message for `GetIamPolicy` + method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if + any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.GetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def test_iam_permissions( + self, + request: Optional[iam_policy_pb2.TestIamPermissionsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> iam_policy_pb2.TestIamPermissionsResponse: + r"""Tests the specified IAM permissions against the IAM access control + policy for a function. + + If the function does not exist, this will return an empty set + of permissions, not a NOT_FOUND error. + + Args: + request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`): + The request object. Request message for + `TestIamPermissions` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.iam_policy_pb2.TestIamPermissionsResponse: + Response message for ``TestIamPermissions`` method. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.TestIamPermissionsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.test_iam_permissions, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_location( + self, + request: Optional[locations_pb2.GetLocationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.Location: + r"""Gets information about a location. + + Args: + request (:class:`~.location_pb2.GetLocationRequest`): + The request object. Request message for + `GetLocation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.Location: + Location object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.GetLocationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_location, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_locations( + self, + request: Optional[locations_pb2.ListLocationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.ListLocationsResponse: + r"""Lists information about the supported locations for this service. + + Args: + request (:class:`~.location_pb2.ListLocationsRequest`): + The request object. Request message for + `ListLocations` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.ListLocationsResponse: + Response message for ``ListLocations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.ListLocationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_locations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "EvaluationServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("EvaluationServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/evaluation_service/client.py b/google/cloud/aiplatform_v1/services/evaluation_service/client.py new file mode 100644 index 0000000000..658041afae --- /dev/null +++ b/google/cloud/aiplatform_v1/services/evaluation_service/client.py @@ -0,0 +1,1483 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import os +import re +from typing import ( + Dict, + Callable, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +import warnings + +from google.cloud.aiplatform_v1 import gapic_version as package_version + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +from google.cloud.aiplatform_v1.types import evaluation_service +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from .transports.base import EvaluationServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import EvaluationServiceGrpcTransport +from .transports.grpc_asyncio import EvaluationServiceGrpcAsyncIOTransport +from .transports.rest import EvaluationServiceRestTransport + + +class EvaluationServiceClientMeta(type): + """Metaclass for the EvaluationService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[EvaluationServiceTransport]] + _transport_registry["grpc"] = EvaluationServiceGrpcTransport + _transport_registry["grpc_asyncio"] = EvaluationServiceGrpcAsyncIOTransport + _transport_registry["rest"] = EvaluationServiceRestTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[EvaluationServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class EvaluationServiceClient(metaclass=EvaluationServiceClientMeta): + """Vertex AI Online Evaluation Service.""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + _DEFAULT_ENDPOINT_TEMPLATE = "aiplatform.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + EvaluationServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + EvaluationServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> EvaluationServiceTransport: + """Returns the transport used by the client instance. + + Returns: + EvaluationServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Deprecated. Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + @staticmethod + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + @staticmethod + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = EvaluationServiceClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = EvaluationServiceClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = EvaluationServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = EvaluationServiceClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + @staticmethod + def _compare_universes( + client_universe: str, credentials: ga_credentials.Credentials + ) -> bool: + """Returns True iff the universe domains used by the client and credentials match. + + Args: + client_universe (str): The universe domain configured via the client options. + credentials (ga_credentials.Credentials): The credentials being used in the client. + + Returns: + bool: True iff client_universe matches the universe in credentials. + + Raises: + ValueError: when client_universe does not match the universe in credentials. + """ + + default_universe = EvaluationServiceClient._DEFAULT_UNIVERSE + credentials_universe = getattr(credentials, "universe_domain", default_universe) + + if client_universe != credentials_universe: + raise ValueError( + "The configured universe domain " + f"({client_universe}) does not match the universe domain " + f"found in the credentials ({credentials_universe}). " + "If you haven't configured the universe domain explicitly, " + f"`{default_universe}` is the default." + ) + return True + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + self._is_universe_domain_valid = ( + self._is_universe_domain_valid + or EvaluationServiceClient._compare_universes( + self.universe_domain, self.transport._credentials + ) + ) + return self._is_universe_domain_valid + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[ + str, + EvaluationServiceTransport, + Callable[..., EvaluationServiceTransport], + ] + ] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the evaluation service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,EvaluationServiceTransport,Callable[..., EvaluationServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the EvaluationServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) + + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = EvaluationServiceClient._read_environment_variables() + self._client_cert_source = EvaluationServiceClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert + ) + self._universe_domain = EvaluationServiceClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False + + api_key_value = getattr(self._client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + transport_provided = isinstance(transport, EvaluationServiceTransport) + if transport_provided: + # transport is a EvaluationServiceTransport instance. + if credentials or self._client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if self._client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = cast(EvaluationServiceTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = ( + self._api_endpoint + or EvaluationServiceClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + ) + + if not transport_provided: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + transport_init: Union[ + Type[EvaluationServiceTransport], + Callable[..., EvaluationServiceTransport], + ] = ( + type(self).get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., EvaluationServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( + credentials=credentials, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=self._client_options.api_audience, + ) + + def evaluate_instances( + self, + request: Optional[ + Union[evaluation_service.EvaluateInstancesRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> evaluation_service.EvaluateInstancesResponse: + r"""Evaluates instances based on a given metric. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_evaluate_instances(): + # Create a client + client = aiplatform_v1.EvaluationServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.EvaluateInstancesRequest( + location="location_value", + ) + + # Make the request + response = client.evaluate_instances(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1.types.EvaluateInstancesRequest, dict]): + The request object. Request message for + EvaluationService.EvaluateInstances. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.EvaluateInstancesResponse: + Response message for + EvaluationService.EvaluateInstances. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, evaluation_service.EvaluateInstancesRequest): + request = evaluation_service.EvaluateInstancesRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.evaluate_instances] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("location", request.location),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "EvaluationServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.list_operations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.get_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_operation( + self, + request: Optional[operations_pb2.DeleteOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a long-running operation. + + This method indicates that the client is no longer interested + in the operation result. It does not cancel the operation. + If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.DeleteOperationRequest`): + The request object. Request message for + `DeleteOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.DeleteOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.delete_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def cancel_operation( + self, + request: Optional[operations_pb2.CancelOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Starts asynchronous cancellation on a long-running operation. + + The server makes a best effort to cancel the operation, but success + is not guaranteed. If the server doesn't support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.CancelOperationRequest`): + The request object. Request message for + `CancelOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + None + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.CancelOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.cancel_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def wait_operation( + self, + request: Optional[operations_pb2.WaitOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + r"""Waits until the specified long-running operation is done or reaches at most + a specified timeout, returning the latest state. + + If the operation is already done, the latest state is immediately returned. + If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC + timeout is used. If the server does not support this method, it returns + `google.rpc.Code.UNIMPLEMENTED`. + + Args: + request (:class:`~.operations_pb2.WaitOperationRequest`): + The request object. Request message for + `WaitOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.WaitOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.wait_operation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def set_iam_policy( + self, + request: Optional[iam_policy_pb2.SetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Sets the IAM access control policy on the specified function. + + Replaces any existing policy. + + Args: + request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`): + The request object. Request message for `SetIamPolicy` + method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.SetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.set_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_iam_policy( + self, + request: Optional[iam_policy_pb2.GetIamPolicyRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + r"""Gets the IAM access control policy for a function. + + Returns an empty policy if the function exists and does not have a + policy set. + + Args: + request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`): + The request object. Request message for `GetIamPolicy` + method. + retry (google.api_core.retry.Retry): Designation of what errors, if + any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.policy_pb2.Policy: + Defines an Identity and Access Management (IAM) policy. + It is used to specify access control policies for Cloud + Platform resources. + A ``Policy`` is a collection of ``bindings``. A + ``binding`` binds one or more ``members`` to a single + ``role``. Members can be user accounts, service + accounts, Google groups, and domains (such as G Suite). + A ``role`` is a named list of permissions (defined by + IAM or configured by users). A ``binding`` can + optionally specify a ``condition``, which is a logic + expression that further constrains the role binding + based on attributes about the request and/or target + resource. + + **JSON Example** + + :: + + { + "bindings": [ + { + "role": "roles/resourcemanager.organizationAdmin", + "members": [ + "user:mike@example.com", + "group:admins@example.com", + "domain:google.com", + "serviceAccount:my-project-id@appspot.gserviceaccount.com" + ] + }, + { + "role": "roles/resourcemanager.organizationViewer", + "members": ["user:eve@example.com"], + "condition": { + "title": "expirable access", + "description": "Does not grant access after Sep 2020", + "expression": "request.time < + timestamp('2020-10-01T00:00:00.000Z')", + } + } + ] + } + + **YAML Example** + + :: + + bindings: + - members: + - user:mike@example.com + - group:admins@example.com + - domain:google.com + - serviceAccount:my-project-id@appspot.gserviceaccount.com + role: roles/resourcemanager.organizationAdmin + - members: + - user:eve@example.com + role: roles/resourcemanager.organizationViewer + condition: + title: expirable access + description: Does not grant access after Sep 2020 + expression: request.time < timestamp('2020-10-01T00:00:00.000Z') + + For a description of IAM and its features, see the `IAM + developer's + guide `__. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.GetIamPolicyRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.get_iam_policy, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def test_iam_permissions( + self, + request: Optional[iam_policy_pb2.TestIamPermissionsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> iam_policy_pb2.TestIamPermissionsResponse: + r"""Tests the specified IAM permissions against the IAM access control + policy for a function. + + If the function does not exist, this will return an empty set + of permissions, not a NOT_FOUND error. + + Args: + request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`): + The request object. Request message for + `TestIamPermissions` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.iam_policy_pb2.TestIamPermissionsResponse: + Response message for ``TestIamPermissions`` method. + """ + # Create or coerce a protobuf request object. + + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = iam_policy_pb2.TestIamPermissionsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.test_iam_permissions, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_location( + self, + request: Optional[locations_pb2.GetLocationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.Location: + r"""Gets information about a location. + + Args: + request (:class:`~.location_pb2.GetLocationRequest`): + The request object. Request message for + `GetLocation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.Location: + Location object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.GetLocationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.get_location, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_locations( + self, + request: Optional[locations_pb2.ListLocationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.ListLocationsResponse: + r"""Lists information about the supported locations for this service. + + Args: + request (:class:`~.location_pb2.ListLocationsRequest`): + The request object. Request message for + `ListLocations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + Returns: + ~.location_pb2.ListLocationsResponse: + Response message for ``ListLocations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = locations_pb2.ListLocationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method.wrap_method( + self._transport.list_locations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("EvaluationServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/evaluation_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/evaluation_service/transports/__init__.py new file mode 100644 index 0000000000..60197b61ce --- /dev/null +++ b/google/cloud/aiplatform_v1/services/evaluation_service/transports/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import EvaluationServiceTransport +from .grpc import EvaluationServiceGrpcTransport +from .grpc_asyncio import EvaluationServiceGrpcAsyncIOTransport +from .rest import EvaluationServiceRestTransport +from .rest import EvaluationServiceRestInterceptor + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[EvaluationServiceTransport]] +_transport_registry["grpc"] = EvaluationServiceGrpcTransport +_transport_registry["grpc_asyncio"] = EvaluationServiceGrpcAsyncIOTransport +_transport_registry["rest"] = EvaluationServiceRestTransport + +__all__ = ( + "EvaluationServiceTransport", + "EvaluationServiceGrpcTransport", + "EvaluationServiceGrpcAsyncIOTransport", + "EvaluationServiceRestTransport", + "EvaluationServiceRestInterceptor", +) diff --git a/google/cloud/aiplatform_v1/services/evaluation_service/transports/base.py b/google/cloud/aiplatform_v1/services/evaluation_service/transports/base.py new file mode 100644 index 0000000000..46b78f9107 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/evaluation_service/transports/base.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +from google.cloud.aiplatform_v1 import gapic_version as package_version + +import google.auth # type: ignore +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.cloud.aiplatform_v1.types import evaluation_service +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class EvaluationServiceTransport(abc.ABC): + """Abstract transport class for EvaluationService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + DEFAULT_HOST: str = "aiplatform.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'aiplatform.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None and not self._ignore_credentials: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + @property + def host(self): + return self._host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.evaluate_instances: gapic_v1.method.wrap_method( + self.evaluate_instances, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def evaluate_instances( + self, + ) -> Callable[ + [evaluation_service.EvaluateInstancesRequest], + Union[ + evaluation_service.EvaluateInstancesResponse, + Awaitable[evaluation_service.EvaluateInstancesResponse], + ], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def cancel_operation( + self, + ) -> Callable[[operations_pb2.CancelOperationRequest], None,]: + raise NotImplementedError() + + @property + def delete_operation( + self, + ) -> Callable[[operations_pb2.DeleteOperationRequest], None,]: + raise NotImplementedError() + + @property + def wait_operation( + self, + ) -> Callable[ + [operations_pb2.WaitOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def set_iam_policy( + self, + ) -> Callable[ + [iam_policy_pb2.SetIamPolicyRequest], + Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]], + ]: + raise NotImplementedError() + + @property + def get_iam_policy( + self, + ) -> Callable[ + [iam_policy_pb2.GetIamPolicyRequest], + Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]], + ]: + raise NotImplementedError() + + @property + def test_iam_permissions( + self, + ) -> Callable[ + [iam_policy_pb2.TestIamPermissionsRequest], + Union[ + iam_policy_pb2.TestIamPermissionsResponse, + Awaitable[iam_policy_pb2.TestIamPermissionsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_location( + self, + ) -> Callable[ + [locations_pb2.GetLocationRequest], + Union[locations_pb2.Location, Awaitable[locations_pb2.Location]], + ]: + raise NotImplementedError() + + @property + def list_locations( + self, + ) -> Callable[ + [locations_pb2.ListLocationsRequest], + Union[ + locations_pb2.ListLocationsResponse, + Awaitable[locations_pb2.ListLocationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("EvaluationServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/evaluation_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/evaluation_service/transports/grpc.py new file mode 100644 index 0000000000..d9c40f3c91 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/evaluation_service/transports/grpc.py @@ -0,0 +1,482 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import grpc_helpers +from google.api_core import gapic_v1 +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1.types import evaluation_service +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from .base import EvaluationServiceTransport, DEFAULT_CLIENT_INFO + + +class EvaluationServiceGrpcTransport(EvaluationServiceTransport): + """gRPC backend transport for EvaluationService. + + Vertex AI Online Evaluation Service. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'aiplatform.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, grpc.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def evaluate_instances( + self, + ) -> Callable[ + [evaluation_service.EvaluateInstancesRequest], + evaluation_service.EvaluateInstancesResponse, + ]: + r"""Return a callable for the evaluate instances method over gRPC. + + Evaluates instances based on a given metric. + + Returns: + Callable[[~.EvaluateInstancesRequest], + ~.EvaluateInstancesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "evaluate_instances" not in self._stubs: + self._stubs["evaluate_instances"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EvaluationService/EvaluateInstances", + request_serializer=evaluation_service.EvaluateInstancesRequest.serialize, + response_deserializer=evaluation_service.EvaluateInstancesResponse.deserialize, + ) + return self._stubs["evaluate_instances"] + + def close(self): + self.grpc_channel.close() + + @property + def delete_operation( + self, + ) -> Callable[[operations_pb2.DeleteOperationRequest], None]: + r"""Return a callable for the delete_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["delete_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/DeleteOperation", + request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["delete_operation"] + + @property + def cancel_operation( + self, + ) -> Callable[[operations_pb2.CancelOperationRequest], None]: + r"""Return a callable for the cancel_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_operation" not in self._stubs: + self._stubs["cancel_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/CancelOperation", + request_serializer=operations_pb2.CancelOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["cancel_operation"] + + @property + def wait_operation( + self, + ) -> Callable[[operations_pb2.WaitOperationRequest], None]: + r"""Return a callable for the wait_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["wait_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/WaitOperation", + request_serializer=operations_pb2.WaitOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["wait_operation"] + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def list_locations( + self, + ) -> Callable[ + [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse + ]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_locations" not in self._stubs: + self._stubs["list_locations"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/ListLocations", + request_serializer=locations_pb2.ListLocationsRequest.SerializeToString, + response_deserializer=locations_pb2.ListLocationsResponse.FromString, + ) + return self._stubs["list_locations"] + + @property + def get_location( + self, + ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_location" not in self._stubs: + self._stubs["get_location"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/GetLocation", + request_serializer=locations_pb2.GetLocationRequest.SerializeToString, + response_deserializer=locations_pb2.Location.FromString, + ) + return self._stubs["get_location"] + + @property + def set_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the set iam policy method over gRPC. + Sets the IAM access control policy on the specified + function. Replaces any existing policy. + Returns: + Callable[[~.SetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "set_iam_policy" not in self._stubs: + self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/SetIamPolicy", + request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["set_iam_policy"] + + @property + def get_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the get iam policy method over gRPC. + Gets the IAM access control policy for a function. + Returns an empty policy if the function exists and does + not have a policy set. + Returns: + Callable[[~.GetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_iam_policy" not in self._stubs: + self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/GetIamPolicy", + request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["get_iam_policy"] + + @property + def test_iam_permissions( + self, + ) -> Callable[ + [iam_policy_pb2.TestIamPermissionsRequest], + iam_policy_pb2.TestIamPermissionsResponse, + ]: + r"""Return a callable for the test iam permissions method over gRPC. + Tests the specified permissions against the IAM access control + policy for a function. If the function does not exist, this will + return an empty set of permissions, not a NOT_FOUND error. + Returns: + Callable[[~.TestIamPermissionsRequest], + ~.TestIamPermissionsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "test_iam_permissions" not in self._stubs: + self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/TestIamPermissions", + request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString, + response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString, + ) + return self._stubs["test_iam_permissions"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("EvaluationServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/evaluation_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/evaluation_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..4a0cd27584 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/evaluation_service/transports/grpc_asyncio.py @@ -0,0 +1,492 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.api_core import exceptions as core_exceptions +from google.api_core import retry_async as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1.types import evaluation_service +from google.cloud.location import locations_pb2 # type: ignore +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from .base import EvaluationServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import EvaluationServiceGrpcTransport + + +class EvaluationServiceGrpcAsyncIOTransport(EvaluationServiceTransport): + """gRPC AsyncIO backend transport for EvaluationService. + + Vertex AI Online Evaluation Service. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'aiplatform.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, aio.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def evaluate_instances( + self, + ) -> Callable[ + [evaluation_service.EvaluateInstancesRequest], + Awaitable[evaluation_service.EvaluateInstancesResponse], + ]: + r"""Return a callable for the evaluate instances method over gRPC. + + Evaluates instances based on a given metric. + + Returns: + Callable[[~.EvaluateInstancesRequest], + Awaitable[~.EvaluateInstancesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "evaluate_instances" not in self._stubs: + self._stubs["evaluate_instances"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EvaluationService/EvaluateInstances", + request_serializer=evaluation_service.EvaluateInstancesRequest.serialize, + response_deserializer=evaluation_service.EvaluateInstancesResponse.deserialize, + ) + return self._stubs["evaluate_instances"] + + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.evaluate_instances: gapic_v1.method_async.wrap_method( + self.evaluate_instances, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + return self.grpc_channel.close() + + @property + def delete_operation( + self, + ) -> Callable[[operations_pb2.DeleteOperationRequest], None]: + r"""Return a callable for the delete_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["delete_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/DeleteOperation", + request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["delete_operation"] + + @property + def cancel_operation( + self, + ) -> Callable[[operations_pb2.CancelOperationRequest], None]: + r"""Return a callable for the cancel_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_operation" not in self._stubs: + self._stubs["cancel_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/CancelOperation", + request_serializer=operations_pb2.CancelOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["cancel_operation"] + + @property + def wait_operation( + self, + ) -> Callable[[operations_pb2.WaitOperationRequest], None]: + r"""Return a callable for the wait_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_operation" not in self._stubs: + self._stubs["wait_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/WaitOperation", + request_serializer=operations_pb2.WaitOperationRequest.SerializeToString, + response_deserializer=None, + ) + return self._stubs["wait_operation"] + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self.grpc_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def list_locations( + self, + ) -> Callable[ + [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse + ]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_locations" not in self._stubs: + self._stubs["list_locations"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/ListLocations", + request_serializer=locations_pb2.ListLocationsRequest.SerializeToString, + response_deserializer=locations_pb2.ListLocationsResponse.FromString, + ) + return self._stubs["list_locations"] + + @property + def get_location( + self, + ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]: + r"""Return a callable for the list locations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_location" not in self._stubs: + self._stubs["get_location"] = self.grpc_channel.unary_unary( + "/google.cloud.location.Locations/GetLocation", + request_serializer=locations_pb2.GetLocationRequest.SerializeToString, + response_deserializer=locations_pb2.Location.FromString, + ) + return self._stubs["get_location"] + + @property + def set_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the set iam policy method over gRPC. + Sets the IAM access control policy on the specified + function. Replaces any existing policy. + Returns: + Callable[[~.SetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "set_iam_policy" not in self._stubs: + self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/SetIamPolicy", + request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["set_iam_policy"] + + @property + def get_iam_policy( + self, + ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]: + r"""Return a callable for the get iam policy method over gRPC. + Gets the IAM access control policy for a function. + Returns an empty policy if the function exists and does + not have a policy set. + Returns: + Callable[[~.GetIamPolicyRequest], + ~.Policy]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_iam_policy" not in self._stubs: + self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/GetIamPolicy", + request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString, + response_deserializer=policy_pb2.Policy.FromString, + ) + return self._stubs["get_iam_policy"] + + @property + def test_iam_permissions( + self, + ) -> Callable[ + [iam_policy_pb2.TestIamPermissionsRequest], + iam_policy_pb2.TestIamPermissionsResponse, + ]: + r"""Return a callable for the test iam permissions method over gRPC. + Tests the specified permissions against the IAM access control + policy for a function. If the function does not exist, this will + return an empty set of permissions, not a NOT_FOUND error. + Returns: + Callable[[~.TestIamPermissionsRequest], + ~.TestIamPermissionsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "test_iam_permissions" not in self._stubs: + self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary( + "/google.iam.v1.IAMPolicy/TestIamPermissions", + request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString, + response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString, + ) + return self._stubs["test_iam_permissions"] + + +__all__ = ("EvaluationServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/evaluation_service/transports/rest.py b/google/cloud/aiplatform_v1/services/evaluation_service/transports/rest.py new file mode 100644 index 0000000000..74a2e41185 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/evaluation_service/transports/rest.py @@ -0,0 +1,3137 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.cloud.location import locations_pb2 # type: ignore +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + + +from google.cloud.aiplatform_v1.types import evaluation_service +from google.longrunning import operations_pb2 # type: ignore + +from .base import ( + EvaluationServiceTransport, + DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO, +) + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class EvaluationServiceRestInterceptor: + """Interceptor for EvaluationService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the EvaluationServiceRestTransport. + + .. code-block:: python + class MyCustomEvaluationServiceInterceptor(EvaluationServiceRestInterceptor): + def pre_evaluate_instances(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_evaluate_instances(self, response): + logging.log(f"Received response: {response}") + return response + + transport = EvaluationServiceRestTransport(interceptor=MyCustomEvaluationServiceInterceptor()) + client = EvaluationServiceClient(transport=transport) + + + """ + + def pre_evaluate_instances( + self, + request: evaluation_service.EvaluateInstancesRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[evaluation_service.EvaluateInstancesRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for evaluate_instances + + Override in a subclass to manipulate the request or metadata + before they are sent to the EvaluationService server. + """ + return request, metadata + + def post_evaluate_instances( + self, response: evaluation_service.EvaluateInstancesResponse + ) -> evaluation_service.EvaluateInstancesResponse: + """Post-rpc interceptor for evaluate_instances + + Override in a subclass to manipulate the response + after it is returned by the EvaluationService server but before + it is returned to user code. + """ + return response + + def pre_get_location( + self, + request: locations_pb2.GetLocationRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[locations_pb2.GetLocationRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for get_location + + Override in a subclass to manipulate the request or metadata + before they are sent to the EvaluationService server. + """ + return request, metadata + + def post_get_location( + self, response: locations_pb2.Location + ) -> locations_pb2.Location: + """Post-rpc interceptor for get_location + + Override in a subclass to manipulate the response + after it is returned by the EvaluationService server but before + it is returned to user code. + """ + return response + + def pre_list_locations( + self, + request: locations_pb2.ListLocationsRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[locations_pb2.ListLocationsRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for list_locations + + Override in a subclass to manipulate the request or metadata + before they are sent to the EvaluationService server. + """ + return request, metadata + + def post_list_locations( + self, response: locations_pb2.ListLocationsResponse + ) -> locations_pb2.ListLocationsResponse: + """Post-rpc interceptor for list_locations + + Override in a subclass to manipulate the response + after it is returned by the EvaluationService server but before + it is returned to user code. + """ + return response + + def pre_get_iam_policy( + self, + request: iam_policy_pb2.GetIamPolicyRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[iam_policy_pb2.GetIamPolicyRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for get_iam_policy + + Override in a subclass to manipulate the request or metadata + before they are sent to the EvaluationService server. + """ + return request, metadata + + def post_get_iam_policy(self, response: policy_pb2.Policy) -> policy_pb2.Policy: + """Post-rpc interceptor for get_iam_policy + + Override in a subclass to manipulate the response + after it is returned by the EvaluationService server but before + it is returned to user code. + """ + return response + + def pre_set_iam_policy( + self, + request: iam_policy_pb2.SetIamPolicyRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[iam_policy_pb2.SetIamPolicyRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for set_iam_policy + + Override in a subclass to manipulate the request or metadata + before they are sent to the EvaluationService server. + """ + return request, metadata + + def post_set_iam_policy(self, response: policy_pb2.Policy) -> policy_pb2.Policy: + """Post-rpc interceptor for set_iam_policy + + Override in a subclass to manipulate the response + after it is returned by the EvaluationService server but before + it is returned to user code. + """ + return response + + def pre_test_iam_permissions( + self, + request: iam_policy_pb2.TestIamPermissionsRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[iam_policy_pb2.TestIamPermissionsRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for test_iam_permissions + + Override in a subclass to manipulate the request or metadata + before they are sent to the EvaluationService server. + """ + return request, metadata + + def post_test_iam_permissions( + self, response: iam_policy_pb2.TestIamPermissionsResponse + ) -> iam_policy_pb2.TestIamPermissionsResponse: + """Post-rpc interceptor for test_iam_permissions + + Override in a subclass to manipulate the response + after it is returned by the EvaluationService server but before + it is returned to user code. + """ + return response + + def pre_cancel_operation( + self, + request: operations_pb2.CancelOperationRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[operations_pb2.CancelOperationRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for cancel_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the EvaluationService server. + """ + return request, metadata + + def post_cancel_operation(self, response: None) -> None: + """Post-rpc interceptor for cancel_operation + + Override in a subclass to manipulate the response + after it is returned by the EvaluationService server but before + it is returned to user code. + """ + return response + + def pre_delete_operation( + self, + request: operations_pb2.DeleteOperationRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[operations_pb2.DeleteOperationRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for delete_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the EvaluationService server. + """ + return request, metadata + + def post_delete_operation(self, response: None) -> None: + """Post-rpc interceptor for delete_operation + + Override in a subclass to manipulate the response + after it is returned by the EvaluationService server but before + it is returned to user code. + """ + return response + + def pre_get_operation( + self, + request: operations_pb2.GetOperationRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[operations_pb2.GetOperationRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for get_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the EvaluationService server. + """ + return request, metadata + + def post_get_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for get_operation + + Override in a subclass to manipulate the response + after it is returned by the EvaluationService server but before + it is returned to user code. + """ + return response + + def pre_list_operations( + self, + request: operations_pb2.ListOperationsRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[operations_pb2.ListOperationsRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for list_operations + + Override in a subclass to manipulate the request or metadata + before they are sent to the EvaluationService server. + """ + return request, metadata + + def post_list_operations( + self, response: operations_pb2.ListOperationsResponse + ) -> operations_pb2.ListOperationsResponse: + """Post-rpc interceptor for list_operations + + Override in a subclass to manipulate the response + after it is returned by the EvaluationService server but before + it is returned to user code. + """ + return response + + def pre_wait_operation( + self, + request: operations_pb2.WaitOperationRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[operations_pb2.WaitOperationRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for wait_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the EvaluationService server. + """ + return request, metadata + + def post_wait_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for wait_operation + + Override in a subclass to manipulate the response + after it is returned by the EvaluationService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class EvaluationServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: EvaluationServiceRestInterceptor + + +class EvaluationServiceRestTransport(EvaluationServiceTransport): + """REST backend transport for EvaluationService. + + Vertex AI Online Evaluation Service. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[EvaluationServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'aiplatform.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or EvaluationServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _EvaluateInstances(EvaluationServiceRestStub): + def __hash__(self): + return hash("EvaluateInstances") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: evaluation_service.EvaluateInstancesRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> evaluation_service.EvaluateInstancesResponse: + r"""Call the evaluate instances method over HTTP. + + Args: + request (~.evaluation_service.EvaluateInstancesRequest): + The request object. Request message for + EvaluationService.EvaluateInstances. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.evaluation_service.EvaluateInstancesResponse: + Response message for + EvaluationService.EvaluateInstances. + + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{location=projects/*/locations/*}:evaluateInstances", + "body": "*", + }, + ] + request, metadata = self._interceptor.pre_evaluate_instances( + request, metadata + ) + pb_request = evaluation_service.EvaluateInstancesRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = evaluation_service.EvaluateInstancesResponse() + pb_resp = evaluation_service.EvaluateInstancesResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_evaluate_instances(resp) + return resp + + @property + def evaluate_instances( + self, + ) -> Callable[ + [evaluation_service.EvaluateInstancesRequest], + evaluation_service.EvaluateInstancesResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._EvaluateInstances(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_location(self): + return self._GetLocation(self._session, self._host, self._interceptor) # type: ignore + + class _GetLocation(EvaluationServiceRestStub): + def __call__( + self, + request: locations_pb2.GetLocationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.Location: + + r"""Call the get location method over HTTP. + + Args: + request (locations_pb2.GetLocationRequest): + The request object for GetLocation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + locations_pb2.Location: Response from GetLocation method. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*}", + }, + ] + + request, metadata = self._interceptor.pre_get_location(request, metadata) + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + resp = locations_pb2.Location() + resp = json_format.Parse(response.content.decode("utf-8"), resp) + resp = self._interceptor.post_get_location(resp) + return resp + + @property + def list_locations(self): + return self._ListLocations(self._session, self._host, self._interceptor) # type: ignore + + class _ListLocations(EvaluationServiceRestStub): + def __call__( + self, + request: locations_pb2.ListLocationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> locations_pb2.ListLocationsResponse: + + r"""Call the list locations method over HTTP. + + Args: + request (locations_pb2.ListLocationsRequest): + The request object for ListLocations method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + locations_pb2.ListLocationsResponse: Response from ListLocations method. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/ui/{name=projects/*}/locations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*}/locations", + }, + ] + + request, metadata = self._interceptor.pre_list_locations(request, metadata) + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + resp = locations_pb2.ListLocationsResponse() + resp = json_format.Parse(response.content.decode("utf-8"), resp) + resp = self._interceptor.post_list_locations(resp) + return resp + + @property + def get_iam_policy(self): + return self._GetIamPolicy(self._session, self._host, self._interceptor) # type: ignore + + class _GetIamPolicy(EvaluationServiceRestStub): + def __call__( + self, + request: iam_policy_pb2.GetIamPolicyRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + + r"""Call the get iam policy method over HTTP. + + Args: + request (iam_policy_pb2.GetIamPolicyRequest): + The request object for GetIamPolicy method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + policy_pb2.Policy: Response from GetIamPolicy method. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/featurestores/*}:getIamPolicy", + }, + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/featurestores/*/entityTypes/*}:getIamPolicy", + }, + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/models/*}:getIamPolicy", + }, + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/notebookRuntimeTemplates/*}:getIamPolicy", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featurestores/*}:getIamPolicy", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featurestores/*/entityTypes/*}:getIamPolicy", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/models/*}:getIamPolicy", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/endpoints/*}:getIamPolicy", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/notebookRuntimeTemplates/*}:getIamPolicy", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/publishers/*/models/*}:getIamPolicy", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featureOnlineStores/*}:getIamPolicy", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featureOnlineStores/*/featureViews/*}:getIamPolicy", + }, + ] + + request, metadata = self._interceptor.pre_get_iam_policy(request, metadata) + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + resp = policy_pb2.Policy() + resp = json_format.Parse(response.content.decode("utf-8"), resp) + resp = self._interceptor.post_get_iam_policy(resp) + return resp + + @property + def set_iam_policy(self): + return self._SetIamPolicy(self._session, self._host, self._interceptor) # type: ignore + + class _SetIamPolicy(EvaluationServiceRestStub): + def __call__( + self, + request: iam_policy_pb2.SetIamPolicyRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> policy_pb2.Policy: + + r"""Call the set iam policy method over HTTP. + + Args: + request (iam_policy_pb2.SetIamPolicyRequest): + The request object for SetIamPolicy method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + policy_pb2.Policy: Response from SetIamPolicy method. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/featurestores/*}:setIamPolicy", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/featurestores/*/entityTypes/*}:setIamPolicy", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/models/*}:setIamPolicy", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/notebookRuntimeTemplates/*}:setIamPolicy", + "body": "*", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featurestores/*}:setIamPolicy", + "body": "*", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featurestores/*/entityTypes/*}:setIamPolicy", + "body": "*", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/models/*}:setIamPolicy", + "body": "*", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/endpoints/*}:setIamPolicy", + "body": "*", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/notebookRuntimeTemplates/*}:setIamPolicy", + "body": "*", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featureOnlineStores/*}:setIamPolicy", + "body": "*", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featureOnlineStores/*/featureViews/*}:setIamPolicy", + "body": "*", + }, + ] + + request, metadata = self._interceptor.pre_set_iam_policy(request, metadata) + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + body = json.dumps(transcoded_request["body"]) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + resp = policy_pb2.Policy() + resp = json_format.Parse(response.content.decode("utf-8"), resp) + resp = self._interceptor.post_set_iam_policy(resp) + return resp + + @property + def test_iam_permissions(self): + return self._TestIamPermissions(self._session, self._host, self._interceptor) # type: ignore + + class _TestIamPermissions(EvaluationServiceRestStub): + def __call__( + self, + request: iam_policy_pb2.TestIamPermissionsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> iam_policy_pb2.TestIamPermissionsResponse: + + r"""Call the test iam permissions method over HTTP. + + Args: + request (iam_policy_pb2.TestIamPermissionsRequest): + The request object for TestIamPermissions method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + iam_policy_pb2.TestIamPermissionsResponse: Response from TestIamPermissions method. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/featurestores/*}:testIamPermissions", + }, + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/featurestores/*/entityTypes/*}:testIamPermissions", + }, + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/models/*}:testIamPermissions", + }, + { + "method": "post", + "uri": "/v1/{resource=projects/*/locations/*/notebookRuntimeTemplates/*}:testIamPermissions", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featurestores/*}:testIamPermissions", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featurestores/*/entityTypes/*}:testIamPermissions", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/models/*}:testIamPermissions", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/endpoints/*}:testIamPermissions", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/notebookRuntimeTemplates/*}:testIamPermissions", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featureOnlineStores/*}:testIamPermissions", + }, + { + "method": "post", + "uri": "/ui/{resource=projects/*/locations/*/featureOnlineStores/*/featureViews/*}:testIamPermissions", + }, + ] + + request, metadata = self._interceptor.pre_test_iam_permissions( + request, metadata + ) + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + resp = iam_policy_pb2.TestIamPermissionsResponse() + resp = json_format.Parse(response.content.decode("utf-8"), resp) + resp = self._interceptor.post_test_iam_permissions(resp) + return resp + + @property + def cancel_operation(self): + return self._CancelOperation(self._session, self._host, self._interceptor) # type: ignore + + class _CancelOperation(EvaluationServiceRestStub): + def __call__( + self, + request: operations_pb2.CancelOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + + r"""Call the cancel operation method over HTTP. + + Args: + request (operations_pb2.CancelOperationRequest): + The request object for CancelOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/agents/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/apps/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/dataItems/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/savedQueries/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/annotationSpecs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/dataItems/*/annotations/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/deploymentResourcePools/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/edgeDevices/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/endpoints/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/extensionControllers/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/extensions/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/entityTypes/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/entityTypes/*/features/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/customJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/dataLabelingJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/tuningJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/indexes/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/indexEndpoints/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/artifacts/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/contexts/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/executions/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/modelDeploymentMonitoringJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/modelMonitors/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/migratableResources/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/models/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/models/*/evaluations/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/notebookExecutionJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/notebookRuntimes/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/notebookRuntimeTemplates/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/persistentResources/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/studies/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/studies/*/trials/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/trainingPipelines/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/pipelineJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/schedules/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/specialistPools/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/timeSeries/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/dataItems/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/savedQueries/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/annotationSpecs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/dataItems/*/annotations/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/deploymentResourcePools/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/endpoints/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/entityTypes/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/entityTypes/*/features/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/customJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/dataLabelingJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/tuningJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/indexes/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/indexEndpoints/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/artifacts/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/contexts/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/executions/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/modelDeploymentMonitoringJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/migratableResources/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/models/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/models/*/evaluations/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/notebookExecutionJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/notebookRuntimes/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/notebookRuntimeTemplates/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/persistentResources/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/studies/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/studies/*/trials/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/trainingPipelines/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/pipelineJobs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/schedules/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/specialistPools/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/operations/*}:cancel", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/timeSeries/*/operations/*}:cancel", + }, + ] + + request, metadata = self._interceptor.pre_cancel_operation( + request, metadata + ) + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + return self._interceptor.post_cancel_operation(None) + + @property + def delete_operation(self): + return self._DeleteOperation(self._session, self._host, self._interceptor) # type: ignore + + class _DeleteOperation(EvaluationServiceRestStub): + def __call__( + self, + request: operations_pb2.DeleteOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + + r"""Call the delete operation method over HTTP. + + Args: + request (operations_pb2.DeleteOperationRequest): + The request object for DeleteOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/agents/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/apps/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/dataItems/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/savedQueries/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/annotationSpecs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/dataItems/*/annotations/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/deploymentResourcePools/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/edgeDevices/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/endpoints/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/extensionControllers/*}/operations", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/extensions/*}/operations", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/entityTypes/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/entityTypes/*/features/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/customJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/dataLabelingJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/indexes/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/indexEndpoints/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/artifacts/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/contexts/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/executions/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/modelDeploymentMonitoringJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/modelMonitors/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/migratableResources/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/models/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/models/*/evaluations/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/notebookExecutionJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/notebookRuntimes/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/notebookRuntimeTemplates/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/persistentResources/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/studies/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/studies/*/trials/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/trainingPipelines/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/pipelineJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/schedules/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/specialistPools/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/timeSeries/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/featureOnlineStores/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/featureGroups/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/featureGroups/*/features/*/operations/*}", + }, + { + "method": "delete", + "uri": "/ui/{name=projects/*/locations/*/featureOnlineStores/*/featureViews/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/dataItems/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/savedQueries/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/annotationSpecs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/dataItems/*/annotations/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/deploymentResourcePools/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/endpoints/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/entityTypes/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/entityTypes/*/features/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/customJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/dataLabelingJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/indexes/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/indexEndpoints/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/artifacts/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/contexts/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/executions/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/modelDeploymentMonitoringJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/migratableResources/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/models/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/models/*/evaluations/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/notebookExecutionJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/notebookRuntimes/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/notebookRuntimeTemplates/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/studies/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/studies/*/trials/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/trainingPipelines/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/persistentResources/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/pipelineJobs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/schedules/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/specialistPools/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/timeSeries/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/featureOnlineStores/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/featureGroups/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/featureGroups/*/features/*/operations/*}", + }, + { + "method": "delete", + "uri": "/v1/{name=projects/*/locations/*/featureOnlineStores/*/featureViews/*/operations/*}", + }, + ] + + request, metadata = self._interceptor.pre_delete_operation( + request, metadata + ) + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + return self._interceptor.post_delete_operation(None) + + @property + def get_operation(self): + return self._GetOperation(self._session, self._host, self._interceptor) # type: ignore + + class _GetOperation(EvaluationServiceRestStub): + def __call__( + self, + request: operations_pb2.GetOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + + r"""Call the get operation method over HTTP. + + Args: + request (operations_pb2.GetOperationRequest): + The request object for GetOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + operations_pb2.Operation: Response from GetOperation method. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/agents/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/apps/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/dataItems/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/savedQueries/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/annotationSpecs/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/dataItems/*/annotations/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/deploymentResourcePools/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/edgeDeploymentJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/edgeDevices/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/endpoints/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/extensionControllers/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/extensions/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/entityTypes/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/entityTypes/*/features/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/customJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/dataLabelingJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/tuningJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/indexes/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/indexEndpoints/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/artifacts/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/contexts/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/executions/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/modelDeploymentMonitoringJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/modelMonitors/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/migratableResources/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/models/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/models/*/evaluations/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/notebookExecutionJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/notebookRuntimes/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/notebookRuntimeTemplates/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/persistentResources/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/studies/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/studies/*/trials/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/trainingPipelines/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/pipelineJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/schedules/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/specialistPools/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/timeSeries/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featureOnlineStores/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featureOnlineStores/*/featureViews/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featureGroups/*/operations/*}", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featureGroups/*/features/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/dataItems/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/savedQueries/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/annotationSpecs/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/dataItems/*/annotations/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/deploymentResourcePools/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/endpoints/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/entityTypes/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/entityTypes/*/features/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/customJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/dataLabelingJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/tuningJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/indexes/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/indexEndpoints/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/artifacts/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/contexts/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/executions/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/modelDeploymentMonitoringJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/migratableResources/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/models/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/models/*/evaluations/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/notebookExecutionJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/notebookRuntimes/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/notebookRuntimeTemplates/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/studies/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/studies/*/trials/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/trainingPipelines/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/persistentResources/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/pipelineJobs/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/schedules/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/specialistPools/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/timeSeries/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featureOnlineStores/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featureOnlineStores/*/featureViews/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featureGroups/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featureGroups/*/features/*/operations/*}", + }, + ] + + request, metadata = self._interceptor.pre_get_operation(request, metadata) + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + resp = operations_pb2.Operation() + resp = json_format.Parse(response.content.decode("utf-8"), resp) + resp = self._interceptor.post_get_operation(resp) + return resp + + @property + def list_operations(self): + return self._ListOperations(self._session, self._host, self._interceptor) # type: ignore + + class _ListOperations(EvaluationServiceRestStub): + def __call__( + self, + request: operations_pb2.ListOperationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.ListOperationsResponse: + + r"""Call the list operations method over HTTP. + + Args: + request (operations_pb2.ListOperationsRequest): + The request object for ListOperations method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + operations_pb2.ListOperationsResponse: Response from ListOperations method. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/agents/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/apps/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/datasets/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/dataItems/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/savedQueries/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/annotationSpecs/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/dataItems/*/annotations/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/deploymentResourcePools/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/edgeDevices/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/endpoints/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/extensionControllers/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/extensions/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/entityTypes/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/entityTypes/*/features/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/customJobs/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/dataLabelingJobs/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/hyperparameterTuningJobs/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/tuningJobs/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/indexes/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/indexEndpoints/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/artifacts/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/contexts/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/executions/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/modelDeploymentMonitoringJobs/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/modelMonitors/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/migratableResources/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/models/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/models/*/evaluations/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/notebookExecutionJobs/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/notebookRuntimes/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/notebookRuntimeTemplates/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/studies/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/studies/*/trials/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/trainingPipelines/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/persistentResources/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/pipelineJobs/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/schedules/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/specialistPools/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/timeSeries/*}/operations", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featureOnlineStores/*/operations/*}:wait", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featureOnlineStores/*/featureViews/*/operations/*}:wait", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featureGroups/*/operations/*}:wait", + }, + { + "method": "get", + "uri": "/ui/{name=projects/*/locations/*/featureGroups/*/features/*/operations/*}:wait", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/datasets/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/dataItems/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/savedQueries/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/annotationSpecs/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/dataItems/*/annotations/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/deploymentResourcePools/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/endpoints/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/entityTypes/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/entityTypes/*/features/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/customJobs/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/dataLabelingJobs/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/hyperparameterTuningJobs/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/tuningJobs/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/indexes/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/indexEndpoints/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/artifacts/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/contexts/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/executions/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/modelDeploymentMonitoringJobs/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/migratableResources/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/models/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/models/*/evaluations/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/notebookExecutionJobs/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/notebookRuntimes/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/notebookRuntimeTemplates/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/studies/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/studies/*/trials/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/trainingPipelines/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/persistentResources/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/pipelineJobs/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/schedules/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/specialistPools/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/timeSeries/*}/operations", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featureOnlineStores/*/operations/*}:wait", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featureOnlineStores/*/featureViews/*/operations/*}:wait", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featureGroups/*/operations/*}:wait", + }, + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/featureGroups/*/features/*/operations/*}:wait", + }, + ] + + request, metadata = self._interceptor.pre_list_operations(request, metadata) + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + resp = operations_pb2.ListOperationsResponse() + resp = json_format.Parse(response.content.decode("utf-8"), resp) + resp = self._interceptor.post_list_operations(resp) + return resp + + @property + def wait_operation(self): + return self._WaitOperation(self._session, self._host, self._interceptor) # type: ignore + + class _WaitOperation(EvaluationServiceRestStub): + def __call__( + self, + request: operations_pb2.WaitOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: + + r"""Call the wait operation method over HTTP. + + Args: + request (operations_pb2.WaitOperationRequest): + The request object for WaitOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + operations_pb2.Operation: Response from WaitOperation method. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/agents/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/apps/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/dataItems/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/savedQueries/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/annotationSpecs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/datasets/*/dataItems/*/annotations/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/deploymentResourcePools/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/edgeDevices/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/endpoints/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/extensionControllers/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/extensions/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/entityTypes/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/featurestores/*/entityTypes/*/features/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/customJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/dataLabelingJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/tuningJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/indexes/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/indexEndpoints/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/artifacts/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/contexts/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/metadataStores/*/executions/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/modelDeploymentMonitoringJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/modelMonitors/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/migratableResources/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/models/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/models/*/evaluations/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/notebookExecutionJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/notebookRuntimes/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/notebookRuntimeTemplates/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/studies/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/studies/*/trials/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/trainingPipelines/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/persistentResources/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/pipelineJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/schedules/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/specialistPools/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/timeSeries/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/featureOnlineStores/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/featureOnlineStores/*/featureViews/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/featureGroups/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/ui/{name=projects/*/locations/*/featureGroups/*/features/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/dataItems/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/savedQueries/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/annotationSpecs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/datasets/*/dataItems/*/annotations/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/deploymentResourcePools/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/endpoints/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/entityTypes/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/featurestores/*/entityTypes/*/features/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/customJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/dataLabelingJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/hyperparameterTuningJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/indexes/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/indexEndpoints/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/artifacts/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/contexts/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/metadataStores/*/executions/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/modelDeploymentMonitoringJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/migratableResources/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/models/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/models/*/evaluations/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/notebookExecutionJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/notebookRuntimes/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/notebookRuntimeTemplates/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/studies/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/studies/*/trials/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/trainingPipelines/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/persistentResources/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/pipelineJobs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/schedules/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/specialistPools/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/tensorboards/*/experiments/*/runs/*/timeSeries/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/featureOnlineStores/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/featureOnlineStores/*/featureViews/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/featureGroups/*/operations/*}:wait", + }, + { + "method": "post", + "uri": "/v1/{name=projects/*/locations/*/featureGroups/*/features/*/operations/*}:wait", + }, + ] + + request, metadata = self._interceptor.pre_wait_operation(request, metadata) + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + + # Send the request + headers = dict(metadata) + headers["Content-Type"] = "application/json" + + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + resp = operations_pb2.Operation() + resp = json_format.Parse(response.content.decode("utf-8"), resp) + resp = self._interceptor.post_wait_operation(resp) + return resp + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("EvaluationServiceRestTransport",) diff --git a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/async_client.py b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/async_client.py index a5a31c7dca..8c81aa05f8 100644 --- a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/async_client.py @@ -722,6 +722,8 @@ async def sample_list_feature_online_stores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1437,6 +1439,8 @@ async def sample_list_feature_views(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2086,6 +2090,8 @@ async def sample_list_feature_view_syncs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/client.py b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/client.py index 5ab91a3260..fa29be3860 100644 --- a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/client.py +++ b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/client.py @@ -1174,6 +1174,8 @@ def sample_list_feature_online_stores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1878,6 +1880,8 @@ def sample_list_feature_views(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2512,6 +2516,8 @@ def sample_list_feature_view_syncs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/pagers.py b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/pagers.py index ec74f1b967..0ceeaf3ff5 100644 --- a/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/feature_online_store_admin_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import feature_online_store from google.cloud.aiplatform_v1.types import feature_online_store_admin_service from google.cloud.aiplatform_v1.types import feature_view @@ -56,6 +69,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureOnlineStoresRequest, response: feature_online_store_admin_service.ListFeatureOnlineStoresResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -67,6 +82,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeatureOnlineStoresResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -75,6 +93,8 @@ def __init__( feature_online_store_admin_service.ListFeatureOnlineStoresRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -87,7 +107,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature_online_store.FeatureOnlineStore]: @@ -127,6 +152,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureOnlineStoresRequest, response: feature_online_store_admin_service.ListFeatureOnlineStoresResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -138,6 +165,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeatureOnlineStoresResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -146,6 +176,8 @@ def __init__( feature_online_store_admin_service.ListFeatureOnlineStoresRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -160,7 +192,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature_online_store.FeatureOnlineStore]: @@ -201,6 +238,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureViewsRequest, response: feature_online_store_admin_service.ListFeatureViewsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -212,6 +251,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeatureViewsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -220,6 +262,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -232,7 +276,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature_view.FeatureView]: @@ -269,6 +318,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureViewsRequest, response: feature_online_store_admin_service.ListFeatureViewsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -280,6 +331,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeatureViewsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -288,6 +342,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -300,7 +356,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature_view.FeatureView]: @@ -341,6 +402,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureViewSyncsRequest, response: feature_online_store_admin_service.ListFeatureViewSyncsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -352,6 +415,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeatureViewSyncsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -360,6 +426,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -372,7 +440,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature_view_sync.FeatureViewSync]: @@ -410,6 +483,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureViewSyncsRequest, response: feature_online_store_admin_service.ListFeatureViewSyncsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -421,6 +496,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeatureViewSyncsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -429,6 +507,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -441,7 +521,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature_view_sync.FeatureViewSync]: diff --git a/google/cloud/aiplatform_v1/services/feature_registry_service/async_client.py b/google/cloud/aiplatform_v1/services/feature_registry_service/async_client.py index 68ea8173c6..b59f6c7e2c 100644 --- a/google/cloud/aiplatform_v1/services/feature_registry_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/feature_registry_service/async_client.py @@ -663,6 +663,8 @@ async def sample_list_feature_groups(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1352,6 +1354,8 @@ async def sample_list_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/feature_registry_service/client.py b/google/cloud/aiplatform_v1/services/feature_registry_service/client.py index b315a0a05f..4487adc130 100644 --- a/google/cloud/aiplatform_v1/services/feature_registry_service/client.py +++ b/google/cloud/aiplatform_v1/services/feature_registry_service/client.py @@ -1095,6 +1095,8 @@ def sample_list_feature_groups(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1769,6 +1771,8 @@ def sample_list_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/feature_registry_service/pagers.py b/google/cloud/aiplatform_v1/services/feature_registry_service/pagers.py index 6aef71018f..1b301fa296 100644 --- a/google/cloud/aiplatform_v1/services/feature_registry_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/feature_registry_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import feature from google.cloud.aiplatform_v1.types import feature_group from google.cloud.aiplatform_v1.types import feature_registry_service @@ -54,6 +67,8 @@ def __init__( request: feature_registry_service.ListFeatureGroupsRequest, response: feature_registry_service.ListFeatureGroupsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -65,12 +80,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeatureGroupsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = feature_registry_service.ListFeatureGroupsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -81,7 +101,12 @@ def pages(self) -> Iterator[feature_registry_service.ListFeatureGroupsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature_group.FeatureGroup]: @@ -118,6 +143,8 @@ def __init__( request: feature_registry_service.ListFeatureGroupsRequest, response: feature_registry_service.ListFeatureGroupsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -129,12 +156,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeatureGroupsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = feature_registry_service.ListFeatureGroupsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -147,7 +179,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature_group.FeatureGroup]: @@ -186,6 +223,8 @@ def __init__( request: featurestore_service.ListFeaturesRequest, response: featurestore_service.ListFeaturesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -197,12 +236,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeaturesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -213,7 +257,12 @@ def pages(self) -> Iterator[featurestore_service.ListFeaturesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature.Feature]: @@ -248,6 +297,8 @@ def __init__( request: featurestore_service.ListFeaturesRequest, response: featurestore_service.ListFeaturesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -259,12 +310,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeaturesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -275,7 +331,12 @@ async def pages(self) -> AsyncIterator[featurestore_service.ListFeaturesResponse yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature.Feature]: diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py index a6fc5dd664..f5acd99130 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py @@ -670,6 +670,8 @@ async def sample_list_featurestores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1350,6 +1352,8 @@ async def sample_list_entity_types(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2176,6 +2180,8 @@ async def sample_list_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3229,6 +3235,8 @@ async def sample_search_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1/services/featurestore_service/client.py index 6a5f9217fd..24bc029510 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/client.py @@ -1122,6 +1122,8 @@ def sample_list_featurestores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1787,6 +1789,8 @@ def sample_list_entity_types(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2595,6 +2599,8 @@ def sample_list_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3629,6 +3635,8 @@ def sample_search_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/pagers.py b/google/cloud/aiplatform_v1/services/featurestore_service/pagers.py index 1db6424db4..8759fb8d55 100644 --- a/google/cloud/aiplatform_v1/services/featurestore_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/featurestore_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import entity_type from google.cloud.aiplatform_v1.types import feature from google.cloud.aiplatform_v1.types import featurestore @@ -54,6 +67,8 @@ def __init__( request: featurestore_service.ListFeaturestoresRequest, response: featurestore_service.ListFeaturestoresResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -65,12 +80,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeaturestoresResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturestoresRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -81,7 +101,12 @@ def pages(self) -> Iterator[featurestore_service.ListFeaturestoresResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[featurestore.Featurestore]: @@ -118,6 +143,8 @@ def __init__( request: featurestore_service.ListFeaturestoresRequest, response: featurestore_service.ListFeaturestoresResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -129,12 +156,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeaturestoresResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturestoresRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -147,7 +179,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[featurestore.Featurestore]: @@ -186,6 +223,8 @@ def __init__( request: featurestore_service.ListEntityTypesRequest, response: featurestore_service.ListEntityTypesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -197,12 +236,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListEntityTypesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListEntityTypesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -213,7 +257,12 @@ def pages(self) -> Iterator[featurestore_service.ListEntityTypesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[entity_type.EntityType]: @@ -248,6 +297,8 @@ def __init__( request: featurestore_service.ListEntityTypesRequest, response: featurestore_service.ListEntityTypesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -259,12 +310,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListEntityTypesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListEntityTypesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -277,7 +333,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[entity_type.EntityType]: @@ -316,6 +377,8 @@ def __init__( request: featurestore_service.ListFeaturesRequest, response: featurestore_service.ListFeaturesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -327,12 +390,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeaturesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -343,7 +411,12 @@ def pages(self) -> Iterator[featurestore_service.ListFeaturesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature.Feature]: @@ -378,6 +451,8 @@ def __init__( request: featurestore_service.ListFeaturesRequest, response: featurestore_service.ListFeaturesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -389,12 +464,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListFeaturesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -405,7 +485,12 @@ async def pages(self) -> AsyncIterator[featurestore_service.ListFeaturesResponse yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature.Feature]: @@ -444,6 +529,8 @@ def __init__( request: featurestore_service.SearchFeaturesRequest, response: featurestore_service.SearchFeaturesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -455,12 +542,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.SearchFeaturesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.SearchFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -471,7 +563,12 @@ def pages(self) -> Iterator[featurestore_service.SearchFeaturesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature.Feature]: @@ -506,6 +603,8 @@ def __init__( request: featurestore_service.SearchFeaturesRequest, response: featurestore_service.SearchFeaturesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -517,12 +616,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.SearchFeaturesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.SearchFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -533,7 +637,12 @@ async def pages(self) -> AsyncIterator[featurestore_service.SearchFeaturesRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature.Feature]: diff --git a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/async_client.py b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/async_client.py index d8281e49b1..2d0b1d7add 100644 --- a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/async_client.py @@ -623,6 +623,8 @@ async def sample_list_tuning_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/client.py b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/client.py index 2aae977135..1f3eae6484 100644 --- a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/client.py +++ b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/client.py @@ -1095,6 +1095,8 @@ def sample_list_tuning_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/pagers.py b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/pagers.py index 7f621423f6..4597735bb5 100644 --- a/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/gen_ai_tuning_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import genai_tuning_service from google.cloud.aiplatform_v1.types import tuning_job @@ -52,6 +65,8 @@ def __init__( request: genai_tuning_service.ListTuningJobsRequest, response: genai_tuning_service.ListTuningJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTuningJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = genai_tuning_service.ListTuningJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[genai_tuning_service.ListTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tuning_job.TuningJob]: @@ -114,6 +139,8 @@ def __init__( request: genai_tuning_service.ListTuningJobsRequest, response: genai_tuning_service.ListTuningJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -125,12 +152,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTuningJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = genai_tuning_service.ListTuningJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -141,7 +173,12 @@ async def pages(self) -> AsyncIterator[genai_tuning_service.ListTuningJobsRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tuning_job.TuningJob]: diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py index 158363ffd6..892a64cd76 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/async_client.py @@ -83,6 +83,10 @@ class IndexEndpointServiceAsyncClient: parse_index_endpoint_path = staticmethod( IndexEndpointServiceClient.parse_index_endpoint_path ) + reservation_path = staticmethod(IndexEndpointServiceClient.reservation_path) + parse_reservation_path = staticmethod( + IndexEndpointServiceClient.parse_reservation_path + ) common_billing_account_path = staticmethod( IndexEndpointServiceClient.common_billing_account_path ) @@ -643,6 +647,8 @@ async def sample_list_index_endpoints(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py index 03f7d45248..51004c47d2 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/client.py @@ -240,6 +240,28 @@ def parse_index_endpoint_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path( billing_account: str, @@ -1071,6 +1093,8 @@ def sample_list_index_endpoints(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/index_endpoint_service/pagers.py b/google/cloud/aiplatform_v1/services/index_endpoint_service/pagers.py index d6393df11a..2532ea671a 100644 --- a/google/cloud/aiplatform_v1/services/index_endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/index_endpoint_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import index_endpoint from google.cloud.aiplatform_v1.types import index_endpoint_service @@ -52,6 +65,8 @@ def __init__( request: index_endpoint_service.ListIndexEndpointsRequest, response: index_endpoint_service.ListIndexEndpointsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListIndexEndpointsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = index_endpoint_service.ListIndexEndpointsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[index_endpoint_service.ListIndexEndpointsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[index_endpoint.IndexEndpoint]: @@ -116,6 +141,8 @@ def __init__( request: index_endpoint_service.ListIndexEndpointsRequest, response: index_endpoint_service.ListIndexEndpointsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListIndexEndpointsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = index_endpoint_service.ListIndexEndpointsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[index_endpoint.IndexEndpoint]: diff --git a/google/cloud/aiplatform_v1/services/index_service/async_client.py b/google/cloud/aiplatform_v1/services/index_service/async_client.py index 68c2add4aa..a364052620 100644 --- a/google/cloud/aiplatform_v1/services/index_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/index_service/async_client.py @@ -630,6 +630,8 @@ async def sample_list_indexes(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/index_service/client.py b/google/cloud/aiplatform_v1/services/index_service/client.py index dd5abfd2f6..e9018601ec 100644 --- a/google/cloud/aiplatform_v1/services/index_service/client.py +++ b/google/cloud/aiplatform_v1/services/index_service/client.py @@ -1057,6 +1057,8 @@ def sample_list_indexes(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/index_service/pagers.py b/google/cloud/aiplatform_v1/services/index_service/pagers.py index 7ecff8d9fa..4cc04cabc3 100644 --- a/google/cloud/aiplatform_v1/services/index_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/index_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import index from google.cloud.aiplatform_v1.types import index_service @@ -52,6 +65,8 @@ def __init__( request: index_service.ListIndexesRequest, response: index_service.ListIndexesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListIndexesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = index_service.ListIndexesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[index_service.ListIndexesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[index.Index]: @@ -114,6 +139,8 @@ def __init__( request: index_service.ListIndexesRequest, response: index_service.ListIndexesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -125,12 +152,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListIndexesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = index_service.ListIndexesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -141,7 +173,12 @@ async def pages(self) -> AsyncIterator[index_service.ListIndexesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[index.Index]: diff --git a/google/cloud/aiplatform_v1/services/job_service/async_client.py b/google/cloud/aiplatform_v1/services/job_service/async_client.py index 527d2fdd41..b833d7cc2f 100644 --- a/google/cloud/aiplatform_v1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/job_service/async_client.py @@ -151,6 +151,8 @@ class JobServiceAsyncClient: parse_persistent_resource_path = staticmethod( JobServiceClient.parse_persistent_resource_path ) + reservation_path = staticmethod(JobServiceClient.reservation_path) + parse_reservation_path = staticmethod(JobServiceClient.parse_reservation_path) tensorboard_path = staticmethod(JobServiceClient.tensorboard_path) parse_tensorboard_path = staticmethod(JobServiceClient.parse_tensorboard_path) trial_path = staticmethod(JobServiceClient.trial_path) @@ -694,6 +696,8 @@ async def sample_list_custom_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1281,6 +1285,8 @@ async def sample_list_data_labeling_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1873,6 +1879,8 @@ async def sample_list_hyperparameter_tuning_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2462,6 +2470,8 @@ async def sample_list_nas_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2925,6 +2935,8 @@ async def sample_list_nas_trial_details(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3294,6 +3306,8 @@ async def sample_list_batch_prediction_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3813,6 +3827,8 @@ async def sample_search_model_deployment_monitoring_stats_anomalies(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4049,6 +4065,8 @@ async def sample_list_model_deployment_monitoring_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/job_service/client.py b/google/cloud/aiplatform_v1/services/job_service/client.py index 4a15f944cb..4dcbed90e4 100644 --- a/google/cloud/aiplatform_v1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1/services/job_service/client.py @@ -529,6 +529,28 @@ def parse_persistent_resource_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def tensorboard_path( project: str, @@ -1390,6 +1412,8 @@ def sample_list_custom_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1962,6 +1986,8 @@ def sample_list_data_labeling_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2545,6 +2571,8 @@ def sample_list_hyperparameter_tuning_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3123,6 +3151,8 @@ def sample_list_nas_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3574,6 +3604,8 @@ def sample_list_nas_trial_details(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3938,6 +3970,8 @@ def sample_list_batch_prediction_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4457,6 +4491,8 @@ def sample_search_model_deployment_monitoring_stats_anomalies(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4691,6 +4727,8 @@ def sample_list_model_deployment_monitoring_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/job_service/pagers.py b/google/cloud/aiplatform_v1/services/job_service/pagers.py index d9d92d86fd..d02340c10a 100644 --- a/google/cloud/aiplatform_v1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/job_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import batch_prediction_job from google.cloud.aiplatform_v1.types import custom_job from google.cloud.aiplatform_v1.types import data_labeling_job @@ -60,6 +73,8 @@ def __init__( request: job_service.ListCustomJobsRequest, response: job_service.ListCustomJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -71,12 +86,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListCustomJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListCustomJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -87,7 +107,12 @@ def pages(self) -> Iterator[job_service.ListCustomJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[custom_job.CustomJob]: @@ -122,6 +147,8 @@ def __init__( request: job_service.ListCustomJobsRequest, response: job_service.ListCustomJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -133,12 +160,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListCustomJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListCustomJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -149,7 +181,12 @@ async def pages(self) -> AsyncIterator[job_service.ListCustomJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[custom_job.CustomJob]: @@ -188,6 +225,8 @@ def __init__( request: job_service.ListDataLabelingJobsRequest, response: job_service.ListDataLabelingJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -199,12 +238,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListDataLabelingJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListDataLabelingJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -215,7 +259,12 @@ def pages(self) -> Iterator[job_service.ListDataLabelingJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[data_labeling_job.DataLabelingJob]: @@ -250,6 +299,8 @@ def __init__( request: job_service.ListDataLabelingJobsRequest, response: job_service.ListDataLabelingJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -261,12 +312,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListDataLabelingJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListDataLabelingJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -277,7 +333,12 @@ async def pages(self) -> AsyncIterator[job_service.ListDataLabelingJobsResponse] yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[data_labeling_job.DataLabelingJob]: @@ -316,6 +377,8 @@ def __init__( request: job_service.ListHyperparameterTuningJobsRequest, response: job_service.ListHyperparameterTuningJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -327,12 +390,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListHyperparameterTuningJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -343,7 +411,12 @@ def pages(self) -> Iterator[job_service.ListHyperparameterTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[hyperparameter_tuning_job.HyperparameterTuningJob]: @@ -380,6 +453,8 @@ def __init__( request: job_service.ListHyperparameterTuningJobsRequest, response: job_service.ListHyperparameterTuningJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -391,12 +466,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListHyperparameterTuningJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -409,7 +489,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__( @@ -450,6 +535,8 @@ def __init__( request: job_service.ListNasJobsRequest, response: job_service.ListNasJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -461,12 +548,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListNasJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListNasJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -477,7 +569,12 @@ def pages(self) -> Iterator[job_service.ListNasJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[nas_job.NasJob]: @@ -512,6 +609,8 @@ def __init__( request: job_service.ListNasJobsRequest, response: job_service.ListNasJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -523,12 +622,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListNasJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListNasJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -539,7 +643,12 @@ async def pages(self) -> AsyncIterator[job_service.ListNasJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[nas_job.NasJob]: @@ -578,6 +687,8 @@ def __init__( request: job_service.ListNasTrialDetailsRequest, response: job_service.ListNasTrialDetailsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -589,12 +700,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListNasTrialDetailsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListNasTrialDetailsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -605,7 +721,12 @@ def pages(self) -> Iterator[job_service.ListNasTrialDetailsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[nas_job.NasTrialDetail]: @@ -640,6 +761,8 @@ def __init__( request: job_service.ListNasTrialDetailsRequest, response: job_service.ListNasTrialDetailsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -651,12 +774,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListNasTrialDetailsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListNasTrialDetailsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -667,7 +795,12 @@ async def pages(self) -> AsyncIterator[job_service.ListNasTrialDetailsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[nas_job.NasTrialDetail]: @@ -706,6 +839,8 @@ def __init__( request: job_service.ListBatchPredictionJobsRequest, response: job_service.ListBatchPredictionJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -717,12 +852,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListBatchPredictionJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListBatchPredictionJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -733,7 +873,12 @@ def pages(self) -> Iterator[job_service.ListBatchPredictionJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[batch_prediction_job.BatchPredictionJob]: @@ -768,6 +913,8 @@ def __init__( request: job_service.ListBatchPredictionJobsRequest, response: job_service.ListBatchPredictionJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -779,12 +926,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListBatchPredictionJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListBatchPredictionJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -795,7 +947,12 @@ async def pages(self) -> AsyncIterator[job_service.ListBatchPredictionJobsRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[batch_prediction_job.BatchPredictionJob]: @@ -836,6 +993,8 @@ def __init__( request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, response: job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -847,6 +1006,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.SearchModelDeploymentMonitoringStatsAnomaliesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -855,6 +1017,8 @@ def __init__( job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -867,7 +1031,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__( @@ -909,6 +1078,8 @@ def __init__( request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, response: job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -920,6 +1091,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.SearchModelDeploymentMonitoringStatsAnomaliesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -928,6 +1102,8 @@ def __init__( job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -942,7 +1118,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__( @@ -985,6 +1166,8 @@ def __init__( request: job_service.ListModelDeploymentMonitoringJobsRequest, response: job_service.ListModelDeploymentMonitoringJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -996,12 +1179,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListModelDeploymentMonitoringJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListModelDeploymentMonitoringJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -1012,7 +1200,12 @@ def pages(self) -> Iterator[job_service.ListModelDeploymentMonitoringJobsRespons yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__( @@ -1051,6 +1244,8 @@ def __init__( request: job_service.ListModelDeploymentMonitoringJobsRequest, response: job_service.ListModelDeploymentMonitoringJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -1062,12 +1257,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListModelDeploymentMonitoringJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListModelDeploymentMonitoringJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -1080,7 +1280,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__( diff --git a/google/cloud/aiplatform_v1/services/llm_utility_service/transports/rest.py b/google/cloud/aiplatform_v1/services/llm_utility_service/transports/rest.py index f39544de09..36753f6823 100644 --- a/google/cloud/aiplatform_v1/services/llm_utility_service/transports/rest.py +++ b/google/cloud/aiplatform_v1/services/llm_utility_service/transports/rest.py @@ -513,6 +513,16 @@ def __call__( "uri": "/v1/{endpoint=projects/*/locations/*/publishers/*/models/*}:computeTokens", "body": "*", }, + { + "method": "post", + "uri": "/v1/{endpoint=endpoints/*}:computeTokens", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{endpoint=publishers/*/models/*}:computeTokens", + "body": "*", + }, ] request, metadata = self._interceptor.pre_compute_tokens(request, metadata) pb_request = llm_utility_service.ComputeTokensRequest.pb(request) @@ -610,6 +620,16 @@ def __call__( "uri": "/v1/{endpoint=projects/*/locations/*/publishers/*/models/*}:countTokens", "body": "*", }, + { + "method": "post", + "uri": "/v1/{endpoint=endpoints/*}:countTokens", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{endpoint=publishers/*/models/*}:countTokens", + "body": "*", + }, ] request, metadata = self._interceptor.pre_count_tokens(request, metadata) pb_request = prediction_service.CountTokensRequest.pb(request) diff --git a/google/cloud/aiplatform_v1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1/services/metadata_service/async_client.py index 3ed110fd2d..088e12282f 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/async_client.py @@ -671,6 +671,8 @@ async def sample_list_metadata_stores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1159,6 +1161,8 @@ async def sample_list_artifacts(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1885,6 +1889,8 @@ async def sample_list_contexts(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3113,6 +3119,8 @@ async def sample_list_executions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4091,6 +4099,8 @@ async def sample_list_metadata_schemas(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/metadata_service/client.py b/google/cloud/aiplatform_v1/services/metadata_service/client.py index 173f4eed16..2d0005bfca 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/client.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/client.py @@ -1165,6 +1165,8 @@ def sample_list_metadata_stores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1641,6 +1643,8 @@ def sample_list_artifacts(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2349,6 +2353,8 @@ def sample_list_contexts(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3551,6 +3557,8 @@ def sample_list_executions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4507,6 +4515,8 @@ def sample_list_metadata_schemas(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/metadata_service/pagers.py b/google/cloud/aiplatform_v1/services/metadata_service/pagers.py index 42d9380760..68963a5d4a 100644 --- a/google/cloud/aiplatform_v1/services/metadata_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/metadata_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import artifact from google.cloud.aiplatform_v1.types import context from google.cloud.aiplatform_v1.types import execution @@ -56,6 +69,8 @@ def __init__( request: metadata_service.ListMetadataStoresRequest, response: metadata_service.ListMetadataStoresResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -67,12 +82,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListMetadataStoresResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListMetadataStoresRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -83,7 +103,12 @@ def pages(self) -> Iterator[metadata_service.ListMetadataStoresResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[metadata_store.MetadataStore]: @@ -118,6 +143,8 @@ def __init__( request: metadata_service.ListMetadataStoresRequest, response: metadata_service.ListMetadataStoresResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -129,12 +156,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListMetadataStoresResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListMetadataStoresRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages(self) -> AsyncIterator[metadata_service.ListMetadataStoresRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[metadata_store.MetadataStore]: @@ -184,6 +221,8 @@ def __init__( request: metadata_service.ListArtifactsRequest, response: metadata_service.ListArtifactsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -195,12 +234,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListArtifactsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListArtifactsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -211,7 +255,12 @@ def pages(self) -> Iterator[metadata_service.ListArtifactsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[artifact.Artifact]: @@ -246,6 +295,8 @@ def __init__( request: metadata_service.ListArtifactsRequest, response: metadata_service.ListArtifactsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -257,12 +308,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListArtifactsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListArtifactsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -273,7 +329,12 @@ async def pages(self) -> AsyncIterator[metadata_service.ListArtifactsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[artifact.Artifact]: @@ -312,6 +373,8 @@ def __init__( request: metadata_service.ListContextsRequest, response: metadata_service.ListContextsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -323,12 +386,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListContextsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListContextsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -339,7 +407,12 @@ def pages(self) -> Iterator[metadata_service.ListContextsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[context.Context]: @@ -374,6 +447,8 @@ def __init__( request: metadata_service.ListContextsRequest, response: metadata_service.ListContextsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -385,12 +460,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListContextsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListContextsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -401,7 +481,12 @@ async def pages(self) -> AsyncIterator[metadata_service.ListContextsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[context.Context]: @@ -440,6 +525,8 @@ def __init__( request: metadata_service.ListExecutionsRequest, response: metadata_service.ListExecutionsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -451,12 +538,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListExecutionsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListExecutionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -467,7 +559,12 @@ def pages(self) -> Iterator[metadata_service.ListExecutionsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[execution.Execution]: @@ -502,6 +599,8 @@ def __init__( request: metadata_service.ListExecutionsRequest, response: metadata_service.ListExecutionsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -513,12 +612,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListExecutionsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListExecutionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -529,7 +633,12 @@ async def pages(self) -> AsyncIterator[metadata_service.ListExecutionsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[execution.Execution]: @@ -568,6 +677,8 @@ def __init__( request: metadata_service.ListMetadataSchemasRequest, response: metadata_service.ListMetadataSchemasResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -579,12 +690,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListMetadataSchemasResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListMetadataSchemasRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -595,7 +711,12 @@ def pages(self) -> Iterator[metadata_service.ListMetadataSchemasResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[metadata_schema.MetadataSchema]: @@ -630,6 +751,8 @@ def __init__( request: metadata_service.ListMetadataSchemasRequest, response: metadata_service.ListMetadataSchemasResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -641,12 +764,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListMetadataSchemasResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListMetadataSchemasRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -659,7 +787,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[metadata_schema.MetadataSchema]: diff --git a/google/cloud/aiplatform_v1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1/services/migration_service/async_client.py index 49cf7cba3b..9bea0917ad 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/async_client.py @@ -403,6 +403,8 @@ async def sample_search_migratable_resources(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 42be200245..e996115527 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -238,40 +238,40 @@ def parse_dataset_path(path: str) -> Dict[str, str]: @staticmethod def dataset_path( project: str, + location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( + return "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod def dataset_path( project: str, - location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( + return "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod @@ -935,6 +935,8 @@ def sample_search_migratable_resources(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1/services/migration_service/pagers.py index ff3167bbe8..b66546b7ca 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/migration_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import migratable_resource from google.cloud.aiplatform_v1.types import migration_service @@ -52,6 +65,8 @@ def __init__( request: migration_service.SearchMigratableResourcesRequest, response: migration_service.SearchMigratableResourcesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.SearchMigratableResourcesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = migration_service.SearchMigratableResourcesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[migration_service.SearchMigratableResourcesResponse] yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[migratable_resource.MigratableResource]: @@ -116,6 +141,8 @@ def __init__( request: migration_service.SearchMigratableResourcesRequest, response: migration_service.SearchMigratableResourcesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.SearchMigratableResourcesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = migration_service.SearchMigratableResourcesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[migratable_resource.MigratableResource]: diff --git a/google/cloud/aiplatform_v1/services/model_garden_service/async_client.py b/google/cloud/aiplatform_v1/services/model_garden_service/async_client.py index e8eb6231b6..2bc76e5800 100644 --- a/google/cloud/aiplatform_v1/services/model_garden_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/model_garden_service/async_client.py @@ -72,6 +72,10 @@ class ModelGardenServiceAsyncClient: parse_publisher_model_path = staticmethod( ModelGardenServiceClient.parse_publisher_model_path ) + reservation_path = staticmethod(ModelGardenServiceClient.reservation_path) + parse_reservation_path = staticmethod( + ModelGardenServiceClient.parse_reservation_path + ) common_billing_account_path = staticmethod( ModelGardenServiceClient.common_billing_account_path ) diff --git a/google/cloud/aiplatform_v1/services/model_garden_service/client.py b/google/cloud/aiplatform_v1/services/model_garden_service/client.py index 1410315708..0118a152e0 100644 --- a/google/cloud/aiplatform_v1/services/model_garden_service/client.py +++ b/google/cloud/aiplatform_v1/services/model_garden_service/client.py @@ -204,6 +204,28 @@ def parse_publisher_model_path(path: str) -> Dict[str, str]: m = re.match(r"^publishers/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path( billing_account: str, diff --git a/google/cloud/aiplatform_v1/services/model_service/async_client.py b/google/cloud/aiplatform_v1/services/model_service/async_client.py index 97e5ffd135..5257f7e7ac 100644 --- a/google/cloud/aiplatform_v1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/model_service/async_client.py @@ -654,6 +654,8 @@ async def sample_list_models(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -773,6 +775,8 @@ async def sample_list_model_versions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2321,6 +2325,8 @@ async def sample_list_model_evaluations(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2554,6 +2560,8 @@ async def sample_list_model_evaluation_slices(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/model_service/client.py b/google/cloud/aiplatform_v1/services/model_service/client.py index da2d0a487c..a8eb86198f 100644 --- a/google/cloud/aiplatform_v1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1/services/model_service/client.py @@ -1161,6 +1161,8 @@ def sample_list_models(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1277,6 +1279,8 @@ def sample_list_model_versions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2795,6 +2799,8 @@ def sample_list_model_evaluations(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3026,6 +3032,8 @@ def sample_list_model_evaluation_slices(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/model_service/pagers.py b/google/cloud/aiplatform_v1/services/model_service/pagers.py index 83c638d3f4..bcbb42ce5a 100644 --- a/google/cloud/aiplatform_v1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/model_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import model from google.cloud.aiplatform_v1.types import model_evaluation from google.cloud.aiplatform_v1.types import model_evaluation_slice @@ -54,6 +67,8 @@ def __init__( request: model_service.ListModelsRequest, response: model_service.ListModelsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -65,12 +80,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListModelsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -81,7 +101,12 @@ def pages(self) -> Iterator[model_service.ListModelsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model.Model]: @@ -116,6 +141,8 @@ def __init__( request: model_service.ListModelsRequest, response: model_service.ListModelsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListModelsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -143,7 +175,12 @@ async def pages(self) -> AsyncIterator[model_service.ListModelsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model.Model]: @@ -182,6 +219,8 @@ def __init__( request: model_service.ListModelVersionsRequest, response: model_service.ListModelVersionsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -193,12 +232,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListModelVersionsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelVersionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -209,7 +253,12 @@ def pages(self) -> Iterator[model_service.ListModelVersionsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model.Model]: @@ -244,6 +293,8 @@ def __init__( request: model_service.ListModelVersionsRequest, response: model_service.ListModelVersionsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -255,12 +306,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListModelVersionsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelVersionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -271,7 +327,12 @@ async def pages(self) -> AsyncIterator[model_service.ListModelVersionsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model.Model]: @@ -310,6 +371,8 @@ def __init__( request: model_service.ListModelEvaluationsRequest, response: model_service.ListModelEvaluationsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -321,12 +384,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListModelEvaluationsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -337,7 +405,12 @@ def pages(self) -> Iterator[model_service.ListModelEvaluationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model_evaluation.ModelEvaluation]: @@ -372,6 +445,8 @@ def __init__( request: model_service.ListModelEvaluationsRequest, response: model_service.ListModelEvaluationsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -383,12 +458,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListModelEvaluationsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -399,7 +479,12 @@ async def pages(self) -> AsyncIterator[model_service.ListModelEvaluationsRespons yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model_evaluation.ModelEvaluation]: @@ -438,6 +523,8 @@ def __init__( request: model_service.ListModelEvaluationSlicesRequest, response: model_service.ListModelEvaluationSlicesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -449,12 +536,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationSlicesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -465,7 +557,12 @@ def pages(self) -> Iterator[model_service.ListModelEvaluationSlicesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model_evaluation_slice.ModelEvaluationSlice]: @@ -502,6 +599,8 @@ def __init__( request: model_service.ListModelEvaluationSlicesRequest, response: model_service.ListModelEvaluationSlicesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -513,12 +612,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationSlicesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -531,7 +635,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model_evaluation_slice.ModelEvaluationSlice]: diff --git a/google/cloud/aiplatform_v1/services/notebook_service/async_client.py b/google/cloud/aiplatform_v1/services/notebook_service/async_client.py index a4494a62c6..c79e273223 100644 --- a/google/cloud/aiplatform_v1/services/notebook_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/notebook_service/async_client.py @@ -108,6 +108,8 @@ class NotebookServiceAsyncClient: parse_notebook_runtime_template_path = staticmethod( NotebookServiceClient.parse_notebook_runtime_template_path ) + reservation_path = staticmethod(NotebookServiceClient.reservation_path) + parse_reservation_path = staticmethod(NotebookServiceClient.parse_reservation_path) schedule_path = staticmethod(NotebookServiceClient.schedule_path) parse_schedule_path = staticmethod(NotebookServiceClient.parse_schedule_path) subnetwork_path = staticmethod(NotebookServiceClient.subnetwork_path) @@ -691,6 +693,8 @@ async def sample_list_notebook_runtime_templates(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1369,6 +1373,8 @@ async def sample_list_notebook_runtimes(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2143,6 +2149,8 @@ async def sample_list_notebook_execution_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/notebook_service/client.py b/google/cloud/aiplatform_v1/services/notebook_service/client.py index 6dc1907ace..d5017ffe7c 100644 --- a/google/cloud/aiplatform_v1/services/notebook_service/client.py +++ b/google/cloud/aiplatform_v1/services/notebook_service/client.py @@ -294,6 +294,28 @@ def parse_notebook_runtime_template_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def schedule_path( project: str, @@ -1194,6 +1216,8 @@ def sample_list_notebook_runtime_templates(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1861,6 +1885,8 @@ def sample_list_notebook_runtimes(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2623,6 +2649,8 @@ def sample_list_notebook_execution_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/notebook_service/pagers.py b/google/cloud/aiplatform_v1/services/notebook_service/pagers.py index fe23fb07b1..95ad25bf9c 100644 --- a/google/cloud/aiplatform_v1/services/notebook_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/notebook_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import notebook_execution_job from google.cloud.aiplatform_v1.types import notebook_runtime from google.cloud.aiplatform_v1.types import notebook_service @@ -53,6 +66,8 @@ def __init__( request: notebook_service.ListNotebookRuntimeTemplatesRequest, response: notebook_service.ListNotebookRuntimeTemplatesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -64,12 +79,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListNotebookRuntimeTemplatesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookRuntimeTemplatesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -80,7 +100,12 @@ def pages(self) -> Iterator[notebook_service.ListNotebookRuntimeTemplatesRespons yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[notebook_runtime.NotebookRuntimeTemplate]: @@ -117,6 +142,8 @@ def __init__( request: notebook_service.ListNotebookRuntimeTemplatesRequest, response: notebook_service.ListNotebookRuntimeTemplatesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -128,12 +155,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListNotebookRuntimeTemplatesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookRuntimeTemplatesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -146,7 +178,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[notebook_runtime.NotebookRuntimeTemplate]: @@ -185,6 +222,8 @@ def __init__( request: notebook_service.ListNotebookRuntimesRequest, response: notebook_service.ListNotebookRuntimesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -196,12 +235,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListNotebookRuntimesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookRuntimesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -212,7 +256,12 @@ def pages(self) -> Iterator[notebook_service.ListNotebookRuntimesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[notebook_runtime.NotebookRuntime]: @@ -247,6 +296,8 @@ def __init__( request: notebook_service.ListNotebookRuntimesRequest, response: notebook_service.ListNotebookRuntimesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -258,12 +309,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListNotebookRuntimesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookRuntimesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -276,7 +332,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[notebook_runtime.NotebookRuntime]: @@ -315,6 +376,8 @@ def __init__( request: notebook_service.ListNotebookExecutionJobsRequest, response: notebook_service.ListNotebookExecutionJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -326,12 +389,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListNotebookExecutionJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookExecutionJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -342,7 +410,12 @@ def pages(self) -> Iterator[notebook_service.ListNotebookExecutionJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[notebook_execution_job.NotebookExecutionJob]: @@ -379,6 +452,8 @@ def __init__( request: notebook_service.ListNotebookExecutionJobsRequest, response: notebook_service.ListNotebookExecutionJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -390,12 +465,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListNotebookExecutionJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookExecutionJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -408,7 +488,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[notebook_execution_job.NotebookExecutionJob]: diff --git a/google/cloud/aiplatform_v1/services/persistent_resource_service/async_client.py b/google/cloud/aiplatform_v1/services/persistent_resource_service/async_client.py index 312b809f5d..512eb0c3a0 100644 --- a/google/cloud/aiplatform_v1/services/persistent_resource_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/persistent_resource_service/async_client.py @@ -93,6 +93,10 @@ class PersistentResourceServiceAsyncClient: parse_persistent_resource_path = staticmethod( PersistentResourceServiceClient.parse_persistent_resource_path ) + reservation_path = staticmethod(PersistentResourceServiceClient.reservation_path) + parse_reservation_path = staticmethod( + PersistentResourceServiceClient.parse_reservation_path + ) common_billing_account_path = staticmethod( PersistentResourceServiceClient.common_billing_account_path ) @@ -688,6 +692,8 @@ async def sample_list_persistent_resources(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/persistent_resource_service/client.py b/google/cloud/aiplatform_v1/services/persistent_resource_service/client.py index b2bbaa0a65..4a03ec9984 100644 --- a/google/cloud/aiplatform_v1/services/persistent_resource_service/client.py +++ b/google/cloud/aiplatform_v1/services/persistent_resource_service/client.py @@ -241,6 +241,28 @@ def parse_persistent_resource_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path( billing_account: str, @@ -1109,6 +1131,8 @@ def sample_list_persistent_resources(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/persistent_resource_service/pagers.py b/google/cloud/aiplatform_v1/services/persistent_resource_service/pagers.py index fbcfeaee1f..ca9a084315 100644 --- a/google/cloud/aiplatform_v1/services/persistent_resource_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/persistent_resource_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import persistent_resource from google.cloud.aiplatform_v1.types import persistent_resource_service @@ -54,6 +67,8 @@ def __init__( request: persistent_resource_service.ListPersistentResourcesRequest, response: persistent_resource_service.ListPersistentResourcesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -65,6 +80,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListPersistentResourcesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -73,6 +91,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -85,7 +105,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[persistent_resource.PersistentResource]: @@ -122,6 +147,8 @@ def __init__( request: persistent_resource_service.ListPersistentResourcesRequest, response: persistent_resource_service.ListPersistentResourcesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -133,6 +160,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListPersistentResourcesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -141,6 +171,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -153,7 +185,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[persistent_resource.PersistentResource]: diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py index 422d4e6917..9f4138e3c7 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py @@ -660,6 +660,8 @@ async def sample_list_training_pipelines(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1258,6 +1260,8 @@ async def sample_list_pipeline_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1/services/pipeline_service/client.py index caa4f73d57..53a94394e9 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/client.py @@ -1229,6 +1229,8 @@ def sample_list_training_pipelines(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1812,6 +1814,8 @@ def sample_list_pipeline_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py index 79c3b471d7..08c4883bb9 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import pipeline_job from google.cloud.aiplatform_v1.types import pipeline_service from google.cloud.aiplatform_v1.types import training_pipeline @@ -53,6 +66,8 @@ def __init__( request: pipeline_service.ListTrainingPipelinesRequest, response: pipeline_service.ListTrainingPipelinesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -64,12 +79,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTrainingPipelinesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = pipeline_service.ListTrainingPipelinesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -80,7 +100,12 @@ def pages(self) -> Iterator[pipeline_service.ListTrainingPipelinesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[training_pipeline.TrainingPipeline]: @@ -117,6 +142,8 @@ def __init__( request: pipeline_service.ListTrainingPipelinesRequest, response: pipeline_service.ListTrainingPipelinesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -128,12 +155,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTrainingPipelinesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = pipeline_service.ListTrainingPipelinesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -146,7 +178,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[training_pipeline.TrainingPipeline]: @@ -185,6 +222,8 @@ def __init__( request: pipeline_service.ListPipelineJobsRequest, response: pipeline_service.ListPipelineJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -196,12 +235,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListPipelineJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = pipeline_service.ListPipelineJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -212,7 +256,12 @@ def pages(self) -> Iterator[pipeline_service.ListPipelineJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[pipeline_job.PipelineJob]: @@ -247,6 +296,8 @@ def __init__( request: pipeline_service.ListPipelineJobsRequest, response: pipeline_service.ListPipelineJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -258,12 +309,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListPipelineJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = pipeline_service.ListPipelineJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -274,7 +330,12 @@ async def pages(self) -> AsyncIterator[pipeline_service.ListPipelineJobsResponse yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[pipeline_job.PipelineJob]: diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/rest.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/rest.py index b9f56f63de..99ef8796ec 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/rest.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/rest.py @@ -1016,6 +1016,16 @@ def __call__( "uri": "/v1/{model=projects/*/locations/*/publishers/*/models/*}:generateContent", "body": "*", }, + { + "method": "post", + "uri": "/v1/{model=endpoints/*}:generateContent", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{model=publishers/*/models/*}:generateContent", + "body": "*", + }, ] request, metadata = self._interceptor.pre_generate_content( request, metadata @@ -1502,6 +1512,16 @@ def __call__( "uri": "/v1/{model=projects/*/locations/*/publishers/*/models/*}:streamGenerateContent", "body": "*", }, + { + "method": "post", + "uri": "/v1/{model=endpoints/*}:streamGenerateContent", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{model=publishers/*/models/*}:streamGenerateContent", + "body": "*", + }, ] request, metadata = self._interceptor.pre_stream_generate_content( request, metadata diff --git a/google/cloud/aiplatform_v1/services/schedule_service/async_client.py b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py index 9d573e9fd4..7fbabd6442 100644 --- a/google/cloud/aiplatform_v1/services/schedule_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/schedule_service/async_client.py @@ -764,6 +764,8 @@ async def sample_list_schedules(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/schedule_service/client.py b/google/cloud/aiplatform_v1/services/schedule_service/client.py index d49bcd201d..fbce613dcf 100644 --- a/google/cloud/aiplatform_v1/services/schedule_service/client.py +++ b/google/cloud/aiplatform_v1/services/schedule_service/client.py @@ -1292,6 +1292,8 @@ def sample_list_schedules(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/schedule_service/pagers.py b/google/cloud/aiplatform_v1/services/schedule_service/pagers.py index 5794f8ea95..2ceeb86850 100644 --- a/google/cloud/aiplatform_v1/services/schedule_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/schedule_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import schedule from google.cloud.aiplatform_v1.types import schedule_service @@ -52,6 +65,8 @@ def __init__( request: schedule_service.ListSchedulesRequest, response: schedule_service.ListSchedulesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListSchedulesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = schedule_service.ListSchedulesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[schedule_service.ListSchedulesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[schedule.Schedule]: @@ -114,6 +139,8 @@ def __init__( request: schedule_service.ListSchedulesRequest, response: schedule_service.ListSchedulesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -125,12 +152,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListSchedulesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = schedule_service.ListSchedulesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -141,7 +173,12 @@ async def pages(self) -> AsyncIterator[schedule_service.ListSchedulesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[schedule.Schedule]: diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py index 95107d2fc5..a79c5d6367 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py @@ -664,6 +664,8 @@ async def sample_list_specialist_pools(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py index b8b8788fab..f1836322fc 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py @@ -1070,6 +1070,8 @@ def sample_list_specialist_pools(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py index a23136756d..9661ffae3d 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import specialist_pool from google.cloud.aiplatform_v1.types import specialist_pool_service @@ -52,6 +65,8 @@ def __init__( request: specialist_pool_service.ListSpecialistPoolsRequest, response: specialist_pool_service.ListSpecialistPoolsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListSpecialistPoolsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[specialist_pool_service.ListSpecialistPoolsResponse] yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[specialist_pool.SpecialistPool]: @@ -116,6 +141,8 @@ def __init__( request: specialist_pool_service.ListSpecialistPoolsRequest, response: specialist_pool_service.ListSpecialistPoolsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListSpecialistPoolsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[specialist_pool.SpecialistPool]: diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py b/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py index cd55f4e543..6b6fb89692 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/async_client.py @@ -816,6 +816,8 @@ async def sample_list_tensorboards(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1676,6 +1678,8 @@ async def sample_list_tensorboard_experiments(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2447,6 +2451,8 @@ async def sample_list_tensorboard_runs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3217,6 +3223,8 @@ async def sample_list_tensorboard_time_series(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4105,6 +4113,8 @@ async def sample_export_tensorboard_time_series_data(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/client.py b/google/cloud/aiplatform_v1/services/tensorboard_service/client.py index 1462c2fbde..051f780c64 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/client.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/client.py @@ -1284,6 +1284,8 @@ def sample_list_tensorboards(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2131,6 +2133,8 @@ def sample_list_tensorboard_experiments(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2888,6 +2892,8 @@ def sample_list_tensorboard_runs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3650,6 +3656,8 @@ def sample_list_tensorboard_time_series(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4531,6 +4539,8 @@ def sample_export_tensorboard_time_series_data(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/tensorboard_service/pagers.py b/google/cloud/aiplatform_v1/services/tensorboard_service/pagers.py index 18e304a294..5e59dd121e 100644 --- a/google/cloud/aiplatform_v1/services/tensorboard_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/tensorboard_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import tensorboard from google.cloud.aiplatform_v1.types import tensorboard_data from google.cloud.aiplatform_v1.types import tensorboard_experiment @@ -56,6 +69,8 @@ def __init__( request: tensorboard_service.ListTensorboardsRequest, response: tensorboard_service.ListTensorboardsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -67,12 +82,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTensorboardsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -83,7 +103,12 @@ def pages(self) -> Iterator[tensorboard_service.ListTensorboardsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tensorboard.Tensorboard]: @@ -118,6 +143,8 @@ def __init__( request: tensorboard_service.ListTensorboardsRequest, response: tensorboard_service.ListTensorboardsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -129,12 +156,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTensorboardsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -147,7 +179,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tensorboard.Tensorboard]: @@ -186,6 +223,8 @@ def __init__( request: tensorboard_service.ListTensorboardExperimentsRequest, response: tensorboard_service.ListTensorboardExperimentsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -197,12 +236,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTensorboardExperimentsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardExperimentsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -213,7 +257,12 @@ def pages(self) -> Iterator[tensorboard_service.ListTensorboardExperimentsRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tensorboard_experiment.TensorboardExperiment]: @@ -250,6 +299,8 @@ def __init__( request: tensorboard_service.ListTensorboardExperimentsRequest, response: tensorboard_service.ListTensorboardExperimentsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -261,12 +312,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTensorboardExperimentsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardExperimentsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -279,7 +335,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tensorboard_experiment.TensorboardExperiment]: @@ -318,6 +379,8 @@ def __init__( request: tensorboard_service.ListTensorboardRunsRequest, response: tensorboard_service.ListTensorboardRunsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -329,12 +392,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTensorboardRunsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardRunsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -345,7 +413,12 @@ def pages(self) -> Iterator[tensorboard_service.ListTensorboardRunsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tensorboard_run.TensorboardRun]: @@ -382,6 +455,8 @@ def __init__( request: tensorboard_service.ListTensorboardRunsRequest, response: tensorboard_service.ListTensorboardRunsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -393,12 +468,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTensorboardRunsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardRunsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -411,7 +491,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tensorboard_run.TensorboardRun]: @@ -450,6 +535,8 @@ def __init__( request: tensorboard_service.ListTensorboardTimeSeriesRequest, response: tensorboard_service.ListTensorboardTimeSeriesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -461,12 +548,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTensorboardTimeSeriesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -477,7 +569,12 @@ def pages(self) -> Iterator[tensorboard_service.ListTensorboardTimeSeriesRespons yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tensorboard_time_series.TensorboardTimeSeries]: @@ -514,6 +611,8 @@ def __init__( request: tensorboard_service.ListTensorboardTimeSeriesRequest, response: tensorboard_service.ListTensorboardTimeSeriesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -525,12 +624,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTensorboardTimeSeriesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -543,7 +647,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tensorboard_time_series.TensorboardTimeSeries]: @@ -584,6 +693,8 @@ def __init__( request: tensorboard_service.ExportTensorboardTimeSeriesDataRequest, response: tensorboard_service.ExportTensorboardTimeSeriesDataResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -595,6 +706,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ExportTensorboardTimeSeriesDataResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -603,6 +717,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -615,7 +731,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tensorboard_data.TimeSeriesDataPoint]: @@ -652,6 +773,8 @@ def __init__( request: tensorboard_service.ExportTensorboardTimeSeriesDataRequest, response: tensorboard_service.ExportTensorboardTimeSeriesDataResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -663,6 +786,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ExportTensorboardTimeSeriesDataResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -671,6 +797,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -683,7 +811,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tensorboard_data.TimeSeriesDataPoint]: diff --git a/google/cloud/aiplatform_v1/services/vizier_service/async_client.py b/google/cloud/aiplatform_v1/services/vizier_service/async_client.py index b4fccfc04e..0d86e4059a 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/async_client.py @@ -619,6 +619,8 @@ async def sample_list_studies(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1272,6 +1274,8 @@ async def sample_list_trials(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/vizier_service/client.py b/google/cloud/aiplatform_v1/services/vizier_service/client.py index 73acfbc683..93db10900b 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/client.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/client.py @@ -1069,6 +1069,8 @@ def sample_list_studies(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1705,6 +1707,8 @@ def sample_list_trials(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1/services/vizier_service/pagers.py b/google/cloud/aiplatform_v1/services/vizier_service/pagers.py index 145868ca3d..cb69c49782 100644 --- a/google/cloud/aiplatform_v1/services/vizier_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/vizier_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1.types import study from google.cloud.aiplatform_v1.types import vizier_service @@ -52,6 +65,8 @@ def __init__( request: vizier_service.ListStudiesRequest, response: vizier_service.ListStudiesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListStudiesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vizier_service.ListStudiesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[vizier_service.ListStudiesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[study.Study]: @@ -114,6 +139,8 @@ def __init__( request: vizier_service.ListStudiesRequest, response: vizier_service.ListStudiesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -125,12 +152,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListStudiesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vizier_service.ListStudiesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -141,7 +173,12 @@ async def pages(self) -> AsyncIterator[vizier_service.ListStudiesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[study.Study]: @@ -180,6 +217,8 @@ def __init__( request: vizier_service.ListTrialsRequest, response: vizier_service.ListTrialsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -191,12 +230,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTrialsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vizier_service.ListTrialsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -207,7 +251,12 @@ def pages(self) -> Iterator[vizier_service.ListTrialsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[study.Trial]: @@ -242,6 +291,8 @@ def __init__( request: vizier_service.ListTrialsRequest, response: vizier_service.ListTrialsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -253,12 +304,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1.types.ListTrialsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vizier_service.ListTrialsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -269,7 +325,12 @@ async def pages(self) -> AsyncIterator[vizier_service.ListTrialsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[study.Trial]: diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py index 42ee8feec0..7dd320b3ae 100644 --- a/google/cloud/aiplatform_v1/types/__init__.py +++ b/google/cloud/aiplatform_v1/types/__init__.py @@ -174,6 +174,110 @@ EvaluatedAnnotation, EvaluatedAnnotationExplanation, ) +from .evaluation_service import ( + BleuInput, + BleuInstance, + BleuMetricValue, + BleuResults, + BleuSpec, + CoherenceInput, + CoherenceInstance, + CoherenceResult, + CoherenceSpec, + EvaluateInstancesRequest, + EvaluateInstancesResponse, + ExactMatchInput, + ExactMatchInstance, + ExactMatchMetricValue, + ExactMatchResults, + ExactMatchSpec, + FluencyInput, + FluencyInstance, + FluencyResult, + FluencySpec, + FulfillmentInput, + FulfillmentInstance, + FulfillmentResult, + FulfillmentSpec, + GroundednessInput, + GroundednessInstance, + GroundednessResult, + GroundednessSpec, + PairwiseMetricInput, + PairwiseMetricInstance, + PairwiseMetricResult, + PairwiseMetricSpec, + PairwiseQuestionAnsweringQualityInput, + PairwiseQuestionAnsweringQualityInstance, + PairwiseQuestionAnsweringQualityResult, + PairwiseQuestionAnsweringQualitySpec, + PairwiseSummarizationQualityInput, + PairwiseSummarizationQualityInstance, + PairwiseSummarizationQualityResult, + PairwiseSummarizationQualitySpec, + PointwiseMetricInput, + PointwiseMetricInstance, + PointwiseMetricResult, + PointwiseMetricSpec, + QuestionAnsweringCorrectnessInput, + QuestionAnsweringCorrectnessInstance, + QuestionAnsweringCorrectnessResult, + QuestionAnsweringCorrectnessSpec, + QuestionAnsweringHelpfulnessInput, + QuestionAnsweringHelpfulnessInstance, + QuestionAnsweringHelpfulnessResult, + QuestionAnsweringHelpfulnessSpec, + QuestionAnsweringQualityInput, + QuestionAnsweringQualityInstance, + QuestionAnsweringQualityResult, + QuestionAnsweringQualitySpec, + QuestionAnsweringRelevanceInput, + QuestionAnsweringRelevanceInstance, + QuestionAnsweringRelevanceResult, + QuestionAnsweringRelevanceSpec, + RougeInput, + RougeInstance, + RougeMetricValue, + RougeResults, + RougeSpec, + SafetyInput, + SafetyInstance, + SafetyResult, + SafetySpec, + SummarizationHelpfulnessInput, + SummarizationHelpfulnessInstance, + SummarizationHelpfulnessResult, + SummarizationHelpfulnessSpec, + SummarizationQualityInput, + SummarizationQualityInstance, + SummarizationQualityResult, + SummarizationQualitySpec, + SummarizationVerbosityInput, + SummarizationVerbosityInstance, + SummarizationVerbosityResult, + SummarizationVerbositySpec, + ToolCallValidInput, + ToolCallValidInstance, + ToolCallValidMetricValue, + ToolCallValidResults, + ToolCallValidSpec, + ToolNameMatchInput, + ToolNameMatchInstance, + ToolNameMatchMetricValue, + ToolNameMatchResults, + ToolNameMatchSpec, + ToolParameterKeyMatchInput, + ToolParameterKeyMatchInstance, + ToolParameterKeyMatchMetricValue, + ToolParameterKeyMatchResults, + ToolParameterKeyMatchSpec, + ToolParameterKVMatchInput, + ToolParameterKVMatchInstance, + ToolParameterKVMatchMetricValue, + ToolParameterKVMatchResults, + ToolParameterKVMatchSpec, + PairwiseChoice, +) from .event import ( Event, ) @@ -757,6 +861,9 @@ from .publisher_model import ( PublisherModel, ) +from .reservation_affinity import ( + ReservationAffinity, +) from .saved_query import ( SavedQuery, ) @@ -1050,6 +1157,108 @@ "ErrorAnalysisAnnotation", "EvaluatedAnnotation", "EvaluatedAnnotationExplanation", + "BleuInput", + "BleuInstance", + "BleuMetricValue", + "BleuResults", + "BleuSpec", + "CoherenceInput", + "CoherenceInstance", + "CoherenceResult", + "CoherenceSpec", + "EvaluateInstancesRequest", + "EvaluateInstancesResponse", + "ExactMatchInput", + "ExactMatchInstance", + "ExactMatchMetricValue", + "ExactMatchResults", + "ExactMatchSpec", + "FluencyInput", + "FluencyInstance", + "FluencyResult", + "FluencySpec", + "FulfillmentInput", + "FulfillmentInstance", + "FulfillmentResult", + "FulfillmentSpec", + "GroundednessInput", + "GroundednessInstance", + "GroundednessResult", + "GroundednessSpec", + "PairwiseMetricInput", + "PairwiseMetricInstance", + "PairwiseMetricResult", + "PairwiseMetricSpec", + "PairwiseQuestionAnsweringQualityInput", + "PairwiseQuestionAnsweringQualityInstance", + "PairwiseQuestionAnsweringQualityResult", + "PairwiseQuestionAnsweringQualitySpec", + "PairwiseSummarizationQualityInput", + "PairwiseSummarizationQualityInstance", + "PairwiseSummarizationQualityResult", + "PairwiseSummarizationQualitySpec", + "PointwiseMetricInput", + "PointwiseMetricInstance", + "PointwiseMetricResult", + "PointwiseMetricSpec", + "QuestionAnsweringCorrectnessInput", + "QuestionAnsweringCorrectnessInstance", + "QuestionAnsweringCorrectnessResult", + "QuestionAnsweringCorrectnessSpec", + "QuestionAnsweringHelpfulnessInput", + "QuestionAnsweringHelpfulnessInstance", + "QuestionAnsweringHelpfulnessResult", + "QuestionAnsweringHelpfulnessSpec", + "QuestionAnsweringQualityInput", + "QuestionAnsweringQualityInstance", + "QuestionAnsweringQualityResult", + "QuestionAnsweringQualitySpec", + "QuestionAnsweringRelevanceInput", + "QuestionAnsweringRelevanceInstance", + "QuestionAnsweringRelevanceResult", + "QuestionAnsweringRelevanceSpec", + "RougeInput", + "RougeInstance", + "RougeMetricValue", + "RougeResults", + "RougeSpec", + "SafetyInput", + "SafetyInstance", + "SafetyResult", + "SafetySpec", + "SummarizationHelpfulnessInput", + "SummarizationHelpfulnessInstance", + "SummarizationHelpfulnessResult", + "SummarizationHelpfulnessSpec", + "SummarizationQualityInput", + "SummarizationQualityInstance", + "SummarizationQualityResult", + "SummarizationQualitySpec", + "SummarizationVerbosityInput", + "SummarizationVerbosityInstance", + "SummarizationVerbosityResult", + "SummarizationVerbositySpec", + "ToolCallValidInput", + "ToolCallValidInstance", + "ToolCallValidMetricValue", + "ToolCallValidResults", + "ToolCallValidSpec", + "ToolNameMatchInput", + "ToolNameMatchInstance", + "ToolNameMatchMetricValue", + "ToolNameMatchResults", + "ToolNameMatchSpec", + "ToolParameterKeyMatchInput", + "ToolParameterKeyMatchInstance", + "ToolParameterKeyMatchMetricValue", + "ToolParameterKeyMatchResults", + "ToolParameterKeyMatchSpec", + "ToolParameterKVMatchInput", + "ToolParameterKVMatchInstance", + "ToolParameterKVMatchMetricValue", + "ToolParameterKVMatchResults", + "ToolParameterKVMatchSpec", + "PairwiseChoice", "Event", "Execution", "Attribution", @@ -1509,6 +1718,7 @@ "StreamingRawPredictResponse", "StreamRawPredictRequest", "PublisherModel", + "ReservationAffinity", "SavedQuery", "Schedule", "CreateScheduleRequest", diff --git a/google/cloud/aiplatform_v1/types/accelerator_type.py b/google/cloud/aiplatform_v1/types/accelerator_type.py index 041207152f..bf6749dff9 100644 --- a/google/cloud/aiplatform_v1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1/types/accelerator_type.py @@ -36,7 +36,9 @@ class AcceleratorType(proto.Enum): Unspecified accelerator type, which means no accelerator. NVIDIA_TESLA_K80 (1): - Nvidia Tesla K80 GPU. + Deprecated: Nvidia Tesla K80 GPU has reached + end of support, see + https://cloud.google.com/compute/docs/eol/k80-eol. NVIDIA_TESLA_P100 (2): Nvidia Tesla P100 GPU. NVIDIA_TESLA_V100 (3): diff --git a/google/cloud/aiplatform_v1/types/custom_job.py b/google/cloud/aiplatform_v1/types/custom_job.py index 805fa504f9..80f5be068b 100644 --- a/google/cloud/aiplatform_v1/types/custom_job.py +++ b/google/cloud/aiplatform_v1/types/custom_job.py @@ -563,10 +563,18 @@ class Strategy(proto.Enum): LOW_COST (2): Low cost by making potential use of spot resources. + STANDARD (3): + Standard provisioning strategy uses regular + on-demand resources. + SPOT (4): + Spot provisioning strategy uses spot + resources. """ STRATEGY_UNSPECIFIED = 0 ON_DEMAND = 1 LOW_COST = 2 + STANDARD = 3 + SPOT = 4 timeout: duration_pb2.Duration = proto.Field( proto.MESSAGE, diff --git a/google/cloud/aiplatform_v1/types/evaluation_service.py b/google/cloud/aiplatform_v1/types/evaluation_service.py new file mode 100644 index 0000000000..6b7d976bff --- /dev/null +++ b/google/cloud/aiplatform_v1/types/evaluation_service.py @@ -0,0 +1,3217 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "PairwiseChoice", + "EvaluateInstancesRequest", + "EvaluateInstancesResponse", + "ExactMatchInput", + "ExactMatchInstance", + "ExactMatchSpec", + "ExactMatchResults", + "ExactMatchMetricValue", + "BleuInput", + "BleuInstance", + "BleuSpec", + "BleuResults", + "BleuMetricValue", + "RougeInput", + "RougeInstance", + "RougeSpec", + "RougeResults", + "RougeMetricValue", + "CoherenceInput", + "CoherenceInstance", + "CoherenceSpec", + "CoherenceResult", + "FluencyInput", + "FluencyInstance", + "FluencySpec", + "FluencyResult", + "SafetyInput", + "SafetyInstance", + "SafetySpec", + "SafetyResult", + "GroundednessInput", + "GroundednessInstance", + "GroundednessSpec", + "GroundednessResult", + "FulfillmentInput", + "FulfillmentInstance", + "FulfillmentSpec", + "FulfillmentResult", + "SummarizationQualityInput", + "SummarizationQualityInstance", + "SummarizationQualitySpec", + "SummarizationQualityResult", + "PairwiseSummarizationQualityInput", + "PairwiseSummarizationQualityInstance", + "PairwiseSummarizationQualitySpec", + "PairwiseSummarizationQualityResult", + "SummarizationHelpfulnessInput", + "SummarizationHelpfulnessInstance", + "SummarizationHelpfulnessSpec", + "SummarizationHelpfulnessResult", + "SummarizationVerbosityInput", + "SummarizationVerbosityInstance", + "SummarizationVerbositySpec", + "SummarizationVerbosityResult", + "QuestionAnsweringQualityInput", + "QuestionAnsweringQualityInstance", + "QuestionAnsweringQualitySpec", + "QuestionAnsweringQualityResult", + "PairwiseQuestionAnsweringQualityInput", + "PairwiseQuestionAnsweringQualityInstance", + "PairwiseQuestionAnsweringQualitySpec", + "PairwiseQuestionAnsweringQualityResult", + "QuestionAnsweringRelevanceInput", + "QuestionAnsweringRelevanceInstance", + "QuestionAnsweringRelevanceSpec", + "QuestionAnsweringRelevanceResult", + "QuestionAnsweringHelpfulnessInput", + "QuestionAnsweringHelpfulnessInstance", + "QuestionAnsweringHelpfulnessSpec", + "QuestionAnsweringHelpfulnessResult", + "QuestionAnsweringCorrectnessInput", + "QuestionAnsweringCorrectnessInstance", + "QuestionAnsweringCorrectnessSpec", + "QuestionAnsweringCorrectnessResult", + "PointwiseMetricInput", + "PointwiseMetricInstance", + "PointwiseMetricSpec", + "PointwiseMetricResult", + "PairwiseMetricInput", + "PairwiseMetricInstance", + "PairwiseMetricSpec", + "PairwiseMetricResult", + "ToolCallValidInput", + "ToolCallValidSpec", + "ToolCallValidInstance", + "ToolCallValidResults", + "ToolCallValidMetricValue", + "ToolNameMatchInput", + "ToolNameMatchSpec", + "ToolNameMatchInstance", + "ToolNameMatchResults", + "ToolNameMatchMetricValue", + "ToolParameterKeyMatchInput", + "ToolParameterKeyMatchSpec", + "ToolParameterKeyMatchInstance", + "ToolParameterKeyMatchResults", + "ToolParameterKeyMatchMetricValue", + "ToolParameterKVMatchInput", + "ToolParameterKVMatchSpec", + "ToolParameterKVMatchInstance", + "ToolParameterKVMatchResults", + "ToolParameterKVMatchMetricValue", + }, +) + + +class PairwiseChoice(proto.Enum): + r"""Pairwise prediction autorater preference. + + Values: + PAIRWISE_CHOICE_UNSPECIFIED (0): + Unspecified prediction choice. + BASELINE (1): + Baseline prediction wins + CANDIDATE (2): + Candidate prediction wins + TIE (3): + Winner cannot be determined + """ + PAIRWISE_CHOICE_UNSPECIFIED = 0 + BASELINE = 1 + CANDIDATE = 2 + TIE = 3 + + +class EvaluateInstancesRequest(proto.Message): + r"""Request message for EvaluationService.EvaluateInstances. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + exact_match_input (google.cloud.aiplatform_v1.types.ExactMatchInput): + Auto metric instances. + Instances and metric spec for exact match + metric. + + This field is a member of `oneof`_ ``metric_inputs``. + bleu_input (google.cloud.aiplatform_v1.types.BleuInput): + Instances and metric spec for bleu metric. + + This field is a member of `oneof`_ ``metric_inputs``. + rouge_input (google.cloud.aiplatform_v1.types.RougeInput): + Instances and metric spec for rouge metric. + + This field is a member of `oneof`_ ``metric_inputs``. + fluency_input (google.cloud.aiplatform_v1.types.FluencyInput): + LLM-based metric instance. + General text generation metrics, applicable to + other categories. Input for fluency metric. + + This field is a member of `oneof`_ ``metric_inputs``. + coherence_input (google.cloud.aiplatform_v1.types.CoherenceInput): + Input for coherence metric. + + This field is a member of `oneof`_ ``metric_inputs``. + safety_input (google.cloud.aiplatform_v1.types.SafetyInput): + Input for safety metric. + + This field is a member of `oneof`_ ``metric_inputs``. + groundedness_input (google.cloud.aiplatform_v1.types.GroundednessInput): + Input for groundedness metric. + + This field is a member of `oneof`_ ``metric_inputs``. + fulfillment_input (google.cloud.aiplatform_v1.types.FulfillmentInput): + Input for fulfillment metric. + + This field is a member of `oneof`_ ``metric_inputs``. + summarization_quality_input (google.cloud.aiplatform_v1.types.SummarizationQualityInput): + Input for summarization quality metric. + + This field is a member of `oneof`_ ``metric_inputs``. + pairwise_summarization_quality_input (google.cloud.aiplatform_v1.types.PairwiseSummarizationQualityInput): + Input for pairwise summarization quality + metric. + + This field is a member of `oneof`_ ``metric_inputs``. + summarization_helpfulness_input (google.cloud.aiplatform_v1.types.SummarizationHelpfulnessInput): + Input for summarization helpfulness metric. + + This field is a member of `oneof`_ ``metric_inputs``. + summarization_verbosity_input (google.cloud.aiplatform_v1.types.SummarizationVerbosityInput): + Input for summarization verbosity metric. + + This field is a member of `oneof`_ ``metric_inputs``. + question_answering_quality_input (google.cloud.aiplatform_v1.types.QuestionAnsweringQualityInput): + Input for question answering quality metric. + + This field is a member of `oneof`_ ``metric_inputs``. + pairwise_question_answering_quality_input (google.cloud.aiplatform_v1.types.PairwiseQuestionAnsweringQualityInput): + Input for pairwise question answering quality + metric. + + This field is a member of `oneof`_ ``metric_inputs``. + question_answering_relevance_input (google.cloud.aiplatform_v1.types.QuestionAnsweringRelevanceInput): + Input for question answering relevance + metric. + + This field is a member of `oneof`_ ``metric_inputs``. + question_answering_helpfulness_input (google.cloud.aiplatform_v1.types.QuestionAnsweringHelpfulnessInput): + Input for question answering helpfulness + metric. + + This field is a member of `oneof`_ ``metric_inputs``. + question_answering_correctness_input (google.cloud.aiplatform_v1.types.QuestionAnsweringCorrectnessInput): + Input for question answering correctness + metric. + + This field is a member of `oneof`_ ``metric_inputs``. + pointwise_metric_input (google.cloud.aiplatform_v1.types.PointwiseMetricInput): + Input for pointwise metric. + + This field is a member of `oneof`_ ``metric_inputs``. + pairwise_metric_input (google.cloud.aiplatform_v1.types.PairwiseMetricInput): + Input for pairwise metric. + + This field is a member of `oneof`_ ``metric_inputs``. + tool_call_valid_input (google.cloud.aiplatform_v1.types.ToolCallValidInput): + Tool call metric instances. + Input for tool call valid metric. + + This field is a member of `oneof`_ ``metric_inputs``. + tool_name_match_input (google.cloud.aiplatform_v1.types.ToolNameMatchInput): + Input for tool name match metric. + + This field is a member of `oneof`_ ``metric_inputs``. + tool_parameter_key_match_input (google.cloud.aiplatform_v1.types.ToolParameterKeyMatchInput): + Input for tool parameter key match metric. + + This field is a member of `oneof`_ ``metric_inputs``. + tool_parameter_kv_match_input (google.cloud.aiplatform_v1.types.ToolParameterKVMatchInput): + Input for tool parameter key value match + metric. + + This field is a member of `oneof`_ ``metric_inputs``. + location (str): + Required. The resource name of the Location to evaluate the + instances. Format: + ``projects/{project}/locations/{location}`` + """ + + exact_match_input: "ExactMatchInput" = proto.Field( + proto.MESSAGE, + number=2, + oneof="metric_inputs", + message="ExactMatchInput", + ) + bleu_input: "BleuInput" = proto.Field( + proto.MESSAGE, + number=3, + oneof="metric_inputs", + message="BleuInput", + ) + rouge_input: "RougeInput" = proto.Field( + proto.MESSAGE, + number=4, + oneof="metric_inputs", + message="RougeInput", + ) + fluency_input: "FluencyInput" = proto.Field( + proto.MESSAGE, + number=5, + oneof="metric_inputs", + message="FluencyInput", + ) + coherence_input: "CoherenceInput" = proto.Field( + proto.MESSAGE, + number=6, + oneof="metric_inputs", + message="CoherenceInput", + ) + safety_input: "SafetyInput" = proto.Field( + proto.MESSAGE, + number=8, + oneof="metric_inputs", + message="SafetyInput", + ) + groundedness_input: "GroundednessInput" = proto.Field( + proto.MESSAGE, + number=9, + oneof="metric_inputs", + message="GroundednessInput", + ) + fulfillment_input: "FulfillmentInput" = proto.Field( + proto.MESSAGE, + number=12, + oneof="metric_inputs", + message="FulfillmentInput", + ) + summarization_quality_input: "SummarizationQualityInput" = proto.Field( + proto.MESSAGE, + number=7, + oneof="metric_inputs", + message="SummarizationQualityInput", + ) + pairwise_summarization_quality_input: "PairwiseSummarizationQualityInput" = ( + proto.Field( + proto.MESSAGE, + number=23, + oneof="metric_inputs", + message="PairwiseSummarizationQualityInput", + ) + ) + summarization_helpfulness_input: "SummarizationHelpfulnessInput" = proto.Field( + proto.MESSAGE, + number=14, + oneof="metric_inputs", + message="SummarizationHelpfulnessInput", + ) + summarization_verbosity_input: "SummarizationVerbosityInput" = proto.Field( + proto.MESSAGE, + number=15, + oneof="metric_inputs", + message="SummarizationVerbosityInput", + ) + question_answering_quality_input: "QuestionAnsweringQualityInput" = proto.Field( + proto.MESSAGE, + number=10, + oneof="metric_inputs", + message="QuestionAnsweringQualityInput", + ) + pairwise_question_answering_quality_input: "PairwiseQuestionAnsweringQualityInput" = proto.Field( + proto.MESSAGE, + number=24, + oneof="metric_inputs", + message="PairwiseQuestionAnsweringQualityInput", + ) + question_answering_relevance_input: "QuestionAnsweringRelevanceInput" = proto.Field( + proto.MESSAGE, + number=16, + oneof="metric_inputs", + message="QuestionAnsweringRelevanceInput", + ) + question_answering_helpfulness_input: "QuestionAnsweringHelpfulnessInput" = ( + proto.Field( + proto.MESSAGE, + number=17, + oneof="metric_inputs", + message="QuestionAnsweringHelpfulnessInput", + ) + ) + question_answering_correctness_input: "QuestionAnsweringCorrectnessInput" = ( + proto.Field( + proto.MESSAGE, + number=18, + oneof="metric_inputs", + message="QuestionAnsweringCorrectnessInput", + ) + ) + pointwise_metric_input: "PointwiseMetricInput" = proto.Field( + proto.MESSAGE, + number=28, + oneof="metric_inputs", + message="PointwiseMetricInput", + ) + pairwise_metric_input: "PairwiseMetricInput" = proto.Field( + proto.MESSAGE, + number=29, + oneof="metric_inputs", + message="PairwiseMetricInput", + ) + tool_call_valid_input: "ToolCallValidInput" = proto.Field( + proto.MESSAGE, + number=19, + oneof="metric_inputs", + message="ToolCallValidInput", + ) + tool_name_match_input: "ToolNameMatchInput" = proto.Field( + proto.MESSAGE, + number=20, + oneof="metric_inputs", + message="ToolNameMatchInput", + ) + tool_parameter_key_match_input: "ToolParameterKeyMatchInput" = proto.Field( + proto.MESSAGE, + number=21, + oneof="metric_inputs", + message="ToolParameterKeyMatchInput", + ) + tool_parameter_kv_match_input: "ToolParameterKVMatchInput" = proto.Field( + proto.MESSAGE, + number=22, + oneof="metric_inputs", + message="ToolParameterKVMatchInput", + ) + location: str = proto.Field( + proto.STRING, + number=1, + ) + + +class EvaluateInstancesResponse(proto.Message): + r"""Response message for EvaluationService.EvaluateInstances. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + exact_match_results (google.cloud.aiplatform_v1.types.ExactMatchResults): + Auto metric evaluation results. + Results for exact match metric. + + This field is a member of `oneof`_ ``evaluation_results``. + bleu_results (google.cloud.aiplatform_v1.types.BleuResults): + Results for bleu metric. + + This field is a member of `oneof`_ ``evaluation_results``. + rouge_results (google.cloud.aiplatform_v1.types.RougeResults): + Results for rouge metric. + + This field is a member of `oneof`_ ``evaluation_results``. + fluency_result (google.cloud.aiplatform_v1.types.FluencyResult): + LLM-based metric evaluation result. + General text generation metrics, applicable to + other categories. Result for fluency metric. + + This field is a member of `oneof`_ ``evaluation_results``. + coherence_result (google.cloud.aiplatform_v1.types.CoherenceResult): + Result for coherence metric. + + This field is a member of `oneof`_ ``evaluation_results``. + safety_result (google.cloud.aiplatform_v1.types.SafetyResult): + Result for safety metric. + + This field is a member of `oneof`_ ``evaluation_results``. + groundedness_result (google.cloud.aiplatform_v1.types.GroundednessResult): + Result for groundedness metric. + + This field is a member of `oneof`_ ``evaluation_results``. + fulfillment_result (google.cloud.aiplatform_v1.types.FulfillmentResult): + Result for fulfillment metric. + + This field is a member of `oneof`_ ``evaluation_results``. + summarization_quality_result (google.cloud.aiplatform_v1.types.SummarizationQualityResult): + Summarization only metrics. + Result for summarization quality metric. + + This field is a member of `oneof`_ ``evaluation_results``. + pairwise_summarization_quality_result (google.cloud.aiplatform_v1.types.PairwiseSummarizationQualityResult): + Result for pairwise summarization quality + metric. + + This field is a member of `oneof`_ ``evaluation_results``. + summarization_helpfulness_result (google.cloud.aiplatform_v1.types.SummarizationHelpfulnessResult): + Result for summarization helpfulness metric. + + This field is a member of `oneof`_ ``evaluation_results``. + summarization_verbosity_result (google.cloud.aiplatform_v1.types.SummarizationVerbosityResult): + Result for summarization verbosity metric. + + This field is a member of `oneof`_ ``evaluation_results``. + question_answering_quality_result (google.cloud.aiplatform_v1.types.QuestionAnsweringQualityResult): + Question answering only metrics. + Result for question answering quality metric. + + This field is a member of `oneof`_ ``evaluation_results``. + pairwise_question_answering_quality_result (google.cloud.aiplatform_v1.types.PairwiseQuestionAnsweringQualityResult): + Result for pairwise question answering + quality metric. + + This field is a member of `oneof`_ ``evaluation_results``. + question_answering_relevance_result (google.cloud.aiplatform_v1.types.QuestionAnsweringRelevanceResult): + Result for question answering relevance + metric. + + This field is a member of `oneof`_ ``evaluation_results``. + question_answering_helpfulness_result (google.cloud.aiplatform_v1.types.QuestionAnsweringHelpfulnessResult): + Result for question answering helpfulness + metric. + + This field is a member of `oneof`_ ``evaluation_results``. + question_answering_correctness_result (google.cloud.aiplatform_v1.types.QuestionAnsweringCorrectnessResult): + Result for question answering correctness + metric. + + This field is a member of `oneof`_ ``evaluation_results``. + pointwise_metric_result (google.cloud.aiplatform_v1.types.PointwiseMetricResult): + Generic metrics. + Result for pointwise metric. + + This field is a member of `oneof`_ ``evaluation_results``. + pairwise_metric_result (google.cloud.aiplatform_v1.types.PairwiseMetricResult): + Result for pairwise metric. + + This field is a member of `oneof`_ ``evaluation_results``. + tool_call_valid_results (google.cloud.aiplatform_v1.types.ToolCallValidResults): + Tool call metrics. + Results for tool call valid metric. + + This field is a member of `oneof`_ ``evaluation_results``. + tool_name_match_results (google.cloud.aiplatform_v1.types.ToolNameMatchResults): + Results for tool name match metric. + + This field is a member of `oneof`_ ``evaluation_results``. + tool_parameter_key_match_results (google.cloud.aiplatform_v1.types.ToolParameterKeyMatchResults): + Results for tool parameter key match metric. + + This field is a member of `oneof`_ ``evaluation_results``. + tool_parameter_kv_match_results (google.cloud.aiplatform_v1.types.ToolParameterKVMatchResults): + Results for tool parameter key value match + metric. + + This field is a member of `oneof`_ ``evaluation_results``. + """ + + exact_match_results: "ExactMatchResults" = proto.Field( + proto.MESSAGE, + number=1, + oneof="evaluation_results", + message="ExactMatchResults", + ) + bleu_results: "BleuResults" = proto.Field( + proto.MESSAGE, + number=2, + oneof="evaluation_results", + message="BleuResults", + ) + rouge_results: "RougeResults" = proto.Field( + proto.MESSAGE, + number=3, + oneof="evaluation_results", + message="RougeResults", + ) + fluency_result: "FluencyResult" = proto.Field( + proto.MESSAGE, + number=4, + oneof="evaluation_results", + message="FluencyResult", + ) + coherence_result: "CoherenceResult" = proto.Field( + proto.MESSAGE, + number=5, + oneof="evaluation_results", + message="CoherenceResult", + ) + safety_result: "SafetyResult" = proto.Field( + proto.MESSAGE, + number=7, + oneof="evaluation_results", + message="SafetyResult", + ) + groundedness_result: "GroundednessResult" = proto.Field( + proto.MESSAGE, + number=8, + oneof="evaluation_results", + message="GroundednessResult", + ) + fulfillment_result: "FulfillmentResult" = proto.Field( + proto.MESSAGE, + number=11, + oneof="evaluation_results", + message="FulfillmentResult", + ) + summarization_quality_result: "SummarizationQualityResult" = proto.Field( + proto.MESSAGE, + number=6, + oneof="evaluation_results", + message="SummarizationQualityResult", + ) + pairwise_summarization_quality_result: "PairwiseSummarizationQualityResult" = ( + proto.Field( + proto.MESSAGE, + number=22, + oneof="evaluation_results", + message="PairwiseSummarizationQualityResult", + ) + ) + summarization_helpfulness_result: "SummarizationHelpfulnessResult" = proto.Field( + proto.MESSAGE, + number=13, + oneof="evaluation_results", + message="SummarizationHelpfulnessResult", + ) + summarization_verbosity_result: "SummarizationVerbosityResult" = proto.Field( + proto.MESSAGE, + number=14, + oneof="evaluation_results", + message="SummarizationVerbosityResult", + ) + question_answering_quality_result: "QuestionAnsweringQualityResult" = proto.Field( + proto.MESSAGE, + number=9, + oneof="evaluation_results", + message="QuestionAnsweringQualityResult", + ) + pairwise_question_answering_quality_result: "PairwiseQuestionAnsweringQualityResult" = proto.Field( + proto.MESSAGE, + number=23, + oneof="evaluation_results", + message="PairwiseQuestionAnsweringQualityResult", + ) + question_answering_relevance_result: "QuestionAnsweringRelevanceResult" = ( + proto.Field( + proto.MESSAGE, + number=15, + oneof="evaluation_results", + message="QuestionAnsweringRelevanceResult", + ) + ) + question_answering_helpfulness_result: "QuestionAnsweringHelpfulnessResult" = ( + proto.Field( + proto.MESSAGE, + number=16, + oneof="evaluation_results", + message="QuestionAnsweringHelpfulnessResult", + ) + ) + question_answering_correctness_result: "QuestionAnsweringCorrectnessResult" = ( + proto.Field( + proto.MESSAGE, + number=17, + oneof="evaluation_results", + message="QuestionAnsweringCorrectnessResult", + ) + ) + pointwise_metric_result: "PointwiseMetricResult" = proto.Field( + proto.MESSAGE, + number=27, + oneof="evaluation_results", + message="PointwiseMetricResult", + ) + pairwise_metric_result: "PairwiseMetricResult" = proto.Field( + proto.MESSAGE, + number=28, + oneof="evaluation_results", + message="PairwiseMetricResult", + ) + tool_call_valid_results: "ToolCallValidResults" = proto.Field( + proto.MESSAGE, + number=18, + oneof="evaluation_results", + message="ToolCallValidResults", + ) + tool_name_match_results: "ToolNameMatchResults" = proto.Field( + proto.MESSAGE, + number=19, + oneof="evaluation_results", + message="ToolNameMatchResults", + ) + tool_parameter_key_match_results: "ToolParameterKeyMatchResults" = proto.Field( + proto.MESSAGE, + number=20, + oneof="evaluation_results", + message="ToolParameterKeyMatchResults", + ) + tool_parameter_kv_match_results: "ToolParameterKVMatchResults" = proto.Field( + proto.MESSAGE, + number=21, + oneof="evaluation_results", + message="ToolParameterKVMatchResults", + ) + + +class ExactMatchInput(proto.Message): + r"""Input for exact match metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.ExactMatchSpec): + Required. Spec for exact match metric. + instances (MutableSequence[google.cloud.aiplatform_v1.types.ExactMatchInstance]): + Required. Repeated exact match instances. + """ + + metric_spec: "ExactMatchSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="ExactMatchSpec", + ) + instances: MutableSequence["ExactMatchInstance"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="ExactMatchInstance", + ) + + +class ExactMatchInstance(proto.Message): + r"""Spec for exact match instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Required. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class ExactMatchSpec(proto.Message): + r"""Spec for exact match metric - returns 1 if prediction and + reference exactly matches, otherwise 0. + + """ + + +class ExactMatchResults(proto.Message): + r"""Results for exact match metric. + + Attributes: + exact_match_metric_values (MutableSequence[google.cloud.aiplatform_v1.types.ExactMatchMetricValue]): + Output only. Exact match metric values. + """ + + exact_match_metric_values: MutableSequence[ + "ExactMatchMetricValue" + ] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="ExactMatchMetricValue", + ) + + +class ExactMatchMetricValue(proto.Message): + r"""Exact match metric value for an instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Exact match score. + + This field is a member of `oneof`_ ``_score``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + + +class BleuInput(proto.Message): + r"""Input for bleu metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.BleuSpec): + Required. Spec for bleu score metric. + instances (MutableSequence[google.cloud.aiplatform_v1.types.BleuInstance]): + Required. Repeated bleu instances. + """ + + metric_spec: "BleuSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="BleuSpec", + ) + instances: MutableSequence["BleuInstance"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="BleuInstance", + ) + + +class BleuInstance(proto.Message): + r"""Spec for bleu instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Required. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class BleuSpec(proto.Message): + r"""Spec for bleu score metric - calculates the precision of + n-grams in the prediction as compared to reference - returns a + score ranging between 0 to 1. + + Attributes: + use_effective_order (bool): + Optional. Whether to use_effective_order to compute bleu + score. + """ + + use_effective_order: bool = proto.Field( + proto.BOOL, + number=1, + ) + + +class BleuResults(proto.Message): + r"""Results for bleu metric. + + Attributes: + bleu_metric_values (MutableSequence[google.cloud.aiplatform_v1.types.BleuMetricValue]): + Output only. Bleu metric values. + """ + + bleu_metric_values: MutableSequence["BleuMetricValue"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="BleuMetricValue", + ) + + +class BleuMetricValue(proto.Message): + r"""Bleu metric value for an instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Bleu score. + + This field is a member of `oneof`_ ``_score``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + + +class RougeInput(proto.Message): + r"""Input for rouge metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.RougeSpec): + Required. Spec for rouge score metric. + instances (MutableSequence[google.cloud.aiplatform_v1.types.RougeInstance]): + Required. Repeated rouge instances. + """ + + metric_spec: "RougeSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="RougeSpec", + ) + instances: MutableSequence["RougeInstance"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="RougeInstance", + ) + + +class RougeInstance(proto.Message): + r"""Spec for rouge instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Required. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class RougeSpec(proto.Message): + r"""Spec for rouge score metric - calculates the recall of + n-grams in prediction as compared to reference - returns a score + ranging between 0 and 1. + + Attributes: + rouge_type (str): + Optional. Supported rouge types are rougen[1-9], rougeL, and + rougeLsum. + use_stemmer (bool): + Optional. Whether to use stemmer to compute + rouge score. + split_summaries (bool): + Optional. Whether to split summaries while + using rougeLsum. + """ + + rouge_type: str = proto.Field( + proto.STRING, + number=1, + ) + use_stemmer: bool = proto.Field( + proto.BOOL, + number=2, + ) + split_summaries: bool = proto.Field( + proto.BOOL, + number=3, + ) + + +class RougeResults(proto.Message): + r"""Results for rouge metric. + + Attributes: + rouge_metric_values (MutableSequence[google.cloud.aiplatform_v1.types.RougeMetricValue]): + Output only. Rouge metric values. + """ + + rouge_metric_values: MutableSequence["RougeMetricValue"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="RougeMetricValue", + ) + + +class RougeMetricValue(proto.Message): + r"""Rouge metric value for an instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Rouge score. + + This field is a member of `oneof`_ ``_score``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + + +class CoherenceInput(proto.Message): + r"""Input for coherence metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.CoherenceSpec): + Required. Spec for coherence score metric. + instance (google.cloud.aiplatform_v1.types.CoherenceInstance): + Required. Coherence instance. + """ + + metric_spec: "CoherenceSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="CoherenceSpec", + ) + instance: "CoherenceInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="CoherenceInstance", + ) + + +class CoherenceInstance(proto.Message): + r"""Spec for coherence instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + + +class CoherenceSpec(proto.Message): + r"""Spec for coherence score metric. + + Attributes: + version (int): + Optional. Which version to use for + evaluation. + """ + + version: int = proto.Field( + proto.INT32, + number=1, + ) + + +class CoherenceResult(proto.Message): + r"""Spec for coherence result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Coherence score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for coherence score. + confidence (float): + Output only. Confidence for coherence score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class FluencyInput(proto.Message): + r"""Input for fluency metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.FluencySpec): + Required. Spec for fluency score metric. + instance (google.cloud.aiplatform_v1.types.FluencyInstance): + Required. Fluency instance. + """ + + metric_spec: "FluencySpec" = proto.Field( + proto.MESSAGE, + number=1, + message="FluencySpec", + ) + instance: "FluencyInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="FluencyInstance", + ) + + +class FluencyInstance(proto.Message): + r"""Spec for fluency instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + + +class FluencySpec(proto.Message): + r"""Spec for fluency score metric. + + Attributes: + version (int): + Optional. Which version to use for + evaluation. + """ + + version: int = proto.Field( + proto.INT32, + number=1, + ) + + +class FluencyResult(proto.Message): + r"""Spec for fluency result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Fluency score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for fluency score. + confidence (float): + Output only. Confidence for fluency score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class SafetyInput(proto.Message): + r"""Input for safety metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.SafetySpec): + Required. Spec for safety metric. + instance (google.cloud.aiplatform_v1.types.SafetyInstance): + Required. Safety instance. + """ + + metric_spec: "SafetySpec" = proto.Field( + proto.MESSAGE, + number=1, + message="SafetySpec", + ) + instance: "SafetyInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="SafetyInstance", + ) + + +class SafetyInstance(proto.Message): + r"""Spec for safety instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + + +class SafetySpec(proto.Message): + r"""Spec for safety metric. + + Attributes: + version (int): + Optional. Which version to use for + evaluation. + """ + + version: int = proto.Field( + proto.INT32, + number=1, + ) + + +class SafetyResult(proto.Message): + r"""Spec for safety result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Safety score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for safety score. + confidence (float): + Output only. Confidence for safety score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class GroundednessInput(proto.Message): + r"""Input for groundedness metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.GroundednessSpec): + Required. Spec for groundedness metric. + instance (google.cloud.aiplatform_v1.types.GroundednessInstance): + Required. Groundedness instance. + """ + + metric_spec: "GroundednessSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="GroundednessSpec", + ) + instance: "GroundednessInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="GroundednessInstance", + ) + + +class GroundednessInstance(proto.Message): + r"""Spec for groundedness instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + context (str): + Required. Background information provided in + context used to compare against the prediction. + + This field is a member of `oneof`_ ``_context``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + context: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class GroundednessSpec(proto.Message): + r"""Spec for groundedness metric. + + Attributes: + version (int): + Optional. Which version to use for + evaluation. + """ + + version: int = proto.Field( + proto.INT32, + number=1, + ) + + +class GroundednessResult(proto.Message): + r"""Spec for groundedness result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Groundedness score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for groundedness + score. + confidence (float): + Output only. Confidence for groundedness + score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class FulfillmentInput(proto.Message): + r"""Input for fulfillment metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.FulfillmentSpec): + Required. Spec for fulfillment score metric. + instance (google.cloud.aiplatform_v1.types.FulfillmentInstance): + Required. Fulfillment instance. + """ + + metric_spec: "FulfillmentSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="FulfillmentSpec", + ) + instance: "FulfillmentInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="FulfillmentInstance", + ) + + +class FulfillmentInstance(proto.Message): + r"""Spec for fulfillment instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + instruction (str): + Required. Inference instruction prompt to + compare prediction with. + + This field is a member of `oneof`_ ``_instruction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + instruction: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class FulfillmentSpec(proto.Message): + r"""Spec for fulfillment metric. + + Attributes: + version (int): + Optional. Which version to use for + evaluation. + """ + + version: int = proto.Field( + proto.INT32, + number=1, + ) + + +class FulfillmentResult(proto.Message): + r"""Spec for fulfillment result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Fulfillment score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for fulfillment + score. + confidence (float): + Output only. Confidence for fulfillment + score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class SummarizationQualityInput(proto.Message): + r"""Input for summarization quality metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.SummarizationQualitySpec): + Required. Spec for summarization quality + score metric. + instance (google.cloud.aiplatform_v1.types.SummarizationQualityInstance): + Required. Summarization quality instance. + """ + + metric_spec: "SummarizationQualitySpec" = proto.Field( + proto.MESSAGE, + number=1, + message="SummarizationQualitySpec", + ) + instance: "SummarizationQualityInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="SummarizationQualityInstance", + ) + + +class SummarizationQualityInstance(proto.Message): + r"""Spec for summarization quality instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Optional. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + context (str): + Required. Text to be summarized. + + This field is a member of `oneof`_ ``_context``. + instruction (str): + Required. Summarization prompt for LLM. + + This field is a member of `oneof`_ ``_instruction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + context: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + instruction: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + + +class SummarizationQualitySpec(proto.Message): + r"""Spec for summarization quality score metric. + + Attributes: + use_reference (bool): + Optional. Whether to use instance.reference + to compute summarization quality. + version (int): + Optional. Which version to use for + evaluation. + """ + + use_reference: bool = proto.Field( + proto.BOOL, + number=1, + ) + version: int = proto.Field( + proto.INT32, + number=2, + ) + + +class SummarizationQualityResult(proto.Message): + r"""Spec for summarization quality result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Summarization Quality score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for summarization + quality score. + confidence (float): + Output only. Confidence for summarization + quality score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class PairwiseSummarizationQualityInput(proto.Message): + r"""Input for pairwise summarization quality metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.PairwiseSummarizationQualitySpec): + Required. Spec for pairwise summarization + quality score metric. + instance (google.cloud.aiplatform_v1.types.PairwiseSummarizationQualityInstance): + Required. Pairwise summarization quality + instance. + """ + + metric_spec: "PairwiseSummarizationQualitySpec" = proto.Field( + proto.MESSAGE, + number=1, + message="PairwiseSummarizationQualitySpec", + ) + instance: "PairwiseSummarizationQualityInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="PairwiseSummarizationQualityInstance", + ) + + +class PairwiseSummarizationQualityInstance(proto.Message): + r"""Spec for pairwise summarization quality instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the candidate model. + + This field is a member of `oneof`_ ``_prediction``. + baseline_prediction (str): + Required. Output of the baseline model. + + This field is a member of `oneof`_ ``_baseline_prediction``. + reference (str): + Optional. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + context (str): + Required. Text to be summarized. + + This field is a member of `oneof`_ ``_context``. + instruction (str): + Required. Summarization prompt for LLM. + + This field is a member of `oneof`_ ``_instruction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + baseline_prediction: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + context: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + instruction: str = proto.Field( + proto.STRING, + number=5, + optional=True, + ) + + +class PairwiseSummarizationQualitySpec(proto.Message): + r"""Spec for pairwise summarization quality score metric. + + Attributes: + use_reference (bool): + Optional. Whether to use instance.reference + to compute pairwise summarization quality. + version (int): + Optional. Which version to use for + evaluation. + """ + + use_reference: bool = proto.Field( + proto.BOOL, + number=1, + ) + version: int = proto.Field( + proto.INT32, + number=2, + ) + + +class PairwiseSummarizationQualityResult(proto.Message): + r"""Spec for pairwise summarization quality result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + pairwise_choice (google.cloud.aiplatform_v1.types.PairwiseChoice): + Output only. Pairwise summarization + prediction choice. + explanation (str): + Output only. Explanation for summarization + quality score. + confidence (float): + Output only. Confidence for summarization + quality score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + pairwise_choice: "PairwiseChoice" = proto.Field( + proto.ENUM, + number=1, + enum="PairwiseChoice", + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class SummarizationHelpfulnessInput(proto.Message): + r"""Input for summarization helpfulness metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.SummarizationHelpfulnessSpec): + Required. Spec for summarization helpfulness + score metric. + instance (google.cloud.aiplatform_v1.types.SummarizationHelpfulnessInstance): + Required. Summarization helpfulness instance. + """ + + metric_spec: "SummarizationHelpfulnessSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="SummarizationHelpfulnessSpec", + ) + instance: "SummarizationHelpfulnessInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="SummarizationHelpfulnessInstance", + ) + + +class SummarizationHelpfulnessInstance(proto.Message): + r"""Spec for summarization helpfulness instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Optional. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + context (str): + Required. Text to be summarized. + + This field is a member of `oneof`_ ``_context``. + instruction (str): + Optional. Summarization prompt for LLM. + + This field is a member of `oneof`_ ``_instruction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + context: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + instruction: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + + +class SummarizationHelpfulnessSpec(proto.Message): + r"""Spec for summarization helpfulness score metric. + + Attributes: + use_reference (bool): + Optional. Whether to use instance.reference + to compute summarization helpfulness. + version (int): + Optional. Which version to use for + evaluation. + """ + + use_reference: bool = proto.Field( + proto.BOOL, + number=1, + ) + version: int = proto.Field( + proto.INT32, + number=2, + ) + + +class SummarizationHelpfulnessResult(proto.Message): + r"""Spec for summarization helpfulness result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Summarization Helpfulness score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for summarization + helpfulness score. + confidence (float): + Output only. Confidence for summarization + helpfulness score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class SummarizationVerbosityInput(proto.Message): + r"""Input for summarization verbosity metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.SummarizationVerbositySpec): + Required. Spec for summarization verbosity + score metric. + instance (google.cloud.aiplatform_v1.types.SummarizationVerbosityInstance): + Required. Summarization verbosity instance. + """ + + metric_spec: "SummarizationVerbositySpec" = proto.Field( + proto.MESSAGE, + number=1, + message="SummarizationVerbositySpec", + ) + instance: "SummarizationVerbosityInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="SummarizationVerbosityInstance", + ) + + +class SummarizationVerbosityInstance(proto.Message): + r"""Spec for summarization verbosity instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Optional. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + context (str): + Required. Text to be summarized. + + This field is a member of `oneof`_ ``_context``. + instruction (str): + Optional. Summarization prompt for LLM. + + This field is a member of `oneof`_ ``_instruction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + context: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + instruction: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + + +class SummarizationVerbositySpec(proto.Message): + r"""Spec for summarization verbosity score metric. + + Attributes: + use_reference (bool): + Optional. Whether to use instance.reference + to compute summarization verbosity. + version (int): + Optional. Which version to use for + evaluation. + """ + + use_reference: bool = proto.Field( + proto.BOOL, + number=1, + ) + version: int = proto.Field( + proto.INT32, + number=2, + ) + + +class SummarizationVerbosityResult(proto.Message): + r"""Spec for summarization verbosity result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Summarization Verbosity score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for summarization + verbosity score. + confidence (float): + Output only. Confidence for summarization + verbosity score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class QuestionAnsweringQualityInput(proto.Message): + r"""Input for question answering quality metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.QuestionAnsweringQualitySpec): + Required. Spec for question answering quality + score metric. + instance (google.cloud.aiplatform_v1.types.QuestionAnsweringQualityInstance): + Required. Question answering quality + instance. + """ + + metric_spec: "QuestionAnsweringQualitySpec" = proto.Field( + proto.MESSAGE, + number=1, + message="QuestionAnsweringQualitySpec", + ) + instance: "QuestionAnsweringQualityInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="QuestionAnsweringQualityInstance", + ) + + +class QuestionAnsweringQualityInstance(proto.Message): + r"""Spec for question answering quality instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Optional. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + context (str): + Required. Text to answer the question. + + This field is a member of `oneof`_ ``_context``. + instruction (str): + Required. Question Answering prompt for LLM. + + This field is a member of `oneof`_ ``_instruction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + context: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + instruction: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + + +class QuestionAnsweringQualitySpec(proto.Message): + r"""Spec for question answering quality score metric. + + Attributes: + use_reference (bool): + Optional. Whether to use instance.reference + to compute question answering quality. + version (int): + Optional. Which version to use for + evaluation. + """ + + use_reference: bool = proto.Field( + proto.BOOL, + number=1, + ) + version: int = proto.Field( + proto.INT32, + number=2, + ) + + +class QuestionAnsweringQualityResult(proto.Message): + r"""Spec for question answering quality result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Question Answering Quality + score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for question + answering quality score. + confidence (float): + Output only. Confidence for question + answering quality score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class PairwiseQuestionAnsweringQualityInput(proto.Message): + r"""Input for pairwise question answering quality metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.PairwiseQuestionAnsweringQualitySpec): + Required. Spec for pairwise question + answering quality score metric. + instance (google.cloud.aiplatform_v1.types.PairwiseQuestionAnsweringQualityInstance): + Required. Pairwise question answering quality + instance. + """ + + metric_spec: "PairwiseQuestionAnsweringQualitySpec" = proto.Field( + proto.MESSAGE, + number=1, + message="PairwiseQuestionAnsweringQualitySpec", + ) + instance: "PairwiseQuestionAnsweringQualityInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="PairwiseQuestionAnsweringQualityInstance", + ) + + +class PairwiseQuestionAnsweringQualityInstance(proto.Message): + r"""Spec for pairwise question answering quality instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the candidate model. + + This field is a member of `oneof`_ ``_prediction``. + baseline_prediction (str): + Required. Output of the baseline model. + + This field is a member of `oneof`_ ``_baseline_prediction``. + reference (str): + Optional. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + context (str): + Required. Text to answer the question. + + This field is a member of `oneof`_ ``_context``. + instruction (str): + Required. Question Answering prompt for LLM. + + This field is a member of `oneof`_ ``_instruction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + baseline_prediction: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + context: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + instruction: str = proto.Field( + proto.STRING, + number=5, + optional=True, + ) + + +class PairwiseQuestionAnsweringQualitySpec(proto.Message): + r"""Spec for pairwise question answering quality score metric. + + Attributes: + use_reference (bool): + Optional. Whether to use instance.reference + to compute question answering quality. + version (int): + Optional. Which version to use for + evaluation. + """ + + use_reference: bool = proto.Field( + proto.BOOL, + number=1, + ) + version: int = proto.Field( + proto.INT32, + number=2, + ) + + +class PairwiseQuestionAnsweringQualityResult(proto.Message): + r"""Spec for pairwise question answering quality result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + pairwise_choice (google.cloud.aiplatform_v1.types.PairwiseChoice): + Output only. Pairwise question answering + prediction choice. + explanation (str): + Output only. Explanation for question + answering quality score. + confidence (float): + Output only. Confidence for question + answering quality score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + pairwise_choice: "PairwiseChoice" = proto.Field( + proto.ENUM, + number=1, + enum="PairwiseChoice", + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class QuestionAnsweringRelevanceInput(proto.Message): + r"""Input for question answering relevance metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.QuestionAnsweringRelevanceSpec): + Required. Spec for question answering + relevance score metric. + instance (google.cloud.aiplatform_v1.types.QuestionAnsweringRelevanceInstance): + Required. Question answering relevance + instance. + """ + + metric_spec: "QuestionAnsweringRelevanceSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="QuestionAnsweringRelevanceSpec", + ) + instance: "QuestionAnsweringRelevanceInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="QuestionAnsweringRelevanceInstance", + ) + + +class QuestionAnsweringRelevanceInstance(proto.Message): + r"""Spec for question answering relevance instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Optional. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + context (str): + Optional. Text provided as context to answer + the question. + + This field is a member of `oneof`_ ``_context``. + instruction (str): + Required. The question asked and other + instruction in the inference prompt. + + This field is a member of `oneof`_ ``_instruction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + context: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + instruction: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + + +class QuestionAnsweringRelevanceSpec(proto.Message): + r"""Spec for question answering relevance metric. + + Attributes: + use_reference (bool): + Optional. Whether to use instance.reference + to compute question answering relevance. + version (int): + Optional. Which version to use for + evaluation. + """ + + use_reference: bool = proto.Field( + proto.BOOL, + number=1, + ) + version: int = proto.Field( + proto.INT32, + number=2, + ) + + +class QuestionAnsweringRelevanceResult(proto.Message): + r"""Spec for question answering relevance result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Question Answering Relevance + score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for question + answering relevance score. + confidence (float): + Output only. Confidence for question + answering relevance score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class QuestionAnsweringHelpfulnessInput(proto.Message): + r"""Input for question answering helpfulness metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.QuestionAnsweringHelpfulnessSpec): + Required. Spec for question answering + helpfulness score metric. + instance (google.cloud.aiplatform_v1.types.QuestionAnsweringHelpfulnessInstance): + Required. Question answering helpfulness + instance. + """ + + metric_spec: "QuestionAnsweringHelpfulnessSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="QuestionAnsweringHelpfulnessSpec", + ) + instance: "QuestionAnsweringHelpfulnessInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="QuestionAnsweringHelpfulnessInstance", + ) + + +class QuestionAnsweringHelpfulnessInstance(proto.Message): + r"""Spec for question answering helpfulness instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Optional. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + context (str): + Optional. Text provided as context to answer + the question. + + This field is a member of `oneof`_ ``_context``. + instruction (str): + Required. The question asked and other + instruction in the inference prompt. + + This field is a member of `oneof`_ ``_instruction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + context: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + instruction: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + + +class QuestionAnsweringHelpfulnessSpec(proto.Message): + r"""Spec for question answering helpfulness metric. + + Attributes: + use_reference (bool): + Optional. Whether to use instance.reference + to compute question answering helpfulness. + version (int): + Optional. Which version to use for + evaluation. + """ + + use_reference: bool = proto.Field( + proto.BOOL, + number=1, + ) + version: int = proto.Field( + proto.INT32, + number=2, + ) + + +class QuestionAnsweringHelpfulnessResult(proto.Message): + r"""Spec for question answering helpfulness result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Question Answering Helpfulness + score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for question + answering helpfulness score. + confidence (float): + Output only. Confidence for question + answering helpfulness score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class QuestionAnsweringCorrectnessInput(proto.Message): + r"""Input for question answering correctness metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.QuestionAnsweringCorrectnessSpec): + Required. Spec for question answering + correctness score metric. + instance (google.cloud.aiplatform_v1.types.QuestionAnsweringCorrectnessInstance): + Required. Question answering correctness + instance. + """ + + metric_spec: "QuestionAnsweringCorrectnessSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="QuestionAnsweringCorrectnessSpec", + ) + instance: "QuestionAnsweringCorrectnessInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="QuestionAnsweringCorrectnessInstance", + ) + + +class QuestionAnsweringCorrectnessInstance(proto.Message): + r"""Spec for question answering correctness instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Optional. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + context (str): + Optional. Text provided as context to answer + the question. + + This field is a member of `oneof`_ ``_context``. + instruction (str): + Required. The question asked and other + instruction in the inference prompt. + + This field is a member of `oneof`_ ``_instruction``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + context: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + instruction: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + + +class QuestionAnsweringCorrectnessSpec(proto.Message): + r"""Spec for question answering correctness metric. + + Attributes: + use_reference (bool): + Optional. Whether to use instance.reference + to compute question answering correctness. + version (int): + Optional. Which version to use for + evaluation. + """ + + use_reference: bool = proto.Field( + proto.BOOL, + number=1, + ) + version: int = proto.Field( + proto.INT32, + number=2, + ) + + +class QuestionAnsweringCorrectnessResult(proto.Message): + r"""Spec for question answering correctness result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Question Answering Correctness + score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for question + answering correctness score. + confidence (float): + Output only. Confidence for question + answering correctness score. + + This field is a member of `oneof`_ ``_confidence``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + confidence: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + + +class PointwiseMetricInput(proto.Message): + r"""Input for pointwise metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.PointwiseMetricSpec): + Required. Spec for pointwise metric. + instance (google.cloud.aiplatform_v1.types.PointwiseMetricInstance): + Required. Pointwise metric instance. + """ + + metric_spec: "PointwiseMetricSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="PointwiseMetricSpec", + ) + instance: "PointwiseMetricInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="PointwiseMetricInstance", + ) + + +class PointwiseMetricInstance(proto.Message): + r"""Pointwise metric instance. Usually one instance corresponds + to one row in an evaluation dataset. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + json_instance (str): + Instance specified as a json string. String key-value pairs + are expected in the json_instance to render + PointwiseMetricSpec.instance_prompt_template. + + This field is a member of `oneof`_ ``instance``. + """ + + json_instance: str = proto.Field( + proto.STRING, + number=1, + oneof="instance", + ) + + +class PointwiseMetricSpec(proto.Message): + r"""Spec for pointwise metric. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + metric_prompt_template (str): + Required. Metric prompt template for + pointwise metric. + + This field is a member of `oneof`_ ``_metric_prompt_template``. + """ + + metric_prompt_template: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + + +class PointwiseMetricResult(proto.Message): + r"""Spec for pointwise metric result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Pointwise metric score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for pointwise metric + score. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + + +class PairwiseMetricInput(proto.Message): + r"""Input for pairwise metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.PairwiseMetricSpec): + Required. Spec for pairwise metric. + instance (google.cloud.aiplatform_v1.types.PairwiseMetricInstance): + Required. Pairwise metric instance. + """ + + metric_spec: "PairwiseMetricSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="PairwiseMetricSpec", + ) + instance: "PairwiseMetricInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="PairwiseMetricInstance", + ) + + +class PairwiseMetricInstance(proto.Message): + r"""Pairwise metric instance. Usually one instance corresponds to + one row in an evaluation dataset. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + json_instance (str): + Instance specified as a json string. String key-value pairs + are expected in the json_instance to render + PairwiseMetricSpec.instance_prompt_template. + + This field is a member of `oneof`_ ``instance``. + """ + + json_instance: str = proto.Field( + proto.STRING, + number=1, + oneof="instance", + ) + + +class PairwiseMetricSpec(proto.Message): + r"""Spec for pairwise metric. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + metric_prompt_template (str): + Required. Metric prompt template for pairwise + metric. + + This field is a member of `oneof`_ ``_metric_prompt_template``. + """ + + metric_prompt_template: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + + +class PairwiseMetricResult(proto.Message): + r"""Spec for pairwise metric result. + + Attributes: + pairwise_choice (google.cloud.aiplatform_v1.types.PairwiseChoice): + Output only. Pairwise metric choice. + explanation (str): + Output only. Explanation for pairwise metric + score. + """ + + pairwise_choice: "PairwiseChoice" = proto.Field( + proto.ENUM, + number=1, + enum="PairwiseChoice", + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + + +class ToolCallValidInput(proto.Message): + r"""Input for tool call valid metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.ToolCallValidSpec): + Required. Spec for tool call valid metric. + instances (MutableSequence[google.cloud.aiplatform_v1.types.ToolCallValidInstance]): + Required. Repeated tool call valid instances. + """ + + metric_spec: "ToolCallValidSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="ToolCallValidSpec", + ) + instances: MutableSequence["ToolCallValidInstance"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="ToolCallValidInstance", + ) + + +class ToolCallValidSpec(proto.Message): + r"""Spec for tool call valid metric.""" + + +class ToolCallValidInstance(proto.Message): + r"""Spec for tool call valid instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Required. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class ToolCallValidResults(proto.Message): + r"""Results for tool call valid metric. + + Attributes: + tool_call_valid_metric_values (MutableSequence[google.cloud.aiplatform_v1.types.ToolCallValidMetricValue]): + Output only. Tool call valid metric values. + """ + + tool_call_valid_metric_values: MutableSequence[ + "ToolCallValidMetricValue" + ] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="ToolCallValidMetricValue", + ) + + +class ToolCallValidMetricValue(proto.Message): + r"""Tool call valid metric value for an instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Tool call valid score. + + This field is a member of `oneof`_ ``_score``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + + +class ToolNameMatchInput(proto.Message): + r"""Input for tool name match metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.ToolNameMatchSpec): + Required. Spec for tool name match metric. + instances (MutableSequence[google.cloud.aiplatform_v1.types.ToolNameMatchInstance]): + Required. Repeated tool name match instances. + """ + + metric_spec: "ToolNameMatchSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="ToolNameMatchSpec", + ) + instances: MutableSequence["ToolNameMatchInstance"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="ToolNameMatchInstance", + ) + + +class ToolNameMatchSpec(proto.Message): + r"""Spec for tool name match metric.""" + + +class ToolNameMatchInstance(proto.Message): + r"""Spec for tool name match instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Required. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class ToolNameMatchResults(proto.Message): + r"""Results for tool name match metric. + + Attributes: + tool_name_match_metric_values (MutableSequence[google.cloud.aiplatform_v1.types.ToolNameMatchMetricValue]): + Output only. Tool name match metric values. + """ + + tool_name_match_metric_values: MutableSequence[ + "ToolNameMatchMetricValue" + ] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="ToolNameMatchMetricValue", + ) + + +class ToolNameMatchMetricValue(proto.Message): + r"""Tool name match metric value for an instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Tool name match score. + + This field is a member of `oneof`_ ``_score``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + + +class ToolParameterKeyMatchInput(proto.Message): + r"""Input for tool parameter key match metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.ToolParameterKeyMatchSpec): + Required. Spec for tool parameter key match + metric. + instances (MutableSequence[google.cloud.aiplatform_v1.types.ToolParameterKeyMatchInstance]): + Required. Repeated tool parameter key match + instances. + """ + + metric_spec: "ToolParameterKeyMatchSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="ToolParameterKeyMatchSpec", + ) + instances: MutableSequence["ToolParameterKeyMatchInstance"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="ToolParameterKeyMatchInstance", + ) + + +class ToolParameterKeyMatchSpec(proto.Message): + r"""Spec for tool parameter key match metric.""" + + +class ToolParameterKeyMatchInstance(proto.Message): + r"""Spec for tool parameter key match instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Required. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class ToolParameterKeyMatchResults(proto.Message): + r"""Results for tool parameter key match metric. + + Attributes: + tool_parameter_key_match_metric_values (MutableSequence[google.cloud.aiplatform_v1.types.ToolParameterKeyMatchMetricValue]): + Output only. Tool parameter key match metric + values. + """ + + tool_parameter_key_match_metric_values: MutableSequence[ + "ToolParameterKeyMatchMetricValue" + ] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="ToolParameterKeyMatchMetricValue", + ) + + +class ToolParameterKeyMatchMetricValue(proto.Message): + r"""Tool parameter key match metric value for an instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Tool parameter key match score. + + This field is a member of `oneof`_ ``_score``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + + +class ToolParameterKVMatchInput(proto.Message): + r"""Input for tool parameter key value match metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1.types.ToolParameterKVMatchSpec): + Required. Spec for tool parameter key value + match metric. + instances (MutableSequence[google.cloud.aiplatform_v1.types.ToolParameterKVMatchInstance]): + Required. Repeated tool parameter key value + match instances. + """ + + metric_spec: "ToolParameterKVMatchSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="ToolParameterKVMatchSpec", + ) + instances: MutableSequence["ToolParameterKVMatchInstance"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="ToolParameterKVMatchInstance", + ) + + +class ToolParameterKVMatchSpec(proto.Message): + r"""Spec for tool parameter key value match metric. + + Attributes: + use_strict_string_match (bool): + Optional. Whether to use STRCIT string match + on parameter values. + """ + + use_strict_string_match: bool = proto.Field( + proto.BOOL, + number=1, + ) + + +class ToolParameterKVMatchInstance(proto.Message): + r"""Spec for tool parameter key value match instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prediction (str): + Required. Output of the evaluated model. + + This field is a member of `oneof`_ ``_prediction``. + reference (str): + Required. Ground truth used to compare + against the prediction. + + This field is a member of `oneof`_ ``_reference``. + """ + + prediction: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + reference: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class ToolParameterKVMatchResults(proto.Message): + r"""Results for tool parameter key value match metric. + + Attributes: + tool_parameter_kv_match_metric_values (MutableSequence[google.cloud.aiplatform_v1.types.ToolParameterKVMatchMetricValue]): + Output only. Tool parameter key value match + metric values. + """ + + tool_parameter_kv_match_metric_values: MutableSequence[ + "ToolParameterKVMatchMetricValue" + ] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="ToolParameterKVMatchMetricValue", + ) + + +class ToolParameterKVMatchMetricValue(proto.Message): + r"""Tool parameter key value match metric value for an instance. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Tool parameter key value match + score. + + This field is a member of `oneof`_ ``_score``. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/machine_resources.py b/google/cloud/aiplatform_v1/types/machine_resources.py index 3eb1ba0ad5..af3dcfab6d 100644 --- a/google/cloud/aiplatform_v1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1/types/machine_resources.py @@ -20,6 +20,9 @@ import proto # type: ignore from google.cloud.aiplatform_v1.types import accelerator_type as gca_accelerator_type +from google.cloud.aiplatform_v1.types import ( + reservation_affinity as gca_reservation_affinity, +) __protobuf__ = proto.module( @@ -71,6 +74,10 @@ class MachineSpec(proto.Message): Immutable. The topology of the TPUs. Corresponds to the TPU topologies available from GKE. (Example: tpu_topology: "2x2x1"). + reservation_affinity (google.cloud.aiplatform_v1.types.ReservationAffinity): + Optional. Immutable. Configuration + controlling how this resource pool consumes + reservation. """ machine_type: str = proto.Field( @@ -90,6 +97,11 @@ class MachineSpec(proto.Message): proto.STRING, number=4, ) + reservation_affinity: gca_reservation_affinity.ReservationAffinity = proto.Field( + proto.MESSAGE, + number=5, + message=gca_reservation_affinity.ReservationAffinity, + ) class DedicatedResources(proto.Message): @@ -157,6 +169,9 @@ class DedicatedResources(proto.Message): and [autoscaling_metric_specs.target][google.cloud.aiplatform.v1.AutoscalingMetricSpec.target] to ``80``. + spot (bool): + Optional. If true, schedule the deployment workload on `spot + VMs `__. """ machine_spec: "MachineSpec" = proto.Field( @@ -179,6 +194,10 @@ class DedicatedResources(proto.Message): number=4, message="AutoscalingMetricSpec", ) + spot: bool = proto.Field( + proto.BOOL, + number=5, + ) class AutomaticResources(proto.Message): diff --git a/google/cloud/aiplatform_v1/types/prediction_service.py b/google/cloud/aiplatform_v1/types/prediction_service.py index be63fe9d54..ca2a09ee2d 100644 --- a/google/cloud/aiplatform_v1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1/types/prediction_service.py @@ -915,7 +915,9 @@ class UsageMetadata(proto.Message): Attributes: prompt_token_count (int): - Number of tokens in the request. + Number of tokens in the request. When ``cached_content`` is + set, this is still the total effective prompt size meaning + this includes the number of tokens in the cached content. candidates_token_count (int): Number of tokens in the response(s). total_token_count (int): diff --git a/google/cloud/aiplatform_v1/types/reservation_affinity.py b/google/cloud/aiplatform_v1/types/reservation_affinity.py new file mode 100644 index 0000000000..6186e6e040 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/reservation_affinity.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "ReservationAffinity", + }, +) + + +class ReservationAffinity(proto.Message): + r"""A ReservationAffinity can be used to configure a Vertex AI + resource (e.g., a DeployedModel) to draw its Compute Engine + resources from a Shared Reservation, or exclusively from + on-demand capacity. + + Attributes: + reservation_affinity_type (google.cloud.aiplatform_v1.types.ReservationAffinity.Type): + Required. Specifies the reservation affinity + type. + key (str): + Optional. Corresponds to the label key of a reservation + resource. To target a SPECIFIC_RESERVATION by name, use + ``compute.googleapis.com/reservation-name`` as the key and + specify the name of your reservation as its value. + values (MutableSequence[str]): + Optional. Corresponds to the label values of + a reservation resource. This must be the full + resource name of the reservation. + """ + + class Type(proto.Enum): + r"""Identifies a type of reservation affinity. + + Values: + TYPE_UNSPECIFIED (0): + Default value. This should not be used. + NO_RESERVATION (1): + Do not consume from any reserved capacity, + only use on-demand. + ANY_RESERVATION (2): + Consume any reservation available, falling + back to on-demand. + SPECIFIC_RESERVATION (3): + Consume from a specific reservation. When chosen, the + reservation must be identified via the ``key`` and + ``values`` fields. + """ + TYPE_UNSPECIFIED = 0 + NO_RESERVATION = 1 + ANY_RESERVATION = 2 + SPECIFIC_RESERVATION = 3 + + reservation_affinity_type: Type = proto.Field( + proto.ENUM, + number=1, + enum=Type, + ) + key: str = proto.Field( + proto.STRING, + number=2, + ) + values: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=3, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 28554ac2a3..9f1bacd88b 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -252,6 +252,10 @@ from .types.evaluation_service import GroundednessInstance from .types.evaluation_service import GroundednessResult from .types.evaluation_service import GroundednessSpec +from .types.evaluation_service import PairwiseMetricInput +from .types.evaluation_service import PairwiseMetricInstance +from .types.evaluation_service import PairwiseMetricResult +from .types.evaluation_service import PairwiseMetricSpec from .types.evaluation_service import PairwiseQuestionAnsweringQualityInput from .types.evaluation_service import PairwiseQuestionAnsweringQualityInstance from .types.evaluation_service import PairwiseQuestionAnsweringQualityResult @@ -260,6 +264,10 @@ from .types.evaluation_service import PairwiseSummarizationQualityInstance from .types.evaluation_service import PairwiseSummarizationQualityResult from .types.evaluation_service import PairwiseSummarizationQualitySpec +from .types.evaluation_service import PointwiseMetricInput +from .types.evaluation_service import PointwiseMetricInstance +from .types.evaluation_service import PointwiseMetricResult +from .types.evaluation_service import PointwiseMetricSpec from .types.evaluation_service import QuestionAnsweringCorrectnessInput from .types.evaluation_service import QuestionAnsweringCorrectnessInstance from .types.evaluation_service import QuestionAnsweringCorrectnessResult @@ -872,6 +880,7 @@ from .types.reasoning_engine_service import ListReasoningEnginesResponse from .types.reasoning_engine_service import UpdateReasoningEngineOperationMetadata from .types.reasoning_engine_service import UpdateReasoningEngineRequest +from .types.reservation_affinity import ReservationAffinity from .types.saved_query import SavedQuery from .types.schedule import Schedule from .types.schedule_service import CreateScheduleRequest @@ -1702,6 +1711,10 @@ "NotebookRuntimeType", "NotebookServiceClient", "PairwiseChoice", + "PairwiseMetricInput", + "PairwiseMetricInstance", + "PairwiseMetricResult", + "PairwiseMetricSpec", "PairwiseQuestionAnsweringQualityInput", "PairwiseQuestionAnsweringQualityInstance", "PairwiseQuestionAnsweringQualityResult", @@ -1724,6 +1737,10 @@ "PipelineTaskDetail", "PipelineTaskExecutorDetail", "PipelineTemplateMetadata", + "PointwiseMetricInput", + "PointwiseMetricInstance", + "PointwiseMetricResult", + "PointwiseMetricSpec", "Port", "PredefinedSplit", "PredictLongRunningMetadata", @@ -1808,6 +1825,7 @@ "RemoveContextChildrenResponse", "RemoveDatapointsRequest", "RemoveDatapointsResponse", + "ReservationAffinity", "ResourcePool", "ResourceRuntime", "ResourceRuntimeSpec", diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py index 2d6a83d76e..af533f5c38 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.61.0" # {x-release-please-version} +__version__ = "1.62.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py index 021b74969d..fb892f0fd6 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -777,6 +777,8 @@ async def sample_list_datasets(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1799,6 +1801,8 @@ async def sample_list_dataset_versions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2043,6 +2047,8 @@ async def sample_list_data_items(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2141,6 +2147,8 @@ async def sample_search_data_items(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2261,6 +2269,8 @@ async def sample_list_saved_queries(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2617,6 +2627,8 @@ async def sample_list_annotations(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 64e4ec3d9b..99fceffda5 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -1292,6 +1292,8 @@ def sample_list_datasets(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2290,6 +2292,8 @@ def sample_list_dataset_versions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2528,6 +2532,8 @@ def sample_list_data_items(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2624,6 +2630,8 @@ def sample_search_data_items(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2741,6 +2749,8 @@ def sample_list_saved_queries(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3088,6 +3098,8 @@ def sample_list_annotations(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py index fc6367e7d2..88fb1099d7 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import annotation from google.cloud.aiplatform_v1beta1.types import data_item from google.cloud.aiplatform_v1beta1.types import dataset @@ -56,6 +69,8 @@ def __init__( request: dataset_service.ListDatasetsRequest, response: dataset_service.ListDatasetsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -67,12 +82,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListDatasetsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDatasetsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -83,7 +103,12 @@ def pages(self) -> Iterator[dataset_service.ListDatasetsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[dataset.Dataset]: @@ -118,6 +143,8 @@ def __init__( request: dataset_service.ListDatasetsRequest, response: dataset_service.ListDatasetsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -129,12 +156,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListDatasetsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDatasetsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages(self) -> AsyncIterator[dataset_service.ListDatasetsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[dataset.Dataset]: @@ -184,6 +221,8 @@ def __init__( request: dataset_service.ListDatasetVersionsRequest, response: dataset_service.ListDatasetVersionsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -195,12 +234,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListDatasetVersionsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDatasetVersionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -211,7 +255,12 @@ def pages(self) -> Iterator[dataset_service.ListDatasetVersionsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[dataset_version.DatasetVersion]: @@ -246,6 +295,8 @@ def __init__( request: dataset_service.ListDatasetVersionsRequest, response: dataset_service.ListDatasetVersionsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -257,12 +308,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListDatasetVersionsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDatasetVersionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -273,7 +329,12 @@ async def pages(self) -> AsyncIterator[dataset_service.ListDatasetVersionsRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[dataset_version.DatasetVersion]: @@ -312,6 +373,8 @@ def __init__( request: dataset_service.ListDataItemsRequest, response: dataset_service.ListDataItemsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -323,12 +386,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListDataItemsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDataItemsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -339,7 +407,12 @@ def pages(self) -> Iterator[dataset_service.ListDataItemsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[data_item.DataItem]: @@ -374,6 +447,8 @@ def __init__( request: dataset_service.ListDataItemsRequest, response: dataset_service.ListDataItemsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -385,12 +460,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListDataItemsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDataItemsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -401,7 +481,12 @@ async def pages(self) -> AsyncIterator[dataset_service.ListDataItemsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[data_item.DataItem]: @@ -440,6 +525,8 @@ def __init__( request: dataset_service.SearchDataItemsRequest, response: dataset_service.SearchDataItemsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -451,12 +538,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchDataItemsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.SearchDataItemsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -467,7 +559,12 @@ def pages(self) -> Iterator[dataset_service.SearchDataItemsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[dataset_service.DataItemView]: @@ -502,6 +599,8 @@ def __init__( request: dataset_service.SearchDataItemsRequest, response: dataset_service.SearchDataItemsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -513,12 +612,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchDataItemsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.SearchDataItemsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -529,7 +633,12 @@ async def pages(self) -> AsyncIterator[dataset_service.SearchDataItemsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[dataset_service.DataItemView]: @@ -568,6 +677,8 @@ def __init__( request: dataset_service.ListSavedQueriesRequest, response: dataset_service.ListSavedQueriesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -579,12 +690,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListSavedQueriesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListSavedQueriesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -595,7 +711,12 @@ def pages(self) -> Iterator[dataset_service.ListSavedQueriesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[saved_query.SavedQuery]: @@ -630,6 +751,8 @@ def __init__( request: dataset_service.ListSavedQueriesRequest, response: dataset_service.ListSavedQueriesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -641,12 +764,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListSavedQueriesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListSavedQueriesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -657,7 +785,12 @@ async def pages(self) -> AsyncIterator[dataset_service.ListSavedQueriesResponse] yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[saved_query.SavedQuery]: @@ -696,6 +829,8 @@ def __init__( request: dataset_service.ListAnnotationsRequest, response: dataset_service.ListAnnotationsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -707,12 +842,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListAnnotationsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListAnnotationsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -723,7 +863,12 @@ def pages(self) -> Iterator[dataset_service.ListAnnotationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[annotation.Annotation]: @@ -758,6 +903,8 @@ def __init__( request: dataset_service.ListAnnotationsRequest, response: dataset_service.ListAnnotationsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -769,12 +916,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListAnnotationsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListAnnotationsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -785,7 +937,12 @@ async def pages(self) -> AsyncIterator[dataset_service.ListAnnotationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[annotation.Annotation]: diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/rest.py index 39cbeaa6b4..45ef2d60cc 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/rest.py @@ -3093,6 +3093,11 @@ def __call__( "uri": "/v1beta1/{parent=projects/*/locations/*}/datasets", "body": "dataset", }, + { + "method": "post", + "uri": "/v1beta1/datasets", + "body": "dataset", + }, ] request, metadata = self._interceptor.pre_create_dataset(request, metadata) pb_request = dataset_service.CreateDatasetRequest.pb(request) @@ -3187,6 +3192,11 @@ def __call__( "uri": "/v1beta1/{parent=projects/*/locations/*/datasets/*}/datasetVersions", "body": "dataset_version", }, + { + "method": "post", + "uri": "/v1beta1/{parent=datasets/*}/datasetVersions", + "body": "dataset_version", + }, ] request, metadata = self._interceptor.pre_create_dataset_version( request, metadata @@ -3282,6 +3292,10 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/datasets/*}", }, + { + "method": "delete", + "uri": "/v1beta1/{name=datasets/*}", + }, ] request, metadata = self._interceptor.pre_delete_dataset(request, metadata) pb_request = dataset_service.DeleteDatasetRequest.pb(request) @@ -3369,6 +3383,10 @@ def __call__( "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/datasets/*/datasetVersions/*}", }, + { + "method": "delete", + "uri": "/v1beta1/{name=datasets/*/datasetVersions/*}", + }, ] request, metadata = self._interceptor.pre_delete_dataset_version( request, metadata @@ -3730,6 +3748,10 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/datasets/*}", }, + { + "method": "get", + "uri": "/v1beta1/{name=datasets/*}", + }, ] request, metadata = self._interceptor.pre_get_dataset(request, metadata) pb_request = dataset_service.GetDatasetRequest.pb(request) @@ -3816,6 +3838,10 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/datasets/*/datasetVersions/*}", }, + { + "method": "get", + "uri": "/v1beta1/{name=datasets/*/datasetVersions/*}", + }, ] request, metadata = self._interceptor.pre_get_dataset_version( request, metadata @@ -4178,6 +4204,10 @@ def __call__( "method": "get", "uri": "/v1beta1/{parent=projects/*/locations/*}/datasets", }, + { + "method": "get", + "uri": "/v1beta1/datasets", + }, ] request, metadata = self._interceptor.pre_list_datasets(request, metadata) pb_request = dataset_service.ListDatasetsRequest.pb(request) @@ -4266,6 +4296,10 @@ def __call__( "method": "get", "uri": "/v1beta1/{parent=projects/*/locations/*/datasets/*}/datasetVersions", }, + { + "method": "get", + "uri": "/v1beta1/{parent=datasets/*}/datasetVersions", + }, ] request, metadata = self._interceptor.pre_list_dataset_versions( request, metadata @@ -4447,6 +4481,10 @@ def __call__( "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/datasets/*/datasetVersions/*}:restore", }, + { + "method": "get", + "uri": "/v1beta1/{name=datasets/*/datasetVersions/*}:restore", + }, ] request, metadata = self._interceptor.pre_restore_dataset_version( request, metadata @@ -4628,6 +4666,11 @@ def __call__( "uri": "/v1beta1/{dataset.name=projects/*/locations/*/datasets/*}", "body": "dataset", }, + { + "method": "patch", + "uri": "/v1beta1/{dataset.name=datasets/*}", + "body": "dataset", + }, ] request, metadata = self._interceptor.pre_update_dataset(request, metadata) pb_request = dataset_service.UpdateDatasetRequest.pb(request) @@ -4723,6 +4766,11 @@ def __call__( "uri": "/v1beta1/{dataset_version.name=projects/*/locations/*/datasets/*/datasetVersions/*}", "body": "dataset_version", }, + { + "method": "patch", + "uri": "/v1beta1/{dataset_version.name=datasets/*/datasetVersions/*}", + "body": "dataset_version", + }, ] request, metadata = self._interceptor.pre_update_dataset_version( request, metadata diff --git a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/async_client.py index 338e7736be..ccfde0d7a5 100644 --- a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/async_client.py @@ -98,6 +98,12 @@ class DeploymentResourcePoolServiceAsyncClient: parse_model_path = staticmethod( DeploymentResourcePoolServiceClient.parse_model_path ) + reservation_path = staticmethod( + DeploymentResourcePoolServiceClient.reservation_path + ) + parse_reservation_path = staticmethod( + DeploymentResourcePoolServiceClient.parse_reservation_path + ) common_billing_account_path = staticmethod( DeploymentResourcePoolServiceClient.common_billing_account_path ) @@ -710,6 +716,8 @@ async def sample_list_deployment_resource_pools(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1138,6 +1146,8 @@ async def sample_query_deployed_models(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/client.py index 3f9ca5953e..ab05789133 100644 --- a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/client.py @@ -271,6 +271,28 @@ def parse_model_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path( billing_account: str, @@ -1164,6 +1186,8 @@ def sample_list_deployment_resource_pools(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1587,6 +1611,8 @@ def sample_query_deployed_models(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/pagers.py index b3709bed5a..cac8cf3ff6 100644 --- a/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/deployment_resource_pool_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import deployment_resource_pool from google.cloud.aiplatform_v1beta1.types import deployment_resource_pool_service from google.cloud.aiplatform_v1beta1.types import endpoint @@ -55,6 +68,8 @@ def __init__( request: deployment_resource_pool_service.ListDeploymentResourcePoolsRequest, response: deployment_resource_pool_service.ListDeploymentResourcePoolsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -66,6 +81,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListDeploymentResourcePoolsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -74,6 +92,8 @@ def __init__( deployment_resource_pool_service.ListDeploymentResourcePoolsRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -86,7 +106,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[deployment_resource_pool.DeploymentResourcePool]: @@ -126,6 +151,8 @@ def __init__( request: deployment_resource_pool_service.ListDeploymentResourcePoolsRequest, response: deployment_resource_pool_service.ListDeploymentResourcePoolsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -137,6 +164,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListDeploymentResourcePoolsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -145,6 +175,8 @@ def __init__( deployment_resource_pool_service.ListDeploymentResourcePoolsRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -159,7 +191,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__( @@ -202,6 +239,8 @@ def __init__( request: deployment_resource_pool_service.QueryDeployedModelsRequest, response: deployment_resource_pool_service.QueryDeployedModelsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -213,6 +252,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.QueryDeployedModelsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -221,6 +263,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -233,7 +277,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[endpoint.DeployedModel]: @@ -270,6 +319,8 @@ def __init__( request: deployment_resource_pool_service.QueryDeployedModelsRequest, response: deployment_resource_pool_service.QueryDeployedModelsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -281,6 +332,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.QueryDeployedModelsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -289,6 +343,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -301,7 +357,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[endpoint.DeployedModel]: diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index da74fc673f..c2a8688e66 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -95,6 +95,8 @@ class EndpointServiceAsyncClient: ) network_path = staticmethod(EndpointServiceClient.network_path) parse_network_path = staticmethod(EndpointServiceClient.parse_network_path) + reservation_path = staticmethod(EndpointServiceClient.reservation_path) + parse_reservation_path = staticmethod(EndpointServiceClient.parse_reservation_path) common_billing_account_path = staticmethod( EndpointServiceClient.common_billing_account_path ) @@ -669,6 +671,8 @@ async def sample_list_endpoints(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index adbc691cbd..6065d18176 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -303,6 +303,28 @@ def parse_network_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path( billing_account: str, @@ -1148,6 +1170,8 @@ def sample_list_endpoints(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py index 516757a8e0..d3ee5ef502 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service @@ -52,6 +65,8 @@ def __init__( request: endpoint_service.ListEndpointsRequest, response: endpoint_service.ListEndpointsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListEndpointsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = endpoint_service.ListEndpointsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[endpoint_service.ListEndpointsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[endpoint.Endpoint]: @@ -114,6 +139,8 @@ def __init__( request: endpoint_service.ListEndpointsRequest, response: endpoint_service.ListEndpointsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -125,12 +152,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListEndpointsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = endpoint_service.ListEndpointsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -141,7 +173,12 @@ async def pages(self) -> AsyncIterator[endpoint_service.ListEndpointsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[endpoint.Endpoint]: diff --git a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/async_client.py index a3eb3c4c23..4989cabb59 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/async_client.py @@ -657,6 +657,8 @@ async def sample_list_extensions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/client.py b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/client.py index c9b17ad38c..cf4575a434 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/client.py @@ -1101,6 +1101,8 @@ def sample_list_extensions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/pagers.py index 0f9616c706..10fbd99851 100644 --- a/google/cloud/aiplatform_v1beta1/services/extension_registry_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/extension_registry_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import extension from google.cloud.aiplatform_v1beta1.types import extension_registry_service @@ -52,6 +65,8 @@ def __init__( request: extension_registry_service.ListExtensionsRequest, response: extension_registry_service.ListExtensionsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListExtensionsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = extension_registry_service.ListExtensionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[extension_registry_service.ListExtensionsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[extension.Extension]: @@ -116,6 +141,8 @@ def __init__( request: extension_registry_service.ListExtensionsRequest, response: extension_registry_service.ListExtensionsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListExtensionsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = extension_registry_service.ListExtensionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[extension.Extension]: diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/async_client.py index 87e702095d..8d55d788c1 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/async_client.py @@ -722,6 +722,8 @@ async def sample_list_feature_online_stores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1437,6 +1439,8 @@ async def sample_list_feature_views(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2086,6 +2090,8 @@ async def sample_list_feature_view_syncs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/client.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/client.py index 0ccf5d784d..c6af1e42b1 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/client.py @@ -1174,6 +1174,8 @@ def sample_list_feature_online_stores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1878,6 +1880,8 @@ def sample_list_feature_views(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2512,6 +2516,8 @@ def sample_list_feature_view_syncs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/pagers.py index 66a0b6e231..26898d0704 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_online_store_admin_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import feature_online_store from google.cloud.aiplatform_v1beta1.types import feature_online_store_admin_service from google.cloud.aiplatform_v1beta1.types import feature_view @@ -56,6 +69,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureOnlineStoresRequest, response: feature_online_store_admin_service.ListFeatureOnlineStoresResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -67,6 +82,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeatureOnlineStoresResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -75,6 +93,8 @@ def __init__( feature_online_store_admin_service.ListFeatureOnlineStoresRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -87,7 +107,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature_online_store.FeatureOnlineStore]: @@ -127,6 +152,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureOnlineStoresRequest, response: feature_online_store_admin_service.ListFeatureOnlineStoresResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -138,6 +165,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeatureOnlineStoresResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -146,6 +176,8 @@ def __init__( feature_online_store_admin_service.ListFeatureOnlineStoresRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -160,7 +192,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature_online_store.FeatureOnlineStore]: @@ -201,6 +238,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureViewsRequest, response: feature_online_store_admin_service.ListFeatureViewsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -212,6 +251,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeatureViewsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -220,6 +262,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -232,7 +276,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature_view.FeatureView]: @@ -269,6 +318,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureViewsRequest, response: feature_online_store_admin_service.ListFeatureViewsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -280,6 +331,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeatureViewsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -288,6 +342,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -300,7 +356,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature_view.FeatureView]: @@ -341,6 +402,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureViewSyncsRequest, response: feature_online_store_admin_service.ListFeatureViewSyncsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -352,6 +415,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeatureViewSyncsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -360,6 +426,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -372,7 +440,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature_view_sync.FeatureViewSync]: @@ -410,6 +483,8 @@ def __init__( request: feature_online_store_admin_service.ListFeatureViewSyncsRequest, response: feature_online_store_admin_service.ListFeatureViewSyncsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -421,6 +496,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeatureViewSyncsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -429,6 +507,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -441,7 +521,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature_view_sync.FeatureViewSync]: diff --git a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py index 511a4efafe..abaaa8179a 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py @@ -665,6 +665,8 @@ async def sample_list_feature_groups(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1354,6 +1356,8 @@ async def sample_list_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py index 63d0d404e7..fbb5637054 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py @@ -1097,6 +1097,8 @@ def sample_list_feature_groups(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1771,6 +1773,8 @@ def sample_list_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/pagers.py index 73148b6618..ef352e3ee9 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import feature from google.cloud.aiplatform_v1beta1.types import feature_group from google.cloud.aiplatform_v1beta1.types import feature_registry_service @@ -54,6 +67,8 @@ def __init__( request: feature_registry_service.ListFeatureGroupsRequest, response: feature_registry_service.ListFeatureGroupsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -65,12 +80,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeatureGroupsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = feature_registry_service.ListFeatureGroupsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -81,7 +101,12 @@ def pages(self) -> Iterator[feature_registry_service.ListFeatureGroupsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature_group.FeatureGroup]: @@ -118,6 +143,8 @@ def __init__( request: feature_registry_service.ListFeatureGroupsRequest, response: feature_registry_service.ListFeatureGroupsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -129,12 +156,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeatureGroupsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = feature_registry_service.ListFeatureGroupsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -147,7 +179,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature_group.FeatureGroup]: @@ -186,6 +223,8 @@ def __init__( request: featurestore_service.ListFeaturesRequest, response: featurestore_service.ListFeaturesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -197,12 +236,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeaturesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -213,7 +257,12 @@ def pages(self) -> Iterator[featurestore_service.ListFeaturesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature.Feature]: @@ -248,6 +297,8 @@ def __init__( request: featurestore_service.ListFeaturesRequest, response: featurestore_service.ListFeaturesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -259,12 +310,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeaturesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -275,7 +331,12 @@ async def pages(self) -> AsyncIterator[featurestore_service.ListFeaturesResponse yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature.Feature]: diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py index a3dacd250f..f1709909f4 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/async_client.py @@ -671,6 +671,8 @@ async def sample_list_featurestores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1351,6 +1353,8 @@ async def sample_list_entity_types(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2177,6 +2181,8 @@ async def sample_list_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3230,6 +3236,8 @@ async def sample_search_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py index b6257dc5a8..2621bff93d 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py @@ -1123,6 +1123,8 @@ def sample_list_featurestores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1788,6 +1790,8 @@ def sample_list_entity_types(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2596,6 +2600,8 @@ def sample_list_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3630,6 +3636,8 @@ def sample_search_features(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py index ed991c1e3b..4e41316417 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import entity_type from google.cloud.aiplatform_v1beta1.types import feature from google.cloud.aiplatform_v1beta1.types import featurestore @@ -54,6 +67,8 @@ def __init__( request: featurestore_service.ListFeaturestoresRequest, response: featurestore_service.ListFeaturestoresResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -65,12 +80,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeaturestoresResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturestoresRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -81,7 +101,12 @@ def pages(self) -> Iterator[featurestore_service.ListFeaturestoresResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[featurestore.Featurestore]: @@ -118,6 +143,8 @@ def __init__( request: featurestore_service.ListFeaturestoresRequest, response: featurestore_service.ListFeaturestoresResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -129,12 +156,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeaturestoresResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturestoresRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -147,7 +179,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[featurestore.Featurestore]: @@ -186,6 +223,8 @@ def __init__( request: featurestore_service.ListEntityTypesRequest, response: featurestore_service.ListEntityTypesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -197,12 +236,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListEntityTypesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListEntityTypesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -213,7 +257,12 @@ def pages(self) -> Iterator[featurestore_service.ListEntityTypesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[entity_type.EntityType]: @@ -248,6 +297,8 @@ def __init__( request: featurestore_service.ListEntityTypesRequest, response: featurestore_service.ListEntityTypesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -259,12 +310,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListEntityTypesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListEntityTypesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -277,7 +333,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[entity_type.EntityType]: @@ -316,6 +377,8 @@ def __init__( request: featurestore_service.ListFeaturesRequest, response: featurestore_service.ListFeaturesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -327,12 +390,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeaturesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -343,7 +411,12 @@ def pages(self) -> Iterator[featurestore_service.ListFeaturesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature.Feature]: @@ -378,6 +451,8 @@ def __init__( request: featurestore_service.ListFeaturesRequest, response: featurestore_service.ListFeaturesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -389,12 +464,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListFeaturesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.ListFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -405,7 +485,12 @@ async def pages(self) -> AsyncIterator[featurestore_service.ListFeaturesResponse yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature.Feature]: @@ -444,6 +529,8 @@ def __init__( request: featurestore_service.SearchFeaturesRequest, response: featurestore_service.SearchFeaturesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -455,12 +542,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchFeaturesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.SearchFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -471,7 +563,12 @@ def pages(self) -> Iterator[featurestore_service.SearchFeaturesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[feature.Feature]: @@ -506,6 +603,8 @@ def __init__( request: featurestore_service.SearchFeaturesRequest, response: featurestore_service.SearchFeaturesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -517,12 +616,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchFeaturesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = featurestore_service.SearchFeaturesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -533,7 +637,12 @@ async def pages(self) -> AsyncIterator[featurestore_service.SearchFeaturesRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[feature.Feature]: diff --git a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py index 084086794d..15d032acc4 100644 --- a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py @@ -837,6 +837,8 @@ async def sample_list_cached_contents(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py index 3168996e6f..449c20295d 100644 --- a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py @@ -1260,6 +1260,8 @@ def sample_list_cached_contents(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/pagers.py index bcbbef757b..91e1305a06 100644 --- a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import cached_content from google.cloud.aiplatform_v1beta1.types import gen_ai_cache_service @@ -52,6 +65,8 @@ def __init__( request: gen_ai_cache_service.ListCachedContentsRequest, response: gen_ai_cache_service.ListCachedContentsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListCachedContentsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = gen_ai_cache_service.ListCachedContentsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[gen_ai_cache_service.ListCachedContentsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[cached_content.CachedContent]: @@ -116,6 +141,8 @@ def __init__( request: gen_ai_cache_service.ListCachedContentsRequest, response: gen_ai_cache_service.ListCachedContentsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListCachedContentsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = gen_ai_cache_service.ListCachedContentsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[cached_content.CachedContent]: diff --git a/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/async_client.py index 21e7236e3a..7325dd4f2b 100644 --- a/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/async_client.py @@ -627,6 +627,8 @@ async def sample_list_tuning_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/client.py b/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/client.py index 6e8d0c3803..c6826ab1fa 100644 --- a/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/client.py @@ -1117,6 +1117,8 @@ def sample_list_tuning_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/pagers.py index a6b35d59de..46617ea2b6 100644 --- a/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/gen_ai_tuning_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import genai_tuning_service from google.cloud.aiplatform_v1beta1.types import tuning_job @@ -52,6 +65,8 @@ def __init__( request: genai_tuning_service.ListTuningJobsRequest, response: genai_tuning_service.ListTuningJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTuningJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = genai_tuning_service.ListTuningJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[genai_tuning_service.ListTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tuning_job.TuningJob]: @@ -114,6 +139,8 @@ def __init__( request: genai_tuning_service.ListTuningJobsRequest, response: genai_tuning_service.ListTuningJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -125,12 +152,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTuningJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = genai_tuning_service.ListTuningJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -141,7 +173,12 @@ async def pages(self) -> AsyncIterator[genai_tuning_service.ListTuningJobsRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tuning_job.TuningJob]: diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py index 627acaa655..887b448e85 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/async_client.py @@ -83,6 +83,10 @@ class IndexEndpointServiceAsyncClient: parse_index_endpoint_path = staticmethod( IndexEndpointServiceClient.parse_index_endpoint_path ) + reservation_path = staticmethod(IndexEndpointServiceClient.reservation_path) + parse_reservation_path = staticmethod( + IndexEndpointServiceClient.parse_reservation_path + ) common_billing_account_path = staticmethod( IndexEndpointServiceClient.common_billing_account_path ) @@ -643,6 +647,8 @@ async def sample_list_index_endpoints(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py index d7003433e4..f859e11ebe 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py @@ -240,6 +240,28 @@ def parse_index_endpoint_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path( billing_account: str, @@ -1071,6 +1093,8 @@ def sample_list_index_endpoints(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py index eec12dfe9b..39506c02eb 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import index_endpoint from google.cloud.aiplatform_v1beta1.types import index_endpoint_service @@ -52,6 +65,8 @@ def __init__( request: index_endpoint_service.ListIndexEndpointsRequest, response: index_endpoint_service.ListIndexEndpointsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = index_endpoint_service.ListIndexEndpointsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[index_endpoint_service.ListIndexEndpointsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[index_endpoint.IndexEndpoint]: @@ -116,6 +141,8 @@ def __init__( request: index_endpoint_service.ListIndexEndpointsRequest, response: index_endpoint_service.ListIndexEndpointsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListIndexEndpointsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = index_endpoint_service.ListIndexEndpointsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[index_endpoint.IndexEndpoint]: diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py index a27de9bca4..6e6b0e075e 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/async_client.py @@ -630,6 +630,8 @@ async def sample_list_indexes(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_service/client.py index 418df013ea..45a9060d2e 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/client.py @@ -1057,6 +1057,8 @@ def sample_list_indexes(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/index_service/pagers.py index e3687ab4a7..ac90df799b 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import index from google.cloud.aiplatform_v1beta1.types import index_service @@ -52,6 +65,8 @@ def __init__( request: index_service.ListIndexesRequest, response: index_service.ListIndexesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListIndexesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = index_service.ListIndexesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[index_service.ListIndexesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[index.Index]: @@ -114,6 +139,8 @@ def __init__( request: index_service.ListIndexesRequest, response: index_service.ListIndexesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -125,12 +152,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListIndexesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = index_service.ListIndexesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -141,7 +173,12 @@ async def pages(self) -> AsyncIterator[index_service.ListIndexesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[index.Index]: diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index 2d3b7402ce..5b8d1bca99 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -153,6 +153,8 @@ class JobServiceAsyncClient: parse_persistent_resource_path = staticmethod( JobServiceClient.parse_persistent_resource_path ) + reservation_path = staticmethod(JobServiceClient.reservation_path) + parse_reservation_path = staticmethod(JobServiceClient.parse_reservation_path) tensorboard_path = staticmethod(JobServiceClient.tensorboard_path) parse_tensorboard_path = staticmethod(JobServiceClient.parse_tensorboard_path) trial_path = staticmethod(JobServiceClient.trial_path) @@ -696,6 +698,8 @@ async def sample_list_custom_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1283,6 +1287,8 @@ async def sample_list_data_labeling_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1875,6 +1881,8 @@ async def sample_list_hyperparameter_tuning_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2464,6 +2472,8 @@ async def sample_list_nas_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2927,6 +2937,8 @@ async def sample_list_nas_trial_details(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3296,6 +3308,8 @@ async def sample_list_batch_prediction_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3815,6 +3829,8 @@ async def sample_search_model_deployment_monitoring_stats_anomalies(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4051,6 +4067,8 @@ async def sample_list_model_deployment_monitoring_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index f15e57ca85..7cbc5bc93a 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -531,6 +531,28 @@ def parse_persistent_resource_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def tensorboard_path( project: str, @@ -1392,6 +1414,8 @@ def sample_list_custom_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1964,6 +1988,8 @@ def sample_list_data_labeling_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2547,6 +2573,8 @@ def sample_list_hyperparameter_tuning_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3125,6 +3153,8 @@ def sample_list_nas_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3576,6 +3606,8 @@ def sample_list_nas_trial_details(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3940,6 +3972,8 @@ def sample_list_batch_prediction_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4459,6 +4493,8 @@ def sample_search_model_deployment_monitoring_stats_anomalies(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4693,6 +4729,8 @@ def sample_list_model_deployment_monitoring_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py index d76d51b6b0..77e0fb6999 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job @@ -60,6 +73,8 @@ def __init__( request: job_service.ListCustomJobsRequest, response: job_service.ListCustomJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -71,12 +86,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListCustomJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListCustomJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -87,7 +107,12 @@ def pages(self) -> Iterator[job_service.ListCustomJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[custom_job.CustomJob]: @@ -122,6 +147,8 @@ def __init__( request: job_service.ListCustomJobsRequest, response: job_service.ListCustomJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -133,12 +160,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListCustomJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListCustomJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -149,7 +181,12 @@ async def pages(self) -> AsyncIterator[job_service.ListCustomJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[custom_job.CustomJob]: @@ -188,6 +225,8 @@ def __init__( request: job_service.ListDataLabelingJobsRequest, response: job_service.ListDataLabelingJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -199,12 +238,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListDataLabelingJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -215,7 +259,12 @@ def pages(self) -> Iterator[job_service.ListDataLabelingJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[data_labeling_job.DataLabelingJob]: @@ -250,6 +299,8 @@ def __init__( request: job_service.ListDataLabelingJobsRequest, response: job_service.ListDataLabelingJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -261,12 +312,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListDataLabelingJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -277,7 +333,12 @@ async def pages(self) -> AsyncIterator[job_service.ListDataLabelingJobsResponse] yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[data_labeling_job.DataLabelingJob]: @@ -316,6 +377,8 @@ def __init__( request: job_service.ListHyperparameterTuningJobsRequest, response: job_service.ListHyperparameterTuningJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -327,12 +390,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListHyperparameterTuningJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -343,7 +411,12 @@ def pages(self) -> Iterator[job_service.ListHyperparameterTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[hyperparameter_tuning_job.HyperparameterTuningJob]: @@ -380,6 +453,8 @@ def __init__( request: job_service.ListHyperparameterTuningJobsRequest, response: job_service.ListHyperparameterTuningJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -391,12 +466,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListHyperparameterTuningJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -409,7 +489,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__( @@ -450,6 +535,8 @@ def __init__( request: job_service.ListNasJobsRequest, response: job_service.ListNasJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -461,12 +548,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListNasJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListNasJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -477,7 +569,12 @@ def pages(self) -> Iterator[job_service.ListNasJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[nas_job.NasJob]: @@ -512,6 +609,8 @@ def __init__( request: job_service.ListNasJobsRequest, response: job_service.ListNasJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -523,12 +622,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListNasJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListNasJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -539,7 +643,12 @@ async def pages(self) -> AsyncIterator[job_service.ListNasJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[nas_job.NasJob]: @@ -578,6 +687,8 @@ def __init__( request: job_service.ListNasTrialDetailsRequest, response: job_service.ListNasTrialDetailsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -589,12 +700,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListNasTrialDetailsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListNasTrialDetailsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -605,7 +721,12 @@ def pages(self) -> Iterator[job_service.ListNasTrialDetailsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[nas_job.NasTrialDetail]: @@ -640,6 +761,8 @@ def __init__( request: job_service.ListNasTrialDetailsRequest, response: job_service.ListNasTrialDetailsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -651,12 +774,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListNasTrialDetailsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListNasTrialDetailsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -667,7 +795,12 @@ async def pages(self) -> AsyncIterator[job_service.ListNasTrialDetailsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[nas_job.NasTrialDetail]: @@ -706,6 +839,8 @@ def __init__( request: job_service.ListBatchPredictionJobsRequest, response: job_service.ListBatchPredictionJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -717,12 +852,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListBatchPredictionJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -733,7 +873,12 @@ def pages(self) -> Iterator[job_service.ListBatchPredictionJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[batch_prediction_job.BatchPredictionJob]: @@ -768,6 +913,8 @@ def __init__( request: job_service.ListBatchPredictionJobsRequest, response: job_service.ListBatchPredictionJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -779,12 +926,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListBatchPredictionJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -795,7 +947,12 @@ async def pages(self) -> AsyncIterator[job_service.ListBatchPredictionJobsRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[batch_prediction_job.BatchPredictionJob]: @@ -836,6 +993,8 @@ def __init__( request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, response: job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -847,6 +1006,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -855,6 +1017,8 @@ def __init__( job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -867,7 +1031,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__( @@ -909,6 +1078,8 @@ def __init__( request: job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest, response: job_service.SearchModelDeploymentMonitoringStatsAnomaliesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -920,6 +1091,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchModelDeploymentMonitoringStatsAnomaliesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -928,6 +1102,8 @@ def __init__( job_service.SearchModelDeploymentMonitoringStatsAnomaliesRequest(request) ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -942,7 +1118,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__( @@ -985,6 +1166,8 @@ def __init__( request: job_service.ListModelDeploymentMonitoringJobsRequest, response: job_service.ListModelDeploymentMonitoringJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -996,12 +1179,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListModelDeploymentMonitoringJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -1012,7 +1200,12 @@ def pages(self) -> Iterator[job_service.ListModelDeploymentMonitoringJobsRespons yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__( @@ -1051,6 +1244,8 @@ def __init__( request: job_service.ListModelDeploymentMonitoringJobsRequest, response: job_service.ListModelDeploymentMonitoringJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -1062,12 +1257,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelDeploymentMonitoringJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = job_service.ListModelDeploymentMonitoringJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -1080,7 +1280,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__( diff --git a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/rest.py index 48037054d7..1f99d082c4 100644 --- a/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/llm_utility_service/transports/rest.py @@ -481,6 +481,16 @@ def __call__( "uri": "/v1beta1/{endpoint=projects/*/locations/*/publishers/*/models/*}:computeTokens", "body": "*", }, + { + "method": "post", + "uri": "/v1beta1/{endpoint=endpoints/*}:computeTokens", + "body": "*", + }, + { + "method": "post", + "uri": "/v1beta1/{endpoint=publishers/*/models/*}:computeTokens", + "body": "*", + }, ] request, metadata = self._interceptor.pre_compute_tokens(request, metadata) pb_request = llm_utility_service.ComputeTokensRequest.pb(request) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py index 84204a5992..a83cc9087a 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py @@ -671,6 +671,8 @@ async def sample_list_metadata_stores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1159,6 +1161,8 @@ async def sample_list_artifacts(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1885,6 +1889,8 @@ async def sample_list_contexts(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3113,6 +3119,8 @@ async def sample_list_executions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4091,6 +4099,8 @@ async def sample_list_metadata_schemas(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py index bcf6bc6deb..294b7e7ee0 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py @@ -1165,6 +1165,8 @@ def sample_list_metadata_stores(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1641,6 +1643,8 @@ def sample_list_artifacts(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2349,6 +2353,8 @@ def sample_list_contexts(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3551,6 +3557,8 @@ def sample_list_executions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4507,6 +4515,8 @@ def sample_list_metadata_schemas(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py index e844867ed1..0b9886e0ce 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import artifact from google.cloud.aiplatform_v1beta1.types import context from google.cloud.aiplatform_v1beta1.types import execution @@ -56,6 +69,8 @@ def __init__( request: metadata_service.ListMetadataStoresRequest, response: metadata_service.ListMetadataStoresResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -67,12 +82,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListMetadataStoresResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListMetadataStoresRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -83,7 +103,12 @@ def pages(self) -> Iterator[metadata_service.ListMetadataStoresResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[metadata_store.MetadataStore]: @@ -118,6 +143,8 @@ def __init__( request: metadata_service.ListMetadataStoresRequest, response: metadata_service.ListMetadataStoresResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -129,12 +156,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListMetadataStoresResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListMetadataStoresRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages(self) -> AsyncIterator[metadata_service.ListMetadataStoresRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[metadata_store.MetadataStore]: @@ -184,6 +221,8 @@ def __init__( request: metadata_service.ListArtifactsRequest, response: metadata_service.ListArtifactsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -195,12 +234,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListArtifactsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListArtifactsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -211,7 +255,12 @@ def pages(self) -> Iterator[metadata_service.ListArtifactsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[artifact.Artifact]: @@ -246,6 +295,8 @@ def __init__( request: metadata_service.ListArtifactsRequest, response: metadata_service.ListArtifactsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -257,12 +308,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListArtifactsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListArtifactsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -273,7 +329,12 @@ async def pages(self) -> AsyncIterator[metadata_service.ListArtifactsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[artifact.Artifact]: @@ -312,6 +373,8 @@ def __init__( request: metadata_service.ListContextsRequest, response: metadata_service.ListContextsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -323,12 +386,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListContextsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListContextsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -339,7 +407,12 @@ def pages(self) -> Iterator[metadata_service.ListContextsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[context.Context]: @@ -374,6 +447,8 @@ def __init__( request: metadata_service.ListContextsRequest, response: metadata_service.ListContextsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -385,12 +460,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListContextsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListContextsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -401,7 +481,12 @@ async def pages(self) -> AsyncIterator[metadata_service.ListContextsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[context.Context]: @@ -440,6 +525,8 @@ def __init__( request: metadata_service.ListExecutionsRequest, response: metadata_service.ListExecutionsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -451,12 +538,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListExecutionsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListExecutionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -467,7 +559,12 @@ def pages(self) -> Iterator[metadata_service.ListExecutionsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[execution.Execution]: @@ -502,6 +599,8 @@ def __init__( request: metadata_service.ListExecutionsRequest, response: metadata_service.ListExecutionsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -513,12 +612,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListExecutionsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListExecutionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -529,7 +633,12 @@ async def pages(self) -> AsyncIterator[metadata_service.ListExecutionsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[execution.Execution]: @@ -568,6 +677,8 @@ def __init__( request: metadata_service.ListMetadataSchemasRequest, response: metadata_service.ListMetadataSchemasResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -579,12 +690,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListMetadataSchemasRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -595,7 +711,12 @@ def pages(self) -> Iterator[metadata_service.ListMetadataSchemasResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[metadata_schema.MetadataSchema]: @@ -630,6 +751,8 @@ def __init__( request: metadata_service.ListMetadataSchemasRequest, response: metadata_service.ListMetadataSchemasResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -641,12 +764,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListMetadataSchemasResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = metadata_service.ListMetadataSchemasRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -659,7 +787,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[metadata_schema.MetadataSchema]: diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py index a7552e626e..da735b908f 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py @@ -403,6 +403,8 @@ async def sample_search_migratable_resources(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index 007273fbb2..07901f6e15 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -216,40 +216,40 @@ def parse_annotated_dataset_path(path: str) -> Dict[str, str]: @staticmethod def dataset_path( project: str, + location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( + return "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod def dataset_path( project: str, - location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( + return "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod @@ -935,6 +935,8 @@ def sample_search_migratable_resources(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py index bc0169be00..73692a5dad 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import migratable_resource from google.cloud.aiplatform_v1beta1.types import migration_service @@ -52,6 +65,8 @@ def __init__( request: migration_service.SearchMigratableResourcesRequest, response: migration_service.SearchMigratableResourcesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = migration_service.SearchMigratableResourcesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[migration_service.SearchMigratableResourcesResponse] yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[migratable_resource.MigratableResource]: @@ -116,6 +141,8 @@ def __init__( request: migration_service.SearchMigratableResourcesRequest, response: migration_service.SearchMigratableResourcesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = migration_service.SearchMigratableResourcesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[migratable_resource.MigratableResource]: diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/async_client.py index 95e6fedb59..73481da6fb 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_garden_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/async_client.py @@ -73,6 +73,10 @@ class ModelGardenServiceAsyncClient: parse_publisher_model_path = staticmethod( ModelGardenServiceClient.parse_publisher_model_path ) + reservation_path = staticmethod(ModelGardenServiceClient.reservation_path) + parse_reservation_path = staticmethod( + ModelGardenServiceClient.parse_reservation_path + ) common_billing_account_path = staticmethod( ModelGardenServiceClient.common_billing_account_path ) @@ -492,6 +496,8 @@ async def sample_list_publisher_models(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/client.py index c712ce578d..c0ea41be31 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_garden_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/client.py @@ -205,6 +205,28 @@ def parse_publisher_model_path(path: str) -> Dict[str, str]: m = re.match(r"^publishers/(?P.+?)/models/(?P.+?)$", path) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path( billing_account: str, @@ -898,6 +920,8 @@ def sample_list_publisher_models(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_garden_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/model_garden_service/pagers.py index 9086199d4f..df4ff9eac9 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_garden_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/model_garden_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import model_garden_service from google.cloud.aiplatform_v1beta1.types import publisher_model @@ -52,6 +65,8 @@ def __init__( request: model_garden_service.ListPublisherModelsRequest, response: model_garden_service.ListPublisherModelsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListPublisherModelsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_garden_service.ListPublisherModelsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[model_garden_service.ListPublisherModelsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[publisher_model.PublisherModel]: @@ -116,6 +141,8 @@ def __init__( request: model_garden_service.ListPublisherModelsRequest, response: model_garden_service.ListPublisherModelsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListPublisherModelsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_garden_service.ListPublisherModelsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[publisher_model.PublisherModel]: diff --git a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/async_client.py index c6e7bb54d5..e0a7c79050 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/async_client.py @@ -109,6 +109,10 @@ class ModelMonitoringServiceAsyncClient: parse_model_monitoring_job_path = staticmethod( ModelMonitoringServiceClient.parse_model_monitoring_job_path ) + reservation_path = staticmethod(ModelMonitoringServiceClient.reservation_path) + parse_reservation_path = staticmethod( + ModelMonitoringServiceClient.parse_reservation_path + ) schedule_path = staticmethod(ModelMonitoringServiceClient.schedule_path) parse_schedule_path = staticmethod(ModelMonitoringServiceClient.parse_schedule_path) common_billing_account_path = staticmethod( @@ -808,6 +812,8 @@ async def sample_list_model_monitors(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1302,6 +1308,8 @@ async def sample_list_model_monitoring_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1563,6 +1571,8 @@ async def sample_search_model_monitoring_stats(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1690,6 +1700,8 @@ async def sample_search_model_monitoring_alerts(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/client.py index 3bf2401174..16aad8593c 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/client.py @@ -340,6 +340,28 @@ def parse_model_monitoring_job_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def schedule_path( project: str, @@ -1327,6 +1349,8 @@ def sample_list_model_monitors(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1813,6 +1837,8 @@ def sample_list_model_monitoring_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2072,6 +2098,8 @@ def sample_search_model_monitoring_stats(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2198,6 +2226,8 @@ def sample_search_model_monitoring_alerts(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/pagers.py index 35379d323c..4531b433d7 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/model_monitoring_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import model_monitor from google.cloud.aiplatform_v1beta1.types import model_monitoring_alert from google.cloud.aiplatform_v1beta1.types import model_monitoring_job @@ -55,6 +68,8 @@ def __init__( request: model_monitoring_service.ListModelMonitorsRequest, response: model_monitoring_service.ListModelMonitorsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -66,12 +81,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelMonitorsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_monitoring_service.ListModelMonitorsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -82,7 +102,12 @@ def pages(self) -> Iterator[model_monitoring_service.ListModelMonitorsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model_monitor.ModelMonitor]: @@ -119,6 +144,8 @@ def __init__( request: model_monitoring_service.ListModelMonitorsRequest, response: model_monitoring_service.ListModelMonitorsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -130,12 +157,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelMonitorsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_monitoring_service.ListModelMonitorsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -148,7 +180,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model_monitor.ModelMonitor]: @@ -187,6 +224,8 @@ def __init__( request: model_monitoring_service.ListModelMonitoringJobsRequest, response: model_monitoring_service.ListModelMonitoringJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -198,12 +237,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelMonitoringJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_monitoring_service.ListModelMonitoringJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -216,7 +260,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model_monitoring_job.ModelMonitoringJob]: @@ -253,6 +302,8 @@ def __init__( request: model_monitoring_service.ListModelMonitoringJobsRequest, response: model_monitoring_service.ListModelMonitoringJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -264,12 +315,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelMonitoringJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_monitoring_service.ListModelMonitoringJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -282,7 +338,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model_monitoring_job.ModelMonitoringJob]: @@ -323,6 +384,8 @@ def __init__( request: model_monitoring_service.SearchModelMonitoringStatsRequest, response: model_monitoring_service.SearchModelMonitoringStatsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -334,6 +397,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchModelMonitoringStatsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -342,6 +408,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -354,7 +422,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model_monitoring_stats.ModelMonitoringStats]: @@ -391,6 +464,8 @@ def __init__( request: model_monitoring_service.SearchModelMonitoringStatsRequest, response: model_monitoring_service.SearchModelMonitoringStatsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -402,6 +477,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchModelMonitoringStatsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -410,6 +488,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -422,7 +502,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model_monitoring_stats.ModelMonitoringStats]: @@ -463,6 +548,8 @@ def __init__( request: model_monitoring_service.SearchModelMonitoringAlertsRequest, response: model_monitoring_service.SearchModelMonitoringAlertsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -474,6 +561,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchModelMonitoringAlertsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -482,6 +572,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -494,7 +586,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model_monitoring_alert.ModelMonitoringAlert]: @@ -531,6 +628,8 @@ def __init__( request: model_monitoring_service.SearchModelMonitoringAlertsRequest, response: model_monitoring_service.SearchModelMonitoringAlertsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -542,6 +641,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.SearchModelMonitoringAlertsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -550,6 +652,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -562,7 +666,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model_monitoring_alert.ModelMonitoringAlert]: diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py index 25f3665fe3..a3e1301379 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -654,6 +654,8 @@ async def sample_list_models(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -773,6 +775,8 @@ async def sample_list_model_versions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2322,6 +2326,8 @@ async def sample_list_model_evaluations(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2555,6 +2561,8 @@ async def sample_list_model_evaluation_slices(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index 73e8054285..78bd024c4a 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -1141,6 +1141,8 @@ def sample_list_models(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1257,6 +1259,8 @@ def sample_list_model_versions(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2776,6 +2780,8 @@ def sample_list_model_evaluations(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3007,6 +3013,8 @@ def sample_list_model_evaluation_slices(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py index 4789b83c9e..19e04948b7 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import model_evaluation from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice @@ -54,6 +67,8 @@ def __init__( request: model_service.ListModelsRequest, response: model_service.ListModelsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -65,12 +80,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -81,7 +101,12 @@ def pages(self) -> Iterator[model_service.ListModelsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model.Model]: @@ -116,6 +141,8 @@ def __init__( request: model_service.ListModelsRequest, response: model_service.ListModelsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -143,7 +175,12 @@ async def pages(self) -> AsyncIterator[model_service.ListModelsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model.Model]: @@ -182,6 +219,8 @@ def __init__( request: model_service.ListModelVersionsRequest, response: model_service.ListModelVersionsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -193,12 +232,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelVersionsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelVersionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -209,7 +253,12 @@ def pages(self) -> Iterator[model_service.ListModelVersionsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model.Model]: @@ -244,6 +293,8 @@ def __init__( request: model_service.ListModelVersionsRequest, response: model_service.ListModelVersionsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -255,12 +306,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelVersionsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelVersionsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -271,7 +327,12 @@ async def pages(self) -> AsyncIterator[model_service.ListModelVersionsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model.Model]: @@ -310,6 +371,8 @@ def __init__( request: model_service.ListModelEvaluationsRequest, response: model_service.ListModelEvaluationsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -321,12 +384,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -337,7 +405,12 @@ def pages(self) -> Iterator[model_service.ListModelEvaluationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model_evaluation.ModelEvaluation]: @@ -372,6 +445,8 @@ def __init__( request: model_service.ListModelEvaluationsRequest, response: model_service.ListModelEvaluationsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -383,12 +458,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -399,7 +479,12 @@ async def pages(self) -> AsyncIterator[model_service.ListModelEvaluationsRespons yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model_evaluation.ModelEvaluation]: @@ -438,6 +523,8 @@ def __init__( request: model_service.ListModelEvaluationSlicesRequest, response: model_service.ListModelEvaluationSlicesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -449,12 +536,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationSlicesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -465,7 +557,12 @@ def pages(self) -> Iterator[model_service.ListModelEvaluationSlicesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[model_evaluation_slice.ModelEvaluationSlice]: @@ -502,6 +599,8 @@ def __init__( request: model_service.ListModelEvaluationSlicesRequest, response: model_service.ListModelEvaluationSlicesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -513,12 +612,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationSlicesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -531,7 +635,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[model_evaluation_slice.ModelEvaluationSlice]: diff --git a/google/cloud/aiplatform_v1beta1/services/notebook_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/notebook_service/async_client.py index 1800db5caf..1fee5a93aa 100644 --- a/google/cloud/aiplatform_v1beta1/services/notebook_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/notebook_service/async_client.py @@ -110,6 +110,8 @@ class NotebookServiceAsyncClient: parse_notebook_runtime_template_path = staticmethod( NotebookServiceClient.parse_notebook_runtime_template_path ) + reservation_path = staticmethod(NotebookServiceClient.reservation_path) + parse_reservation_path = staticmethod(NotebookServiceClient.parse_reservation_path) schedule_path = staticmethod(NotebookServiceClient.schedule_path) parse_schedule_path = staticmethod(NotebookServiceClient.parse_schedule_path) subnetwork_path = staticmethod(NotebookServiceClient.subnetwork_path) @@ -693,6 +695,8 @@ async def sample_list_notebook_runtime_templates(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1371,6 +1375,8 @@ async def sample_list_notebook_runtimes(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2145,6 +2151,8 @@ async def sample_list_notebook_execution_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/notebook_service/client.py b/google/cloud/aiplatform_v1beta1/services/notebook_service/client.py index 68f7414072..c1065d3d6e 100644 --- a/google/cloud/aiplatform_v1beta1/services/notebook_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/notebook_service/client.py @@ -296,6 +296,28 @@ def parse_notebook_runtime_template_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def schedule_path( project: str, @@ -1196,6 +1218,8 @@ def sample_list_notebook_runtime_templates(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1863,6 +1887,8 @@ def sample_list_notebook_runtimes(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2625,6 +2651,8 @@ def sample_list_notebook_execution_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/notebook_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/notebook_service/pagers.py index 1f9a2dc346..9ce5b53212 100644 --- a/google/cloud/aiplatform_v1beta1/services/notebook_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/notebook_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import notebook_execution_job from google.cloud.aiplatform_v1beta1.types import notebook_runtime from google.cloud.aiplatform_v1beta1.types import notebook_service @@ -53,6 +66,8 @@ def __init__( request: notebook_service.ListNotebookRuntimeTemplatesRequest, response: notebook_service.ListNotebookRuntimeTemplatesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -64,12 +79,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListNotebookRuntimeTemplatesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookRuntimeTemplatesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -80,7 +100,12 @@ def pages(self) -> Iterator[notebook_service.ListNotebookRuntimeTemplatesRespons yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[notebook_runtime.NotebookRuntimeTemplate]: @@ -117,6 +142,8 @@ def __init__( request: notebook_service.ListNotebookRuntimeTemplatesRequest, response: notebook_service.ListNotebookRuntimeTemplatesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -128,12 +155,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListNotebookRuntimeTemplatesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookRuntimeTemplatesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -146,7 +178,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[notebook_runtime.NotebookRuntimeTemplate]: @@ -185,6 +222,8 @@ def __init__( request: notebook_service.ListNotebookRuntimesRequest, response: notebook_service.ListNotebookRuntimesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -196,12 +235,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListNotebookRuntimesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookRuntimesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -212,7 +256,12 @@ def pages(self) -> Iterator[notebook_service.ListNotebookRuntimesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[notebook_runtime.NotebookRuntime]: @@ -247,6 +296,8 @@ def __init__( request: notebook_service.ListNotebookRuntimesRequest, response: notebook_service.ListNotebookRuntimesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -258,12 +309,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListNotebookRuntimesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookRuntimesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -276,7 +332,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[notebook_runtime.NotebookRuntime]: @@ -315,6 +376,8 @@ def __init__( request: notebook_service.ListNotebookExecutionJobsRequest, response: notebook_service.ListNotebookExecutionJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -326,12 +389,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookExecutionJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -342,7 +410,12 @@ def pages(self) -> Iterator[notebook_service.ListNotebookExecutionJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[notebook_execution_job.NotebookExecutionJob]: @@ -379,6 +452,8 @@ def __init__( request: notebook_service.ListNotebookExecutionJobsRequest, response: notebook_service.ListNotebookExecutionJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -390,12 +465,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListNotebookExecutionJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = notebook_service.ListNotebookExecutionJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -408,7 +488,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[notebook_execution_job.NotebookExecutionJob]: diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py index 054360a275..8c530b8536 100644 --- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/async_client.py @@ -106,6 +106,10 @@ class PersistentResourceServiceAsyncClient: parse_persistent_resource_path = staticmethod( PersistentResourceServiceClient.parse_persistent_resource_path ) + reservation_path = staticmethod(PersistentResourceServiceClient.reservation_path) + parse_reservation_path = staticmethod( + PersistentResourceServiceClient.parse_reservation_path + ) common_billing_account_path = staticmethod( PersistentResourceServiceClient.common_billing_account_path ) @@ -701,6 +705,8 @@ async def sample_list_persistent_resources(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py index 90b8ee19c1..3bc84d35ed 100644 --- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/client.py @@ -286,6 +286,28 @@ def parse_persistent_resource_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path( billing_account: str, @@ -1154,6 +1176,8 @@ def sample_list_persistent_resources(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/pagers.py index 1365b69727..ec0d30ef5a 100644 --- a/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/persistent_resource_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import persistent_resource from google.cloud.aiplatform_v1beta1.types import persistent_resource_service @@ -54,6 +67,8 @@ def __init__( request: persistent_resource_service.ListPersistentResourcesRequest, response: persistent_resource_service.ListPersistentResourcesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -65,6 +80,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListPersistentResourcesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -73,6 +91,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -85,7 +105,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[persistent_resource.PersistentResource]: @@ -122,6 +147,8 @@ def __init__( request: persistent_resource_service.ListPersistentResourcesRequest, response: persistent_resource_service.ListPersistentResourcesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -133,6 +160,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListPersistentResourcesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -141,6 +171,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -153,7 +185,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[persistent_resource.PersistentResource]: diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py index 1b88c1eb48..c67ef2a73a 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -662,6 +662,8 @@ async def sample_list_training_pipelines(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1260,6 +1262,8 @@ async def sample_list_pipeline_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index 4d20c66776..27b99126ed 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -1231,6 +1231,8 @@ def sample_list_training_pipelines(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1814,6 +1816,8 @@ def sample_list_pipeline_jobs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py index 3e64373b12..6d60b4ed93 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import pipeline_job from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline @@ -53,6 +66,8 @@ def __init__( request: pipeline_service.ListTrainingPipelinesRequest, response: pipeline_service.ListTrainingPipelinesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -64,12 +79,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = pipeline_service.ListTrainingPipelinesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -80,7 +100,12 @@ def pages(self) -> Iterator[pipeline_service.ListTrainingPipelinesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[training_pipeline.TrainingPipeline]: @@ -117,6 +142,8 @@ def __init__( request: pipeline_service.ListTrainingPipelinesRequest, response: pipeline_service.ListTrainingPipelinesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -128,12 +155,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = pipeline_service.ListTrainingPipelinesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -146,7 +178,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[training_pipeline.TrainingPipeline]: @@ -185,6 +222,8 @@ def __init__( request: pipeline_service.ListPipelineJobsRequest, response: pipeline_service.ListPipelineJobsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -196,12 +235,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListPipelineJobsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = pipeline_service.ListPipelineJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -212,7 +256,12 @@ def pages(self) -> Iterator[pipeline_service.ListPipelineJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[pipeline_job.PipelineJob]: @@ -247,6 +296,8 @@ def __init__( request: pipeline_service.ListPipelineJobsRequest, response: pipeline_service.ListPipelineJobsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -258,12 +309,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListPipelineJobsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = pipeline_service.ListPipelineJobsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -274,7 +330,12 @@ async def pages(self) -> AsyncIterator[pipeline_service.ListPipelineJobsResponse yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[pipeline_job.PipelineJob]: diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/rest.py index 33c179b05f..8af19f63a1 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/rest.py @@ -932,6 +932,16 @@ def __call__( "uri": "/v1beta1/{endpoint=projects/*/locations/*/publishers/*/models/*}:countTokens", "body": "*", }, + { + "method": "post", + "uri": "/v1beta1/{endpoint=endpoints/*}:countTokens", + "body": "*", + }, + { + "method": "post", + "uri": "/v1beta1/{endpoint=publishers/*/models/*}:countTokens", + "body": "*", + }, ] request, metadata = self._interceptor.pre_count_tokens(request, metadata) pb_request = prediction_service.CountTokensRequest.pb(request) @@ -1318,6 +1328,16 @@ def __call__( "uri": "/v1beta1/{model=projects/*/locations/*/publishers/*/models/*}:generateContent", "body": "*", }, + { + "method": "post", + "uri": "/v1beta1/{model=endpoints/*}:generateContent", + "body": "*", + }, + { + "method": "post", + "uri": "/v1beta1/{model=publishers/*/models/*}:generateContent", + "body": "*", + }, ] request, metadata = self._interceptor.pre_generate_content( request, metadata @@ -1804,6 +1824,16 @@ def __call__( "uri": "/v1beta1/{model=projects/*/locations/*/publishers/*/models/*}:streamGenerateContent", "body": "*", }, + { + "method": "post", + "uri": "/v1beta1/{model=endpoints/*}:streamGenerateContent", + "body": "*", + }, + { + "method": "post", + "uri": "/v1beta1/{model=publishers/*/models/*}:streamGenerateContent", + "body": "*", + }, ] request, metadata = self._interceptor.pre_stream_generate_content( request, metadata diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/async_client.py index 3b776f2451..f15cd84301 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/async_client.py @@ -650,6 +650,8 @@ async def sample_list_reasoning_engines(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/client.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/client.py index cb50b27320..4e18d08e86 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/client.py @@ -1056,6 +1056,8 @@ def sample_list_reasoning_engines(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/pagers.py index fa79ea46b9..494e469ba0 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import reasoning_engine from google.cloud.aiplatform_v1beta1.types import reasoning_engine_service @@ -52,6 +65,8 @@ def __init__( request: reasoning_engine_service.ListReasoningEnginesRequest, response: reasoning_engine_service.ListReasoningEnginesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListReasoningEnginesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = reasoning_engine_service.ListReasoningEnginesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[reasoning_engine_service.ListReasoningEnginesRespons yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[reasoning_engine.ReasoningEngine]: @@ -116,6 +141,8 @@ def __init__( request: reasoning_engine_service.ListReasoningEnginesRequest, response: reasoning_engine_service.ListReasoningEnginesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListReasoningEnginesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = reasoning_engine_service.ListReasoningEnginesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[reasoning_engine.ReasoningEngine]: diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py index 542ad106bc..13c3a500cf 100644 --- a/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/async_client.py @@ -127,6 +127,8 @@ class ScheduleServiceAsyncClient: parse_pipeline_job_path = staticmethod( ScheduleServiceClient.parse_pipeline_job_path ) + reservation_path = staticmethod(ScheduleServiceClient.reservation_path) + parse_reservation_path = staticmethod(ScheduleServiceClient.parse_reservation_path) schedule_path = staticmethod(ScheduleServiceClient.schedule_path) parse_schedule_path = staticmethod(ScheduleServiceClient.parse_schedule_path) common_billing_account_path = staticmethod( @@ -798,6 +800,8 @@ async def sample_list_schedules(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py index dd9c6cb990..6bdbba5567 100644 --- a/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/client.py @@ -491,6 +491,28 @@ def parse_pipeline_job_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def reservation_path( + project_id_or_number: str, + zone: str, + reservation_name: str, + ) -> str: + """Returns a fully-qualified reservation string.""" + return "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + + @staticmethod + def parse_reservation_path(path: str) -> Dict[str, str]: + """Parses a reservation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/zones/(?P.+?)/reservations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def schedule_path( project: str, @@ -1450,6 +1472,8 @@ def sample_list_schedules(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/schedule_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/schedule_service/pagers.py index b114c40a3c..f71ce81012 100644 --- a/google/cloud/aiplatform_v1beta1/services/schedule_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/schedule_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import schedule from google.cloud.aiplatform_v1beta1.types import schedule_service @@ -52,6 +65,8 @@ def __init__( request: schedule_service.ListSchedulesRequest, response: schedule_service.ListSchedulesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListSchedulesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = schedule_service.ListSchedulesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[schedule_service.ListSchedulesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[schedule.Schedule]: @@ -114,6 +139,8 @@ def __init__( request: schedule_service.ListSchedulesRequest, response: schedule_service.ListSchedulesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -125,12 +152,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListSchedulesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = schedule_service.ListSchedulesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -141,7 +173,12 @@ async def pages(self) -> AsyncIterator[schedule_service.ListSchedulesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[schedule.Schedule]: diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py index 83a445a311..91fe7e1844 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -664,6 +664,8 @@ async def sample_list_specialist_pools(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index 06b0f507a1..c158526065 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -1070,6 +1070,8 @@ def sample_list_specialist_pools(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py index 3bd3bddb81..b78d12a6a5 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import specialist_pool from google.cloud.aiplatform_v1beta1.types import specialist_pool_service @@ -52,6 +65,8 @@ def __init__( request: specialist_pool_service.ListSpecialistPoolsRequest, response: specialist_pool_service.ListSpecialistPoolsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[specialist_pool_service.ListSpecialistPoolsResponse] yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[specialist_pool.SpecialistPool]: @@ -116,6 +141,8 @@ def __init__( request: specialist_pool_service.ListSpecialistPoolsRequest, response: specialist_pool_service.ListSpecialistPoolsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[specialist_pool.SpecialistPool]: diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py index e598b869f0..2b7252bba8 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/async_client.py @@ -816,6 +816,8 @@ async def sample_list_tensorboards(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1676,6 +1678,8 @@ async def sample_list_tensorboard_experiments(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2447,6 +2451,8 @@ async def sample_list_tensorboard_runs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3217,6 +3223,8 @@ async def sample_list_tensorboard_time_series(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4105,6 +4113,8 @@ async def sample_export_tensorboard_time_series_data(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py index 32fe9b9bf5..48438b7542 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py @@ -1284,6 +1284,8 @@ def sample_list_tensorboards(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2131,6 +2133,8 @@ def sample_list_tensorboard_experiments(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -2888,6 +2892,8 @@ def sample_list_tensorboard_runs(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -3650,6 +3656,8 @@ def sample_list_tensorboard_time_series(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -4531,6 +4539,8 @@ def sample_export_tensorboard_time_series_data(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/pagers.py index ab65964c69..425e7088c4 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import tensorboard from google.cloud.aiplatform_v1beta1.types import tensorboard_data from google.cloud.aiplatform_v1beta1.types import tensorboard_experiment @@ -56,6 +69,8 @@ def __init__( request: tensorboard_service.ListTensorboardsRequest, response: tensorboard_service.ListTensorboardsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -67,12 +82,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTensorboardsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -83,7 +103,12 @@ def pages(self) -> Iterator[tensorboard_service.ListTensorboardsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tensorboard.Tensorboard]: @@ -118,6 +143,8 @@ def __init__( request: tensorboard_service.ListTensorboardsRequest, response: tensorboard_service.ListTensorboardsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -129,12 +156,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTensorboardsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -147,7 +179,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tensorboard.Tensorboard]: @@ -186,6 +223,8 @@ def __init__( request: tensorboard_service.ListTensorboardExperimentsRequest, response: tensorboard_service.ListTensorboardExperimentsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -197,12 +236,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardExperimentsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -213,7 +257,12 @@ def pages(self) -> Iterator[tensorboard_service.ListTensorboardExperimentsRespon yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tensorboard_experiment.TensorboardExperiment]: @@ -250,6 +299,8 @@ def __init__( request: tensorboard_service.ListTensorboardExperimentsRequest, response: tensorboard_service.ListTensorboardExperimentsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -261,12 +312,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTensorboardExperimentsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardExperimentsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -279,7 +335,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tensorboard_experiment.TensorboardExperiment]: @@ -318,6 +379,8 @@ def __init__( request: tensorboard_service.ListTensorboardRunsRequest, response: tensorboard_service.ListTensorboardRunsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -329,12 +392,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardRunsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -345,7 +413,12 @@ def pages(self) -> Iterator[tensorboard_service.ListTensorboardRunsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tensorboard_run.TensorboardRun]: @@ -382,6 +455,8 @@ def __init__( request: tensorboard_service.ListTensorboardRunsRequest, response: tensorboard_service.ListTensorboardRunsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -393,12 +468,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTensorboardRunsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardRunsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -411,7 +491,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tensorboard_run.TensorboardRun]: @@ -450,6 +535,8 @@ def __init__( request: tensorboard_service.ListTensorboardTimeSeriesRequest, response: tensorboard_service.ListTensorboardTimeSeriesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -461,12 +548,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -477,7 +569,12 @@ def pages(self) -> Iterator[tensorboard_service.ListTensorboardTimeSeriesRespons yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tensorboard_time_series.TensorboardTimeSeries]: @@ -514,6 +611,8 @@ def __init__( request: tensorboard_service.ListTensorboardTimeSeriesRequest, response: tensorboard_service.ListTensorboardTimeSeriesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -525,12 +624,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTensorboardTimeSeriesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = tensorboard_service.ListTensorboardTimeSeriesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -543,7 +647,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tensorboard_time_series.TensorboardTimeSeries]: @@ -584,6 +693,8 @@ def __init__( request: tensorboard_service.ExportTensorboardTimeSeriesDataRequest, response: tensorboard_service.ExportTensorboardTimeSeriesDataResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -595,6 +706,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -603,6 +717,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -615,7 +731,12 @@ def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[tensorboard_data.TimeSeriesDataPoint]: @@ -652,6 +773,8 @@ def __init__( request: tensorboard_service.ExportTensorboardTimeSeriesDataRequest, response: tensorboard_service.ExportTensorboardTimeSeriesDataResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -663,6 +786,9 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ExportTensorboardTimeSeriesDataResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ @@ -671,6 +797,8 @@ def __init__( request ) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -683,7 +811,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[tensorboard_data.TimeSeriesDataPoint]: diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py index b1d6c1a9b1..361809c22b 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/async_client.py @@ -646,6 +646,8 @@ async def sample_list_rag_corpora(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1279,6 +1281,8 @@ async def sample_list_rag_files(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py index 950e2b4116..021b4e50e3 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/client.py @@ -1134,6 +1134,8 @@ def sample_list_rag_corpora(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1752,6 +1754,8 @@ def sample_list_rag_files(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/pagers.py index ff69b42cc4..da6eb3df02 100644 --- a/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/vertex_rag_data_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import vertex_rag_data from google.cloud.aiplatform_v1beta1.types import vertex_rag_data_service @@ -52,6 +65,8 @@ def __init__( request: vertex_rag_data_service.ListRagCorporaRequest, response: vertex_rag_data_service.ListRagCorporaResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListRagCorporaResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vertex_rag_data_service.ListRagCorporaRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[vertex_rag_data_service.ListRagCorporaResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[vertex_rag_data.RagCorpus]: @@ -116,6 +141,8 @@ def __init__( request: vertex_rag_data_service.ListRagCorporaRequest, response: vertex_rag_data_service.ListRagCorporaResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -127,12 +154,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListRagCorporaResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vertex_rag_data_service.ListRagCorporaRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -145,7 +177,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[vertex_rag_data.RagCorpus]: @@ -184,6 +221,8 @@ def __init__( request: vertex_rag_data_service.ListRagFilesRequest, response: vertex_rag_data_service.ListRagFilesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -195,12 +234,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListRagFilesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vertex_rag_data_service.ListRagFilesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -211,7 +255,12 @@ def pages(self) -> Iterator[vertex_rag_data_service.ListRagFilesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[vertex_rag_data.RagFile]: @@ -246,6 +295,8 @@ def __init__( request: vertex_rag_data_service.ListRagFilesRequest, response: vertex_rag_data_service.ListRagFilesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -257,12 +308,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListRagFilesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vertex_rag_data_service.ListRagFilesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -275,7 +331,12 @@ async def pages( yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[vertex_rag_data.RagFile]: diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py index 2665192714..ab7fbb1fc7 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py @@ -619,6 +619,8 @@ async def sample_list_studies(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1272,6 +1274,8 @@ async def sample_list_trials(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py index f7eeb79c46..27c11404b2 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py @@ -1069,6 +1069,8 @@ def sample_list_studies(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) @@ -1705,6 +1707,8 @@ def sample_list_trials(): method=rpc, request=request, response=response, + retry=retry, + timeout=timeout, metadata=metadata, ) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py index ccf6fa8e35..a847b197be 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async from typing import ( Any, AsyncIterator, @@ -22,8 +25,18 @@ Tuple, Optional, Iterator, + Union, ) +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + from google.cloud.aiplatform_v1beta1.types import study from google.cloud.aiplatform_v1beta1.types import vizier_service @@ -52,6 +65,8 @@ def __init__( request: vizier_service.ListStudiesRequest, response: vizier_service.ListStudiesResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -63,12 +78,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListStudiesResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vizier_service.ListStudiesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -79,7 +99,12 @@ def pages(self) -> Iterator[vizier_service.ListStudiesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[study.Study]: @@ -114,6 +139,8 @@ def __init__( request: vizier_service.ListStudiesRequest, response: vizier_service.ListStudiesResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -125,12 +152,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListStudiesResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vizier_service.ListStudiesRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -141,7 +173,12 @@ async def pages(self) -> AsyncIterator[vizier_service.ListStudiesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[study.Study]: @@ -180,6 +217,8 @@ def __init__( request: vizier_service.ListTrialsRequest, response: vizier_service.ListTrialsResponse, *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -191,12 +230,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTrialsResponse): The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vizier_service.ListTrialsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -207,7 +251,12 @@ def pages(self) -> Iterator[vizier_service.ListTrialsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __iter__(self) -> Iterator[study.Trial]: @@ -242,6 +291,8 @@ def __init__( request: vizier_service.ListTrialsRequest, response: vizier_service.ListTrialsResponse, *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = () ): """Instantiates the pager. @@ -253,12 +304,17 @@ def __init__( The initial request object. response (google.cloud.aiplatform_v1beta1.types.ListTrialsResponse): The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. """ self._method = method self._request = vizier_service.ListTrialsRequest(request) self._response = response + self._retry = retry + self._timeout = timeout self._metadata = metadata def __getattr__(self, name: str) -> Any: @@ -269,7 +325,12 @@ async def pages(self) -> AsyncIterator[vizier_service.ListTrialsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) yield self._response def __aiter__(self) -> AsyncIterator[study.Trial]: diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 2ed65d01fe..18aee27842 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -209,6 +209,10 @@ GroundednessInstance, GroundednessResult, GroundednessSpec, + PairwiseMetricInput, + PairwiseMetricInstance, + PairwiseMetricResult, + PairwiseMetricSpec, PairwiseQuestionAnsweringQualityInput, PairwiseQuestionAnsweringQualityInstance, PairwiseQuestionAnsweringQualityResult, @@ -217,6 +221,10 @@ PairwiseSummarizationQualityInstance, PairwiseSummarizationQualityResult, PairwiseSummarizationQualitySpec, + PointwiseMetricInput, + PointwiseMetricInstance, + PointwiseMetricResult, + PointwiseMetricSpec, QuestionAnsweringCorrectnessInput, QuestionAnsweringCorrectnessInstance, QuestionAnsweringCorrectnessResult, @@ -968,6 +976,9 @@ UpdateReasoningEngineOperationMetadata, UpdateReasoningEngineRequest, ) +from .reservation_affinity import ( + ReservationAffinity, +) from .saved_query import ( SavedQuery, ) @@ -1330,6 +1341,10 @@ "GroundednessInstance", "GroundednessResult", "GroundednessSpec", + "PairwiseMetricInput", + "PairwiseMetricInstance", + "PairwiseMetricResult", + "PairwiseMetricSpec", "PairwiseQuestionAnsweringQualityInput", "PairwiseQuestionAnsweringQualityInstance", "PairwiseQuestionAnsweringQualityResult", @@ -1338,6 +1353,10 @@ "PairwiseSummarizationQualityInstance", "PairwiseSummarizationQualityResult", "PairwiseSummarizationQualitySpec", + "PointwiseMetricInput", + "PointwiseMetricInstance", + "PointwiseMetricResult", + "PointwiseMetricSpec", "QuestionAnsweringCorrectnessInput", "QuestionAnsweringCorrectnessInstance", "QuestionAnsweringCorrectnessResult", @@ -1938,6 +1957,7 @@ "ListReasoningEnginesResponse", "UpdateReasoningEngineOperationMetadata", "UpdateReasoningEngineRequest", + "ReservationAffinity", "SavedQuery", "Schedule", "CreateScheduleRequest", diff --git a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py index 101df74d6f..ec707a42c1 100644 --- a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py @@ -36,7 +36,9 @@ class AcceleratorType(proto.Enum): Unspecified accelerator type, which means no accelerator. NVIDIA_TESLA_K80 (1): - Nvidia Tesla K80 GPU. + Deprecated: Nvidia Tesla K80 GPU has reached + end of support, see + https://cloud.google.com/compute/docs/eol/k80-eol. NVIDIA_TESLA_P100 (2): Nvidia Tesla P100 GPU. NVIDIA_TESLA_V100 (3): diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 1274ff7a0a..bb26cc85a9 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -563,10 +563,18 @@ class Strategy(proto.Enum): LOW_COST (2): Low cost by making potential use of spot resources. + STANDARD (3): + Standard provisioning strategy uses regular + on-demand resources. + SPOT (4): + Spot provisioning strategy uses spot + resources. """ STRATEGY_UNSPECIFIED = 0 ON_DEMAND = 1 LOW_COST = 2 + STANDARD = 3 + SPOT = 4 timeout: duration_pb2.Duration = proto.Field( proto.MESSAGE, diff --git a/google/cloud/aiplatform_v1beta1/types/evaluation_service.py b/google/cloud/aiplatform_v1beta1/types/evaluation_service.py index 6dfbb855b1..32f525dff6 100644 --- a/google/cloud/aiplatform_v1beta1/types/evaluation_service.py +++ b/google/cloud/aiplatform_v1beta1/types/evaluation_service.py @@ -97,6 +97,14 @@ "QuestionAnsweringCorrectnessInstance", "QuestionAnsweringCorrectnessSpec", "QuestionAnsweringCorrectnessResult", + "PointwiseMetricInput", + "PointwiseMetricInstance", + "PointwiseMetricSpec", + "PointwiseMetricResult", + "PairwiseMetricInput", + "PairwiseMetricInstance", + "PairwiseMetricSpec", + "PairwiseMetricResult", "ToolCallValidInput", "ToolCallValidSpec", "ToolCallValidInstance", @@ -227,6 +235,14 @@ class EvaluateInstancesRequest(proto.Message): Input for question answering correctness metric. + This field is a member of `oneof`_ ``metric_inputs``. + pointwise_metric_input (google.cloud.aiplatform_v1beta1.types.PointwiseMetricInput): + Input for pointwise metric. + + This field is a member of `oneof`_ ``metric_inputs``. + pairwise_metric_input (google.cloud.aiplatform_v1beta1.types.PairwiseMetricInput): + Input for pairwise metric. + This field is a member of `oneof`_ ``metric_inputs``. tool_call_valid_input (google.cloud.aiplatform_v1beta1.types.ToolCallValidInput): Tool call metric instances. @@ -360,6 +376,18 @@ class EvaluateInstancesRequest(proto.Message): message="QuestionAnsweringCorrectnessInput", ) ) + pointwise_metric_input: "PointwiseMetricInput" = proto.Field( + proto.MESSAGE, + number=28, + oneof="metric_inputs", + message="PointwiseMetricInput", + ) + pairwise_metric_input: "PairwiseMetricInput" = proto.Field( + proto.MESSAGE, + number=29, + oneof="metric_inputs", + message="PairwiseMetricInput", + ) tool_call_valid_input: "ToolCallValidInput" = proto.Field( proto.MESSAGE, number=19, @@ -478,6 +506,15 @@ class EvaluateInstancesResponse(proto.Message): Result for question answering correctness metric. + This field is a member of `oneof`_ ``evaluation_results``. + pointwise_metric_result (google.cloud.aiplatform_v1beta1.types.PointwiseMetricResult): + Generic metrics. + Result for pointwise metric. + + This field is a member of `oneof`_ ``evaluation_results``. + pairwise_metric_result (google.cloud.aiplatform_v1beta1.types.PairwiseMetricResult): + Result for pairwise metric. + This field is a member of `oneof`_ ``evaluation_results``. tool_call_valid_results (google.cloud.aiplatform_v1beta1.types.ToolCallValidResults): Tool call metrics. @@ -609,6 +646,18 @@ class EvaluateInstancesResponse(proto.Message): message="QuestionAnsweringCorrectnessResult", ) ) + pointwise_metric_result: "PointwiseMetricResult" = proto.Field( + proto.MESSAGE, + number=27, + oneof="evaluation_results", + message="PointwiseMetricResult", + ) + pairwise_metric_result: "PairwiseMetricResult" = proto.Field( + proto.MESSAGE, + number=28, + oneof="evaluation_results", + message="PairwiseMetricResult", + ) tool_call_valid_results: "ToolCallValidResults" = proto.Field( proto.MESSAGE, number=18, @@ -2605,6 +2654,184 @@ class QuestionAnsweringCorrectnessResult(proto.Message): ) +class PointwiseMetricInput(proto.Message): + r"""Input for pointwise metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1beta1.types.PointwiseMetricSpec): + Required. Spec for pointwise metric. + instance (google.cloud.aiplatform_v1beta1.types.PointwiseMetricInstance): + Required. Pointwise metric instance. + """ + + metric_spec: "PointwiseMetricSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="PointwiseMetricSpec", + ) + instance: "PointwiseMetricInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="PointwiseMetricInstance", + ) + + +class PointwiseMetricInstance(proto.Message): + r"""Pointwise metric instance. Usually one instance corresponds + to one row in an evaluation dataset. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + json_instance (str): + Instance specified as a json string. String key-value pairs + are expected in the json_instance to render + PointwiseMetricSpec.instance_prompt_template. + + This field is a member of `oneof`_ ``instance``. + """ + + json_instance: str = proto.Field( + proto.STRING, + number=1, + oneof="instance", + ) + + +class PointwiseMetricSpec(proto.Message): + r"""Spec for pointwise metric. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + metric_prompt_template (str): + Required. Metric prompt template for + pointwise metric. + + This field is a member of `oneof`_ ``_metric_prompt_template``. + """ + + metric_prompt_template: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + + +class PointwiseMetricResult(proto.Message): + r"""Spec for pointwise metric result. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + score (float): + Output only. Pointwise metric score. + + This field is a member of `oneof`_ ``_score``. + explanation (str): + Output only. Explanation for pointwise metric + score. + """ + + score: float = proto.Field( + proto.FLOAT, + number=1, + optional=True, + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + + +class PairwiseMetricInput(proto.Message): + r"""Input for pairwise metric. + + Attributes: + metric_spec (google.cloud.aiplatform_v1beta1.types.PairwiseMetricSpec): + Required. Spec for pairwise metric. + instance (google.cloud.aiplatform_v1beta1.types.PairwiseMetricInstance): + Required. Pairwise metric instance. + """ + + metric_spec: "PairwiseMetricSpec" = proto.Field( + proto.MESSAGE, + number=1, + message="PairwiseMetricSpec", + ) + instance: "PairwiseMetricInstance" = proto.Field( + proto.MESSAGE, + number=2, + message="PairwiseMetricInstance", + ) + + +class PairwiseMetricInstance(proto.Message): + r"""Pairwise metric instance. Usually one instance corresponds to + one row in an evaluation dataset. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + json_instance (str): + Instance specified as a json string. String key-value pairs + are expected in the json_instance to render + PairwiseMetricSpec.instance_prompt_template. + + This field is a member of `oneof`_ ``instance``. + """ + + json_instance: str = proto.Field( + proto.STRING, + number=1, + oneof="instance", + ) + + +class PairwiseMetricSpec(proto.Message): + r"""Spec for pairwise metric. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + metric_prompt_template (str): + Required. Metric prompt template for pairwise + metric. + + This field is a member of `oneof`_ ``_metric_prompt_template``. + """ + + metric_prompt_template: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + + +class PairwiseMetricResult(proto.Message): + r"""Spec for pairwise metric result. + + Attributes: + pairwise_choice (google.cloud.aiplatform_v1beta1.types.PairwiseChoice): + Output only. Pairwise metric choice. + explanation (str): + Output only. Explanation for pairwise metric + score. + """ + + pairwise_choice: "PairwiseChoice" = proto.Field( + proto.ENUM, + number=1, + enum="PairwiseChoice", + ) + explanation: str = proto.Field( + proto.STRING, + number=2, + ) + + class ToolCallValidInput(proto.Message): r"""Input for tool call valid metric. diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py index 2012c50cff..ed91580b2f 100644 --- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py @@ -22,6 +22,9 @@ from google.cloud.aiplatform_v1beta1.types import ( accelerator_type as gca_accelerator_type, ) +from google.cloud.aiplatform_v1beta1.types import ( + reservation_affinity as gca_reservation_affinity, +) __protobuf__ = proto.module( @@ -73,6 +76,10 @@ class MachineSpec(proto.Message): Immutable. The topology of the TPUs. Corresponds to the TPU topologies available from GKE. (Example: tpu_topology: "2x2x1"). + reservation_affinity (google.cloud.aiplatform_v1beta1.types.ReservationAffinity): + Optional. Immutable. Configuration + controlling how this resource pool consumes + reservation. """ machine_type: str = proto.Field( @@ -92,6 +99,11 @@ class MachineSpec(proto.Message): proto.STRING, number=4, ) + reservation_affinity: gca_reservation_affinity.ReservationAffinity = proto.Field( + proto.MESSAGE, + number=5, + message=gca_reservation_affinity.ReservationAffinity, + ) class DedicatedResources(proto.Message): @@ -159,6 +171,9 @@ class DedicatedResources(proto.Message): and [autoscaling_metric_specs.target][google.cloud.aiplatform.v1beta1.AutoscalingMetricSpec.target] to ``80``. + spot (bool): + Optional. If true, schedule the deployment workload on `spot + VMs `__. """ machine_spec: "MachineSpec" = proto.Field( @@ -181,6 +196,10 @@ class DedicatedResources(proto.Message): number=4, message="AutoscalingMetricSpec", ) + spot: bool = proto.Field( + proto.BOOL, + number=5, + ) class AutomaticResources(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index 75b0687a0f..bdf11698a4 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -982,7 +982,9 @@ class UsageMetadata(proto.Message): Attributes: prompt_token_count (int): - Number of tokens in the request. + Number of tokens in the request. When ``cached_content`` is + set, this is still the total effective prompt size meaning + this includes the number of tokens in the cached content. candidates_token_count (int): Number of tokens in the response(s). total_token_count (int): diff --git a/google/cloud/aiplatform_v1beta1/types/reservation_affinity.py b/google/cloud/aiplatform_v1beta1/types/reservation_affinity.py new file mode 100644 index 0000000000..e84d92da74 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/reservation_affinity.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", + manifest={ + "ReservationAffinity", + }, +) + + +class ReservationAffinity(proto.Message): + r"""A ReservationAffinity can be used to configure a Vertex AI + resource (e.g., a DeployedModel) to draw its Compute Engine + resources from a Shared Reservation, or exclusively from + on-demand capacity. + + Attributes: + reservation_affinity_type (google.cloud.aiplatform_v1beta1.types.ReservationAffinity.Type): + Required. Specifies the reservation affinity + type. + key (str): + Optional. Corresponds to the label key of a reservation + resource. To target a SPECIFIC_RESERVATION by name, use + ``compute.googleapis.com/reservation-name`` as the key and + specify the name of your reservation as its value. + values (MutableSequence[str]): + Optional. Corresponds to the label values of + a reservation resource. This must be the full + resource name of the reservation. + """ + + class Type(proto.Enum): + r"""Identifies a type of reservation affinity. + + Values: + TYPE_UNSPECIFIED (0): + Default value. This should not be used. + NO_RESERVATION (1): + Do not consume from any reserved capacity, + only use on-demand. + ANY_RESERVATION (2): + Consume any reservation available, falling + back to on-demand. + SPECIFIC_RESERVATION (3): + Consume from a specific reservation. When chosen, the + reservation must be identified via the ``key`` and + ``values`` fields. + """ + TYPE_UNSPECIFIED = 0 + NO_RESERVATION = 1 + ANY_RESERVATION = 2 + SPECIFIC_RESERVATION = 3 + + reservation_affinity_type: Type = proto.Field( + proto.ENUM, + number=1, + enum=Type, + ) + key: str = proto.Field( + proto.STRING, + number=2, + ) + values: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=3, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owlbot.py b/owlbot.py index 9c43cccb31..ea1a2b9795 100644 --- a/owlbot.py +++ b/owlbot.py @@ -113,6 +113,8 @@ ".kokoro/presubmit/prerelease-deps.cfg", ".kokoro/docs/docs-presubmit.cfg", ".kokoro/release.sh", + ".kokoro/release/common.cfg", + ".kokoro/requirements*", # exclude sample configs so periodic samples are tested against main # instead of pypi ".kokoro/samples/python3.7/common.cfg", diff --git a/pypi/_vertex_ai_placeholder/version.py b/pypi/_vertex_ai_placeholder/version.py index 1f5f4512a0..30a38aa3a2 100644 --- a/pypi/_vertex_ai_placeholder/version.py +++ b/pypi/_vertex_ai_placeholder/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.61.0" +__version__ = "1.62.0" diff --git a/samples/generated_samples/aiplatform_v1_generated_evaluation_service_evaluate_instances_async.py b/samples/generated_samples/aiplatform_v1_generated_evaluation_service_evaluate_instances_async.py new file mode 100644 index 0000000000..8a52fa6e9a --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_evaluation_service_evaluate_instances_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for EvaluateInstances +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_EvaluationService_EvaluateInstances_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_evaluate_instances(): + # Create a client + client = aiplatform_v1.EvaluationServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.EvaluateInstancesRequest( + location="location_value", + ) + + # Make the request + response = await client.evaluate_instances(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_EvaluationService_EvaluateInstances_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_evaluation_service_evaluate_instances_sync.py b/samples/generated_samples/aiplatform_v1_generated_evaluation_service_evaluate_instances_sync.py new file mode 100644 index 0000000000..763c5d078f --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_evaluation_service_evaluate_instances_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for EvaluateInstances +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_EvaluationService_EvaluateInstances_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_evaluate_instances(): + # Create a client + client = aiplatform_v1.EvaluationServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.EvaluateInstancesRequest( + location="location_value", + ) + + # Make the request + response = client.evaluate_instances(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_EvaluationService_EvaluateInstances_sync] diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index 1e023bf576..0bc02249aa 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.61.0" + "version": "1.62.0" }, "snippets": [ { @@ -5460,6 +5460,159 @@ ], "title": "aiplatform_v1_generated_endpoint_service_update_endpoint_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.EvaluationServiceAsyncClient", + "shortName": "EvaluationServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.EvaluationServiceAsyncClient.evaluate_instances", + "method": { + "fullName": "google.cloud.aiplatform.v1.EvaluationService.EvaluateInstances", + "service": { + "fullName": "google.cloud.aiplatform.v1.EvaluationService", + "shortName": "EvaluationService" + }, + "shortName": "EvaluateInstances" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.EvaluateInstancesRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1.types.EvaluateInstancesResponse", + "shortName": "evaluate_instances" + }, + "description": "Sample for EvaluateInstances", + "file": "aiplatform_v1_generated_evaluation_service_evaluate_instances_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_EvaluationService_EvaluateInstances_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_evaluation_service_evaluate_instances_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.EvaluationServiceClient", + "shortName": "EvaluationServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.EvaluationServiceClient.evaluate_instances", + "method": { + "fullName": "google.cloud.aiplatform.v1.EvaluationService.EvaluateInstances", + "service": { + "fullName": "google.cloud.aiplatform.v1.EvaluationService", + "shortName": "EvaluationService" + }, + "shortName": "EvaluateInstances" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.EvaluateInstancesRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.cloud.aiplatform_v1.types.EvaluateInstancesResponse", + "shortName": "evaluate_instances" + }, + "description": "Sample for EvaluateInstances", + "file": "aiplatform_v1_generated_evaluation_service_evaluate_instances_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_EvaluationService_EvaluateInstances_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_evaluation_service_evaluate_instances_sync.py" + }, { "canonical": true, "clientMethod": { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index 506e06f4c3..0cec733527 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.61.0" + "version": "1.62.0" }, "snippets": [ { diff --git a/sdk_schema_tests/__init__.py b/sdk_schema_tests/__init__.py new file mode 100644 index 0000000000..1ee14b7820 --- /dev/null +++ b/sdk_schema_tests/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sdk_schema_tests/common_contract.py b/sdk_schema_tests/common_contract.py new file mode 100644 index 0000000000..9626372cac --- /dev/null +++ b/sdk_schema_tests/common_contract.py @@ -0,0 +1,24 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +expected_generate_content_common_arg_keys = ( + "self", + "contents", + "generation_config", + "safety_settings", + "tools", + "tool_config", + "stream", +) diff --git a/sdk_schema_tests/google_ai_only_contract.py b/sdk_schema_tests/google_ai_only_contract.py new file mode 100644 index 0000000000..26e5797583 --- /dev/null +++ b/sdk_schema_tests/google_ai_only_contract.py @@ -0,0 +1,16 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +expected_generate_content_return_annotation = "generation_types.GenerateContentResponse" diff --git a/sdk_schema_tests/google_ai_tests/__init__.py b/sdk_schema_tests/google_ai_tests/__init__.py new file mode 100644 index 0000000000..1ee14b7820 --- /dev/null +++ b/sdk_schema_tests/google_ai_tests/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sdk_schema_tests/google_ai_tests/method_signature_tests.py b/sdk_schema_tests/google_ai_tests/method_signature_tests.py new file mode 100644 index 0000000000..66890e4e95 --- /dev/null +++ b/sdk_schema_tests/google_ai_tests/method_signature_tests.py @@ -0,0 +1,50 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from inspect import signature +import unittest + +from google.generativeai import GenerativeModel +from sdk_schema_tests import common_contract +from sdk_schema_tests import google_ai_only_contract as specific_contract + +_SDK_NAME = "GoogleAI" + + +class TestGenerativeModelMethodSignatures(unittest.TestCase): + def test_generate_content_signature(self): + generate_content_signature = signature(GenerativeModel.generate_content) + actual_method_arg_keys = generate_content_signature.parameters.keys() + actual_return_annotation = generate_content_signature.return_annotation + + for expected_key in common_contract.expected_generate_content_common_arg_keys: + self.assertIn( + member=expected_key, + container=actual_method_arg_keys, + msg=( + f"[{_SDK_NAME}]: expected common key {expected_key} not found in" + f" actual keys: {actual_method_arg_keys}" + ), + ) + + self.assertEqual( + actual_return_annotation, + specific_contract.expected_generate_content_return_annotation, + msg=( + f"[{_SDK_NAME}]: expected return annotation" + f" {specific_contract.expected_generate_content_return_annotation}" + f" not equal to actual return annotation {actual_return_annotation}" + ), + ) diff --git a/sdk_schema_tests/vertex_ai_only_contract.py b/sdk_schema_tests/vertex_ai_only_contract.py new file mode 100644 index 0000000000..9825ee0564 --- /dev/null +++ b/sdk_schema_tests/vertex_ai_only_contract.py @@ -0,0 +1,21 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ForwardRef, Iterable, Union + +expected_generate_content_return_annotation = Union[ + ForwardRef("GenerationResponse"), # noqa: F821 + Iterable[ForwardRef("GenerationResponse")], # noqa: F821 +] diff --git a/sdk_schema_tests/vertex_ai_tests/__init__.py b/sdk_schema_tests/vertex_ai_tests/__init__.py new file mode 100644 index 0000000000..1ee14b7820 --- /dev/null +++ b/sdk_schema_tests/vertex_ai_tests/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sdk_schema_tests/vertex_ai_tests/method_signature_tests.py b/sdk_schema_tests/vertex_ai_tests/method_signature_tests.py new file mode 100644 index 0000000000..93901139a6 --- /dev/null +++ b/sdk_schema_tests/vertex_ai_tests/method_signature_tests.py @@ -0,0 +1,38 @@ +"""Tests for method_signature.""" + +import inspect +import unittest + +from vertexai.generative_models import GenerativeModel + +from sdk_schema_tests import common_contract +from sdk_schema_tests import vertex_ai_only_contract as specific_contract + +_SDK_NAME = "VertexAI" + + +class TestGenerativeModelMethodSignatures(unittest.TestCase): + def test_generate_content_signature(self): + generate_content_signature = inspect.signature(GenerativeModel.generate_content) + actual_method_arg_keys = generate_content_signature.parameters.keys() + actual_return_annotation = generate_content_signature.return_annotation + + for expected_key in common_contract.expected_generate_content_common_arg_keys: + self.assertIn( + member=expected_key, + container=actual_method_arg_keys, + msg=( + f"[{_SDK_NAME}]: expected common key {expected_key} not found in" + f" actual keys: {actual_method_arg_keys}" + ), + ) + + self.assertEqual( + actual_return_annotation, + specific_contract.expected_generate_content_return_annotation, + msg=( + f"[{_SDK_NAME}]: expected return annotation" + f" {specific_contract.expected_generate_content_return_annotation}" + f" not equal to actual return annotation {actual_return_annotation}" + ), + ) diff --git a/setup.py b/setup.py index d870daa42c..e00f00f412 100644 --- a/setup.py +++ b/setup.py @@ -160,6 +160,7 @@ "langchain-google-vertexai < 2", "openinference-instrumentation-langchain >= 0.1.19, < 0.2", "tenacity <= 8.3", + "orjson <= 3.10.6", ] langchain_testing_extra_require = list( diff --git a/tests/system/vertexai/test_generative_models.py b/tests/system/vertexai/test_generative_models.py index ea26151ccf..6395cf557e 100644 --- a/tests/system/vertexai/test_generative_models.py +++ b/tests/system/vertexai/test_generative_models.py @@ -18,6 +18,7 @@ """System tests for generative models.""" import json +import os import pytest # Google imports @@ -36,6 +37,9 @@ GEMINI_15_MODEL_NAME = "gemini-1.5-pro-preview-0409" GEMINI_15_PRO_MODEL_NAME = "gemini-1.5-pro-001" +STAGING_API_ENDPOINT = os.getenv("STAGING_ENDPOINT") +PROD_API_ENDPOINT = None + # A dummy function for function calling def get_current_weather(location: str, unit: str = "centigrade"): @@ -84,12 +88,14 @@ def get_current_weather(location: str, unit: str = "centigrade"): } +@pytest.mark.parametrize("api_endpoint", [STAGING_API_ENDPOINT, PROD_API_ENDPOINT]) class TestGenerativeModels(e2e_base.TestEndToEnd): """System tests for generative models.""" _temp_prefix = "temp_generative_models_test_" - def setup_method(self): + @pytest.fixture(scope="function", autouse=True) + def setup_method(self, api_endpoint): super().setup_method() credentials, _ = auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"] @@ -98,9 +104,10 @@ def setup_method(self): project=e2e_base._PROJECT, location=e2e_base._LOCATION, credentials=credentials, + api_endpoint=api_endpoint, ) - def test_generate_content_with_cached_content_from_text(self): + def test_generate_content_with_cached_content_from_text(self, api_endpoint): cached_content = caching.CachedContent.create( model_name=GEMINI_15_PRO_MODEL_NAME, system_instruction="Please answer all the questions like a pirate.", @@ -138,7 +145,7 @@ def test_generate_content_with_cached_content_from_text(self): finally: cached_content.delete() - def test_generate_content_from_text(self): + def test_generate_content_from_text(self, api_endpoint): model = generative_models.GenerativeModel(GEMINI_MODEL_NAME) response = model.generate_content( "Why is sky blue?", @@ -147,7 +154,7 @@ def test_generate_content_from_text(self): assert response.text @pytest.mark.asyncio - async def test_generate_content_async(self): + async def test_generate_content_async(self, api_endpoint): model = generative_models.GenerativeModel(GEMINI_MODEL_NAME) response = await model.generate_content_async( "Why is sky blue?", @@ -155,7 +162,7 @@ async def test_generate_content_async(self): ) assert response.text - def test_generate_content_streaming(self): + def test_generate_content_streaming(self, api_endpoint): model = generative_models.GenerativeModel(GEMINI_MODEL_NAME) stream = model.generate_content( "Why is sky blue?", @@ -170,7 +177,7 @@ def test_generate_content_streaming(self): ) @pytest.mark.asyncio - async def test_generate_content_streaming_async(self): + async def test_generate_content_streaming_async(self, api_endpoint): model = generative_models.GenerativeModel(GEMINI_MODEL_NAME) async_stream = await model.generate_content_async( "Why is sky blue?", @@ -184,7 +191,7 @@ async def test_generate_content_streaming_async(self): is generative_models.FinishReason.STOP ) - def test_generate_content_with_parameters(self): + def test_generate_content_with_parameters(self, api_endpoint): model = generative_models.GenerativeModel( GEMINI_MODEL_NAME, system_instruction=[ @@ -211,7 +218,7 @@ def test_generate_content_with_parameters(self): ) assert response.text - def test_generate_content_with_gemini_15_parameters(self): + def test_generate_content_with_gemini_15_parameters(self, api_endpoint): model = generative_models.GenerativeModel(GEMINI_15_MODEL_NAME) response = model.generate_content( contents="Why is sky blue? Respond in JSON Format.", @@ -237,7 +244,7 @@ def test_generate_content_with_gemini_15_parameters(self): assert response.text assert json.loads(response.text) - def test_generate_content_from_list_of_content_dict(self): + def test_generate_content_from_list_of_content_dict(self, api_endpoint): model = generative_models.GenerativeModel(GEMINI_MODEL_NAME) response = model.generate_content( contents=[{"role": "user", "parts": [{"text": "Why is sky blue?"}]}], @@ -248,7 +255,7 @@ def test_generate_content_from_list_of_content_dict(self): @pytest.mark.skip( reason="Breaking change in the gemini-pro-vision model. See b/315803556#comment3" ) - def test_generate_content_from_remote_image(self): + def test_generate_content_from_remote_image(self, api_endpoint): vision_model = generative_models.GenerativeModel(GEMINI_VISION_MODEL_NAME) image_part = generative_models.Part.from_uri( uri="gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg", @@ -261,7 +268,7 @@ def test_generate_content_from_remote_image(self): assert response.text assert "cat" in response.text - def test_generate_content_from_text_and_remote_image(self): + def test_generate_content_from_text_and_remote_image(self, api_endpoint): vision_model = generative_models.GenerativeModel(GEMINI_VISION_MODEL_NAME) image_part = generative_models.Part.from_uri( uri="gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg", @@ -274,7 +281,7 @@ def test_generate_content_from_text_and_remote_image(self): assert response.text assert "cat" in response.text - def test_generate_content_from_text_and_remote_video(self): + def test_generate_content_from_text_and_remote_video(self, api_endpoint): vision_model = generative_models.GenerativeModel(GEMINI_VISION_MODEL_NAME) video_part = generative_models.Part.from_uri( uri="gs://cloud-samples-data/video/animals.mp4", @@ -287,13 +294,11 @@ def test_generate_content_from_text_and_remote_video(self): assert response.text assert "Zootopia" in response.text - def test_grounding_google_search_retriever(self): + def test_grounding_google_search_retriever(self, api_endpoint): model = preview_generative_models.GenerativeModel(GEMINI_MODEL_NAME) google_search_retriever_tool = ( preview_generative_models.Tool.from_google_search_retrieval( - preview_generative_models.grounding.GoogleSearchRetrieval( - disable_attribution=False - ) + preview_generative_models.grounding.GoogleSearchRetrieval() ) ) response = model.generate_content( @@ -309,7 +314,7 @@ def test_grounding_google_search_retriever(self): # Chat - def test_send_message_from_text(self): + def test_send_message_from_text(self, api_endpoint): model = generative_models.GenerativeModel(GEMINI_MODEL_NAME) chat = model.start_chat() response1 = chat.send_message( @@ -326,7 +331,7 @@ def test_send_message_from_text(self): assert response2.text assert len(chat.history) == 4 - def test_chat_function_calling(self): + def test_chat_function_calling(self, api_endpoint): get_current_weather_func = generative_models.FunctionDeclaration( name="get_current_weather", description="Get the current weather in a given location", @@ -360,7 +365,7 @@ def test_chat_function_calling(self): ) assert response2.text - def test_generate_content_function_calling(self): + def test_generate_content_function_calling(self, api_endpoint): get_current_weather_func = generative_models.FunctionDeclaration( name="get_current_weather", description="Get the current weather in a given location", @@ -440,7 +445,7 @@ def test_generate_content_function_calling(self): assert summary - def test_chat_automatic_function_calling(self): + def test_chat_automatic_function_calling(self, api_endpoint): get_current_weather_func = generative_models.FunctionDeclaration.from_func( get_current_weather ) @@ -471,7 +476,7 @@ def test_chat_automatic_function_calling(self): assert chat.history[-2].parts[0].function_response assert chat.history[-2].parts[0].function_response.name == "get_current_weather" - def test_additional_request_metadata(self): + def test_additional_request_metadata(self, api_endpoint): aiplatform.init(request_metadata=[("foo", "bar")]) model = generative_models.GenerativeModel(GEMINI_MODEL_NAME) response = model.generate_content( @@ -480,7 +485,7 @@ def test_additional_request_metadata(self): ) assert response - def test_compute_tokens_from_text(self): + def test_compute_tokens_from_text(self, api_endpoint): model = generative_models.GenerativeModel(GEMINI_MODEL_NAME) response = model.compute_tokens(["Why is sky blue?", "Explain it like I'm 5."]) assert len(response.tokens_info) == 2 diff --git a/tests/unit/aiplatform/constants.py b/tests/unit/aiplatform/constants.py index 9c9758c2df..cf85d9827d 100644 --- a/tests/unit/aiplatform/constants.py +++ b/tests/unit/aiplatform/constants.py @@ -150,6 +150,25 @@ class TrainingJobConstants: }, } ] + _TEST_RESERVATION_AFFINITY_WORKER_POOL_SPEC = [ + { + "machine_spec": { + "machine_type": "n1-standard-4", + "accelerator_type": "NVIDIA_TESLA_K80", + "accelerator_count": 1, + "reservation_affinity": { + "reservation_affinity_type": "ANY_RESERVATION" + }, + }, + "replica_count": 1, + "disk_spec": {"boot_disk_type": "pd-ssd", "boot_disk_size_gb": 100}, + "container_spec": { + "image_uri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": [], + "args": _TEST_RUN_ARGS, + }, + } + ] _TEST_ID = "1028944691210842416" _TEST_NETWORK = ( f"projects/{ProjectConstants._TEST_PROJECT}/global/networks/{_TEST_ID}" @@ -197,6 +216,7 @@ class TrainingJobConstants: "projects/my-project/locations/us-central1/trainingPipelines/12345" ) _TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" + _TEST_SPOT_STRATEGY = custom_job.Scheduling.Strategy.SPOT def create_tpu_job_proto(tpu_version): worker_pool_spec = ( diff --git a/tests/unit/aiplatform/test_custom_job.py b/tests/unit/aiplatform/test_custom_job.py index 46b9ca3fa0..19762a5059 100644 --- a/tests/unit/aiplatform/test_custom_job.py +++ b/tests/unit/aiplatform/test_custom_job.py @@ -62,6 +62,7 @@ test_constants.TrainingJobConstants._TEST_TRAINING_CONTAINER_IMAGE ) _TEST_PREBUILT_CONTAINER_IMAGE = "gcr.io/cloud-aiplatform/container:image" +_TEST_SPOT_STRATEGY = test_constants.TrainingJobConstants._TEST_SPOT_STRATEGY _TEST_RUN_ARGS = test_constants.TrainingJobConstants._TEST_RUN_ARGS _TEST_EXPERIMENT = "test-experiment" @@ -226,6 +227,12 @@ def _get_custom_tpu_job_proto(state=None, name=None, error=None, tpu_version=Non return custom_job_proto +def _get_custom_job_proto_with_spot_strategy(state=None, name=None, error=None): + custom_job_proto = _get_custom_job_proto(state=state, name=name, error=error) + custom_job_proto.job_spec.scheduling.strategy = _TEST_SPOT_STRATEGY + return custom_job_proto + + @pytest.fixture def mock_builtin_open(): with patch("builtins.open", mock_open(read_data="data")) as mock_file: @@ -396,6 +403,28 @@ def get_custom_job_mock_with_enable_web_access_succeeded(): yield get_custom_job_mock +@pytest.fixture +def get_custom_job_mock_with_spot_strategy(): + with patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as get_custom_job_mock: + get_custom_job_mock.side_effect = [ + _get_custom_job_proto_with_spot_strategy( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ), + _get_custom_job_proto_with_spot_strategy( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_spot_strategy( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + ), + ] + yield get_custom_job_mock + + @pytest.fixture def create_custom_job_mock(): with mock.patch.object( @@ -445,6 +474,18 @@ def create_custom_job_mock_fail(): yield create_custom_job_mock +@pytest.fixture +def create_custom_job_mock_with_spot_strategy(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_custom_job" + ) as create_custom_job_mock: + create_custom_job_mock.return_value = _get_custom_job_proto_with_spot_strategy( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ) + yield create_custom_job_mock + + _EXPERIMENT_MOCK = copy.deepcopy(_EXPERIMENT_MOCK) _EXPERIMENT_MOCK.metadata[ constants._BACKING_TENSORBOARD_RESOURCE_KEY @@ -1433,3 +1474,52 @@ def test_create_custom_job_tpu_v3( assert ( job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED ) + + def test_create_custom_job_with_spot_strategy( + self, + create_custom_job_mock_with_spot_strategy, + get_custom_job_mock_with_spot_strategy, + ): + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = aiplatform.CustomJob( + display_name=_TEST_DISPLAY_NAME, + worker_pool_specs=_TEST_WORKER_POOL_SPEC, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + labels=_TEST_LABELS, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + scheduling_strategy=_TEST_SPOT_STRATEGY, + ) + + job.wait_for_resource_creation() + + job.wait() + + assert job.resource_name == _TEST_CUSTOM_JOB_NAME + + expected_custom_job = _get_custom_job_proto_with_spot_strategy() + + create_custom_job_mock_with_spot_strategy.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + assert job.job_spec == expected_custom_job.job_spec + assert ( + job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED + ) diff --git a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py index f49af1868f..e78977e8e4 100644 --- a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py +++ b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py @@ -202,6 +202,22 @@ def _get_hyperparameter_tuning_job_proto_with_enable_web_access( return hyperparameter_tuning_job_proto +def _get_hyperparameter_tuning_job_proto_with_spot_strategy( + state=None, name=None, error=None, trials=[] +): + hyperparameter_tuning_job_proto = _get_hyperparameter_tuning_job_proto( + state=state, + name=name, + error=error, + ) + hyperparameter_tuning_job_proto.trial_job_spec.scheduling.strategy = ( + test_constants.TrainingJobConstants._TEST_SPOT_STRATEGY + ) + if state == gca_job_state_compat.JobState.JOB_STATE_RUNNING: + hyperparameter_tuning_job_proto.trials = trials + return hyperparameter_tuning_job_proto + + @pytest.fixture def get_hyperparameter_tuning_job_mock(): with patch.object( @@ -331,6 +347,28 @@ def get_hyperparameter_tuning_job_mock_with_fail(): yield get_hyperparameter_tuning_job_mock +@pytest.fixture +def get_hyperparameter_tuning_job_mock_with_spot_strategy(): + with patch.object( + job_service_client.JobServiceClient, "get_hyperparameter_tuning_job" + ) as get_hyperparameter_tuning_job_mock: + get_hyperparameter_tuning_job_mock.side_effect = [ + _get_hyperparameter_tuning_job_proto_with_spot_strategy( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ), + _get_hyperparameter_tuning_job_proto_with_spot_strategy( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + ), + _get_hyperparameter_tuning_job_proto_with_spot_strategy( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + ), + ] + yield get_hyperparameter_tuning_job_mock + + @pytest.fixture def create_hyperparameter_tuning_job_mock(): with mock.patch.object( @@ -386,6 +424,20 @@ def create_hyperparameter_tuning_job_mock_with_tensorboard(): yield create_hyperparameter_tuning_job_mock +@pytest.fixture +def create_hyperparameter_tuning_job_mock_with_spot_strategy(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_hyperparameter_tuning_job" + ) as create_hyperparameter_tuning_job_mock: + create_hyperparameter_tuning_job_mock.return_value = ( + _get_hyperparameter_tuning_job_proto_with_spot_strategy( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ) + ) + yield create_hyperparameter_tuning_job_mock + + @pytest.mark.usefixtures("google_auth_mock") class TestHyperparameterTuningJob: def setup_method(self): @@ -908,3 +960,71 @@ def test_log_enable_web_access_after_get_hyperparameter_tuning_job( assert hp_job._logged_web_access_uris == set( test_constants.TrainingJobConstants._TEST_WEB_ACCESS_URIS.values() ) + + def test_create_hyperparameter_tuning_job_with_spot_strategy( + self, + create_hyperparameter_tuning_job_mock_with_spot_strategy, + get_hyperparameter_tuning_job_mock_with_spot_strategy, + ): + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + custom_job = aiplatform.CustomJob( + display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME, + worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC, + base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR, + ) + + job = aiplatform.HyperparameterTuningJob( + display_name=_TEST_DISPLAY_NAME, + custom_job=custom_job, + metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE}, + parameter_spec={ + "lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"), + "units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"), + "activation": hpt.CategoricalParameterSpec( + values=["relu", "sigmoid", "elu", "selu", "tanh"] + ), + "batch_size": hpt.DiscreteParameterSpec( + values=[4, 8, 16, 32, 64], + scale="linear", + conditional_parameter_spec={ + "decay": _TEST_CONDITIONAL_PARAMETER_DECAY, + "learning_rate": _TEST_CONDITIONAL_PARAMETER_LR, + }, + ), + }, + parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT, + max_trial_count=_TEST_MAX_TRIAL_COUNT, + max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT, + search_algorithm=_TEST_SEARCH_ALGORITHM, + measurement_selection=_TEST_MEASUREMENT_SELECTION, + labels=_TEST_LABELS, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + scheduling_strategy=test_constants.TrainingJobConstants._TEST_SPOT_STRATEGY, + ) + + job.wait() + + expected_hyperparameter_tuning_job = ( + _get_hyperparameter_tuning_job_proto_with_spot_strategy() + ) + + create_hyperparameter_tuning_job_mock_with_spot_strategy.assert_called_once_with( + parent=_TEST_PARENT, + hyperparameter_tuning_job=expected_hyperparameter_tuning_job, + timeout=None, + ) diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 7e3fe8944b..3da90e6a41 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -67,6 +67,7 @@ job_state as gca_job_state, machine_resources as gca_machine_resources, machine_resources_v1beta1 as gca_machine_resources_v1beta1, + reservation_affinity_v1 as gca_reservation_affinity_v1, manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat, model as gca_model, model_evaluation as gca_model_evaluation, @@ -134,6 +135,13 @@ _TEST_STARTING_REPLICA_COUNT = 2 _TEST_MAX_REPLICA_COUNT = 12 +_TEST_SPOT = True +_TEST_RESERVATION_AFFINITY_TYPE = "SPECIFIC_RESERVATION" +_TEST_RESERVATION_AFFINITY_KEY = "compute.googleapis.com/reservation-name" +_TEST_RESERVATION_AFFINITY_VALUES = [ + "projects/fake-project-id/zones/fake-zone/reservations/fake-reservation-name" +] + _TEST_TPU_MACHINE_TYPE = "ct5lp-hightpu-4t" _TEST_TPU_TOPOLOGY = "2x2" @@ -2101,7 +2109,10 @@ def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync): accelerator_count=_TEST_ACCELERATOR_COUNT, ) expected_dedicated_resources = gca_machine_resources.DedicatedResources( - machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1 + machine_spec=expected_machine_spec, + min_replica_count=1, + max_replica_count=1, + spot=False, ) expected_deployed_model = gca_endpoint.DeployedModel( dedicated_resources=expected_dedicated_resources, @@ -2141,7 +2152,10 @@ def test_deploy_no_endpoint_with_tpu_topology(self, deploy_model_mock, sync): tpu_topology=_TEST_TPU_TOPOLOGY, ) expected_dedicated_resources = gca_machine_resources.DedicatedResources( - machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1 + machine_spec=expected_machine_spec, + min_replica_count=1, + max_replica_count=1, + spot=False, ) expected_deployed_model = gca_endpoint.DeployedModel( dedicated_resources=expected_dedicated_resources, @@ -2156,6 +2170,163 @@ def test_deploy_no_endpoint_with_tpu_topology(self, deploy_model_mock, sync): timeout=None, ) + @pytest.mark.usefixtures( + "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_no_endpoint_with_spot(self, deploy_model_mock, sync): + test_model = models.Model(_TEST_ID) + test_model._gca_resource.supported_deployment_resources_types.append( + aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ) + test_endpoint = test_model.deploy( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + service_account=_TEST_SERVICE_ACCOUNT, + sync=sync, + deploy_request_timeout=None, + spot=_TEST_SPOT, + ) + + if not sync: + test_endpoint.wait() + + expected_machine_spec = gca_machine_resources.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ) + expected_dedicated_resources = gca_machine_resources.DedicatedResources( + machine_spec=expected_machine_spec, + min_replica_count=1, + max_replica_count=1, + spot=True, + ) + expected_deployed_model = gca_endpoint.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + service_account=_TEST_SERVICE_ACCOUNT, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + timeout=None, + ) + + @pytest.mark.usefixtures( + "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_no_endpoint_with_specific_reservation_affinity( + self, deploy_model_mock, sync + ): + test_model = models.Model(_TEST_ID) + test_model._gca_resource.supported_deployment_resources_types.append( + aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ) + test_endpoint = test_model.deploy( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + service_account=_TEST_SERVICE_ACCOUNT, + sync=sync, + deploy_request_timeout=None, + reservation_affinity_type=_TEST_RESERVATION_AFFINITY_TYPE, + reservation_affinity_key=_TEST_RESERVATION_AFFINITY_KEY, + reservation_affinity_values=_TEST_RESERVATION_AFFINITY_VALUES, + ) + + if not sync: + test_endpoint.wait() + + expected_reservation_affinity = gca_reservation_affinity_v1.ReservationAffinity( + reservation_affinity_type=_TEST_RESERVATION_AFFINITY_TYPE, + key=_TEST_RESERVATION_AFFINITY_KEY, + values=_TEST_RESERVATION_AFFINITY_VALUES, + ) + expected_machine_spec = gca_machine_resources.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + reservation_affinity=expected_reservation_affinity, + ) + expected_dedicated_resources = gca_machine_resources.DedicatedResources( + machine_spec=expected_machine_spec, + min_replica_count=1, + max_replica_count=1, + spot=False, + ) + expected_deployed_model = gca_endpoint.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + service_account=_TEST_SERVICE_ACCOUNT, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + timeout=None, + ) + + @pytest.mark.usefixtures( + "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_no_endpoint_with_any_reservation_affinity( + self, deploy_model_mock, sync + ): + test_model = models.Model(_TEST_ID) + test_model._gca_resource.supported_deployment_resources_types.append( + aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ) + test_endpoint = test_model.deploy( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + service_account=_TEST_SERVICE_ACCOUNT, + sync=sync, + deploy_request_timeout=None, + reservation_affinity_type="ANY_RESERVATION", + ) + + if not sync: + test_endpoint.wait() + + expected_reservation_affinity = gca_reservation_affinity_v1.ReservationAffinity( + reservation_affinity_type="ANY_RESERVATION", + ) + expected_machine_spec = gca_machine_resources.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + reservation_affinity=expected_reservation_affinity, + ) + expected_dedicated_resources = gca_machine_resources.DedicatedResources( + machine_spec=expected_machine_spec, + min_replica_count=1, + max_replica_count=1, + spot=False, + ) + expected_deployed_model = gca_endpoint.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + service_account=_TEST_SERVICE_ACCOUNT, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + timeout=None, + ) + @pytest.mark.usefixtures( "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" ) @@ -2184,7 +2355,10 @@ def test_deploy_no_endpoint_with_explanations(self, deploy_model_mock, sync): accelerator_count=_TEST_ACCELERATOR_COUNT, ) expected_dedicated_resources = gca_machine_resources.DedicatedResources( - machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1 + machine_spec=expected_machine_spec, + min_replica_count=1, + max_replica_count=1, + spot=False, ) expected_deployed_model = gca_endpoint.DeployedModel( dedicated_resources=expected_dedicated_resources, diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index e7f7f9d15a..3ae935ca90 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -248,6 +248,7 @@ _TEST_PERSISTENT_RESOURCE_ID = ( test_constants.PersistentResourceConstants._TEST_PERSISTENT_RESOURCE_ID ) +_TEST_SPOT_STRATEGY = test_constants.TrainingJobConstants._TEST_SPOT_STRATEGY _TEST_BASE_CUSTOM_JOB_PROTO = gca_custom_job.CustomJob( job_spec=gca_custom_job.CustomJobSpec(), @@ -305,6 +306,15 @@ def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"): return custom_job_proto +def _get_custom_job_proto_with_spot_strategy(state=None, name=None, version="v1"): + custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO) + custom_job_proto.name = name + custom_job_proto.state = state + + custom_job_proto.job_spec.scheduling.strategy = _TEST_SPOT_STRATEGY + return custom_job_proto + + def local_copy_method(path): shutil.copy(path, ".") return pathlib.Path(path).name @@ -727,9 +737,11 @@ def make_training_pipeline(state, add_training_task_metadata=True): state=state, model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), training_task_inputs={"tensorboard": _TEST_TENSORBOARD_RESOURCE_NAME}, - training_task_metadata={"backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME} - if add_training_task_metadata - else None, + training_task_metadata=( + {"backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME} + if add_training_task_metadata + else None + ), ) @@ -741,9 +753,11 @@ def make_training_pipeline_with_version(state, add_training_task_metadata=True): name=_TEST_MODEL_NAME, version_id=_TEST_MODEL_VERSION_ID ), training_task_inputs={"tensorboard": _TEST_TENSORBOARD_RESOURCE_NAME}, - training_task_metadata={"backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME} - if add_training_task_metadata - else None, + training_task_metadata=( + {"backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME} + if add_training_task_metadata + else None + ), ) @@ -810,6 +824,21 @@ def make_training_pipeline_with_scheduling(state): return training_pipeline +def make_training_pipeline_with_spot_strategy(state): + training_pipeline = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=state, + training_task_inputs={ + "scheduling_strategy": _TEST_SPOT_STRATEGY, + }, + ) + if state == gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING: + training_pipeline.training_task_metadata = { + "backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME + } + return training_pipeline + + @pytest.fixture def mock_pipeline_service_get(make_call=make_training_pipeline): with mock.patch.object( @@ -952,6 +981,35 @@ def mock_pipeline_service_get_with_scheduling(): yield mock_get_training_pipeline +@pytest.fixture +def mock_pipeline_service_get_with_spot_strategy(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.side_effect = [ + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING, + ), + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ), + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ), + ] + + yield mock_get_training_pipeline + + @pytest.fixture def mock_pipeline_service_cancel(): with mock.patch.object( @@ -1026,6 +1084,19 @@ def mock_pipeline_service_create_with_scheduling(): yield mock_create_training_pipeline +@pytest.fixture +def mock_pipeline_service_create_with_spot_strategy(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = ( + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING, + ) + ) + yield mock_create_training_pipeline + + @pytest.fixture def mock_pipeline_service_get_with_no_model_to_upload(): with mock.patch.object( @@ -2388,6 +2459,58 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): == _TEST_DISABLE_RETRIES ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_spot_strategy", + "mock_pipeline_service_get_with_spot_strategy", + "mock_python_package_to_gcs", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_spot_strategy(self, sync): + + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + sync=sync, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + ) + + if not sync: + job.wait() + + assert job._gca_resource == make_training_pipeline_with_spot_strategy( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + + assert ( + job._gca_resource.state + == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + + assert ( + job._gca_resource.training_task_inputs["scheduling_strategy"] + == _TEST_SPOT_STRATEGY + ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) @pytest.mark.usefixtures( @@ -3440,6 +3563,99 @@ def test_training_job_tpu_v3_pod(self, mock_pipeline_service_create): timeout=None, ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + def test_training_job_reservation_affinity(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + labels=_TEST_LABELS, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + job.run( + machine_type=_TEST_MACHINE_TYPE, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + reservation_affinity_type="ANY_RESERVATION", + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_DISPLAY_NAME + "-model", + labels=_TEST_LABELS, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + version_aliases=["default"], + ) + + true_worker_pool_spec = { + "replica_count": _TEST_REPLICA_COUNT, + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + "reservation_affinity": { + "reservation_affinity_type": "ANY_RESERVATION" + }, + }, + "disk_spec": { + "boot_disk_type": _TEST_BOOT_DISK_TYPE_DEFAULT, + "boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB_DEFAULT, + }, + "python_package_spec": { + "executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE, + "python_module": _TEST_MODULE_NAME, + "package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + }, + } + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + labels=_TEST_LABELS, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "worker_pool_specs": [true_worker_pool_spec], + "base_output_directory": { + "output_uri_prefix": _TEST_BASE_OUTPUT_DIR + }, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + timeout=None, + ) + @pytest.mark.usefixtures("google_auth_mock") class TestCustomContainerTrainingJob: @@ -5458,6 +5674,99 @@ def test_training_job_tpu_v3_pod(self, mock_pipeline_service_create): timeout=None, ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + def test_training_job_reservation_affinity(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + job.run( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=32, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + reservation_affinity_type="ANY_RESERVATION", + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_DISPLAY_NAME + "-model", + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + version_aliases=["default"], + ) + + true_worker_pool_spec = { + "replica_count": _TEST_REPLICA_COUNT, + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + "accelerator_type": _TEST_ACCELERATOR_TYPE, + "accelerator_count": 32, + "reservation_affinity": { + "reservation_affinity_type": "ANY_RESERVATION" + }, + }, + "disk_spec": { + "boot_disk_type": _TEST_BOOT_DISK_TYPE_DEFAULT, + "boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB_DEFAULT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + }, + } + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "worker_pool_specs": [true_worker_pool_spec], + "base_output_directory": { + "output_uri_prefix": _TEST_BASE_OUTPUT_DIR + }, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + timeout=None, + ) + class Test_WorkerPoolSpec: def test_machine_spec_return_spec_dict(self): @@ -5466,6 +5775,9 @@ def test_machine_spec_return_spec_dict(self): machine_type=_TEST_MACHINE_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, accelerator_type=_TEST_ACCELERATOR_TYPE, + reservation_affinity_type="SPECIFIC_RESERVATION", + reservation_affinity_key="compute.googleapis.com/reservation-name", + reservation_affinity_values="projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}", ) true_spec_dict = { @@ -5473,6 +5785,11 @@ def test_machine_spec_return_spec_dict(self): "machine_type": _TEST_MACHINE_TYPE, "accelerator_type": _TEST_ACCELERATOR_TYPE, "accelerator_count": _TEST_ACCELERATOR_COUNT, + "reservation_affinity": { + "reservation_affinity_type": "SPECIFIC_RESERVATION", + "key": "compute.googleapis.com/reservation-name", + "values": "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}", + }, }, "replica_count": _TEST_REPLICA_COUNT, "disk_spec": { @@ -5568,18 +5885,23 @@ def test_machine_spec_returns_pool_spec(self): machine_type=_TEST_MACHINE_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, accelerator_type=_TEST_ACCELERATOR_TYPE, + reservation_affinity_type="ANY_RESERVATION", ), worker_spec=worker_spec_utils._WorkerPoolSpec( replica_count=10, machine_type=_TEST_MACHINE_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, accelerator_type=_TEST_ACCELERATOR_TYPE, + reservation_affinity_type="SPECIFIC_RESERVATION", + reservation_affinity_key="compute.googleapis.com/reservation-name", + reservation_affinity_values="projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}", ), server_spec=worker_spec_utils._WorkerPoolSpec( replica_count=3, machine_type=_TEST_MACHINE_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, accelerator_type=_TEST_ACCELERATOR_TYPE, + reservation_affinity_type="NO_RESERVATION", ), evaluator_spec=worker_spec_utils._WorkerPoolSpec( replica_count=1, @@ -5595,6 +5917,9 @@ def test_machine_spec_returns_pool_spec(self): "machine_type": _TEST_MACHINE_TYPE, "accelerator_type": _TEST_ACCELERATOR_TYPE, "accelerator_count": _TEST_ACCELERATOR_COUNT, + "reservation_affinity": { + "reservation_affinity_type": "ANY_RESERVATION", + }, }, "replica_count": 1, "disk_spec": { @@ -5607,6 +5932,11 @@ def test_machine_spec_returns_pool_spec(self): "machine_type": _TEST_MACHINE_TYPE, "accelerator_type": _TEST_ACCELERATOR_TYPE, "accelerator_count": _TEST_ACCELERATOR_COUNT, + "reservation_affinity": { + "reservation_affinity_type": "SPECIFIC_RESERVATION", + "key": "compute.googleapis.com/reservation-name", + "values": "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}", + }, }, "replica_count": 10, "disk_spec": { @@ -5619,6 +5949,9 @@ def test_machine_spec_returns_pool_spec(self): "machine_type": _TEST_MACHINE_TYPE, "accelerator_type": _TEST_ACCELERATOR_TYPE, "accelerator_count": _TEST_ACCELERATOR_COUNT, + "reservation_affinity": { + "reservation_affinity_type": "NO_RESERVATION", + }, }, "replica_count": 3, "disk_spec": { @@ -5644,13 +5977,14 @@ def test_machine_spec_returns_pool_spec(self): def test_chief_worker_pool_returns_spec(self): - chief_worker_spec = ( - worker_spec_utils._DistributedTrainingSpec.chief_worker_pool( - replica_count=10, - machine_type=_TEST_MACHINE_TYPE, - accelerator_count=_TEST_ACCELERATOR_COUNT, - accelerator_type=_TEST_ACCELERATOR_TYPE, - ) + chief_worker_spec = worker_spec_utils._DistributedTrainingSpec.chief_worker_pool( + replica_count=10, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + reservation_affinity_type="SPECIFIC_RESERVATION", + reservation_affinity_key="compute.googleapis.com/reservation-name", + reservation_affinity_values="projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}", ) true_pool_spec = [ @@ -5659,6 +5993,11 @@ def test_chief_worker_pool_returns_spec(self): "machine_type": _TEST_MACHINE_TYPE, "accelerator_type": _TEST_ACCELERATOR_TYPE, "accelerator_count": _TEST_ACCELERATOR_COUNT, + "reservation_affinity": { + "reservation_affinity_type": "SPECIFIC_RESERVATION", + "key": "compute.googleapis.com/reservation-name", + "values": "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}", + }, }, "replica_count": 1, "disk_spec": { @@ -5671,6 +6010,11 @@ def test_chief_worker_pool_returns_spec(self): "machine_type": _TEST_MACHINE_TYPE, "accelerator_type": _TEST_ACCELERATOR_TYPE, "accelerator_count": _TEST_ACCELERATOR_COUNT, + "reservation_affinity": { + "reservation_affinity_type": "SPECIFIC_RESERVATION", + "key": "compute.googleapis.com/reservation-name", + "values": "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}", + }, }, "replica_count": 9, "disk_spec": { @@ -5684,13 +6028,14 @@ def test_chief_worker_pool_returns_spec(self): def test_chief_worker_pool_returns_just_chief(self): - chief_worker_spec = ( - worker_spec_utils._DistributedTrainingSpec.chief_worker_pool( - replica_count=1, - machine_type=_TEST_MACHINE_TYPE, - accelerator_count=_TEST_ACCELERATOR_COUNT, - accelerator_type=_TEST_ACCELERATOR_TYPE, - ) + chief_worker_spec = worker_spec_utils._DistributedTrainingSpec.chief_worker_pool( + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + reservation_affinity_type="SPECIFIC_RESERVATION", + reservation_affinity_key="compute.googleapis.com/reservation-name", + reservation_affinity_values="projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}", ) true_pool_spec = [ @@ -5699,6 +6044,11 @@ def test_chief_worker_pool_returns_just_chief(self): "machine_type": _TEST_MACHINE_TYPE, "accelerator_type": _TEST_ACCELERATOR_TYPE, "accelerator_count": _TEST_ACCELERATOR_COUNT, + "reservation_affinity": { + "reservation_affinity_type": "SPECIFIC_RESERVATION", + "key": "compute.googleapis.com/reservation-name", + "values": "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}", + }, }, "replica_count": 1, "disk_spec": { @@ -7978,6 +8328,101 @@ def test_training_job_tpu_v3_pod(self, mock_pipeline_service_create): timeout=None, ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + def test_training_job_reservation_affinity(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + job.run( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=32, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + reservation_affinity_type="ANY_RESERVATION", + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_DISPLAY_NAME + "-model", + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + version_aliases=["default"], + ) + + true_worker_pool_spec = { + "replica_count": _TEST_REPLICA_COUNT, + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + "accelerator_type": _TEST_ACCELERATOR_TYPE, + "accelerator_count": 32, + "reservation_affinity": { + "reservation_affinity_type": "ANY_RESERVATION" + }, + }, + "disk_spec": { + "boot_disk_type": _TEST_BOOT_DISK_TYPE_DEFAULT, + "boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB_DEFAULT, + }, + "python_package_spec": { + "executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE, + "python_module": _TEST_PYTHON_MODULE_NAME, + "package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + }, + } + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "worker_pool_specs": [true_worker_pool_spec], + "base_output_directory": { + "output_uri_prefix": _TEST_BASE_OUTPUT_DIR + }, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + timeout=None, + ) + class TestVersionedTrainingJobs: @pytest.mark.usefixtures("mock_pipeline_service_get") diff --git a/tests/unit/aiplatform/test_uploader.py b/tests/unit/aiplatform/test_uploader.py index deb3be0118..e677eb7565 100644 --- a/tests/unit/aiplatform/test_uploader.py +++ b/tests/unit/aiplatform/test_uploader.py @@ -133,6 +133,18 @@ class AbortUploadError(Exception): """Exception used in testing to abort the upload process.""" +def _create_tensorboard_run_mock( + run_display_name=_TEST_RUN_NAME, + run_resource_name=_TEST_TENSORBOARD_RESOURCE_NAME, + time_series_name=_TEST_TIME_SERIES_NAME, +): + tensorboard_run_mock = mock.create_autospec(tensorboard_resource.TensorboardRun) + tensorboard_run_mock.resource_name = run_resource_name + tensorboard_run_mock.display_name = run_display_name + tensorboard_run_mock.get_tensorboard_time_series_id.return_value = time_series_name + return tensorboard_run_mock + + def _create_mock_client(): # Create a stub instance (using a test channel) in order to derive a mock # from it with autospec enabled. Mocking TensorBoardWriterServiceStub itself @@ -177,6 +189,11 @@ def create_tensorboard_time_series( display_name=tensorboard_time_series.display_name, ) + def get_tensorboard_time_series( + request=tensorboard_service.GetTensorboardTimeSeriesRequest, + ): # pylint: disable=unused-argument + return None + def parse_tensorboard_path_response(path): """Parses a tensorboard path into its component segments.""" m = re.match( @@ -201,6 +218,7 @@ def parse_tensorboard_path_response(path): create_tensorboard_time_series ) mock_client.parse_tensorboard_path.side_effect = parse_tensorboard_path_response + mock_client.get_tensorboard_time_series.side_effect = get_tensorboard_time_series return mock_client @@ -508,6 +526,17 @@ def add_meta_graph(self, meta_graph_def, global_step=None): @pytest.mark.usefixtures("google_auth_mock") class TensorboardUploaderTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super(TensorboardUploaderTest, self).setUp() + self.mock_client = _create_mock_client() + self.mock_run_resource_mock = self.enter_context( + patch.object( + uploader_utils.OnePlatformResourceManager, + "_get_or_create_run_resource", + autospec=True, + ) + ) + @patch.object(metadata, "_experiment_tracker", autospec=True) @patch.object(experiment_resources, "Experiment", autospec=True) def test_create_experiment( @@ -519,7 +548,7 @@ def test_create_experiment( _TEST_TENSORBOARD_RESOURCE_NAME ) logdir = _TEST_LOG_DIR_NAME - uploader = _create_uploader(_create_mock_client(), logdir) + uploader = _create_uploader(self.mock_client, logdir) uploader.create_experiment() self.assertEqual( uploader._tensorboard_experiment_resource_name, @@ -537,9 +566,8 @@ def test_create_experiment_with_name( _TEST_TENSORBOARD_RESOURCE_NAME ) logdir = _TEST_LOG_DIR_NAME - mock_client = _create_mock_client() new_name = "This is the new name" - uploader = _create_uploader(mock_client, logdir, experiment_name=new_name) + uploader = _create_uploader(self.mock_client, logdir, experiment_name=new_name) uploader.create_experiment() @patch.object(metadata, "_experiment_tracker", autospec=True) @@ -553,12 +581,13 @@ def test_create_experiment_with_description( _TEST_TENSORBOARD_RESOURCE_NAME ) logdir = _TEST_LOG_DIR_NAME - mock_client = _create_mock_client() new_description = """ **description**" may have "strange" unicode chars 🌴 \\/<> """ - uploader = _create_uploader(mock_client, logdir, description=new_description) + uploader = _create_uploader( + self.mock_client, logdir, description=new_description + ) uploader.create_experiment() self.assertEqual(uploader._experiment_name, _TEST_EXPERIMENT_NAME) @@ -573,21 +602,22 @@ def test_create_experiment_with_all_metadata( _TEST_TENSORBOARD_RESOURCE_NAME ) logdir = _TEST_LOG_DIR_NAME - mock_client = _create_mock_client() new_description = """ **description**" may have "strange" unicode chars 🌴 \\/<> """ new_name = "This is a cool name." uploader = _create_uploader( - mock_client, logdir, experiment_name=new_name, description=new_description + self.mock_client, + logdir, + experiment_name=new_name, + description=new_description, ) uploader.create_experiment() self.assertEqual(uploader._experiment_name, new_name) def test_start_uploading_without_create_experiment_fails(self): - mock_client = _create_mock_client() - uploader = _create_uploader(mock_client, _TEST_LOG_DIR_NAME) + uploader = _create_uploader(self.mock_client, _TEST_LOG_DIR_NAME) with self.assertRaisesRegex(RuntimeError, "call create_experiment()"): uploader.start_uploading() @@ -602,11 +632,11 @@ def test_start_uploading_scalars( self, experiment_resources_mock, experiment_tracker_mock, run_resource_mock ): experiment_resources_mock.get.return_value = _TEST_EXPERIMENT_NAME + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() experiment_tracker_mock.set_experiment.return_value = _TEST_EXPERIMENT_NAME experiment_tracker_mock.set_tensorboard.return_value = ( _TEST_TENSORBOARD_RESOURCE_NAME ) - mock_client = _create_mock_client() mock_rate_limiter = mock.create_autospec(uploader_utils.RateLimiter) mock_tensor_rate_limiter = mock.create_autospec(uploader_utils.RateLimiter) mock_blob_rate_limiter = mock.create_autospec(uploader_utils.RateLimiter) @@ -616,7 +646,7 @@ def test_start_uploading_scalars( upload_tracker, "UploadTracker", return_value=mock_tracker ): uploader = _create_uploader( - writer_client=mock_client, + writer_client=self.mock_client, logdir=_TEST_LOG_DIR_NAME, # Send each Event below in a separate WriteScalarRequest max_scalar_request_size=180, @@ -653,7 +683,9 @@ def test_start_uploading_scalars( uploader, "_logdir_loader", mock_logdir_loader ), self.assertRaises(AbortUploadError): uploader.start_uploading() - self.assertEqual(5, mock_client.write_tensorboard_experiment_data.call_count) + self.assertEqual( + 5, self.mock_client.write_tensorboard_experiment_data.call_count + ) self.assertEqual(5, mock_rate_limiter.tick.call_count) self.assertEqual(0, mock_tensor_rate_limiter.tick.call_count) self.assertEqual(0, mock_blob_rate_limiter.tick.call_count) @@ -666,33 +698,17 @@ def test_start_uploading_scalars( self.assertEqual(mock_tracker.blob_tracker.call_count, 0) @parameterized.parameters( - {"existing_experiment": None, "one_platform_run_name": None}, - {"existing_experiment": None, "one_platform_run_name": "."}, - { - "existing_experiment": _TEST_EXPERIMENT_NAME, - "one_platform_run_name": _TEST_ONE_PLATFORM_RUN_NAME, - }, - ) - @patch.object( - uploader_utils.OnePlatformResourceManager, - "get_run_resource_name", - autospec=True, + {"existing_experiment": None}, + {"existing_experiment": None}, + {"existing_experiment": _TEST_EXPERIMENT_NAME}, ) @patch.object(metadata, "_experiment_tracker", autospec=True) - @patch.object( - uploader_utils.OnePlatformResourceManager, - "_create_or_get_run_resource", - autospec=True, - ) @patch.object(experiment_resources, "Experiment", autospec=True) def test_start_uploading_scalars_one_shot( self, experiment_resources_mock, - experiment_run_resource_mock, experiment_tracker_mock, - run_resource_mock, existing_experiment, - one_platform_run_name, ): """Check that one-shot uploading stops without AbortUploadError.""" @@ -724,29 +740,24 @@ def batch_create_time_series(parent, requests): tensorboard_time_series=tb_time_series ) - tensorboard_run_mock = mock.create_autospec(tensorboard_resource.TensorboardRun) experiment_resources_mock.get.return_value = existing_experiment - tensorboard_run_mock.resource_name = _TEST_TENSORBOARD_RESOURCE_NAME - tensorboard_run_mock.display_name = _TEST_RUN_NAME - experiment_run_resource_mock.return_value = tensorboard_run_mock + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() experiment_tracker_mock.set_experiment.return_value = _TEST_EXPERIMENT_NAME experiment_tracker_mock.set_tensorboard.return_value = ( _TEST_TENSORBOARD_RESOURCE_NAME ) - mock_client = _create_mock_client() - mock_client.batch_create_tensorboard_runs.side_effect = batch_create_runs - mock_client.batch_create_tensorboard_time_series.side_effect = ( + self.mock_client.batch_create_tensorboard_runs.side_effect = batch_create_runs + self.mock_client.batch_create_tensorboard_time_series.side_effect = ( batch_create_time_series ) mock_rate_limiter = mock.create_autospec(uploader_utils.RateLimiter) mock_tracker = mock.MagicMock() - run_resource_mock.return_value = one_platform_run_name with mock.patch.object( upload_tracker, "UploadTracker", return_value=mock_tracker ): uploader = _create_uploader( - writer_client=mock_client, + writer_client=self.mock_client, logdir=_TEST_LOG_DIR_NAME, # Send each Event below in a separate WriteScalarRequest max_scalar_request_size=200, @@ -793,7 +804,9 @@ def batch_create_time_series(parent, requests): uploader._end_experiment_runs.assert_called_once() self.assertEqual(existing_experiment is None, uploader._is_brand_new_experiment) - self.assertEqual(2, mock_client.write_tensorboard_experiment_data.call_count) + self.assertEqual( + 2, self.mock_client.write_tensorboard_experiment_data.call_count + ) self.assertEqual(2, mock_rate_limiter.tick.call_count) # Check upload tracker calls. @@ -815,11 +828,10 @@ def test_upload_empty_logdir( _TEST_TENSORBOARD_RESOURCE_NAME ) logdir = self.get_temp_dir() - mock_client = _create_mock_client() - uploader = _create_uploader(mock_client, logdir) + uploader = _create_uploader(self.mock_client, logdir) uploader.create_experiment() uploader._upload_once() - mock_client.write_tensorboard_experiment_data.assert_not_called() + self.mock_client.write_tensorboard_experiment_data.assert_not_called() experiment_tracker_mock.set_experiment.assert_called_once() @parameterized.parameters( @@ -846,6 +858,7 @@ def test_default_run_name( experiment_tracker_mock.set_tensorboard.return_value = ( _TEST_TENSORBOARD_RESOURCE_NAME ) + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() logdir = self.get_temp_dir() with FileWriter(logdir) as writer: writer.add_test_summary("foo") @@ -881,6 +894,7 @@ class SuccessError(Exception): experiment_tracker_mock.set_tensorboard.return_value = ( _TEST_TENSORBOARD_RESOURCE_NAME ) + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() mock_rate_limiter = mock.create_autospec(uploader_utils.RateLimiter) upload_call_count_box = [0] @@ -917,17 +931,17 @@ def test_upload_swallows_rpc_failure( experiment_tracker_mock.set_tensorboard.return_value = ( _TEST_TENSORBOARD_RESOURCE_NAME ) + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() logdir = self.get_temp_dir() with FileWriter(logdir) as writer: writer.add_test_summary("foo") - mock_client = _create_mock_client() - uploader = _create_uploader(mock_client, logdir) + uploader = _create_uploader(self.mock_client, logdir) uploader.create_experiment() run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME error = _grpc_error(grpc.StatusCode.INTERNAL, "Failure") - mock_client.write_tensorboard_experiment_data.side_effect = error + self.mock_client.write_tensorboard_experiment_data.side_effect = error uploader._upload_once() - mock_client.write_tensorboard_experiment_data.assert_called_once() + self.mock_client.write_tensorboard_experiment_data.assert_called_once() experiment_tracker_mock.set_experiment.assert_called_once() @patch.object( @@ -945,9 +959,9 @@ def test_upload_full_logdir( experiment_tracker_mock.set_tensorboard.return_value = ( _TEST_TENSORBOARD_RESOURCE_NAME ) + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() logdir = self.get_temp_dir() - mock_client = _create_mock_client() - uploader = _create_uploader(mock_client, logdir) + uploader = _create_uploader(self.mock_client, logdir) uploader.create_experiment() run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME @@ -978,14 +992,18 @@ def test_upload_full_logdir( writer_a.add_test_summary("qux", simple_value=9.0, step=2) writer_a.flush() uploader._upload_once() - self.assertEqual(3, mock_client.create_tensorboard_time_series.call_count) - call_args_list = mock_client.create_tensorboard_time_series.call_args_list + self.assertEqual(3, self.mock_client.create_tensorboard_time_series.call_count) + call_args_list = self.mock_client.create_tensorboard_time_series.call_args_list request = call_args_list[1][1]["tensorboard_time_series"] self.assertEqual("scalars", request.plugin_name) self.assertEqual(b"12345", request.plugin_data) - self.assertEqual(1, mock_client.write_tensorboard_experiment_data.call_count) - call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list + self.assertEqual( + 1, self.mock_client.write_tensorboard_experiment_data.call_count + ) + call_args_list = ( + self.mock_client.write_tensorboard_experiment_data.call_args_list + ) request1, request2 = ( call_args_list[0][1]["write_run_data_requests"][0].time_series_data, call_args_list[0][1]["write_run_data_requests"][1].time_series_data, @@ -1020,7 +1038,7 @@ def test_upload_full_logdir( self.assertProtoEquals(expected_request1[1], request1[1]) self.assertProtoEquals(expected_request2[0], request2[0]) - mock_client.write_tensorboard_experiment_data.reset_mock() + self.mock_client.write_tensorboard_experiment_data.reset_mock() # Second round writer.add_test_summary("foo", simple_value=10.0, step=5) @@ -1030,8 +1048,12 @@ def test_upload_full_logdir( writer_b.add_test_summary("xyz", simple_value=12.0, step=1) writer_b.flush() uploader._upload_once() - self.assertEqual(1, mock_client.write_tensorboard_experiment_data.call_count) - call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list + self.assertEqual( + 1, self.mock_client.write_tensorboard_experiment_data.call_count + ) + call_args_list = ( + self.mock_client.write_tensorboard_experiment_data.call_args_list + ) request3, request4 = ( call_args_list[0][1]["write_run_data_requests"][0].time_series_data, call_args_list[0][1]["write_run_data_requests"][1].time_series_data, @@ -1060,12 +1082,12 @@ def test_upload_full_logdir( self.assertProtoEquals(expected_request3[0], request3[0]) self.assertProtoEquals(expected_request3[1], request3[1]) self.assertProtoEquals(expected_request4[0], request4[0]) - mock_client.write_tensorboard_experiment_data.reset_mock() + self.mock_client.write_tensorboard_experiment_data.reset_mock() experiment_tracker_mock.set_experiment.assert_called_once() # Empty third round uploader._upload_once() - mock_client.write_tensorboard_experiment_data.assert_not_called() + self.mock_client.write_tensorboard_experiment_data.assert_not_called() experiment_tracker_mock.set_experiment.assert_called_once() @patch.object( @@ -1083,14 +1105,14 @@ def test_verbosity_zero_creates_upload_tracker_with_verbosity_zero( experiment_tracker_mock.set_tensorboard.return_value = ( _TEST_TENSORBOARD_RESOURCE_NAME ) - mock_client = _create_mock_client() + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME mock_tracker = mock.MagicMock() with mock.patch.object( upload_tracker, "UploadTracker", return_value=mock_tracker ) as mock_constructor: uploader = _create_uploader( - mock_client, + self.mock_client, _TEST_LOG_DIR_NAME, verbosity=0, # Explicitly set verbosity to 0. ) @@ -1131,7 +1153,7 @@ def test_start_uploading_graphs( experiment_tracker_mock.set_tensorboard.return_value = ( _TEST_TENSORBOARD_RESOURCE_NAME ) - mock_client = _create_mock_client() + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() mock_rate_limiter = mock.create_autospec(uploader_utils.RateLimiter) mock_bucket = mock.create_autospec(storage.Bucket) mock_blob = mock.create_autospec(storage.Blob) @@ -1145,12 +1167,12 @@ def create_time_series(tensorboard_time_series, parent=None): display_name=tensorboard_time_series.display_name, ) - mock_client.create_tensorboard_time_series.side_effect = create_time_series + self.mock_client.create_tensorboard_time_series.side_effect = create_time_series with mock.patch.object( upload_tracker, "UploadTracker", return_value=mock_tracker ): uploader = _create_uploader( - writer_client=mock_client, + writer_client=self.mock_client, logdir=_TEST_LOG_DIR_NAME, max_blob_request_size=1000, rpc_rate_limiter=mock_rate_limiter, @@ -1201,7 +1223,7 @@ def create_time_series(tensorboard_time_series, parent=None): actual_graph_def = graph_pb2.GraphDef.FromString(request) self.assertProtoEquals(expected_graph_def, actual_graph_def) - for call in mock_client.write_tensorboard_experiment_data.call_args_list: + for call in self.mock_client.write_tensorboard_experiment_data.call_args_list: kargs = call[1] time_series_data = kargs["write_run_data_requests"][0].time_series_data self.assertEqual(len(time_series_data), 1) @@ -1235,6 +1257,7 @@ def test_filter_graphs( experiment_tracker_mock.set_tensorboard.return_value = ( _TEST_TENSORBOARD_RESOURCE_NAME ) + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() # Three graphs: one short, one long, one corrupt. bytes_0 = _create_example_graph_bytes(123) bytes_1 = _create_example_graph_bytes(9999) @@ -1255,7 +1278,6 @@ def test_filter_graphs( mock_bucket = mock.create_autospec(storage.Bucket) mock_blob = mock.create_autospec(storage.Blob) mock_bucket.blob.return_value = mock_blob - mock_client = _create_mock_client() run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME def create_time_series(tensorboard_time_series, parent=None): @@ -1264,9 +1286,9 @@ def create_time_series(tensorboard_time_series, parent=None): display_name=tensorboard_time_series.display_name, ) - mock_client.create_tensorboard_time_series.side_effect = create_time_series + self.mock_client.create_tensorboard_time_series.side_effect = create_time_series uploader = _create_uploader( - mock_client, + self.mock_client, logdir, logdir_poll_rate_limiter=limiter, blob_storage_bucket=mock_bucket, @@ -1327,7 +1349,7 @@ def test_profile_plugin_included_by_default( os.makedirs(prof_path) uploader = _create_uploader( - _create_mock_client(), + self.mock_client, logdir, one_shot=True, run_name_prefix=run_name, @@ -1355,9 +1377,8 @@ def test_active_experiment_set_experiment_not_called( _TEST_TENSORBOARD_RESOURCE_NAME ) logdir = self.get_temp_dir() - mock_client = _create_mock_client() - uploader = _create_uploader(mock_client, logdir) + uploader = _create_uploader(self.mock_client, logdir) uploader.create_experiment() uploader._upload_once() @@ -1369,6 +1390,24 @@ def test_active_experiment_set_experiment_not_called( @pytest.mark.usefixtures("google_auth_mock") class _TensorBoardTrackerTest(tf.test.TestCase): + def setUp(self): + super(_TensorBoardTrackerTest, self).setUp() + self.mock_client = _create_mock_client() + self.mock_run_resource_mock = self.enter_context( + patch.object( + uploader_utils.OnePlatformResourceManager, + "_get_or_create_run_resource", + autospec=True, + ) + ) + self.mock_time_series_resource_mock = self.enter_context( + patch.object( + uploader_utils.TimeSeriesResourceManager, + "_get_run_resource", + autospec=True, + ) + ) + @patch.object( uploader_utils.OnePlatformResourceManager, "get_run_resource_name", @@ -1386,13 +1425,16 @@ def test_thread_continuously_uploads( experiment_tracker_mock.set_tensorboard.return_value = ( _TEST_TENSORBOARD_RESOURCE_NAME ) + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() + self.mock_time_series_resource_mock.return_value = ( + _create_tensorboard_run_mock() + ) logdir = self.get_temp_dir() - mock_client = _create_mock_client() run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME builder = _create_dispatcher( experiment_resource_name=_TEST_ONE_PLATFORM_EXPERIMENT_NAME, - api=mock_client, + api=self.mock_client, allowed_plugins=_SCALARS_HISTOGRAMS_AND_PROFILE, logdir=logdir, ) @@ -1400,7 +1442,7 @@ def test_thread_continuously_uploads( mock_bucket = _create_mock_blob_storage() uploader = _create_uploader( - mock_client, + self.mock_client, logdir, allowed_plugins=_SCALARS_HISTOGRAMS_AND_PROFILE, rpc_rate_limiter=mock_rate_limiter, @@ -1454,8 +1496,8 @@ def test_thread_continuously_uploads( time.sleep(5) # Check create_time_series calls - self.assertEqual(4, mock_client.create_tensorboard_time_series.call_count) - call_args_list = mock_client.create_tensorboard_time_series.call_args_list + self.assertEqual(4, self.mock_client.create_tensorboard_time_series.call_count) + call_args_list = self.mock_client.create_tensorboard_time_series.call_args_list request1, request2, request3, request4 = ( call_args_list[0][1]["tensorboard_time_series"], call_args_list[1][1]["tensorboard_time_series"], @@ -1470,8 +1512,12 @@ def test_thread_continuously_uploads( experiment_tracker_mock.set_experiment.assert_called_once() # Check write_tensorboard_experiment_data calls - self.assertEqual(1, mock_client.write_tensorboard_experiment_data.call_count) - call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list + self.assertEqual( + 1, self.mock_client.write_tensorboard_experiment_data.call_count + ) + call_args_list = ( + self.mock_client.write_tensorboard_experiment_data.call_args_list + ) request1, request2 = ( call_args_list[0][1]["write_run_data_requests"][0].time_series_data, call_args_list[0][1]["write_run_data_requests"][1].time_series_data, @@ -1511,11 +1557,11 @@ def test_thread_continuously_uploads( uploader._end_experiment_runs.assert_called_once() time.sleep(1) self.assertFalse(uploader_thread.is_alive()) - mock_client.write_tensorboard_experiment_data.reset_mock() + self.mock_client.write_tensorboard_experiment_data.reset_mock() # Empty directory uploader._upload_once() - mock_client.write_tensorboard_experiment_data.assert_not_called() + self.mock_client.write_tensorboard_experiment_data.assert_not_called() with mock.patch.object(uploader, "_end_experiment_runs", return_value=None): uploader._end_uploading() uploader._end_experiment_runs.assert_called_once() @@ -1526,17 +1572,29 @@ def test_thread_continuously_uploads( @pytest.mark.usefixtures("google_auth_mock") class BatchedRequestSenderTest(tf.test.TestCase): + def setUp(self): + super(BatchedRequestSenderTest, self).setUp() + self.mock_client = _create_mock_client() + self.mock_run_resource_mock = self.enter_context( + patch.object( + uploader_utils.OnePlatformResourceManager, + "_get_or_create_run_resource", + autospec=True, + ) + ) + def _populate_run_from_events( self, n_scalar_events, events, allowed_plugins=_USE_DEFAULT ): - mock_client = _create_mock_client() builder = _create_dispatcher( experiment_resource_name="123", - api=mock_client, + api=self.mock_client, allowed_plugins=allowed_plugins, ) builder.dispatch_requests({"": _apply_compat(events)}) - scalar_requests = mock_client.write_tensorboard_experiment_data.call_args_list + scalar_requests = ( + self.mock_client.write_tensorboard_experiment_data.call_args_list + ) if scalar_requests: self.assertLen(scalar_requests, 1) self.assertLen( @@ -1550,6 +1608,7 @@ def test_empty_events(self): self.assertProtoEquals(call_args_list, []) def test_scalar_events(self): + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() events = [ event_pb2.Event(summary=scalar_v2_pb("scalar1", 5.0)), event_pb2.Event(summary=scalar_v2_pb("scalar2", 5.0)), @@ -1559,6 +1618,7 @@ def test_scalar_events(self): self.assertEqual(scalar_tag_counts, {"scalar1": 1, "scalar2": 1}) def test_skips_non_scalar_events(self): + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() events = [ event_pb2.Event(summary=scalar_v2_pb("scalar1", 5.0)), event_pb2.Event(file_version="brain.Event:2"), @@ -1568,6 +1628,7 @@ def test_skips_non_scalar_events(self): self.assertEqual(scalar_tag_counts, {"scalar1": 1}) def test_skips_non_scalar_events_in_scalar_time_series(self): + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() events = [ event_pb2.Event(file_version="brain.Event:2"), event_pb2.Event(summary=scalar_v2_pb("scalar1", 5.0)), @@ -1589,6 +1650,7 @@ def test_skips_events_from_disallowed_plugins(self): self.assertEqual(call_args_lists, []) def test_remembers_first_metadata_in_time_series(self): + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() scalar_1 = event_pb2.Event(summary=scalar_v2_pb("loss", 4.0)) scalar_2 = event_pb2.Event(summary=scalar_v2_pb("loss", 3.0)) scalar_2.summary.value[0].ClearField("metadata") @@ -1602,6 +1664,7 @@ def test_remembers_first_metadata_in_time_series(self): self.assertEqual(scalar_tag_counts, {"loss": 2}) def test_expands_multiple_values_in_event(self): + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() event = event_pb2.Event(step=1, wall_time=123.456) event.summary.value.add(tag="foo", simple_value=1.0) event.summary.value.add(tag="foo", simple_value=2.0) @@ -1638,10 +1701,21 @@ def test_expands_multiple_values_in_event(self): @pytest.mark.usefixtures("google_auth_mock") class ProfileRequestSenderTest(tf.test.TestCase): - def _create_builder(self, mock_client, logdir): + def setUp(self): + super(ProfileRequestSenderTest, self).setUp() + self.mock_client = _create_mock_client() + self.mock_time_series_resource_mock = self.enter_context( + patch.object( + uploader_utils.TimeSeriesResourceManager, + "_get_run_resource", + autospec=True, + ) + ) + + def _create_builder(self, logdir): return _create_dispatcher( experiment_resource_name=_TEST_ONE_PLATFORM_EXPERIMENT_NAME, - api=mock_client, + api=self.mock_client, logdir=logdir, allowed_plugins=frozenset({"profile"}), ) @@ -1650,17 +1724,13 @@ def _populate_run_from_events( self, events, logdir, - mock_client=None, builder=None, ): - if not mock_client: - mock_client = _create_mock_client() - if not builder: - builder = self._create_builder(mock_client, logdir) + builder = self._create_builder(logdir) builder.dispatch_requests({"": _apply_compat(events)}) - profile_requests = mock_client.write_tensorboard_run_data.call_args_list + profile_requests = self.mock_client.write_tensorboard_run_data.call_args_list return profile_requests @@ -1701,6 +1771,9 @@ def test_profile_event_single_prof_run(self, run_resource_mock): ] prof_run_name = "2021_01_01_01_10_10" run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_time_series_resource_mock.return_value = ( + _create_tensorboard_run_mock() + ) with tempfile.TemporaryDirectory() as logdir: prof_path = os.path.join( @@ -1724,11 +1797,13 @@ def test_profile_event_single_prof_run_new_files(self, run_resource_mock): event_pb2.Event(file_version="brain.Event:2"), ] prof_run_name = "2021_01_01_01_10_10" - mock_client = _create_mock_client() run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_time_series_resource_mock.return_value = ( + _create_tensorboard_run_mock() + ) with tempfile.TemporaryDirectory() as logdir: - builder = self._create_builder(mock_client=mock_client, logdir=logdir) + builder = self._create_builder(logdir=logdir) prof_path = os.path.join( logdir, profile_uploader.ProfileRequestSender.PROFILE_PATH ) @@ -1741,13 +1816,13 @@ def test_profile_event_single_prof_run_new_files(self, run_resource_mock): prefix="a", suffix=".xplane.pb", dir=run_path ): call_args_list = self._populate_run_from_events( - events, logdir, builder=builder, mock_client=mock_client + events, logdir, builder=builder ) with tempfile.NamedTemporaryFile( prefix="b", suffix=".xplane.pb", dir=run_path ): call_args_list = self._populate_run_from_events( - events, logdir, builder=builder, mock_client=mock_client + events, logdir, builder=builder ) profile_tag_counts = _extract_tag_counts_time_series(call_args_list) @@ -1763,6 +1838,9 @@ def test_profile_event_multi_prof_run(self, run_resource_mock): "2021_02_02_02_20_20", ] run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_time_series_resource_mock.return_value = ( + _create_tensorboard_run_mock() + ) with tempfile.TemporaryDirectory() as logdir: prof_path = os.path.join( @@ -1798,10 +1876,12 @@ def test_profile_event_add_consecutive_prof_runs(self, run_resource_mock): prof_run_name = "2021_01_01_01_10_10" run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME - mock_client = _create_mock_client() + self.mock_time_series_resource_mock.return_value = ( + _create_tensorboard_run_mock() + ) with tempfile.TemporaryDirectory() as logdir: - builder = self._create_builder(mock_client=mock_client, logdir=logdir) + builder = self._create_builder(logdir=logdir) prof_path = os.path.join( logdir, profile_uploader.ProfileRequestSender.PROFILE_PATH @@ -1819,7 +1899,6 @@ def test_profile_event_add_consecutive_prof_runs(self, run_resource_mock): call_args_list = self._populate_run_from_events( events, logdir, - mock_client=mock_client, builder=builder, ) @@ -1833,13 +1912,12 @@ def test_profile_event_add_consecutive_prof_runs(self, run_resource_mock): run_path = os.path.join(prof_path, prof_run_name_2) os.makedirs(run_path) - mock_client.write_tensorboard_run_data.reset_mock() + self.mock_client.write_tensorboard_run_data.reset_mock() with named_temp(dir=run_path): call_args_list = self._populate_run_from_events( events, logdir, - mock_client=mock_client, builder=builder, ) @@ -1852,21 +1930,31 @@ def test_profile_event_add_consecutive_prof_runs(self, run_resource_mock): @pytest.mark.usefixtures("google_auth_mock") class ScalarBatchedRequestSenderTest(tf.test.TestCase): + def setUp(self): + super(ScalarBatchedRequestSenderTest, self).setUp() + self.mock_client = _create_mock_client() + self.mock_run_resource_mock = self.enter_context( + patch.object( + uploader_utils.OnePlatformResourceManager, + "_get_or_create_run_resource", + autospec=True, + ) + ) + def _add_events(self, sender, events): for event in events: for value in event.summary.value: sender.add_event(_TEST_RUN_NAME, event, value, value.metadata) def _add_events_and_flush(self, events, expected_n_time_series): - mock_client = _create_mock_client() sender = _create_scalar_request_sender( experiment_resource_id=_TEST_EXPERIMENT_NAME, - api=mock_client, + api=self.mock_client, ) self._add_events(sender, events) sender.flush() - requests = mock_client.write_tensorboard_experiment_data.call_args_list + requests = self.mock_client.write_tensorboard_experiment_data.call_args_list self.assertLen(requests, 1) call_args = requests[0] self.assertLen( @@ -1878,6 +1966,7 @@ def _add_events_and_flush(self, events, expected_n_time_series): @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") def test_aggregation_by_tag(self, run_resource_mock): run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() def make_event(step, wall_time, tag, value): return event_pb2.Event( @@ -1919,6 +2008,7 @@ def make_event(step, wall_time, tag, value): @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") def test_v1_summary(self, run_resource_mock): run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() event = event_pb2.Event(step=1, wall_time=123.456) event.summary.value.add(tag="foo", simple_value=5.0) call_args = self._add_events_and_flush(_apply_compat([event]), 1) @@ -1944,6 +2034,7 @@ def test_v1_summary(self, run_resource_mock): @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") def test_v1_summary_tb_summary(self, run_resource_mock): run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() tf_summary = summary_v1.scalar_pb("foo", 5.0) tb_summary = summary_pb2.Summary.FromString(tf_summary.SerializeToString()) event = event_pb2.Event(step=1, wall_time=123.456, summary=tb_summary) @@ -1970,6 +2061,7 @@ def test_v1_summary_tb_summary(self, run_resource_mock): @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") def test_v2_summary(self, run_resource_mock): run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() event = event_pb2.Event( step=1, wall_time=123.456, summary=scalar_v2_pb("foo", 5.0) ) @@ -1996,44 +2088,45 @@ def test_v2_summary(self, run_resource_mock): @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") def test_propagates_experiment_deletion(self, run_resource_mock): run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() event = event_pb2.Event(step=1) event.summary.value.add(tag="foo", simple_value=1.0) - mock_client = _create_mock_client() - sender = _create_scalar_request_sender("123", mock_client) + sender = _create_scalar_request_sender("123", self.mock_client) self._add_events(sender, _apply_compat([event])) error = _grpc_error(grpc.StatusCode.NOT_FOUND, "nope") - mock_client.write_tensorboard_experiment_data.side_effect = error + self.mock_client.write_tensorboard_experiment_data.side_effect = error with self.assertRaises(uploader_lib.ExperimentNotFoundError): sender.flush() def test_no_budget_for_base_request(self): - mock_client = _create_mock_client() long_experiment_id = "A" * 12 with self.assertRaises(uploader_lib._OutOfSpaceError) as cm: _create_scalar_request_sender( experiment_resource_id=long_experiment_id, - api=mock_client, + api=self.mock_client, max_request_size=12, ) self.assertEqual(str(cm.exception), "Byte budget too small for base request") @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") def test_no_room_for_single_point(self, run_resource_mock): - mock_client = _create_mock_client() + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME event = event_pb2.Event(step=1, wall_time=123.456) event.summary.value.add(tag="foo", simple_value=1.0) - sender = _create_scalar_request_sender("123", mock_client, max_request_size=12) + sender = _create_scalar_request_sender( + "123", self.mock_client, max_request_size=12 + ) with self.assertRaises(RuntimeError) as cm: self._add_events(sender, [event]) self.assertEqual(str(cm.exception), "add_event failed despite flush") @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") def test_break_at_run_boundary(self, run_resource_mock): - mock_client = _create_mock_client() run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() # Choose run name sizes such that one run fits in a 1024 byte request, # but not two. long_run_1 = "A" * 768 @@ -2045,14 +2138,14 @@ def test_break_at_run_boundary(self, run_resource_mock): sender_1 = _create_scalar_request_sender( long_run_1, - mock_client, + self.mock_client, # Set a limit to request size max_request_size=1024, ) sender_2 = _create_scalar_request_sender( long_run_2, - mock_client, + self.mock_client, # Set a limit to request size max_request_size=1024, ) @@ -2060,7 +2153,9 @@ def test_break_at_run_boundary(self, run_resource_mock): self._add_events(sender_2, _apply_compat([event_2])) sender_1.flush() sender_2.flush() - call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list + call_args_list = ( + self.mock_client.write_tensorboard_experiment_data.call_args_list + ) for call_args in call_args_list: _clear_wall_times( @@ -2105,8 +2200,8 @@ def test_break_at_run_boundary(self, run_resource_mock): @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") def test_break_at_tag_boundary(self, run_resource_mock): - mock_client = _create_mock_client() run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() # Choose tag name sizes such that one tag fits in a 1024 byte request, # but not two. Note that tag names appear in both `Tag.name` and the # summary metadata. @@ -2118,13 +2213,15 @@ def test_break_at_tag_boundary(self, run_resource_mock): sender = _create_scalar_request_sender( "train", - mock_client, + self.mock_client, # Set a limit to request size max_request_size=1024, ) self._add_events(sender, _apply_compat([event])) sender.flush() - call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list + call_args_list = ( + self.mock_client.write_tensorboard_experiment_data.call_args_list + ) request1 = call_args_list[0][1]["write_run_data_requests"][0].time_series_data _clear_wall_times(request1) @@ -2151,8 +2248,8 @@ def test_break_at_tag_boundary(self, run_resource_mock): @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") def test_break_at_scalar_point_boundary(self, run_resource_mock): - mock_client = _create_mock_client() run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() point_count = 2000 # comfortably saturates a single 1024-byte request events = [] for step in range(point_count): @@ -2163,13 +2260,15 @@ def test_break_at_scalar_point_boundary(self, run_resource_mock): sender = _create_scalar_request_sender( "train", - mock_client, + self.mock_client, # Set a limit to request size max_request_size=1024, ) self._add_events(sender, _apply_compat(events)) sender.flush() - call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list + call_args_list = ( + self.mock_client.write_tensorboard_experiment_data.call_args_list + ) for call_args in call_args_list: _clear_wall_times( @@ -2201,8 +2300,8 @@ def test_break_at_scalar_point_boundary(self, run_resource_mock): @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") def test_prunes_tags_and_runs(self, run_resource_mock): - mock_client = _create_mock_client() run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() event_1 = event_pb2.Event(step=1) event_1.summary.value.add(tag="foo", simple_value=1.0) event_2 = event_pb2.Event(step=2) @@ -2222,12 +2321,14 @@ def mock_add_point(byte_budget_manager_self, point): "add_point", mock_add_point, ): - sender = _create_scalar_request_sender("123", mock_client) + sender = _create_scalar_request_sender("123", self.mock_client) self._add_events(sender, _apply_compat([event_1])) self._add_events(sender, _apply_compat([event_2])) sender.flush() - call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list + call_args_list = ( + self.mock_client.write_tensorboard_experiment_data.call_args_list + ) request1, request2 = ( call_args_list[0][1]["write_run_data_requests"][0].time_series_data, call_args_list[1][1]["write_run_data_requests"][0].time_series_data, @@ -2261,6 +2362,7 @@ def mock_add_point(byte_budget_manager_self, point): @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") def test_wall_time_precision(self, run_resource_mock): run_resource_mock.return_value = _TEST_ONE_PLATFORM_RUN_NAME + self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock() # Test a wall time that is exactly representable in float64 but has enough # digits to incur error if converted to nanoseconds the naive way (* 1e9). event1 = event_pb2.Event(step=1, wall_time=1567808404.765432119) @@ -2292,10 +2394,20 @@ def test_wall_time_precision(self, run_resource_mock): @pytest.mark.usefixtures("google_auth_mock") class FileRequestSenderTest(tf.test.TestCase): + def setUp(self): + super(FileRequestSenderTest, self).setUp() + self.mock_client = _create_mock_client() + self.mock_time_series_resource_mock = self.enter_context( + patch.object( + uploader_utils.TimeSeriesResourceManager, + "_get_run_resource", + autospec=True, + ) + ) + def test_empty_files_no_messages(self): - mock_client = _create_mock_client() sender = _create_file_request_sender( - api=mock_client, + api=self.mock_client, run_resource_id=_TEST_ONE_PLATFORM_RUN_NAME, ) @@ -2303,12 +2415,11 @@ def test_empty_files_no_messages(self): files=[], tag="my_tag", plugin="test_plugin", event_timestamp="" ) - self.assertEmpty(mock_client.write_tensorboard_run_data.call_args_list) + self.assertEmpty(self.mock_client.write_tensorboard_run_data.call_args_list) def test_fake_files_no_sent_messages(self): - mock_client = _create_mock_client() sender = _create_file_request_sender( - api=mock_client, + api=self.mock_client, run_resource_id=_TEST_ONE_PLATFORM_RUN_NAME, ) @@ -2320,12 +2431,14 @@ def test_fake_files_no_sent_messages(self): event_timestamp="", ) - self.assertEmpty(mock_client.write_tensorboard_run_data.call_args_list) + self.assertEmpty(self.mock_client.write_tensorboard_run_data.call_args_list) def test_files_too_large(self): - mock_client = _create_mock_client() + self.mock_time_series_resource_mock.return_value = ( + _create_tensorboard_run_mock() + ) sender = _create_file_request_sender( - api=mock_client, + api=self.mock_client, run_resource_id=_TEST_ONE_PLATFORM_RUN_NAME, max_blob_size=10, ) @@ -2342,12 +2455,14 @@ def test_files_too_large(self): ), ) - self.assertEmpty(mock_client.write_tensorboard_run_data.call_args_list) + self.assertEmpty(self.mock_client.write_tensorboard_run_data.call_args_list) def test_single_file_upload(self): - mock_client = _create_mock_client() + self.mock_time_series_resource_mock.return_value = ( + _create_tensorboard_run_mock() + ) sender = _create_file_request_sender( - api=mock_client, + api=self.mock_client, run_resource_id=_TEST_ONE_PLATFORM_RUN_NAME, ) @@ -2362,15 +2477,19 @@ def test_single_file_upload(self): ), ) - call_args_list = mock_client.write_tensorboard_run_data.call_args_list[0][1] + call_args_list = self.mock_client.write_tensorboard_run_data.call_args_list[0][ + 1 + ] self.assertEqual( fn, call_args_list["time_series_data"][0].values[0].blobs.values[0].id ) def test_multi_file_upload(self): - mock_client = _create_mock_client() + self.mock_time_series_resource_mock.return_value = ( + _create_tensorboard_run_mock() + ) sender = _create_file_request_sender( - api=mock_client, + api=self.mock_client, run_resource_id=_TEST_ONE_PLATFORM_RUN_NAME, ) @@ -2386,7 +2505,9 @@ def test_multi_file_upload(self): ), ) - call_args_list = mock_client.write_tensorboard_run_data.call_args_list[0][1] + call_args_list = self.mock_client.write_tensorboard_run_data.call_args_list[0][ + 1 + ] self.assertEqual( files, @@ -2397,11 +2518,13 @@ def test_multi_file_upload(self): ) def test_add_files_no_experiment(self): - mock_client = _create_mock_client() - mock_client.write_tensorboard_run_data.side_effect = grpc.RpcError + self.mock_time_series_resource_mock.return_value = ( + _create_tensorboard_run_mock() + ) + self.mock_client.write_tensorboard_run_data.side_effect = grpc.RpcError sender = _create_file_request_sender( - api=mock_client, + api=self.mock_client, run_resource_id=_TEST_ONE_PLATFORM_RUN_NAME, ) @@ -2415,14 +2538,17 @@ def test_add_files_no_experiment(self): ), ) - mock_client.write_tensorboard_run_data.assert_called_once() + self.mock_client.write_tensorboard_run_data.assert_called_once() - def test_add_files_from_local(self): - mock_client = _create_mock_client() + @patch.object(uploader_utils.OnePlatformResourceManager, "get_run_resource_name") + def test_add_files_from_local(self, run_resource_mock): + self.mock_time_series_resource_mock.return_value = ( + _create_tensorboard_run_mock() + ) bucket = _create_mock_blob_storage() sender = _create_file_request_sender( - api=mock_client, + api=self.mock_client, run_resource_id=_TEST_ONE_PLATFORM_RUN_NAME, blob_storage_bucket=bucket, source_bucket=None, @@ -2432,7 +2558,7 @@ def test_add_files_from_local(self): sender.add_files( files=[f1.name], tag="my_tag", - plugin="test_plugin", + plugin="profile", event_timestamp=timestamp_pb2.Timestamp().FromDatetime( datetime.datetime.strptime("2020-01-01", "%Y-%m-%d") ), @@ -2441,9 +2567,8 @@ def test_add_files_from_local(self): bucket.blob.assert_called_once() def test_copy_blobs(self): - mock_client = _create_mock_client() sender = _create_file_request_sender( - api=mock_client, + api=self.mock_client, run_resource_id=_TEST_ONE_PLATFORM_RUN_NAME, ) diff --git a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py index 8cf5b53e00..286fe51291 100644 --- a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.dataset_service import ( @@ -2723,12 +2724,16 @@ def test_list_datasets_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_datasets(request={}) + pager = client.list_datasets(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -6064,12 +6069,16 @@ def test_list_dataset_versions_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_dataset_versions(request={}) + pager = client.list_dataset_versions(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -7018,12 +7027,16 @@ def test_list_data_items_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_data_items(request={}) + pager = client.list_data_items(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -7522,12 +7535,16 @@ def test_search_data_items_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("dataset", ""),)), ) - pager = client.search_data_items(request={}) + pager = client.search_data_items(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -8113,12 +8130,16 @@ def test_list_saved_queries_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_saved_queries(request={}) + pager = client.list_saved_queries(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -9465,12 +9486,16 @@ def test_list_annotations_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_annotations(request={}) + pager = client.list_annotations(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_deployment_resource_pool_service.py b/tests/unit/gapic/aiplatform_v1/test_deployment_resource_pool_service.py index 1398b39b8a..abe5837c89 100644 --- a/tests/unit/gapic/aiplatform_v1/test_deployment_resource_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_deployment_resource_pool_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.deployment_resource_pool_service import ( @@ -69,6 +70,7 @@ from google.cloud.aiplatform_v1.types import endpoint from google.cloud.aiplatform_v1.types import machine_resources from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import reservation_affinity from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore @@ -2522,12 +2524,18 @@ def test_list_deployment_resource_pools_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_deployment_resource_pools(request={}) + pager = client.list_deployment_resource_pools( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -3945,14 +3953,18 @@ def test_query_deployed_models_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata( (("deployment_resource_pool", ""),) ), ) - pager = client.query_deployed_models(request={}) + pager = client.query_deployed_models(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -5210,12 +5222,18 @@ def test_update_deployment_resource_pool_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "service_account": "service_account_value", @@ -7048,8 +7066,36 @@ def test_parse_model_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "cuttlefish" + zone = "mussel" + reservation_name = "winkle" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = DeploymentResourcePoolServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "nautilus", + "zone": "scallop", + "reservation_name": "abalone", + } + path = DeploymentResourcePoolServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = DeploymentResourcePoolServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "cuttlefish" + billing_account = "squid" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -7061,7 +7107,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", + "billing_account": "clam", } path = DeploymentResourcePoolServiceClient.common_billing_account_path(**expected) @@ -7071,7 +7117,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "winkle" + folder = "whelk" expected = "folders/{folder}".format( folder=folder, ) @@ -7081,7 +7127,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "nautilus", + "folder": "octopus", } path = DeploymentResourcePoolServiceClient.common_folder_path(**expected) @@ -7091,7 +7137,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "scallop" + organization = "oyster" expected = "organizations/{organization}".format( organization=organization, ) @@ -7101,7 +7147,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "abalone", + "organization": "nudibranch", } path = DeploymentResourcePoolServiceClient.common_organization_path(**expected) @@ -7111,7 +7157,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "squid" + project = "cuttlefish" expected = "projects/{project}".format( project=project, ) @@ -7121,7 +7167,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "clam", + "project": "mussel", } path = DeploymentResourcePoolServiceClient.common_project_path(**expected) @@ -7131,8 +7177,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "whelk" - location = "octopus" + project = "winkle" + location = "nautilus" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -7143,8 +7189,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", + "project": "scallop", + "location": "abalone", } path = DeploymentResourcePoolServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py index 07e3218cd7..ebaacda248 100644 --- a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.endpoint_service import ( @@ -65,6 +66,7 @@ from google.cloud.aiplatform_v1.types import io from google.cloud.aiplatform_v1.types import machine_resources from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import reservation_affinity from google.cloud.aiplatform_v1.types import service_networking from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore @@ -2378,12 +2380,16 @@ def test_list_endpoints_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_endpoints(request={}) + pager = client.list_endpoints(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4597,12 +4603,18 @@ def test_create_endpoint_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "automatic_resources": { "min_replica_count": 1803, @@ -5798,12 +5810,18 @@ def test_update_endpoint_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "automatic_resources": { "min_replica_count": 1803, @@ -8320,8 +8338,36 @@ def test_parse_network_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "oyster" + zone = "nudibranch" + reservation_name = "cuttlefish" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = EndpointServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "mussel", + "zone": "winkle", + "reservation_name": "nautilus", + } + path = EndpointServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "oyster" + billing_account = "scallop" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -8331,7 +8377,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", + "billing_account": "abalone", } path = EndpointServiceClient.common_billing_account_path(**expected) @@ -8341,7 +8387,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "cuttlefish" + folder = "squid" expected = "folders/{folder}".format( folder=folder, ) @@ -8351,7 +8397,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "mussel", + "folder": "clam", } path = EndpointServiceClient.common_folder_path(**expected) @@ -8361,7 +8407,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "winkle" + organization = "whelk" expected = "organizations/{organization}".format( organization=organization, ) @@ -8371,7 +8417,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "nautilus", + "organization": "octopus", } path = EndpointServiceClient.common_organization_path(**expected) @@ -8381,7 +8427,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "scallop" + project = "oyster" expected = "projects/{project}".format( project=project, ) @@ -8391,7 +8437,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "abalone", + "project": "nudibranch", } path = EndpointServiceClient.common_project_path(**expected) @@ -8401,8 +8447,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "squid" - location = "clam" + project = "cuttlefish" + location = "mussel" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -8413,8 +8459,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", + "project": "winkle", + "location": "nautilus", } path = EndpointServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_evaluation_service.py b/tests/unit/gapic/aiplatform_v1/test_evaluation_service.py new file mode 100644 index 0000000000..b46ada2273 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1/test_evaluation_service.py @@ -0,0 +1,4555 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +import grpc +from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json +import math +import pytest +from google.api_core import api_core_version +from proto.marshal.rules.dates import DurationRule, TimestampRule +from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format + +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import path_template +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1.services.evaluation_service import ( + EvaluationServiceAsyncClient, +) +from google.cloud.aiplatform_v1.services.evaluation_service import ( + EvaluationServiceClient, +) +from google.cloud.aiplatform_v1.services.evaluation_service import transports +from google.cloud.aiplatform_v1.types import evaluation_service +from google.cloud.location import locations_pb2 +from google.iam.v1 import iam_policy_pb2 # type: ignore +from google.iam.v1 import options_pb2 # type: ignore +from google.iam.v1 import policy_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +import google.auth + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert EvaluationServiceClient._get_default_mtls_endpoint(None) is None + assert ( + EvaluationServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + EvaluationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + EvaluationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EvaluationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EvaluationServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +def test__read_environment_variables(): + assert EvaluationServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert EvaluationServiceClient._read_environment_variables() == ( + True, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert EvaluationServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + EvaluationServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert EvaluationServiceClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert EvaluationServiceClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert EvaluationServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + EvaluationServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert EvaluationServiceClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert EvaluationServiceClient._get_client_cert_source(None, False) is None + assert ( + EvaluationServiceClient._get_client_cert_source( + mock_provided_cert_source, False + ) + is None + ) + assert ( + EvaluationServiceClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + EvaluationServiceClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + EvaluationServiceClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + EvaluationServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(EvaluationServiceClient), +) +@mock.patch.object( + EvaluationServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(EvaluationServiceAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = EvaluationServiceClient._DEFAULT_UNIVERSE + default_endpoint = EvaluationServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = EvaluationServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + EvaluationServiceClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + EvaluationServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == EvaluationServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + EvaluationServiceClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + EvaluationServiceClient._get_api_endpoint( + None, None, default_universe, "always" + ) + == EvaluationServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + EvaluationServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == EvaluationServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + EvaluationServiceClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + EvaluationServiceClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + EvaluationServiceClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + EvaluationServiceClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + EvaluationServiceClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + EvaluationServiceClient._get_universe_domain(None, None) + == EvaluationServiceClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + EvaluationServiceClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EvaluationServiceClient, transports.EvaluationServiceGrpcTransport, "grpc"), + (EvaluationServiceClient, transports.EvaluationServiceRestTransport, "rest"), + ], +) +def test__validate_universe_domain(client_class, transport_class, transport_name): + client = client_class( + transport=transport_class(credentials=ga_credentials.AnonymousCredentials()) + ) + assert client._validate_universe_domain() == True + + # Test the case when universe is already validated. + assert client._validate_universe_domain() == True + + if transport_name == "grpc": + # Test the case where credentials are provided by the + # `local_channel_credentials`. The default universes in both match. + channel = grpc.secure_channel( + "http://localhost/", grpc.local_channel_credentials() + ) + client = client_class(transport=transport_class(channel=channel)) + assert client._validate_universe_domain() == True + + # Test the case where credentials do not exist: e.g. a transport is provided + # with no credentials. Validation should still succeed because there is no + # mismatch with non-existent credentials. + channel = grpc.secure_channel( + "http://localhost/", grpc.local_channel_credentials() + ) + transport = transport_class(channel=channel) + transport._credentials = None + client = client_class(transport=transport) + assert client._validate_universe_domain() == True + + # TODO: This is needed to cater for older versions of google-auth + # Make this test unconditional once the minimum supported version of + # google-auth becomes 2.23.0 or higher. + google_auth_major, google_auth_minor = [ + int(part) for part in google.auth.__version__.split(".")[0:2] + ] + if google_auth_major > 2 or (google_auth_major == 2 and google_auth_minor >= 23): + credentials = ga_credentials.AnonymousCredentials() + credentials._universe_domain = "foo.com" + # Test the case when there is a universe mismatch from the credentials. + client = client_class(transport=transport_class(credentials=credentials)) + with pytest.raises(ValueError) as excinfo: + client._validate_universe_domain() + assert ( + str(excinfo.value) + == "The configured universe domain (googleapis.com) does not match the universe domain found in the credentials (foo.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default." + ) + + # Test the case when there is a universe mismatch from the client. + # + # TODO: Make this test unconditional once the minimum supported version of + # google-api-core becomes 2.15.0 or higher. + api_core_major, api_core_minor = [ + int(part) for part in api_core_version.__version__.split(".")[0:2] + ] + if api_core_major > 2 or (api_core_major == 2 and api_core_minor >= 15): + client = client_class( + client_options={"universe_domain": "bar.com"}, + transport=transport_class( + credentials=ga_credentials.AnonymousCredentials(), + ), + ) + with pytest.raises(ValueError) as excinfo: + client._validate_universe_domain() + assert ( + str(excinfo.value) + == "The configured universe domain (bar.com) does not match the universe domain found in the credentials (googleapis.com). If you haven't configured the universe domain explicitly, `googleapis.com` is the default." + ) + + # Test that ValueError is raised if universe_domain is provided via client options and credentials is None + with pytest.raises(ValueError): + client._compare_universes("foo.bar", None) + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (EvaluationServiceClient, "grpc"), + (EvaluationServiceAsyncClient, "grpc_asyncio"), + (EvaluationServiceClient, "rest"), + ], +) +def test_evaluation_service_client_from_service_account_info( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "aiplatform.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://aiplatform.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.EvaluationServiceGrpcTransport, "grpc"), + (transports.EvaluationServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.EvaluationServiceRestTransport, "rest"), + ], +) +def test_evaluation_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (EvaluationServiceClient, "grpc"), + (EvaluationServiceAsyncClient, "grpc_asyncio"), + (EvaluationServiceClient, "rest"), + ], +) +def test_evaluation_service_client_from_service_account_file( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "aiplatform.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://aiplatform.googleapis.com" + ) + + +def test_evaluation_service_client_get_transport_class(): + transport = EvaluationServiceClient.get_transport_class() + available_transports = [ + transports.EvaluationServiceGrpcTransport, + transports.EvaluationServiceRestTransport, + ] + assert transport in available_transports + + transport = EvaluationServiceClient.get_transport_class("grpc") + assert transport == transports.EvaluationServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EvaluationServiceClient, transports.EvaluationServiceGrpcTransport, "grpc"), + ( + EvaluationServiceAsyncClient, + transports.EvaluationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (EvaluationServiceClient, transports.EvaluationServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + EvaluationServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(EvaluationServiceClient), +) +@mock.patch.object( + EvaluationServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(EvaluationServiceAsyncClient), +) +def test_evaluation_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(EvaluationServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(EvaluationServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + EvaluationServiceClient, + transports.EvaluationServiceGrpcTransport, + "grpc", + "true", + ), + ( + EvaluationServiceAsyncClient, + transports.EvaluationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + EvaluationServiceClient, + transports.EvaluationServiceGrpcTransport, + "grpc", + "false", + ), + ( + EvaluationServiceAsyncClient, + transports.EvaluationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ( + EvaluationServiceClient, + transports.EvaluationServiceRestTransport, + "rest", + "true", + ), + ( + EvaluationServiceClient, + transports.EvaluationServiceRestTransport, + "rest", + "false", + ), + ], +) +@mock.patch.object( + EvaluationServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(EvaluationServiceClient), +) +@mock.patch.object( + EvaluationServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(EvaluationServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_evaluation_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class", [EvaluationServiceClient, EvaluationServiceAsyncClient] +) +@mock.patch.object( + EvaluationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EvaluationServiceClient), +) +@mock.patch.object( + EvaluationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EvaluationServiceAsyncClient), +) +def test_evaluation_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize( + "client_class", [EvaluationServiceClient, EvaluationServiceAsyncClient] +) +@mock.patch.object( + EvaluationServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(EvaluationServiceClient), +) +@mock.patch.object( + EvaluationServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(EvaluationServiceAsyncClient), +) +def test_evaluation_service_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = EvaluationServiceClient._DEFAULT_UNIVERSE + default_endpoint = EvaluationServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = EvaluationServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=ga_credentials.AnonymousCredentials(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + else: + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == default_endpoint + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EvaluationServiceClient, transports.EvaluationServiceGrpcTransport, "grpc"), + ( + EvaluationServiceAsyncClient, + transports.EvaluationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (EvaluationServiceClient, transports.EvaluationServiceRestTransport, "rest"), + ], +) +def test_evaluation_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + EvaluationServiceClient, + transports.EvaluationServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + EvaluationServiceAsyncClient, + transports.EvaluationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ( + EvaluationServiceClient, + transports.EvaluationServiceRestTransport, + "rest", + None, + ), + ], +) +def test_evaluation_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_evaluation_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1.services.evaluation_service.transports.EvaluationServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = EvaluationServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + EvaluationServiceClient, + transports.EvaluationServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + EvaluationServiceAsyncClient, + transports.EvaluationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_evaluation_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + evaluation_service.EvaluateInstancesRequest, + dict, + ], +) +def test_evaluate_instances(request_type, transport: str = "grpc"): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.evaluate_instances), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = evaluation_service.EvaluateInstancesResponse() + response = client.evaluate_instances(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = evaluation_service.EvaluateInstancesRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, evaluation_service.EvaluateInstancesResponse) + + +def test_evaluate_instances_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.evaluate_instances), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.evaluate_instances() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == evaluation_service.EvaluateInstancesRequest() + + +def test_evaluate_instances_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = evaluation_service.EvaluateInstancesRequest( + location="location_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.evaluate_instances), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.evaluate_instances(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == evaluation_service.EvaluateInstancesRequest( + location="location_value", + ) + + +def test_evaluate_instances_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.evaluate_instances in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.evaluate_instances + ] = mock_rpc + request = {} + client.evaluate_instances(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.evaluate_instances(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_evaluate_instances_empty_call_async(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.evaluate_instances), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + evaluation_service.EvaluateInstancesResponse() + ) + response = await client.evaluate_instances() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == evaluation_service.EvaluateInstancesRequest() + + +@pytest.mark.asyncio +async def test_evaluate_instances_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.evaluate_instances + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_object = mock.AsyncMock() + client._client._transport._wrapped_methods[ + client._client._transport.evaluate_instances + ] = mock_object + + request = {} + await client.evaluate_instances(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_object.call_count == 1 + + await client.evaluate_instances(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_object.call_count == 2 + + +@pytest.mark.asyncio +async def test_evaluate_instances_async( + transport: str = "grpc_asyncio", + request_type=evaluation_service.EvaluateInstancesRequest, +): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.evaluate_instances), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + evaluation_service.EvaluateInstancesResponse() + ) + response = await client.evaluate_instances(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = evaluation_service.EvaluateInstancesRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, evaluation_service.EvaluateInstancesResponse) + + +@pytest.mark.asyncio +async def test_evaluate_instances_async_from_dict(): + await test_evaluate_instances_async(request_type=dict) + + +def test_evaluate_instances_field_headers(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = evaluation_service.EvaluateInstancesRequest() + + request.location = "location_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.evaluate_instances), "__call__" + ) as call: + call.return_value = evaluation_service.EvaluateInstancesResponse() + client.evaluate_instances(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "location=location_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_evaluate_instances_field_headers_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = evaluation_service.EvaluateInstancesRequest() + + request.location = "location_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.evaluate_instances), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + evaluation_service.EvaluateInstancesResponse() + ) + await client.evaluate_instances(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "location=location_value", + ) in kw["metadata"] + + +@pytest.mark.parametrize( + "request_type", + [ + evaluation_service.EvaluateInstancesRequest, + dict, + ], +) +def test_evaluate_instances_rest(request_type): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"location": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = evaluation_service.EvaluateInstancesResponse() + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = evaluation_service.EvaluateInstancesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.evaluate_instances(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, evaluation_service.EvaluateInstancesResponse) + + +def test_evaluate_instances_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.evaluate_instances in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.evaluate_instances + ] = mock_rpc + + request = {} + client.evaluate_instances(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.evaluate_instances(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_evaluate_instances_rest_required_fields( + request_type=evaluation_service.EvaluateInstancesRequest, +): + transport_class = transports.EvaluationServiceRestTransport + + request_init = {} + request_init["location"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).evaluate_instances._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["location"] = "location_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).evaluate_instances._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "location" in jsonified_request + assert jsonified_request["location"] == "location_value" + + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = evaluation_service.EvaluateInstancesResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = evaluation_service.EvaluateInstancesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.evaluate_instances(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_evaluate_instances_rest_unset_required_fields(): + transport = transports.EvaluationServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.evaluate_instances._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("location",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_evaluate_instances_rest_interceptors(null_interceptor): + transport = transports.EvaluationServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.EvaluationServiceRestInterceptor(), + ) + client = EvaluationServiceClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.EvaluationServiceRestInterceptor, "post_evaluate_instances" + ) as post, mock.patch.object( + transports.EvaluationServiceRestInterceptor, "pre_evaluate_instances" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = evaluation_service.EvaluateInstancesRequest.pb( + evaluation_service.EvaluateInstancesRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = ( + evaluation_service.EvaluateInstancesResponse.to_json( + evaluation_service.EvaluateInstancesResponse() + ) + ) + + request = evaluation_service.EvaluateInstancesRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = evaluation_service.EvaluateInstancesResponse() + + client.evaluate_instances( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_evaluate_instances_rest_bad_request( + transport: str = "rest", request_type=evaluation_service.EvaluateInstancesRequest +): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"location": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.evaluate_instances(request) + + +def test_evaluate_instances_rest_error(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.EvaluationServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.EvaluationServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EvaluationServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.EvaluationServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = EvaluationServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = EvaluationServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.EvaluationServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EvaluationServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.EvaluationServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = EvaluationServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.EvaluationServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.EvaluationServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.EvaluationServiceGrpcTransport, + transports.EvaluationServiceGrpcAsyncIOTransport, + transports.EvaluationServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "rest", + ], +) +def test_transport_kind(transport_name): + transport = EvaluationServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.EvaluationServiceGrpcTransport, + ) + + +def test_evaluation_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.EvaluationServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_evaluation_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1.services.evaluation_service.transports.EvaluationServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.EvaluationServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "evaluate_instances", + "set_iam_policy", + "get_iam_policy", + "test_iam_permissions", + "get_location", + "list_locations", + "get_operation", + "wait_operation", + "cancel_operation", + "delete_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_evaluation_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.evaluation_service.transports.EvaluationServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.EvaluationServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_evaluation_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.cloud.aiplatform_v1.services.evaluation_service.transports.EvaluationServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.EvaluationServiceTransport() + adc.assert_called_once() + + +def test_evaluation_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + EvaluationServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.EvaluationServiceGrpcTransport, + transports.EvaluationServiceGrpcAsyncIOTransport, + ], +) +def test_evaluation_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.EvaluationServiceGrpcTransport, + transports.EvaluationServiceGrpcAsyncIOTransport, + transports.EvaluationServiceRestTransport, + ], +) +def test_evaluation_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.EvaluationServiceGrpcTransport, grpc_helpers), + (transports.EvaluationServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_evaluation_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "aiplatform.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=["1", "2"], + default_host="aiplatform.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.EvaluationServiceGrpcTransport, + transports.EvaluationServiceGrpcAsyncIOTransport, + ], +) +def test_evaluation_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_evaluation_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.EvaluationServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_evaluation_service_host_no_port(transport_name): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "aiplatform.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://aiplatform.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_evaluation_service_host_with_port(transport_name): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "aiplatform.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://aiplatform.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_evaluation_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = EvaluationServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = EvaluationServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.evaluate_instances._session + session2 = client2.transport.evaluate_instances._session + assert session1 != session2 + + +def test_evaluation_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.EvaluationServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_evaluation_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.EvaluationServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.EvaluationServiceGrpcTransport, + transports.EvaluationServiceGrpcAsyncIOTransport, + ], +) +def test_evaluation_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.EvaluationServiceGrpcTransport, + transports.EvaluationServiceGrpcAsyncIOTransport, + ], +) +def test_evaluation_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_common_billing_account_path(): + billing_account = "squid" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = EvaluationServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = EvaluationServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = EvaluationServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = EvaluationServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = EvaluationServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = EvaluationServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = EvaluationServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = EvaluationServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = EvaluationServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + expected = "projects/{project}".format( + project=project, + ) + actual = EvaluationServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = EvaluationServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = EvaluationServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = EvaluationServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = EvaluationServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = EvaluationServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.EvaluationServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.EvaluationServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = EvaluationServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_get_location_rest_bad_request( + transport: str = "rest", request_type=locations_pb2.GetLocationRequest +): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + {"name": "projects/sample1/locations/sample2"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_location(request) + + +@pytest.mark.parametrize( + "request_type", + [ + locations_pb2.GetLocationRequest, + dict, + ], +) +def test_get_location_rest(request_type): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = {"name": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = locations_pb2.Location() + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_location(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.Location) + + +def test_list_locations_rest_bad_request( + transport: str = "rest", request_type=locations_pb2.ListLocationsRequest +): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict({"name": "projects/sample1"}, request) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_locations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + locations_pb2.ListLocationsRequest, + dict, + ], +) +def test_list_locations_rest(request_type): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = {"name": "projects/sample1"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = locations_pb2.ListLocationsResponse() + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_locations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.ListLocationsResponse) + + +def test_get_iam_policy_rest_bad_request( + transport: str = "rest", request_type=iam_policy_pb2.GetIamPolicyRequest +): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + {"resource": "projects/sample1/locations/sample2/featurestores/sample3"}, + request, + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_iam_policy(request) + + +@pytest.mark.parametrize( + "request_type", + [ + iam_policy_pb2.GetIamPolicyRequest, + dict, + ], +) +def test_get_iam_policy_rest(request_type): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = { + "resource": "projects/sample1/locations/sample2/featurestores/sample3" + } + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = policy_pb2.Policy() + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_iam_policy(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + +def test_set_iam_policy_rest_bad_request( + transport: str = "rest", request_type=iam_policy_pb2.SetIamPolicyRequest +): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + {"resource": "projects/sample1/locations/sample2/featurestores/sample3"}, + request, + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.set_iam_policy(request) + + +@pytest.mark.parametrize( + "request_type", + [ + iam_policy_pb2.SetIamPolicyRequest, + dict, + ], +) +def test_set_iam_policy_rest(request_type): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = { + "resource": "projects/sample1/locations/sample2/featurestores/sample3" + } + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = policy_pb2.Policy() + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.set_iam_policy(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + +def test_test_iam_permissions_rest_bad_request( + transport: str = "rest", request_type=iam_policy_pb2.TestIamPermissionsRequest +): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + {"resource": "projects/sample1/locations/sample2/featurestores/sample3"}, + request, + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.test_iam_permissions(request) + + +@pytest.mark.parametrize( + "request_type", + [ + iam_policy_pb2.TestIamPermissionsRequest, + dict, + ], +) +def test_test_iam_permissions_rest(request_type): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = { + "resource": "projects/sample1/locations/sample2/featurestores/sample3" + } + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = iam_policy_pb2.TestIamPermissionsResponse() + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.test_iam_permissions(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, iam_policy_pb2.TestIamPermissionsResponse) + + +def test_cancel_operation_rest_bad_request( + transport: str = "rest", request_type=operations_pb2.CancelOperationRequest +): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + {"name": "projects/sample1/locations/sample2/operations/sample3"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.cancel_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.CancelOperationRequest, + dict, + ], +) +def test_cancel_operation_rest(request_type): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = {"name": "projects/sample1/locations/sample2/operations/sample3"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "{}" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.cancel_operation(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_operation_rest_bad_request( + transport: str = "rest", request_type=operations_pb2.DeleteOperationRequest +): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + {"name": "projects/sample1/locations/sample2/operations/sample3"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.delete_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.DeleteOperationRequest, + dict, + ], +) +def test_delete_operation_rest(request_type): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = {"name": "projects/sample1/locations/sample2/operations/sample3"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "{}" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.delete_operation(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_get_operation_rest_bad_request( + transport: str = "rest", request_type=operations_pb2.GetOperationRequest +): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + {"name": "projects/sample1/locations/sample2/operations/sample3"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.GetOperationRequest, + dict, + ], +) +def test_get_operation_rest(request_type): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = {"name": "projects/sample1/locations/sample2/operations/sample3"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_list_operations_rest_bad_request( + transport: str = "rest", request_type=operations_pb2.ListOperationsRequest +): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + {"name": "projects/sample1/locations/sample2"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_operations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.ListOperationsRequest, + dict, + ], +) +def test_list_operations_rest(request_type): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = {"name": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse() + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_wait_operation_rest_bad_request( + transport: str = "rest", request_type=operations_pb2.WaitOperationRequest +): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + {"name": "projects/sample1/locations/sample2/operations/sample3"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.wait_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.WaitOperationRequest, + dict, + ], +) +def test_wait_operation_rest(request_type): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = {"name": "projects/sample1/locations/sample2/operations/sample3"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.wait_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_delete_operation(transport: str = "grpc"): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.DeleteOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_operation_async(transport: str = "grpc_asyncio"): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.DeleteOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_operation_field_headers(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.DeleteOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + call.return_value = None + + client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_operation_field_headers_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.DeleteOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_delete_operation_from_dict(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.delete_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_delete_operation_from_dict_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_cancel_operation(transport: str = "grpc"): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.CancelOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_operation_async(transport: str = "grpc_asyncio"): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.CancelOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_operation_field_headers(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.CancelOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + call.return_value = None + + client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_operation_field_headers_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.CancelOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.cancel_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_cancel_operation_from_dict(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_cancel_operation_from_dict_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.cancel_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.cancel_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_wait_operation(transport: str = "grpc"): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.WaitOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_wait_operation(transport: str = "grpc_asyncio"): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.WaitOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_wait_operation_field_headers(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.WaitOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_wait_operation_field_headers_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.WaitOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.wait_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_wait_operation_from_dict(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.wait_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_wait_operation_from_dict_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.wait_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.wait_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_get_operation(transport: str = "grpc"): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc_asyncio"): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc_asyncio"): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_locations(transport: str = "grpc"): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.ListLocationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.ListLocationsResponse() + response = client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.ListLocationsResponse) + + +@pytest.mark.asyncio +async def test_list_locations_async(transport: str = "grpc_asyncio"): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.ListLocationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.ListLocationsResponse() + ) + response = await client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.ListLocationsResponse) + + +def test_list_locations_field_headers(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.ListLocationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + call.return_value = locations_pb2.ListLocationsResponse() + + client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_locations_field_headers_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.ListLocationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.ListLocationsResponse() + ) + await client.list_locations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_locations_from_dict(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.ListLocationsResponse() + + response = client.list_locations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_locations_from_dict_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.ListLocationsResponse() + ) + response = await client.list_locations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_get_location(transport: str = "grpc"): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.GetLocationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.Location() + response = client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.Location) + + +@pytest.mark.asyncio +async def test_get_location_async(transport: str = "grpc_asyncio"): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = locations_pb2.GetLocationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.Location() + ) + response = await client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, locations_pb2.Location) + + +def test_get_location_field_headers(): + client = EvaluationServiceClient(credentials=ga_credentials.AnonymousCredentials()) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.GetLocationRequest() + request.name = "locations/abc" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + call.return_value = locations_pb2.Location() + + client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations/abc", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_location_field_headers_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials() + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = locations_pb2.GetLocationRequest() + request.name = "locations/abc" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_location), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.Location() + ) + await client.get_location(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations/abc", + ) in kw["metadata"] + + +def test_get_location_from_dict(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = locations_pb2.Location() + + response = client.get_location( + request={ + "name": "locations/abc", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_location_from_dict_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_locations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + locations_pb2.Location() + ) + response = await client.get_location( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_set_iam_policy(transport: str = "grpc"): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.SetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + response = client.set_iam_policy(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +@pytest.mark.asyncio +async def test_set_iam_policy_async(transport: str = "grpc_asyncio"): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.SetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + ) + response = await client.set_iam_policy(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +def test_set_iam_policy_field_headers(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.SetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + call.return_value = policy_pb2.Policy() + + client.set_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_set_iam_policy_field_headers_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.SetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + await client.set_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +def test_set_iam_policy_from_dict(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy() + + response = client.set_iam_policy( + request={ + "resource": "resource_value", + "policy": policy_pb2.Policy(version=774), + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_set_iam_policy_from_dict_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.set_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + response = await client.set_iam_policy( + request={ + "resource": "resource_value", + "policy": policy_pb2.Policy(version=774), + } + ) + call.assert_called() + + +def test_get_iam_policy(transport: str = "grpc"): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.GetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + + response = client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +@pytest.mark.asyncio +async def test_get_iam_policy_async(transport: str = "grpc_asyncio"): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.GetIamPolicyRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + ) + + response = await client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + + assert response.version == 774 + + assert response.etag == b"etag_blob" + + +def test_get_iam_policy_field_headers(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.GetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + call.return_value = policy_pb2.Policy() + + client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_iam_policy_field_headers_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.GetIamPolicyRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + await client.get_iam_policy(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +def test_get_iam_policy_from_dict(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = policy_pb2.Policy() + + response = client.get_iam_policy( + request={ + "resource": "resource_value", + "options": options_pb2.GetPolicyOptions(requested_policy_version=2598), + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_iam_policy_from_dict_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_iam_policy), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(policy_pb2.Policy()) + + response = await client.get_iam_policy( + request={ + "resource": "resource_value", + "options": options_pb2.GetPolicyOptions(requested_policy_version=2598), + } + ) + call.assert_called() + + +def test_test_iam_permissions(transport: str = "grpc"): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.TestIamPermissionsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iam_policy_pb2.TestIamPermissionsResponse( + permissions=["permissions_value"], + ) + + response = client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, iam_policy_pb2.TestIamPermissionsResponse) + + assert response.permissions == ["permissions_value"] + + +@pytest.mark.asyncio +async def test_test_iam_permissions_async(transport: str = "grpc_asyncio"): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = iam_policy_pb2.TestIamPermissionsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + iam_policy_pb2.TestIamPermissionsResponse( + permissions=["permissions_value"], + ) + ) + + response = await client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, iam_policy_pb2.TestIamPermissionsResponse) + + assert response.permissions == ["permissions_value"] + + +def test_test_iam_permissions_field_headers(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.TestIamPermissionsRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + call.return_value = iam_policy_pb2.TestIamPermissionsResponse() + + client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_test_iam_permissions_field_headers_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = iam_policy_pb2.TestIamPermissionsRequest() + request.resource = "resource/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + iam_policy_pb2.TestIamPermissionsResponse() + ) + + await client.test_iam_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "resource=resource/value", + ) in kw["metadata"] + + +def test_test_iam_permissions_from_dict(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iam_policy_pb2.TestIamPermissionsResponse() + + response = client.test_iam_permissions( + request={ + "resource": "resource_value", + "permissions": ["permissions_value"], + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_test_iam_permissions_from_dict_async(): + client = EvaluationServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.test_iam_permissions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + iam_policy_pb2.TestIamPermissionsResponse() + ) + + response = await client.test_iam_permissions( + request={ + "resource": "resource_value", + "permissions": ["permissions_value"], + } + ) + call.assert_called() + + +def test_transport_close(): + transports = { + "rest": "_session", + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "rest", + "grpc", + ] + for transport in transports: + client = EvaluationServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (EvaluationServiceClient, transports.EvaluationServiceGrpcTransport), + ( + EvaluationServiceAsyncClient, + transports.EvaluationServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/tests/unit/gapic/aiplatform_v1/test_feature_online_store_admin_service.py b/tests/unit/gapic/aiplatform_v1/test_feature_online_store_admin_service.py index bd4d3b52fa..fbb4779996 100644 --- a/tests/unit/gapic/aiplatform_v1/test_feature_online_store_admin_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_feature_online_store_admin_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.feature_online_store_admin_service import ( @@ -2556,12 +2557,18 @@ def test_list_feature_online_stores_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_feature_online_stores(request={}) + pager = client.list_feature_online_stores( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4807,12 +4814,16 @@ def test_list_feature_views_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_feature_views(request={}) + pager = client.list_feature_views(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -6994,12 +7005,16 @@ def test_list_feature_view_syncs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_feature_view_syncs(request={}) + pager = client.list_feature_view_syncs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_feature_online_store_service.py b/tests/unit/gapic/aiplatform_v1/test_feature_online_store_service.py index 6f1fc60387..346ca71070 100644 --- a/tests/unit/gapic/aiplatform_v1/test_feature_online_store_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_feature_online_store_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.feature_online_store_service import ( diff --git a/tests/unit/gapic/aiplatform_v1/test_feature_registry_service.py b/tests/unit/gapic/aiplatform_v1/test_feature_registry_service.py index dd961c5f03..b86125db31 100644 --- a/tests/unit/gapic/aiplatform_v1/test_feature_registry_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_feature_registry_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.feature_registry_service import ( @@ -2476,12 +2477,16 @@ def test_list_feature_groups_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_feature_groups(request={}) + pager = client.list_feature_groups(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4634,12 +4639,16 @@ def test_list_features_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_features(request={}) + pager = client.list_features(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py b/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py index b63ba8867e..812122ad0e 100644 --- a/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_featurestore_online_serving_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.featurestore_online_serving_service import ( diff --git a/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py b/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py index f03a0772de..e326c9e6cb 100644 --- a/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_featurestore_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.featurestore_service import ( @@ -2440,12 +2441,16 @@ def test_list_featurestores_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_featurestores(request={}) + pager = client.list_featurestores(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4608,12 +4613,16 @@ def test_list_entity_types_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_entity_types(request={}) + pager = client.list_entity_types(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -7155,12 +7164,16 @@ def test_list_features_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_features(request={}) + pager = client.list_features(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -10042,12 +10055,16 @@ def test_search_features_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("location", ""),)), ) - pager = client.search_features(request={}) + pager = client.search_features(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_gen_ai_tuning_service.py b/tests/unit/gapic/aiplatform_v1/test_gen_ai_tuning_service.py index e0e20bf3e6..04bfa104eb 100644 --- a/tests/unit/gapic/aiplatform_v1/test_gen_ai_tuning_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_gen_ai_tuning_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.gen_ai_tuning_service import ( @@ -2392,12 +2393,16 @@ def test_list_tuning_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_tuning_jobs(request={}) + pager = client.list_tuning_jobs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py index 80ad6f24aa..83c655a389 100644 --- a/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_index_endpoint_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.index_endpoint_service import ( @@ -64,6 +65,7 @@ from google.cloud.aiplatform_v1.types import index_endpoint_service from google.cloud.aiplatform_v1.types import machine_resources from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import reservation_affinity from google.cloud.aiplatform_v1.types import service_networking from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore @@ -2461,12 +2463,16 @@ def test_list_index_endpoints_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_index_endpoints(request={}) + pager = client.list_index_endpoints(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4633,12 +4639,18 @@ def test_create_index_endpoint_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "enable_access_logging": True, "deployed_index_auth_config": { @@ -5811,12 +5823,18 @@ def test_update_index_endpoint_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "enable_access_logging": True, "deployed_index_auth_config": { @@ -7229,12 +7247,18 @@ def test_mutate_deployed_index_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "enable_access_logging": True, "deployed_index_auth_config": { @@ -8302,8 +8326,36 @@ def test_parse_index_endpoint_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "squid" + zone = "clam" + reservation_name = "whelk" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = IndexEndpointServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "octopus", + "zone": "oyster", + "reservation_name": "nudibranch", + } + path = IndexEndpointServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = IndexEndpointServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "squid" + billing_account = "cuttlefish" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -8313,7 +8365,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "mussel", } path = IndexEndpointServiceClient.common_billing_account_path(**expected) @@ -8323,7 +8375,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "whelk" + folder = "winkle" expected = "folders/{folder}".format( folder=folder, ) @@ -8333,7 +8385,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "nautilus", } path = IndexEndpointServiceClient.common_folder_path(**expected) @@ -8343,7 +8395,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "oyster" + organization = "scallop" expected = "organizations/{organization}".format( organization=organization, ) @@ -8353,7 +8405,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "abalone", } path = IndexEndpointServiceClient.common_organization_path(**expected) @@ -8363,7 +8415,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "cuttlefish" + project = "squid" expected = "projects/{project}".format( project=project, ) @@ -8373,7 +8425,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "clam", } path = IndexEndpointServiceClient.common_project_path(**expected) @@ -8383,8 +8435,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "winkle" - location = "nautilus" + project = "whelk" + location = "octopus" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -8395,8 +8447,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "oyster", + "location": "nudibranch", } path = IndexEndpointServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_index_service.py b/tests/unit/gapic/aiplatform_v1/test_index_service.py index 595fb98c65..eb67bdcf24 100644 --- a/tests/unit/gapic/aiplatform_v1/test_index_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_index_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.index_service import IndexServiceAsyncClient @@ -2277,12 +2278,16 @@ def test_list_indexes_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_indexes(request={}) + pager = client.list_indexes(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py index a04938cb6c..e04279d057 100644 --- a/tests/unit/gapic/aiplatform_v1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.job_service import JobServiceAsyncClient @@ -85,6 +86,7 @@ from google.cloud.aiplatform_v1.types import nas_job from google.cloud.aiplatform_v1.types import nas_job as gca_nas_job from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import reservation_affinity from google.cloud.aiplatform_v1.types import study from google.cloud.aiplatform_v1.types import unmanaged_container_model from google.cloud.location import locations_pb2 @@ -2318,12 +2320,16 @@ def test_list_custom_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_custom_jobs(request={}) + pager = client.list_custom_jobs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4521,12 +4527,16 @@ def test_list_data_labeling_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_data_labeling_jobs(request={}) + pager = client.list_data_labeling_jobs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -6719,12 +6729,18 @@ def test_list_hyperparameter_tuning_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_hyperparameter_tuning_jobs(request={}) + pager = client.list_hyperparameter_tuning_jobs( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -8809,12 +8825,16 @@ def test_list_nas_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_nas_jobs(request={}) + pager = client.list_nas_jobs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -10493,12 +10513,16 @@ def test_list_nas_trial_details_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_nas_trial_details(request={}) + pager = client.list_nas_trial_details(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -11947,12 +11971,18 @@ def test_list_batch_prediction_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_batch_prediction_jobs(request={}) + pager = client.list_batch_prediction_jobs( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -13812,14 +13842,20 @@ def test_search_model_deployment_monitoring_stats_anomalies_pager( ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata( (("model_deployment_monitoring_job", ""),) ), ) - pager = client.search_model_deployment_monitoring_stats_anomalies(request={}) + pager = client.search_model_deployment_monitoring_stats_anomalies( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -14858,12 +14894,18 @@ def test_list_model_deployment_monitoring_jobs_pager(transport_name: str = "grpc ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_model_deployment_monitoring_jobs(request={}) + pager = client.list_model_deployment_monitoring_jobs( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -16600,6 +16642,11 @@ def test_create_custom_job_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "replica_count": 1384, "nfs_mounts": [ @@ -20194,6 +20241,11 @@ def test_create_hyperparameter_tuning_job_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "replica_count": 1384, "nfs_mounts": [ @@ -22076,6 +22128,11 @@ def test_create_nas_job_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "replica_count": 1384, "nfs_mounts": [ @@ -24570,6 +24627,11 @@ def test_create_batch_prediction_job_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "starting_replica_count": 2355, "max_replica_count": 1805, @@ -30665,10 +30727,38 @@ def test_parse_persistent_resource_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "squid" + zone = "clam" + reservation_name = "whelk" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = JobServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "octopus", + "zone": "oyster", + "reservation_name": "nudibranch", + } + path = JobServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_tensorboard_path(): - project = "squid" - location = "clam" - tensorboard = "whelk" + project = "cuttlefish" + location = "mussel" + tensorboard = "winkle" expected = ( "projects/{project}/locations/{location}/tensorboards/{tensorboard}".format( project=project, @@ -30682,9 +30772,9 @@ def test_tensorboard_path(): def test_parse_tensorboard_path(): expected = { - "project": "octopus", - "location": "oyster", - "tensorboard": "nudibranch", + "project": "nautilus", + "location": "scallop", + "tensorboard": "abalone", } path = JobServiceClient.tensorboard_path(**expected) @@ -30694,10 +30784,10 @@ def test_parse_tensorboard_path(): def test_trial_path(): - project = "cuttlefish" - location = "mussel" - study = "winkle" - trial = "nautilus" + project = "squid" + location = "clam" + study = "whelk" + trial = "octopus" expected = ( "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( project=project, @@ -30712,10 +30802,10 @@ def test_trial_path(): def test_parse_trial_path(): expected = { - "project": "scallop", - "location": "abalone", - "study": "squid", - "trial": "clam", + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", } path = JobServiceClient.trial_path(**expected) @@ -30725,7 +30815,7 @@ def test_parse_trial_path(): def test_common_billing_account_path(): - billing_account = "whelk" + billing_account = "winkle" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -30735,7 +30825,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "octopus", + "billing_account": "nautilus", } path = JobServiceClient.common_billing_account_path(**expected) @@ -30745,7 +30835,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "oyster" + folder = "scallop" expected = "folders/{folder}".format( folder=folder, ) @@ -30755,7 +30845,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "nudibranch", + "folder": "abalone", } path = JobServiceClient.common_folder_path(**expected) @@ -30765,7 +30855,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "cuttlefish" + organization = "squid" expected = "organizations/{organization}".format( organization=organization, ) @@ -30775,7 +30865,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "mussel", + "organization": "clam", } path = JobServiceClient.common_organization_path(**expected) @@ -30785,7 +30875,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "winkle" + project = "whelk" expected = "projects/{project}".format( project=project, ) @@ -30795,7 +30885,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "nautilus", + "project": "octopus", } path = JobServiceClient.common_project_path(**expected) @@ -30805,8 +30895,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "scallop" - location = "abalone" + project = "oyster" + location = "nudibranch" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -30817,8 +30907,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "squid", - "location": "clam", + "project": "cuttlefish", + "location": "mussel", } path = JobServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_llm_utility_service.py b/tests/unit/gapic/aiplatform_v1/test_llm_utility_service.py index af206d66b6..e432612e73 100644 --- a/tests/unit/gapic/aiplatform_v1/test_llm_utility_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_llm_utility_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.llm_utility_service import ( diff --git a/tests/unit/gapic/aiplatform_v1/test_match_service.py b/tests/unit/gapic/aiplatform_v1/test_match_service.py index f4acd34ac0..1bf49df1da 100644 --- a/tests/unit/gapic/aiplatform_v1/test_match_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_match_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.match_service import MatchServiceAsyncClient diff --git a/tests/unit/gapic/aiplatform_v1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1/test_metadata_service.py index 174a643e03..0a9dfe4638 100644 --- a/tests/unit/gapic/aiplatform_v1/test_metadata_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_metadata_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.metadata_service import ( @@ -2403,12 +2404,16 @@ def test_list_metadata_stores_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_metadata_stores(request={}) + pager = client.list_metadata_stores(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4173,12 +4178,16 @@ def test_list_artifacts_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_artifacts(request={}) + pager = client.list_artifacts(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -6672,12 +6681,16 @@ def test_list_contexts_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_contexts(request={}) + pager = client.list_contexts(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -10732,12 +10745,16 @@ def test_list_executions_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_executions(request={}) + pager = client.list_executions(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -14075,12 +14092,16 @@ def test_list_metadata_schemas_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_metadata_schemas(request={}) + pager = client.list_metadata_schemas(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index 4b3143f2c0..c94733990a 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.migration_service import ( @@ -1593,12 +1594,18 @@ def test_search_migratable_resources_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.search_migratable_resources(request={}) + pager = client.search_migratable_resources( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -3549,19 +3556,22 @@ def test_parse_dataset_path(): def test_dataset_path(): project = "squid" - dataset = "clam" - expected = "projects/{project}/datasets/{dataset}".format( + location = "clam" + dataset = "whelk" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", - "dataset": "octopus", + "project": "octopus", + "location": "oyster", + "dataset": "nudibranch", } path = MigrationServiceClient.dataset_path(**expected) @@ -3571,22 +3581,19 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "oyster" - location = "nudibranch" - dataset = "cuttlefish" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project = "cuttlefish" + dataset = "mussel" + expected = "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", + "project": "winkle", "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_model_garden_service.py b/tests/unit/gapic/aiplatform_v1/test_model_garden_service.py index 8475441c28..f51ac200dc 100644 --- a/tests/unit/gapic/aiplatform_v1/test_model_garden_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_model_garden_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.model_garden_service import ( @@ -2512,8 +2513,36 @@ def test_parse_publisher_model_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "oyster" + zone = "nudibranch" + reservation_name = "cuttlefish" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = ModelGardenServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "mussel", + "zone": "winkle", + "reservation_name": "nautilus", + } + path = ModelGardenServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = ModelGardenServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "oyster" + billing_account = "scallop" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -2523,7 +2552,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", + "billing_account": "abalone", } path = ModelGardenServiceClient.common_billing_account_path(**expected) @@ -2533,7 +2562,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "cuttlefish" + folder = "squid" expected = "folders/{folder}".format( folder=folder, ) @@ -2543,7 +2572,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "mussel", + "folder": "clam", } path = ModelGardenServiceClient.common_folder_path(**expected) @@ -2553,7 +2582,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "winkle" + organization = "whelk" expected = "organizations/{organization}".format( organization=organization, ) @@ -2563,7 +2592,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "nautilus", + "organization": "octopus", } path = ModelGardenServiceClient.common_organization_path(**expected) @@ -2573,7 +2602,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "scallop" + project = "oyster" expected = "projects/{project}".format( project=project, ) @@ -2583,7 +2612,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "abalone", + "project": "nudibranch", } path = ModelGardenServiceClient.common_project_path(**expected) @@ -2593,8 +2622,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "squid" - location = "clam" + project = "cuttlefish" + location = "mussel" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -2605,8 +2634,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", + "project": "winkle", + "location": "nautilus", } path = ModelGardenServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_model_service.py b/tests/unit/gapic/aiplatform_v1/test_model_service.py index 2909a6890e..d5ab2feeda 100644 --- a/tests/unit/gapic/aiplatform_v1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_model_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.model_service import ModelServiceAsyncClient @@ -2376,12 +2377,16 @@ def test_list_models_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_models(request={}) + pager = client.list_models(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -2960,12 +2965,16 @@ def test_list_model_versions_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("name", ""),)), ) - pager = client.list_model_versions(request={}) + pager = client.list_model_versions(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -8096,12 +8105,16 @@ def test_list_model_evaluations_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_model_evaluations(request={}) + pager = client.list_model_evaluations(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -9079,12 +9092,18 @@ def test_list_model_evaluation_slices_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_model_evaluation_slices(request={}) + pager = client.list_model_evaluation_slices( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_notebook_service.py b/tests/unit/gapic/aiplatform_v1/test_notebook_service.py index ef708931cf..b6c8d609e7 100644 --- a/tests/unit/gapic/aiplatform_v1/test_notebook_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_notebook_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.notebook_service import ( @@ -71,6 +72,7 @@ from google.cloud.aiplatform_v1.types import notebook_runtime_template_ref from google.cloud.aiplatform_v1.types import notebook_service from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import reservation_affinity from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore @@ -2456,12 +2458,18 @@ def test_list_notebook_runtime_templates_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_notebook_runtime_templates(request={}) + pager = client.list_notebook_runtime_templates( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4767,12 +4775,16 @@ def test_list_notebook_runtimes_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_notebook_runtimes(request={}) + pager = client.list_notebook_runtimes(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -7363,12 +7375,18 @@ def test_list_notebook_execution_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_notebook_execution_jobs(request={}) + pager = client.list_notebook_execution_jobs( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -7939,6 +7957,11 @@ def test_create_notebook_runtime_template_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "data_persistent_disk_spec": { "disk_type": "disk_type_value", @@ -9431,6 +9454,11 @@ def test_update_notebook_runtime_template_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "data_persistent_disk_spec": { "disk_type": "disk_type_value", @@ -14083,10 +14111,38 @@ def test_parse_notebook_runtime_template_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "scallop" + zone = "abalone" + reservation_name = "squid" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = NotebookServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "clam", + "zone": "whelk", + "reservation_name": "octopus", + } + path = NotebookServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = NotebookServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_schedule_path(): - project = "scallop" - location = "abalone" - schedule = "squid" + project = "oyster" + location = "nudibranch" + schedule = "cuttlefish" expected = "projects/{project}/locations/{location}/schedules/{schedule}".format( project=project, location=location, @@ -14098,9 +14154,9 @@ def test_schedule_path(): def test_parse_schedule_path(): expected = { - "project": "clam", - "location": "whelk", - "schedule": "octopus", + "project": "mussel", + "location": "winkle", + "schedule": "nautilus", } path = NotebookServiceClient.schedule_path(**expected) @@ -14110,9 +14166,9 @@ def test_parse_schedule_path(): def test_subnetwork_path(): - project = "oyster" - region = "nudibranch" - subnetwork = "cuttlefish" + project = "scallop" + region = "abalone" + subnetwork = "squid" expected = "projects/{project}/regions/{region}/subnetworks/{subnetwork}".format( project=project, region=region, @@ -14124,9 +14180,9 @@ def test_subnetwork_path(): def test_parse_subnetwork_path(): expected = { - "project": "mussel", - "region": "winkle", - "subnetwork": "nautilus", + "project": "clam", + "region": "whelk", + "subnetwork": "octopus", } path = NotebookServiceClient.subnetwork_path(**expected) @@ -14136,7 +14192,7 @@ def test_parse_subnetwork_path(): def test_common_billing_account_path(): - billing_account = "scallop" + billing_account = "oyster" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -14146,7 +14202,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "abalone", + "billing_account": "nudibranch", } path = NotebookServiceClient.common_billing_account_path(**expected) @@ -14156,7 +14212,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "squid" + folder = "cuttlefish" expected = "folders/{folder}".format( folder=folder, ) @@ -14166,7 +14222,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "clam", + "folder": "mussel", } path = NotebookServiceClient.common_folder_path(**expected) @@ -14176,7 +14232,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "whelk" + organization = "winkle" expected = "organizations/{organization}".format( organization=organization, ) @@ -14186,7 +14242,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "octopus", + "organization": "nautilus", } path = NotebookServiceClient.common_organization_path(**expected) @@ -14196,7 +14252,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "oyster" + project = "scallop" expected = "projects/{project}".format( project=project, ) @@ -14206,7 +14262,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "nudibranch", + "project": "abalone", } path = NotebookServiceClient.common_project_path(**expected) @@ -14216,8 +14272,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "cuttlefish" - location = "mussel" + project = "squid" + location = "clam" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -14228,8 +14284,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "winkle", - "location": "nautilus", + "project": "whelk", + "location": "octopus", } path = NotebookServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py b/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py index c1820d952a..436b8a9d84 100644 --- a/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.persistent_resource_service import ( @@ -66,6 +67,7 @@ persistent_resource as gca_persistent_resource, ) from google.cloud.aiplatform_v1.types import persistent_resource_service +from google.cloud.aiplatform_v1.types import reservation_affinity from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore @@ -2491,12 +2493,18 @@ def test_list_persistent_resources_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_persistent_resources(request={}) + pager = client.list_persistent_resources( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -3857,6 +3865,11 @@ def test_create_persistent_resource_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "replica_count": 1384, "disk_spec": { @@ -5371,6 +5384,11 @@ def test_update_persistent_resource_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "replica_count": 1384, "disk_spec": { @@ -6787,8 +6805,36 @@ def test_parse_persistent_resource_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "scallop" + zone = "abalone" + reservation_name = "squid" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = PersistentResourceServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "clam", + "zone": "whelk", + "reservation_name": "octopus", + } + path = PersistentResourceServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = PersistentResourceServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "scallop" + billing_account = "oyster" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -6800,7 +6846,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "abalone", + "billing_account": "nudibranch", } path = PersistentResourceServiceClient.common_billing_account_path(**expected) @@ -6810,7 +6856,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "squid" + folder = "cuttlefish" expected = "folders/{folder}".format( folder=folder, ) @@ -6820,7 +6866,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "clam", + "folder": "mussel", } path = PersistentResourceServiceClient.common_folder_path(**expected) @@ -6830,7 +6876,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "whelk" + organization = "winkle" expected = "organizations/{organization}".format( organization=organization, ) @@ -6840,7 +6886,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "octopus", + "organization": "nautilus", } path = PersistentResourceServiceClient.common_organization_path(**expected) @@ -6850,7 +6896,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "oyster" + project = "scallop" expected = "projects/{project}".format( project=project, ) @@ -6860,7 +6906,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "nudibranch", + "project": "abalone", } path = PersistentResourceServiceClient.common_project_path(**expected) @@ -6870,8 +6916,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "cuttlefish" - location = "mussel" + project = "squid" + location = "clam" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -6882,8 +6928,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "winkle", - "location": "nautilus", + "project": "whelk", + "location": "octopus", } path = PersistentResourceServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py index b4f387f2af..27dc2399b2 100644 --- a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.pipeline_service import ( @@ -2447,12 +2448,16 @@ def test_list_training_pipelines_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_training_pipelines(request={}) + pager = client.list_training_pipelines(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4653,12 +4658,16 @@ def test_list_pipeline_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_pipeline_jobs(request={}) + pager = client.list_pipeline_jobs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1/test_prediction_service.py index b11c9d47ef..3668ef84d6 100644 --- a/tests/unit/gapic/aiplatform_v1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_prediction_service.py @@ -44,6 +44,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.prediction_service import ( diff --git a/tests/unit/gapic/aiplatform_v1/test_schedule_service.py b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py index 63e7e759aa..6851be1dc7 100644 --- a/tests/unit/gapic/aiplatform_v1/test_schedule_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.schedule_service import ( @@ -2750,12 +2751,16 @@ def test_list_schedules_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_schedules(request={}) + pager = client.list_schedules(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py index 8e06d860c0..5dfb76821a 100644 --- a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.specialist_pool_service import ( @@ -2447,12 +2448,16 @@ def test_list_specialist_pools_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_specialist_pools(request={}) + pager = client.list_specialist_pools(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py b/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py index cded3dfb68..80d1769025 100644 --- a/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_tensorboard_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.tensorboard_service import ( @@ -2821,12 +2822,16 @@ def test_list_tensorboards_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_tensorboards(request={}) + pager = client.list_tensorboards(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -5836,12 +5841,18 @@ def test_list_tensorboard_experiments_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_tensorboard_experiments(request={}) + pager = client.list_tensorboard_experiments( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -8459,12 +8470,16 @@ def test_list_tensorboard_runs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_tensorboard_runs(request={}) + pager = client.list_tensorboard_runs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -11171,12 +11186,18 @@ def test_list_tensorboard_time_series_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_tensorboard_time_series(request={}) + pager = client.list_tensorboard_time_series( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -14168,14 +14189,20 @@ def test_export_tensorboard_time_series_data_pager(transport_name: str = "grpc") ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata( (("tensorboard_time_series", ""),) ), ) - pager = client.export_tensorboard_time_series_data(request={}) + pager = client.export_tensorboard_time_series_data( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1/test_vizier_service.py b/tests/unit/gapic/aiplatform_v1/test_vizier_service.py index aea6ad63d5..73c5a592df 100644 --- a/tests/unit/gapic/aiplatform_v1/test_vizier_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_vizier_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1.services.vizier_service import VizierServiceAsyncClient @@ -2292,12 +2293,16 @@ def test_list_studies_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_studies(request={}) + pager = client.list_studies(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4635,12 +4640,16 @@ def test_list_trials_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_trials(request={}) + pager = client.list_trials(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py index bd67758ad4..8348b38eff 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.dataset_service import ( @@ -2745,12 +2746,16 @@ def test_list_datasets_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_datasets(request={}) + pager = client.list_datasets(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -6106,12 +6111,16 @@ def test_list_dataset_versions_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_dataset_versions(request={}) + pager = client.list_dataset_versions(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -7060,12 +7069,16 @@ def test_list_data_items_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_data_items(request={}) + pager = client.list_data_items(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -7564,12 +7577,16 @@ def test_search_data_items_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("dataset", ""),)), ) - pager = client.search_data_items(request={}) + pager = client.search_data_items(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -8155,12 +8172,16 @@ def test_list_saved_queries_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_saved_queries(request={}) + pager = client.list_saved_queries(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -9507,12 +9528,16 @@ def test_list_annotations_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_annotations(request={}) + pager = client.list_annotations(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_deployment_resource_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_deployment_resource_pool_service.py index e0ceb095a0..02e389f607 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_deployment_resource_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_deployment_resource_pool_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.deployment_resource_pool_service import ( @@ -71,6 +72,7 @@ from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import reservation_affinity from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore @@ -2534,12 +2536,18 @@ def test_list_deployment_resource_pools_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_deployment_resource_pools(request={}) + pager = client.list_deployment_resource_pools( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -3957,14 +3965,18 @@ def test_query_deployed_models_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata( (("deployment_resource_pool", ""),) ), ) - pager = client.query_deployed_models(request={}) + pager = client.query_deployed_models(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -5226,12 +5238,18 @@ def test_update_deployment_resource_pool_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "service_account": "service_account_value", @@ -7066,8 +7084,36 @@ def test_parse_model_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "cuttlefish" + zone = "mussel" + reservation_name = "winkle" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = DeploymentResourcePoolServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "nautilus", + "zone": "scallop", + "reservation_name": "abalone", + } + path = DeploymentResourcePoolServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = DeploymentResourcePoolServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "cuttlefish" + billing_account = "squid" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -7079,7 +7125,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", + "billing_account": "clam", } path = DeploymentResourcePoolServiceClient.common_billing_account_path(**expected) @@ -7089,7 +7135,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "winkle" + folder = "whelk" expected = "folders/{folder}".format( folder=folder, ) @@ -7099,7 +7145,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "nautilus", + "folder": "octopus", } path = DeploymentResourcePoolServiceClient.common_folder_path(**expected) @@ -7109,7 +7155,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "scallop" + organization = "oyster" expected = "organizations/{organization}".format( organization=organization, ) @@ -7119,7 +7165,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "abalone", + "organization": "nudibranch", } path = DeploymentResourcePoolServiceClient.common_organization_path(**expected) @@ -7129,7 +7175,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "squid" + project = "cuttlefish" expected = "projects/{project}".format( project=project, ) @@ -7139,7 +7185,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "clam", + "project": "mussel", } path = DeploymentResourcePoolServiceClient.common_project_path(**expected) @@ -7149,8 +7195,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "whelk" - location = "octopus" + project = "winkle" + location = "nautilus" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -7161,8 +7207,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", + "project": "scallop", + "location": "abalone", } path = DeploymentResourcePoolServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index cfd91e7e2b..cfa9a14533 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( @@ -67,6 +68,7 @@ from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import reservation_affinity from google.cloud.aiplatform_v1beta1.types import service_networking from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore @@ -2378,12 +2380,16 @@ def test_list_endpoints_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_endpoints(request={}) + pager = client.list_endpoints(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4597,12 +4603,18 @@ def test_create_endpoint_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "automatic_resources": { "min_replica_count": 1803, @@ -5800,12 +5812,18 @@ def test_update_endpoint_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "automatic_resources": { "min_replica_count": 1803, @@ -8324,8 +8342,36 @@ def test_parse_network_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "oyster" + zone = "nudibranch" + reservation_name = "cuttlefish" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = EndpointServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "mussel", + "zone": "winkle", + "reservation_name": "nautilus", + } + path = EndpointServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "oyster" + billing_account = "scallop" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -8335,7 +8381,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", + "billing_account": "abalone", } path = EndpointServiceClient.common_billing_account_path(**expected) @@ -8345,7 +8391,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "cuttlefish" + folder = "squid" expected = "folders/{folder}".format( folder=folder, ) @@ -8355,7 +8401,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "mussel", + "folder": "clam", } path = EndpointServiceClient.common_folder_path(**expected) @@ -8365,7 +8411,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "winkle" + organization = "whelk" expected = "organizations/{organization}".format( organization=organization, ) @@ -8375,7 +8421,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "nautilus", + "organization": "octopus", } path = EndpointServiceClient.common_organization_path(**expected) @@ -8385,7 +8431,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "scallop" + project = "oyster" expected = "projects/{project}".format( project=project, ) @@ -8395,7 +8441,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "abalone", + "project": "nudibranch", } path = EndpointServiceClient.common_project_path(**expected) @@ -8405,8 +8451,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "squid" - location = "clam" + project = "cuttlefish" + location = "mussel" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -8417,8 +8463,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", + "project": "winkle", + "location": "nautilus", } path = EndpointServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_evaluation_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_evaluation_service.py index ccd697fbff..3eb0abf96a 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_evaluation_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_evaluation_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.evaluation_service import ( diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_extension_execution_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_extension_execution_service.py index 5715e32b44..138f81a131 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_extension_execution_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_extension_execution_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.extension_execution_service import ( diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py index 15ba4b6825..d6534e2abc 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.extension_registry_service import ( @@ -2383,12 +2384,16 @@ def test_list_extensions_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_extensions(request={}) + pager = client.list_extensions(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py index 3826036d0c..858f3c04da 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.feature_online_store_admin_service import ( @@ -2556,12 +2557,18 @@ def test_list_feature_online_stores_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_feature_online_stores(request={}) + pager = client.list_feature_online_stores( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4823,12 +4830,16 @@ def test_list_feature_views_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_feature_views(request={}) + pager = client.list_feature_views(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -7010,12 +7021,16 @@ def test_list_feature_view_syncs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_feature_view_syncs(request={}) + pager = client.list_feature_view_syncs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_service.py index 72df206de0..e7e955bc38 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.feature_online_store_service import ( diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_feature_registry_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_feature_registry_service.py index 9bb1af8e00..cdfaf654cc 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_feature_registry_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_feature_registry_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.feature_registry_service import ( @@ -2478,12 +2479,16 @@ def test_list_feature_groups_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_feature_groups(request={}) + pager = client.list_feature_groups(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4636,12 +4641,16 @@ def test_list_features_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_features(request={}) + pager = client.list_features(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py index 5e7e538f44..32e71952cb 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import ( diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py index 38eadb21b0..1b5f3d4033 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.featurestore_service import ( @@ -2441,12 +2442,16 @@ def test_list_featurestores_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_featurestores(request={}) + pager = client.list_featurestores(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4609,12 +4614,16 @@ def test_list_entity_types_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_entity_types(request={}) + pager = client.list_entity_types(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -7156,12 +7165,16 @@ def test_list_features_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_features(request={}) + pager = client.list_features(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -10043,12 +10056,16 @@ def test_search_features_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("location", ""),)), ) - pager = client.search_features(request={}) + pager = client.search_features(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py index c8d5a71a86..c816015140 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.gen_ai_cache_service import ( @@ -3209,12 +3210,16 @@ def test_list_cached_contents_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_cached_contents(request={}) + pager = client.list_cached_contents(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_tuning_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_tuning_service.py index 342935faf4..3eb8d2fabf 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_tuning_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_tuning_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.gen_ai_tuning_service import ( @@ -2402,12 +2403,16 @@ def test_list_tuning_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_tuning_jobs(request={}) + pager = client.list_tuning_jobs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py index fd59b09017..07df1083a9 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import ( @@ -64,6 +65,7 @@ from google.cloud.aiplatform_v1beta1.types import index_endpoint_service from google.cloud.aiplatform_v1beta1.types import machine_resources from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import reservation_affinity from google.cloud.aiplatform_v1beta1.types import service_networking from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore @@ -2461,12 +2463,16 @@ def test_list_index_endpoints_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_index_endpoints(request={}) + pager = client.list_index_endpoints(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4633,12 +4639,18 @@ def test_create_index_endpoint_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "enable_access_logging": True, "deployed_index_auth_config": { @@ -5811,12 +5823,18 @@ def test_update_index_endpoint_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "enable_access_logging": True, "deployed_index_auth_config": { @@ -7229,12 +7247,18 @@ def test_mutate_deployed_index_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "min_replica_count": 1803, "max_replica_count": 1805, "autoscaling_metric_specs": [ {"metric_name": "metric_name_value", "target": 647} ], + "spot": True, }, "enable_access_logging": True, "deployed_index_auth_config": { @@ -8302,8 +8326,36 @@ def test_parse_index_endpoint_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "squid" + zone = "clam" + reservation_name = "whelk" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = IndexEndpointServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "octopus", + "zone": "oyster", + "reservation_name": "nudibranch", + } + path = IndexEndpointServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = IndexEndpointServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "squid" + billing_account = "cuttlefish" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -8313,7 +8365,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "mussel", } path = IndexEndpointServiceClient.common_billing_account_path(**expected) @@ -8323,7 +8375,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "whelk" + folder = "winkle" expected = "folders/{folder}".format( folder=folder, ) @@ -8333,7 +8385,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "nautilus", } path = IndexEndpointServiceClient.common_folder_path(**expected) @@ -8343,7 +8395,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "oyster" + organization = "scallop" expected = "organizations/{organization}".format( organization=organization, ) @@ -8353,7 +8405,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "abalone", } path = IndexEndpointServiceClient.common_organization_path(**expected) @@ -8363,7 +8415,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "cuttlefish" + project = "squid" expected = "projects/{project}".format( project=project, ) @@ -8373,7 +8425,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "clam", } path = IndexEndpointServiceClient.common_project_path(**expected) @@ -8383,8 +8435,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "winkle" - location = "nautilus" + project = "whelk" + location = "octopus" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -8395,8 +8447,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "oyster", + "location": "nudibranch", } path = IndexEndpointServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py index 74a1f9afd3..4a13ab4095 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.index_service import ( @@ -2279,12 +2280,16 @@ def test_list_indexes_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_indexes(request={}) + pager = client.list_indexes(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index aae3afafa4..8045604c54 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceAsyncClient @@ -88,6 +89,7 @@ from google.cloud.aiplatform_v1beta1.types import nas_job from google.cloud.aiplatform_v1beta1.types import nas_job as gca_nas_job from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import reservation_affinity from google.cloud.aiplatform_v1beta1.types import study from google.cloud.aiplatform_v1beta1.types import unmanaged_container_model from google.cloud.location import locations_pb2 @@ -2321,12 +2323,16 @@ def test_list_custom_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_custom_jobs(request={}) + pager = client.list_custom_jobs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4524,12 +4530,16 @@ def test_list_data_labeling_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_data_labeling_jobs(request={}) + pager = client.list_data_labeling_jobs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -6722,12 +6732,18 @@ def test_list_hyperparameter_tuning_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_hyperparameter_tuning_jobs(request={}) + pager = client.list_hyperparameter_tuning_jobs( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -8812,12 +8828,16 @@ def test_list_nas_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_nas_jobs(request={}) + pager = client.list_nas_jobs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -10496,12 +10516,16 @@ def test_list_nas_trial_details_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_nas_trial_details(request={}) + pager = client.list_nas_trial_details(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -11970,12 +11994,18 @@ def test_list_batch_prediction_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_batch_prediction_jobs(request={}) + pager = client.list_batch_prediction_jobs( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -13835,14 +13865,20 @@ def test_search_model_deployment_monitoring_stats_anomalies_pager( ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata( (("model_deployment_monitoring_job", ""),) ), ) - pager = client.search_model_deployment_monitoring_stats_anomalies(request={}) + pager = client.search_model_deployment_monitoring_stats_anomalies( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -14881,12 +14917,18 @@ def test_list_model_deployment_monitoring_jobs_pager(transport_name: str = "grpc ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_model_deployment_monitoring_jobs(request={}) + pager = client.list_model_deployment_monitoring_jobs( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -16623,6 +16665,11 @@ def test_create_custom_job_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "replica_count": 1384, "nfs_mounts": [ @@ -20235,6 +20282,11 @@ def test_create_hyperparameter_tuning_job_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "replica_count": 1384, "nfs_mounts": [ @@ -22117,6 +22169,11 @@ def test_create_nas_job_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "replica_count": 1384, "nfs_mounts": [ @@ -24615,6 +24672,11 @@ def test_create_batch_prediction_job_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "starting_replica_count": 2355, "max_replica_count": 1805, @@ -30791,10 +30853,38 @@ def test_parse_persistent_resource_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "squid" + zone = "clam" + reservation_name = "whelk" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = JobServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "octopus", + "zone": "oyster", + "reservation_name": "nudibranch", + } + path = JobServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_tensorboard_path(): - project = "squid" - location = "clam" - tensorboard = "whelk" + project = "cuttlefish" + location = "mussel" + tensorboard = "winkle" expected = ( "projects/{project}/locations/{location}/tensorboards/{tensorboard}".format( project=project, @@ -30808,9 +30898,9 @@ def test_tensorboard_path(): def test_parse_tensorboard_path(): expected = { - "project": "octopus", - "location": "oyster", - "tensorboard": "nudibranch", + "project": "nautilus", + "location": "scallop", + "tensorboard": "abalone", } path = JobServiceClient.tensorboard_path(**expected) @@ -30820,10 +30910,10 @@ def test_parse_tensorboard_path(): def test_trial_path(): - project = "cuttlefish" - location = "mussel" - study = "winkle" - trial = "nautilus" + project = "squid" + location = "clam" + study = "whelk" + trial = "octopus" expected = ( "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( project=project, @@ -30838,10 +30928,10 @@ def test_trial_path(): def test_parse_trial_path(): expected = { - "project": "scallop", - "location": "abalone", - "study": "squid", - "trial": "clam", + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", } path = JobServiceClient.trial_path(**expected) @@ -30851,7 +30941,7 @@ def test_parse_trial_path(): def test_common_billing_account_path(): - billing_account = "whelk" + billing_account = "winkle" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -30861,7 +30951,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "octopus", + "billing_account": "nautilus", } path = JobServiceClient.common_billing_account_path(**expected) @@ -30871,7 +30961,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "oyster" + folder = "scallop" expected = "folders/{folder}".format( folder=folder, ) @@ -30881,7 +30971,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "nudibranch", + "folder": "abalone", } path = JobServiceClient.common_folder_path(**expected) @@ -30891,7 +30981,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "cuttlefish" + organization = "squid" expected = "organizations/{organization}".format( organization=organization, ) @@ -30901,7 +30991,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "mussel", + "organization": "clam", } path = JobServiceClient.common_organization_path(**expected) @@ -30911,7 +31001,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "winkle" + project = "whelk" expected = "projects/{project}".format( project=project, ) @@ -30921,7 +31011,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "nautilus", + "project": "octopus", } path = JobServiceClient.common_project_path(**expected) @@ -30931,8 +31021,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "scallop" - location = "abalone" + project = "oyster" + location = "nudibranch" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -30943,8 +31033,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "squid", - "location": "clam", + "project": "cuttlefish", + "location": "mussel", } path = JobServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_llm_utility_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_llm_utility_service.py index 79d565c295..fc384c5da3 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_llm_utility_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_llm_utility_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.llm_utility_service import ( diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_match_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_match_service.py index 4dfe0eccb3..7dfb85c105 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_match_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_match_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.match_service import ( diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py index bfeb2fdfba..1ff5e6a6b7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.metadata_service import ( @@ -2405,12 +2406,16 @@ def test_list_metadata_stores_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_metadata_stores(request={}) + pager = client.list_metadata_stores(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4175,12 +4180,16 @@ def test_list_artifacts_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_artifacts(request={}) + pager = client.list_artifacts(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -6674,12 +6683,16 @@ def test_list_contexts_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_contexts(request={}) + pager = client.list_contexts(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -10734,12 +10747,16 @@ def test_list_executions_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_executions(request={}) + pager = client.list_executions(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -14077,12 +14094,16 @@ def test_list_metadata_schemas_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_metadata_schemas(request={}) + pager = client.list_metadata_schemas(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index e651d6ef7b..6f321cc1dc 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.migration_service import ( @@ -1595,12 +1596,18 @@ def test_search_migratable_resources_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.search_migratable_resources(request={}) + pager = client.search_migratable_resources( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -3525,19 +3532,22 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - dataset = "mussel" - expected = "projects/{project}/datasets/{dataset}".format( + location = "mussel" + dataset = "winkle" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", - "dataset": "nautilus", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = MigrationServiceClient.dataset_path(**expected) @@ -3547,22 +3557,19 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "scallop" - location = "abalone" - dataset = "squid" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project = "squid" + dataset = "clam" + expected = "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "clam", - "location": "whelk", + "project": "whelk", "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_garden_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_garden_service.py index 7fb6cf5dbd..28627b23e7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_garden_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_garden_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.model_garden_service import ( @@ -2043,12 +2044,16 @@ def test_list_publisher_models_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_publisher_models(request={}) + pager = client.list_publisher_models(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -3503,8 +3508,36 @@ def test_parse_publisher_model_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "oyster" + zone = "nudibranch" + reservation_name = "cuttlefish" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = ModelGardenServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "mussel", + "zone": "winkle", + "reservation_name": "nautilus", + } + path = ModelGardenServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = ModelGardenServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "oyster" + billing_account = "scallop" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -3514,7 +3547,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", + "billing_account": "abalone", } path = ModelGardenServiceClient.common_billing_account_path(**expected) @@ -3524,7 +3557,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "cuttlefish" + folder = "squid" expected = "folders/{folder}".format( folder=folder, ) @@ -3534,7 +3567,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "mussel", + "folder": "clam", } path = ModelGardenServiceClient.common_folder_path(**expected) @@ -3544,7 +3577,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "winkle" + organization = "whelk" expected = "organizations/{organization}".format( organization=organization, ) @@ -3554,7 +3587,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "nautilus", + "organization": "octopus", } path = ModelGardenServiceClient.common_organization_path(**expected) @@ -3564,7 +3597,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "scallop" + project = "oyster" expected = "projects/{project}".format( project=project, ) @@ -3574,7 +3607,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "abalone", + "project": "nudibranch", } path = ModelGardenServiceClient.common_project_path(**expected) @@ -3584,8 +3617,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "squid" - location = "clam" + project = "cuttlefish" + location = "mussel" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -3596,8 +3629,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", + "project": "winkle", + "location": "nautilus", } path = ModelGardenServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_monitoring_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_monitoring_service.py index 32ae56caaf..dc6d1c16a9 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_monitoring_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_monitoring_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.model_monitoring_service import ( @@ -74,6 +75,7 @@ from google.cloud.aiplatform_v1beta1.types import model_monitoring_spec from google.cloud.aiplatform_v1beta1.types import model_monitoring_stats from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import reservation_affinity from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore @@ -2922,12 +2924,16 @@ def test_list_model_monitors_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_model_monitors(request={}) + pager = client.list_model_monitors(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4725,12 +4731,18 @@ def test_list_model_monitoring_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_model_monitoring_jobs(request={}) + pager = client.list_model_monitoring_jobs( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -5711,12 +5723,18 @@ def test_search_model_monitoring_stats_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("model_monitor", ""),)), ) - pager = client.search_model_monitoring_stats(request={}) + pager = client.search_model_monitoring_stats( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -6323,12 +6341,18 @@ def test_search_model_monitoring_alerts_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("model_monitor", ""),)), ) - pager = client.search_model_monitoring_alerts(request={}) + pager = client.search_model_monitoring_alerts( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -6522,6 +6546,11 @@ def test_create_model_monitor_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "starting_replica_count": 2355, "max_replica_count": 1805, @@ -7051,6 +7080,11 @@ def test_update_model_monitor_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "starting_replica_count": 2355, "max_replica_count": 1805, @@ -8600,6 +8634,11 @@ def test_create_model_monitoring_job_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "starting_replica_count": 2355, "max_replica_count": 1805, @@ -11745,10 +11784,38 @@ def test_parse_model_monitoring_job_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "whelk" + zone = "octopus" + reservation_name = "oyster" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = ModelMonitoringServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "nudibranch", + "zone": "cuttlefish", + "reservation_name": "mussel", + } + path = ModelMonitoringServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = ModelMonitoringServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_schedule_path(): - project = "whelk" - location = "octopus" - schedule = "oyster" + project = "winkle" + location = "nautilus" + schedule = "scallop" expected = "projects/{project}/locations/{location}/schedules/{schedule}".format( project=project, location=location, @@ -11760,9 +11827,9 @@ def test_schedule_path(): def test_parse_schedule_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "schedule": "mussel", + "project": "abalone", + "location": "squid", + "schedule": "clam", } path = ModelMonitoringServiceClient.schedule_path(**expected) @@ -11772,7 +11839,7 @@ def test_parse_schedule_path(): def test_common_billing_account_path(): - billing_account = "winkle" + billing_account = "whelk" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -11782,7 +11849,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", + "billing_account": "octopus", } path = ModelMonitoringServiceClient.common_billing_account_path(**expected) @@ -11792,7 +11859,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "scallop" + folder = "oyster" expected = "folders/{folder}".format( folder=folder, ) @@ -11802,7 +11869,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "abalone", + "folder": "nudibranch", } path = ModelMonitoringServiceClient.common_folder_path(**expected) @@ -11812,7 +11879,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "squid" + organization = "cuttlefish" expected = "organizations/{organization}".format( organization=organization, ) @@ -11822,7 +11889,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "clam", + "organization": "mussel", } path = ModelMonitoringServiceClient.common_organization_path(**expected) @@ -11832,7 +11899,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "whelk" + project = "winkle" expected = "projects/{project}".format( project=project, ) @@ -11842,7 +11909,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "octopus", + "project": "nautilus", } path = ModelMonitoringServiceClient.common_project_path(**expected) @@ -11852,8 +11919,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "oyster" - location = "nudibranch" + project = "scallop" + location = "abalone" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -11864,8 +11931,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", + "project": "squid", + "location": "clam", } path = ModelMonitoringServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py index 1fa64b03d5..1df7b630bc 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.model_service import ( @@ -2373,12 +2374,16 @@ def test_list_models_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_models(request={}) + pager = client.list_models(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -2957,12 +2962,16 @@ def test_list_model_versions_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("name", ""),)), ) - pager = client.list_model_versions(request={}) + pager = client.list_model_versions(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -8063,12 +8072,16 @@ def test_list_model_evaluations_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_model_evaluations(request={}) + pager = client.list_model_evaluations(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -9046,12 +9059,18 @@ def test_list_model_evaluation_slices_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_model_evaluation_slices(request={}) + pager = client.list_model_evaluation_slices( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_notebook_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_notebook_service.py index e59ed840ef..96a1d07ac9 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_notebook_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_notebook_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.notebook_service import ( @@ -75,6 +76,7 @@ from google.cloud.aiplatform_v1beta1.types import notebook_runtime_template_ref from google.cloud.aiplatform_v1beta1.types import notebook_service from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import reservation_affinity from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore @@ -2460,12 +2462,18 @@ def test_list_notebook_runtime_templates_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_notebook_runtime_templates(request={}) + pager = client.list_notebook_runtime_templates( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4771,12 +4779,16 @@ def test_list_notebook_runtimes_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_notebook_runtimes(request={}) + pager = client.list_notebook_runtimes(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -7367,12 +7379,18 @@ def test_list_notebook_execution_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_notebook_execution_jobs(request={}) + pager = client.list_notebook_execution_jobs( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -7943,6 +7961,11 @@ def test_create_notebook_runtime_template_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "data_persistent_disk_spec": { "disk_type": "disk_type_value", @@ -9435,6 +9458,11 @@ def test_update_notebook_runtime_template_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "data_persistent_disk_spec": { "disk_type": "disk_type_value", @@ -14087,10 +14115,38 @@ def test_parse_notebook_runtime_template_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "scallop" + zone = "abalone" + reservation_name = "squid" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = NotebookServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "clam", + "zone": "whelk", + "reservation_name": "octopus", + } + path = NotebookServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = NotebookServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_schedule_path(): - project = "scallop" - location = "abalone" - schedule = "squid" + project = "oyster" + location = "nudibranch" + schedule = "cuttlefish" expected = "projects/{project}/locations/{location}/schedules/{schedule}".format( project=project, location=location, @@ -14102,9 +14158,9 @@ def test_schedule_path(): def test_parse_schedule_path(): expected = { - "project": "clam", - "location": "whelk", - "schedule": "octopus", + "project": "mussel", + "location": "winkle", + "schedule": "nautilus", } path = NotebookServiceClient.schedule_path(**expected) @@ -14114,9 +14170,9 @@ def test_parse_schedule_path(): def test_subnetwork_path(): - project = "oyster" - region = "nudibranch" - subnetwork = "cuttlefish" + project = "scallop" + region = "abalone" + subnetwork = "squid" expected = "projects/{project}/regions/{region}/subnetworks/{subnetwork}".format( project=project, region=region, @@ -14128,9 +14184,9 @@ def test_subnetwork_path(): def test_parse_subnetwork_path(): expected = { - "project": "mussel", - "region": "winkle", - "subnetwork": "nautilus", + "project": "clam", + "region": "whelk", + "subnetwork": "octopus", } path = NotebookServiceClient.subnetwork_path(**expected) @@ -14140,7 +14196,7 @@ def test_parse_subnetwork_path(): def test_common_billing_account_path(): - billing_account = "scallop" + billing_account = "oyster" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -14150,7 +14206,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "abalone", + "billing_account": "nudibranch", } path = NotebookServiceClient.common_billing_account_path(**expected) @@ -14160,7 +14216,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "squid" + folder = "cuttlefish" expected = "folders/{folder}".format( folder=folder, ) @@ -14170,7 +14226,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "clam", + "folder": "mussel", } path = NotebookServiceClient.common_folder_path(**expected) @@ -14180,7 +14236,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "whelk" + organization = "winkle" expected = "organizations/{organization}".format( organization=organization, ) @@ -14190,7 +14246,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "octopus", + "organization": "nautilus", } path = NotebookServiceClient.common_organization_path(**expected) @@ -14200,7 +14256,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "oyster" + project = "scallop" expected = "projects/{project}".format( project=project, ) @@ -14210,7 +14266,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "nudibranch", + "project": "abalone", } path = NotebookServiceClient.common_project_path(**expected) @@ -14220,8 +14276,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "cuttlefish" - location = "mussel" + project = "squid" + location = "clam" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -14232,8 +14288,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "winkle", - "location": "nautilus", + "project": "whelk", + "location": "octopus", } path = NotebookServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py index 1bb0456b55..4c22c3eb1d 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import ( @@ -68,6 +69,7 @@ persistent_resource as gca_persistent_resource, ) from google.cloud.aiplatform_v1beta1.types import persistent_resource_service +from google.cloud.aiplatform_v1beta1.types import reservation_affinity from google.cloud.aiplatform_v1beta1.types import service_networking from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore @@ -2494,12 +2496,18 @@ def test_list_persistent_resources_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_persistent_resources(request={}) + pager = client.list_persistent_resources( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -3860,6 +3868,11 @@ def test_create_persistent_resource_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "replica_count": 1384, "disk_spec": { @@ -5378,6 +5391,11 @@ def test_update_persistent_resource_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": ["values_value1", "values_value2"], + }, }, "replica_count": 1384, "disk_spec": { @@ -6854,8 +6872,36 @@ def test_parse_persistent_resource_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "scallop" + zone = "abalone" + reservation_name = "squid" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = PersistentResourceServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "clam", + "zone": "whelk", + "reservation_name": "octopus", + } + path = PersistentResourceServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = PersistentResourceServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "scallop" + billing_account = "oyster" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -6867,7 +6913,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "abalone", + "billing_account": "nudibranch", } path = PersistentResourceServiceClient.common_billing_account_path(**expected) @@ -6877,7 +6923,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "squid" + folder = "cuttlefish" expected = "folders/{folder}".format( folder=folder, ) @@ -6887,7 +6933,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "clam", + "folder": "mussel", } path = PersistentResourceServiceClient.common_folder_path(**expected) @@ -6897,7 +6943,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "whelk" + organization = "winkle" expected = "organizations/{organization}".format( organization=organization, ) @@ -6907,7 +6953,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "octopus", + "organization": "nautilus", } path = PersistentResourceServiceClient.common_organization_path(**expected) @@ -6917,7 +6963,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "oyster" + project = "scallop" expected = "projects/{project}".format( project=project, ) @@ -6927,7 +6973,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "nudibranch", + "project": "abalone", } path = PersistentResourceServiceClient.common_project_path(**expected) @@ -6937,8 +6983,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "cuttlefish" - location = "mussel" + project = "squid" + location = "clam" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -6949,8 +6995,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "winkle", - "location": "nautilus", + "project": "whelk", + "location": "octopus", } path = PersistentResourceServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index 1efabf73c9..5038c08314 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( @@ -2451,12 +2452,16 @@ def test_list_training_pipelines_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_training_pipelines(request={}) + pager = client.list_training_pipelines(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4677,12 +4682,16 @@ def test_list_pipeline_jobs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_pipeline_jobs(request={}) + pager = client.list_pipeline_jobs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py index 2b20e07918..509e0cb4e1 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -44,6 +44,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.prediction_service import ( diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_execution_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_execution_service.py index c7265fd470..697cf2c24f 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_execution_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_execution_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.reasoning_engine_execution_service import ( diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_service.py index 0e904a65eb..1bc1019fce 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_reasoning_engine_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.reasoning_engine_service import ( @@ -2445,12 +2446,16 @@ def test_list_reasoning_engines_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_reasoning_engines(request={}) + pager = client.list_reasoning_engines(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py index a3089eb8f7..ed2428de38 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.schedule_service import ( @@ -78,6 +79,7 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_job from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state +from google.cloud.aiplatform_v1beta1.types import reservation_affinity from google.cloud.aiplatform_v1beta1.types import schedule from google.cloud.aiplatform_v1beta1.types import schedule as gca_schedule from google.cloud.aiplatform_v1beta1.types import schedule_service @@ -2766,12 +2768,16 @@ def test_list_schedules_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_schedules(request={}) + pager = client.list_schedules(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4191,6 +4197,14 @@ def test_create_schedule_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": [ + "values_value1", + "values_value2", + ], + }, }, "starting_replica_count": 2355, "max_replica_count": 1805, @@ -6473,6 +6487,14 @@ def test_update_schedule_rest(request_type): "accelerator_type": 1, "accelerator_count": 1805, "tpu_topology": "tpu_topology_value", + "reservation_affinity": { + "reservation_affinity_type": 1, + "key": "key_value", + "values": [ + "values_value1", + "values_value2", + ], + }, }, "starting_replica_count": 2355, "max_replica_count": 1805, @@ -8001,10 +8023,38 @@ def test_parse_pipeline_job_path(): assert expected == actual +def test_reservation_path(): + project_id_or_number = "squid" + zone = "clam" + reservation_name = "whelk" + expected = "projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}".format( + project_id_or_number=project_id_or_number, + zone=zone, + reservation_name=reservation_name, + ) + actual = ScheduleServiceClient.reservation_path( + project_id_or_number, zone, reservation_name + ) + assert expected == actual + + +def test_parse_reservation_path(): + expected = { + "project_id_or_number": "octopus", + "zone": "oyster", + "reservation_name": "nudibranch", + } + path = ScheduleServiceClient.reservation_path(**expected) + + # Check that the path construction is reversible. + actual = ScheduleServiceClient.parse_reservation_path(path) + assert expected == actual + + def test_schedule_path(): - project = "squid" - location = "clam" - schedule = "whelk" + project = "cuttlefish" + location = "mussel" + schedule = "winkle" expected = "projects/{project}/locations/{location}/schedules/{schedule}".format( project=project, location=location, @@ -8016,9 +8066,9 @@ def test_schedule_path(): def test_parse_schedule_path(): expected = { - "project": "octopus", - "location": "oyster", - "schedule": "nudibranch", + "project": "nautilus", + "location": "scallop", + "schedule": "abalone", } path = ScheduleServiceClient.schedule_path(**expected) @@ -8028,7 +8078,7 @@ def test_parse_schedule_path(): def test_common_billing_account_path(): - billing_account = "cuttlefish" + billing_account = "squid" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -8038,7 +8088,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", + "billing_account": "clam", } path = ScheduleServiceClient.common_billing_account_path(**expected) @@ -8048,7 +8098,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "winkle" + folder = "whelk" expected = "folders/{folder}".format( folder=folder, ) @@ -8058,7 +8108,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "nautilus", + "folder": "octopus", } path = ScheduleServiceClient.common_folder_path(**expected) @@ -8068,7 +8118,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "scallop" + organization = "oyster" expected = "organizations/{organization}".format( organization=organization, ) @@ -8078,7 +8128,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "abalone", + "organization": "nudibranch", } path = ScheduleServiceClient.common_organization_path(**expected) @@ -8088,7 +8138,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "squid" + project = "cuttlefish" expected = "projects/{project}".format( project=project, ) @@ -8098,7 +8148,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "clam", + "project": "mussel", } path = ScheduleServiceClient.common_project_path(**expected) @@ -8108,8 +8158,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "whelk" - location = "octopus" + project = "winkle" + location = "nautilus" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -8120,8 +8170,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", + "project": "scallop", + "location": "abalone", } path = ScheduleServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py index af928e0092..fdc352c5a6 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( @@ -2447,12 +2448,16 @@ def test_list_specialist_pools_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_specialist_pools(request={}) + pager = client.list_specialist_pools(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py index 94f12b554a..4bc8fbe11c 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.tensorboard_service import ( @@ -2818,12 +2819,16 @@ def test_list_tensorboards_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_tensorboards(request={}) + pager = client.list_tensorboards(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -5833,12 +5838,18 @@ def test_list_tensorboard_experiments_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_tensorboard_experiments(request={}) + pager = client.list_tensorboard_experiments( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -8456,12 +8467,16 @@ def test_list_tensorboard_runs_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_tensorboard_runs(request={}) + pager = client.list_tensorboard_runs(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -11168,12 +11183,18 @@ def test_list_tensorboard_time_series_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_tensorboard_time_series(request={}) + pager = client.list_tensorboard_time_series( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -14165,14 +14186,20 @@ def test_export_tensorboard_time_series_data_pager(transport_name: str = "grpc") ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata( (("tensorboard_time_series", ""),) ), ) - pager = client.export_tensorboard_time_series_data(request={}) + pager = client.export_tensorboard_time_series_data( + request={}, retry=retry, timeout=timeout + ) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py index ce5ff3f467..cc8ffd20b7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import ( @@ -2384,12 +2385,16 @@ def test_list_rag_corpora_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_rag_corpora(request={}) + pager = client.list_rag_corpora(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4513,12 +4518,16 @@ def test_list_rag_files_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_rag_files(request={}) + pager = client.list_rag_files(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_service.py index 9f678ff623..ab451c6beb 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_service.py @@ -43,6 +43,7 @@ from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.vertex_rag_service import ( diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py index 746c254b2d..5a08d23d50 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py @@ -47,6 +47,7 @@ from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template +from google.api_core import retry as retries from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.aiplatform_v1beta1.services.vizier_service import ( @@ -2294,12 +2295,16 @@ def test_list_studies_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_studies(request={}) + pager = client.list_studies(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 @@ -4637,12 +4642,16 @@ def test_list_trials_pager(transport_name: str = "grpc"): ) expected_metadata = () + retry = retries.Retry() + timeout = 5 expected_metadata = tuple(expected_metadata) + ( gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) - pager = client.list_trials(request={}) + pager = client.list_trials(request={}, retry=retry, timeout=timeout) assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout results = list(pager) assert len(results) == 6 diff --git a/tests/unit/vertex_ray/conftest.py b/tests/unit/vertex_ray/conftest.py index de20c135e0..9bebe10e1f 100644 --- a/tests/unit/vertex_ray/conftest.py +++ b/tests/unit/vertex_ray/conftest.py @@ -19,16 +19,16 @@ from google.auth import credentials as auth_credentials from google.cloud import resourcemanager from google.cloud.aiplatform import vertex_ray -from google.cloud.aiplatform_v1.services.persistent_resource_service import ( +from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import ( PersistentResourceServiceClient, ) -from google.cloud.aiplatform_v1.types.persistent_resource import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( PersistentResource, ) -from google.cloud.aiplatform_v1.types.persistent_resource import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( ResourceRuntime, ) -from google.cloud.aiplatform_v1.types.persistent_resource_service import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import ( DeletePersistentResourceRequest, ) import test_constants as tc diff --git a/tests/unit/vertex_ray/test_cluster_init.py b/tests/unit/vertex_ray/test_cluster_init.py index fcb4fa5b6e..a4ddd5f818 100644 --- a/tests/unit/vertex_ray/test_cluster_init.py +++ b/tests/unit/vertex_ray/test_cluster_init.py @@ -22,10 +22,10 @@ Resources, NodeImages, ) -from google.cloud.aiplatform_v1.services.persistent_resource_service import ( +from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import ( PersistentResourceServiceClient, ) -from google.cloud.aiplatform_v1.types import persistent_resource_service +from google.cloud.aiplatform_v1beta1.types import persistent_resource_service import test_constants as tc import mock import pytest @@ -352,6 +352,7 @@ def test_create_ray_cluster_1_pool_gpu_with_labels_success( self, create_persistent_resource_1_pool_mock ): """If head and worker nodes are duplicate, merge to head pool.""" + # Also test disable logging and metrics collection. cluster_name = vertex_ray.create_ray_cluster( head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_1_POOL, worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_1_POOL, @@ -359,6 +360,7 @@ def test_create_ray_cluster_1_pool_gpu_with_labels_success( cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID, labels=tc.ClusterConstants.TEST_LABELS, enable_metrics_collection=False, + enable_logging=False, ) assert tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS == cluster_name @@ -401,11 +403,15 @@ def test_create_ray_cluster_2_pools_success( self, create_persistent_resource_2_pools_mock ): """If head and worker nodes are not duplicate, create separate resource_pools.""" + # Also test PSC-I. + psc_interface_config = vertex_ray.PscIConfig( + network_attachment=tc.ClusterConstants.TEST_PSC_NETWORK_ATTACHMENT + ) cluster_name = vertex_ray.create_ray_cluster( head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_2_POOLS, worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_2_POOLS, - network=tc.ProjectConstants.TEST_VPC_NETWORK, cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID, + psc_interface_config=psc_interface_config, ) assert tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS == cluster_name diff --git a/tests/unit/vertex_ray/test_constants.py b/tests/unit/vertex_ray/test_constants.py index 8326b81595..1f1755b1c5 100644 --- a/tests/unit/vertex_ray/test_constants.py +++ b/tests/unit/vertex_ray/test_constants.py @@ -20,31 +20,36 @@ from google.cloud.aiplatform.vertex_ray.util.resources import Cluster from google.cloud.aiplatform.vertex_ray.util.resources import ( + PscIConfig, Resources, ) -from google.cloud.aiplatform_v1.types.machine_resources import DiskSpec -from google.cloud.aiplatform_v1.types.machine_resources import ( +from google.cloud.aiplatform_v1beta1.types.machine_resources import DiskSpec +from google.cloud.aiplatform_v1beta1.types.machine_resources import ( MachineSpec, ) -from google.cloud.aiplatform_v1.types.persistent_resource import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( PersistentResource, ) -from google.cloud.aiplatform_v1.types.persistent_resource import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( + RayLogsSpec, RayMetricSpec, ) -from google.cloud.aiplatform_v1.types.persistent_resource import RaySpec -from google.cloud.aiplatform_v1.types.persistent_resource import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource import RaySpec +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( ResourcePool, ) -from google.cloud.aiplatform_v1.types.persistent_resource import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( ResourceRuntime, ) -from google.cloud.aiplatform_v1.types.persistent_resource import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( ResourceRuntimeSpec, ) -from google.cloud.aiplatform_v1.types.persistent_resource import ( +from google.cloud.aiplatform_v1beta1.types.persistent_resource import ( ServiceAccountSpec, ) +from google.cloud.aiplatform_v1beta1.types.service_networking import ( + PscInterfaceConfig, +) import pytest @@ -93,6 +98,7 @@ class ClusterConstants: TEST_CPU_IMAGE = "us-docker.pkg.dev/vertex-ai/training/ray-cpu.2-9.py310:latest" TEST_GPU_IMAGE = "us-docker.pkg.dev/vertex-ai/training/ray-gpu.2-9.py310:latest" TEST_CUSTOM_IMAGE = "us-docker.pkg.dev/my-project/ray-custom-image.2.9:latest" + TEST_PSC_NETWORK_ATTACHMENT = "my-network-attachment" # RUNNING Persistent Cluster w/o Ray TEST_RESPONSE_NO_RAY_RUNNING = PersistentResource( name=TEST_VERTEX_RAY_PR_ADDRESS, @@ -129,8 +135,10 @@ class ClusterConstants: ray_spec=RaySpec( resource_pool_images={"head-node": TEST_GPU_IMAGE}, ray_metric_spec=RayMetricSpec(disabled=False), + ray_logs_spec=RayLogsSpec(disabled=False), ), ), + psc_interface_config=None, network=ProjectConstants.TEST_VPC_NETWORK, ) TEST_REQUEST_RUNNING_1_POOL_WITH_LABELS = PersistentResource( @@ -139,8 +147,10 @@ class ClusterConstants: ray_spec=RaySpec( resource_pool_images={"head-node": TEST_GPU_IMAGE}, ray_metric_spec=RayMetricSpec(disabled=True), + ray_logs_spec=RayLogsSpec(disabled=True), ), ), + psc_interface_config=None, network=ProjectConstants.TEST_VPC_NETWORK, labels=TEST_LABELS, ) @@ -150,8 +160,10 @@ class ClusterConstants: ray_spec=RaySpec( resource_pool_images={"head-node": TEST_CUSTOM_IMAGE}, ray_metric_spec=RayMetricSpec(disabled=False), + ray_logs_spec=RayLogsSpec(disabled=False), ), ), + psc_interface_config=None, network=ProjectConstants.TEST_VPC_NETWORK, ) TEST_REQUEST_RUNNING_1_POOL_BYOSA = PersistentResource( @@ -160,12 +172,14 @@ class ClusterConstants: ray_spec=RaySpec( resource_pool_images={"head-node": TEST_GPU_IMAGE}, ray_metric_spec=RayMetricSpec(disabled=False), + ray_logs_spec=RayLogsSpec(disabled=False), ), service_account_spec=ServiceAccountSpec( enable_custom_service_account=True, service_account=ProjectConstants.TEST_SERVICE_ACCOUNT, ), ), + psc_interface_config=None, network=None, ) # Get response has generated name, and URIs @@ -176,8 +190,10 @@ class ClusterConstants: ray_spec=RaySpec( resource_pool_images={"head-node": TEST_GPU_IMAGE}, ray_metric_spec=RayMetricSpec(disabled=False), + ray_logs_spec=RayLogsSpec(disabled=False), ), ), + psc_interface_config=None, network=ProjectConstants.TEST_VPC_NETWORK, resource_runtime=ResourceRuntime( access_uris={ @@ -197,6 +213,7 @@ class ClusterConstants: ray_metric_spec=RayMetricSpec(disabled=False), ), ), + psc_interface_config=None, network=ProjectConstants.TEST_VPC_NETWORK, resource_runtime=ResourceRuntime( access_uris={ @@ -219,6 +236,7 @@ class ClusterConstants: service_account=ProjectConstants.TEST_SERVICE_ACCOUNT, ), ), + psc_interface_config=None, network=None, resource_runtime=ResourceRuntime( access_uris={ @@ -241,6 +259,7 @@ class ClusterConstants: service_account=ProjectConstants.TEST_SERVICE_ACCOUNT, ), ), + psc_interface_config=None, network=ProjectConstants.TEST_VPC_NETWORK, resource_runtime=ResourceRuntime( access_uris={ @@ -303,9 +322,12 @@ class ClusterConstants: "worker-pool1": TEST_GPU_IMAGE, }, ray_metric_spec=RayMetricSpec(disabled=False), + ray_logs_spec=RayLogsSpec(disabled=False), ), ), - network=ProjectConstants.TEST_VPC_NETWORK, + psc_interface_config=PscInterfaceConfig( + network_attachment=TEST_PSC_NETWORK_ATTACHMENT + ), ) TEST_REQUEST_RUNNING_2_POOLS_CUSTOM_IMAGE = PersistentResource( resource_pools=[TEST_RESOURCE_POOL_1, TEST_RESOURCE_POOL_2], @@ -316,8 +338,10 @@ class ClusterConstants: "worker-pool1": TEST_CUSTOM_IMAGE, }, ray_metric_spec=RayMetricSpec(disabled=False), + ray_logs_spec=RayLogsSpec(disabled=False), ), ), + psc_interface_config=None, network=ProjectConstants.TEST_VPC_NETWORK, ) TEST_RESPONSE_RUNNING_2_POOLS = PersistentResource( @@ -332,11 +356,13 @@ class ClusterConstants: ray_metric_spec=RayMetricSpec(disabled=False), ), ), - network=ProjectConstants.TEST_VPC_NETWORK, + psc_interface_config=PscInterfaceConfig( + network_attachment=TEST_PSC_NETWORK_ATTACHMENT + ), + network=None, resource_runtime=ResourceRuntime( access_uris={ "RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS, - "RAY_HEAD_NODE_INTERNAL_IP": TEST_VERTEX_RAY_HEAD_NODE_IP, } ), state="RUNNING", @@ -372,17 +398,22 @@ class ClusterConstants: head_node_type=TEST_HEAD_NODE_TYPE_1_POOL, worker_node_types=TEST_WORKER_NODE_TYPES_1_POOL, dashboard_address=TEST_VERTEX_RAY_DASHBOARD_ADDRESS, + ray_metric_enabled=True, + ray_logs_enabled=True, ) TEST_CLUSTER_2 = Cluster( cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS, python_version="3.10", ray_version="2.9", - network=ProjectConstants.TEST_VPC_NETWORK, + network="", service_account=None, state="RUNNING", head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS, worker_node_types=TEST_WORKER_NODE_TYPES_2_POOLS, dashboard_address=TEST_VERTEX_RAY_DASHBOARD_ADDRESS, + ray_metric_enabled=True, + ray_logs_enabled=True, + psc_interface_config=PscIConfig(network_attachment=TEST_PSC_NETWORK_ATTACHMENT), ) TEST_CLUSTER_CUSTOM_IMAGE = Cluster( cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS, @@ -392,6 +423,8 @@ class ClusterConstants: head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE, worker_node_types=TEST_WORKER_NODE_TYPES_2_POOLS_CUSTOM_IMAGE, dashboard_address=TEST_VERTEX_RAY_DASHBOARD_ADDRESS, + ray_metric_enabled=True, + ray_logs_enabled=True, ) TEST_CLUSTER_BYOSA = Cluster( cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS, @@ -403,6 +436,8 @@ class ClusterConstants: head_node_type=TEST_HEAD_NODE_TYPE_1_POOL, worker_node_types=TEST_WORKER_NODE_TYPES_1_POOL, dashboard_address=TEST_VERTEX_RAY_DASHBOARD_ADDRESS, + ray_metric_enabled=True, + ray_logs_enabled=True, ) TEST_BEARER_TOKEN = "test-bearer-token" TEST_HEADERS = { diff --git a/tests/unit/vertexai/feature_store_constants.py b/tests/unit/vertexai/feature_store_constants.py index 60a9cbce34..dbcec2a1ef 100644 --- a/tests/unit/vertexai/feature_store_constants.py +++ b/tests/unit/vertexai/feature_store_constants.py @@ -100,7 +100,12 @@ _TEST_PSC_OPTIMIZED_FOS = types.feature_online_store_v1.FeatureOnlineStore( name=_TEST_PSC_OPTIMIZED_FOS_PATH, optimized=types.feature_online_store_v1.FeatureOnlineStore.Optimized(), - dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint(), + dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint( + private_service_connect_config=types.service_networking_v1.PrivateServiceConnectConfig( + enable_private_service_connect=True, + project_allowlist=_TEST_PSC_PROJECT_ALLOWLIST, + ), + ), labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, ) diff --git a/tests/unit/vertexai/test_extensions.py b/tests/unit/vertexai/test_extensions.py index ecb3afafd0..8de81baee4 100644 --- a/tests/unit/vertexai/test_extensions.py +++ b/tests/unit/vertexai/test_extensions.py @@ -23,8 +23,13 @@ from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils as aip_utils from google.cloud.aiplatform_v1beta1 import types -from google.cloud.aiplatform_v1beta1.services import extension_execution_service -from google.cloud.aiplatform_v1beta1.services import extension_registry_service +from google.cloud.aiplatform_v1beta1.services import ( + extension_execution_service, +) +from google.cloud.aiplatform_v1beta1.services import ( + extension_registry_service, +) +from vertexai.generative_models import _generative_models from vertexai.preview import extensions from vertexai.reasoning_engines import _utils import pytest @@ -180,6 +185,33 @@ def execute_extension_mock(): yield execute_extension_mock +@pytest.fixture +def query_extension_mock(): + with mock.patch.object( + extension_execution_service.ExtensionExecutionServiceClient, "query_extension" + ) as query_extension_mock: + query_extension_mock.return_value.steps = [ + types.Content( + role="user", + parts=[ + types.Part( + text=_TEST_QUERY_PROMPT, + ) + ], + ), + types.Content( + role="extension", + parts=[ + types.Part( + text=_TEST_RESPONSE_CONTENT, + ) + ], + ), + ] + query_extension_mock.return_value.failure_message = "" + yield query_extension_mock + + @pytest.fixture def delete_extension_mock(): with mock.patch.object( @@ -325,6 +357,49 @@ def test_execute_extension( ), ) + def test_query_extension( + self, + get_extension_mock, + query_extension_mock, + load_yaml_mock, + ): + test_extension = extensions.Extension(_TEST_RESOURCE_ID) + get_extension_mock.assert_called_once_with( + name=_TEST_EXTENSION_RESOURCE_NAME, + retry=aiplatform.base._DEFAULT_RETRY, + ) + # Manually set _gca_resource here to prevent the mocks from propagating. + test_extension._gca_resource = _TEST_EXTENSION_OBJ + response = test_extension.query( + contents=[ + _generative_models.Content( + parts=[ + _generative_models.Part.from_text( + _TEST_QUERY_PROMPT, + ) + ], + role="user", + ) + ], + ) + assert response.steps[-1].parts[0].text == _TEST_RESPONSE_CONTENT + + query_extension_mock.assert_called_once_with( + types.QueryExtensionRequest( + name=_TEST_EXTENSION_RESOURCE_NAME, + contents=[ + types.Content( + role="user", + parts=[ + types.Part( + text=_TEST_QUERY_PROMPT, + ) + ], + ) + ], + ), + ) + def test_api_spec_from_yaml(self, get_extension_mock, load_yaml_mock): test_extension = extensions.Extension(_TEST_RESOURCE_ID) get_extension_mock.assert_called_once_with( diff --git a/tests/unit/vertexai/test_feature.py b/tests/unit/vertexai/test_feature.py index 3be234539b..106f8e1723 100644 --- a/tests/unit/vertexai/test_feature.py +++ b/tests/unit/vertexai/test_feature.py @@ -93,7 +93,7 @@ def test_init_with_feature_id_and_no_fg_id_raises_error(get_feature_mock): with pytest.raises( ValueError, match=re.escape( - "Since feature is not provided as a path, please specify" + "Since feature 'my_fg1_f1' is not provided as a path, please specify" + " feature_group_id." ), ): @@ -106,7 +106,7 @@ def test_init_with_feature_path_and_fg_id_raises_error(get_feature_mock): with pytest.raises( ValueError, match=re.escape( - "Since feature is provided as a path, feature_group_id should not be specified." + "Since feature 'projects/test-project/locations/us-central1/featureGroups/my_fg1/features/my_fg1_f1' is provided as a path, feature_group_id should not be specified." ), ): Feature(_TEST_FG1_F1_PATH, feature_group_id=_TEST_FG1_ID) diff --git a/tests/unit/vertexai/test_feature_online_store.py b/tests/unit/vertexai/test_feature_online_store.py index dae961b2bc..a131041d22 100644 --- a/tests/unit/vertexai/test_feature_online_store.py +++ b/tests/unit/vertexai/test_feature_online_store.py @@ -15,61 +15,61 @@ # limitations under the License. # -from unittest.mock import call import re +from typing import Dict from unittest import mock +from unittest.mock import call from unittest.mock import patch -from typing import Dict from google.api_core import operation as ga_operation from google.cloud import aiplatform from google.cloud.aiplatform import base from google.cloud.aiplatform.compat import types -from vertexai.resources.preview import ( - FeatureOnlineStore, - FeatureOnlineStoreType, - FeatureViewBigQuerySource, - IndexConfig, - DistanceMeasureType, - TreeAhConfig, -) -from vertexai.resources.preview.feature_store import ( - feature_online_store, -) from google.cloud.aiplatform.compat.services import ( feature_online_store_admin_service_client, ) -import pytest - -from test_feature_view import fv_eq from feature_store_constants import ( - _TEST_PROJECT, - _TEST_LOCATION, - _TEST_PARENT, _TEST_BIGTABLE_FOS1_ID, - _TEST_BIGTABLE_FOS1_PATH, _TEST_BIGTABLE_FOS1_LABELS, + _TEST_BIGTABLE_FOS1_PATH, _TEST_BIGTABLE_FOS2_ID, - _TEST_BIGTABLE_FOS2_PATH, _TEST_BIGTABLE_FOS2_LABELS, + _TEST_BIGTABLE_FOS2_PATH, _TEST_BIGTABLE_FOS3_ID, - _TEST_BIGTABLE_FOS3_PATH, _TEST_BIGTABLE_FOS3_LABELS, + _TEST_BIGTABLE_FOS3_PATH, _TEST_ESF_OPTIMIZED_FOS_ID, - _TEST_ESF_OPTIMIZED_FOS_PATH, _TEST_ESF_OPTIMIZED_FOS_LABELS, - _TEST_PSC_OPTIMIZED_FOS_ID, - _TEST_PSC_OPTIMIZED_FOS_LABELS, - _TEST_PSC_PROJECT_ALLOWLIST, + _TEST_ESF_OPTIMIZED_FOS_PATH, _TEST_FOS_LIST, - _TEST_FV1_ID, - _TEST_FV1_PATH, - _TEST_FV1_LABELS, _TEST_FV1_BQ_URI, _TEST_FV1_ENTITY_ID_COLUMNS, + _TEST_FV1_ID, + _TEST_FV1_LABELS, + _TEST_FV1_PATH, + _TEST_LOCATION, _TEST_OPTIMIZED_EMBEDDING_FV_ID, _TEST_OPTIMIZED_EMBEDDING_FV_PATH, + _TEST_PARENT, + _TEST_PROJECT, + _TEST_PSC_OPTIMIZED_FOS_ID, + _TEST_PSC_OPTIMIZED_FOS_LABELS, + _TEST_PSC_OPTIMIZED_FOS_PATH, + _TEST_PSC_PROJECT_ALLOWLIST, ) +from test_feature_view import fv_eq +from vertexai.resources.preview import ( + DistanceMeasureType, + FeatureOnlineStore, + FeatureOnlineStoreType, + FeatureViewBigQuerySource, + IndexConfig, + TreeAhConfig, +) +from vertexai.resources.preview.feature_store import ( + feature_online_store, +) +import pytest @pytest.fixture @@ -277,24 +277,109 @@ def test_create_esf_optimized_store( ) -@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) -def test_create_psc_optimized_store( - create_request_timeout, -): +def test_create_psc_optimized_store_no_project_allowlist_raises_error(): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises( ValueError, - match=re.escape("private_service_connect is not supported"), + match=re.escape( + "`project_allowlist` cannot be empty when `enable_private_service_connect` is" + " set to true." + ), ): FeatureOnlineStore.create_optimized_store( _TEST_PSC_OPTIMIZED_FOS_ID, labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, - create_request_timeout=create_request_timeout, enable_private_service_connect=True, - project_allowlist=_TEST_PSC_PROJECT_ALLOWLIST, ) +def test_create_psc_optimized_store_empty_project_allowlist_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape( + "`project_allowlist` cannot be empty when `enable_private_service_connect` is" + " set to true." + ), + ): + FeatureOnlineStore.create_optimized_store( + _TEST_PSC_OPTIMIZED_FOS_ID, + enable_private_service_connect=True, + project_allowlist=[], + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +@pytest.mark.parametrize("sync", [True, False]) +def test_create_psc_optimized_store( + create_psc_optimized_fos_mock, + get_psc_optimized_fos_mock, + fos_logger_mock, + create_request_timeout, + sync, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore.create_optimized_store( + _TEST_PSC_OPTIMIZED_FOS_ID, + labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, + create_request_timeout=create_request_timeout, + enable_private_service_connect=True, + project_allowlist=_TEST_PSC_PROJECT_ALLOWLIST, + ) + + if not sync: + fos.wait() + + expected_feature_online_store = types.feature_online_store_v1.FeatureOnlineStore( + optimized=types.feature_online_store_v1.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint( + private_service_connect_config=types.service_networking_v1.PrivateServiceConnectConfig( + enable_private_service_connect=True, + project_allowlist=_TEST_PSC_PROJECT_ALLOWLIST, + ) + ), + labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, + ) + create_psc_optimized_fos_mock.assert_called_once_with( + parent=_TEST_PARENT, + feature_online_store=expected_feature_online_store, + feature_online_store_id=_TEST_PSC_OPTIMIZED_FOS_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fos_logger_mock.assert_has_calls( + [ + call("Creating FeatureOnlineStore"), + call( + "Create FeatureOnlineStore backing LRO:" + f" {create_psc_optimized_fos_mock.return_value.operation.name}" + ), + call( + "FeatureOnlineStore created. Resource name:" + " projects/test-project/locations/us-central1/featureOnlineStores/my_psc_optimized_fos" + ), + call("To use this FeatureOnlineStore in another session:"), + call( + "feature_online_store =" + " aiplatform.FeatureOnlineStore('projects/test-project/locations/us-central1/featureOnlineStores/my_psc_optimized_fos')" + ), + ] + ) + + fos_eq( + fos, + name=_TEST_PSC_OPTIMIZED_FOS_ID, + resource_name=_TEST_PSC_OPTIMIZED_FOS_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, + type=FeatureOnlineStoreType.OPTIMIZED, + ) + + def test_list(list_fos_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) diff --git a/tests/unit/vertexai/test_generative_models.py b/tests/unit/vertexai/test_generative_models.py index ada5a80564..6f284dc112 100644 --- a/tests/unit/vertexai/test_generative_models.py +++ b/tests/unit/vertexai/test_generative_models.py @@ -1016,9 +1016,7 @@ def test_generate_content_grounding_google_search_retriever_preview(self): model = preview_generative_models.GenerativeModel("gemini-pro") google_search_retriever_tool = ( preview_generative_models.Tool.from_google_search_retrieval( - preview_generative_models.grounding.GoogleSearchRetrieval( - disable_attribution=False - ) + preview_generative_models.grounding.GoogleSearchRetrieval() ) ) response = model.generate_content( diff --git a/tests/unit/vertexai/test_prompts.py b/tests/unit/vertexai/test_prompts.py new file mode 100644 index 0000000000..c0451f304a --- /dev/null +++ b/tests/unit/vertexai/test_prompts.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Unit tests for generative model prompts.""" +# pylint: disable=protected-access,bad-continuation + +from vertexai.generative_models._prompts import Prompt +from vertexai.generative_models import Content, Part, Image + +import io +import pytest + +from typing import Any, List + + +def is_list_of_type(obj: Any, T: Any) -> bool: + return isinstance(obj, list) and all(isinstance(s, T) for s in obj) + + +def assert_prompt_contents_equal( + prompt_contents: List[Content], + expected_prompt_contents: List[Content], +) -> None: + assert len(prompt_contents) == len(expected_prompt_contents) + for i in range(len(prompt_contents)): + assert prompt_contents[i].role == expected_prompt_contents[i].role + assert len(prompt_contents[i].parts) == len(expected_prompt_contents[i].parts) + for j in range(len(prompt_contents[i].parts)): + assert ( + prompt_contents[i].parts[j]._raw_part.text + == expected_prompt_contents[i].parts[j]._raw_part.text + ) + + +def create_image(): + # Importing external library lazily to reduce the scope of import errors. + from PIL import Image as PIL_Image # pylint: disable=g-import-not-at-top + + pil_image: PIL_Image.Image = PIL_Image.new(mode="RGB", size=(200, 200)) + image_bytes_io = io.BytesIO() + pil_image.save(image_bytes_io, format="jpeg") + return Image.from_bytes(image_bytes_io.getvalue()) + + +@pytest.mark.usefixtures("google_auth_mock") +class TestPrompt: + """Unit tests for generative model prompts.""" + + def test_string_prompt_constructor_string_variables(self): + # Create string prompt with string only variable values + prompt = Prompt( + prompt_data="Rate the movie {movie1}", + variables=[ + { + "movie1": "The Avengers", + } + ], + ) + # String prompt data should remain as string before compilation + assert prompt.prompt_data == "Rate the movie {movie1}" + # Variables values should be converted to List[Part] + assert is_list_of_type(prompt.variables[0]["movie1"], Part) + + def test_string_prompt_constructor_part_variables(self): + # Create string prompt with List[Part] variable values + prompt = Prompt( + prompt_data="Rate the movie {movie1}", + variables=[ + { + "movie1": [Part.from_text("The Avengers")], + } + ], + ) + # Variables values should be converted to List[Part] + assert is_list_of_type(prompt.variables[0]["movie1"], Part) + + def test_string_prompt_constructor_invalid_variables(self): + # String prompt variables must be PartsType + with pytest.raises(TypeError): + Prompt( + prompt_data="Rate the movie {movie1}", + variables=[ + { + "movie1": 12345, + } + ], + ) + + def test_partstype_prompt_constructor(self): + image = create_image() + # Create PartsType prompt with List[Part] variable values + prompt_data = [ + "Compare the movie posters for The Avengers and {movie2}: ", + image, + "{movie2_poster}", + ] + prompt = Prompt( + prompt_data=prompt_data, + variables=[{"movie2": "Frozen", "movie2_poster": [Part.from_image(image)]}], + ) + # Variables values should be List[Part] + assert is_list_of_type(prompt.variables[0]["movie2"], Part) + assert is_list_of_type(prompt.variables[0]["movie2_poster"], Part) + + def test_string_prompt_assemble_contents(self): + prompt = Prompt( + prompt_data="Which movie is better, {movie1} or {movie2}?", + variables=[ + { + "movie1": "The Avengers", + "movie2": "Frozen", + } + ], + ) + assembled_prompt_content = prompt.assemble_contents(**prompt.variables[0]) + expected_content = [ + Content( + parts=[ + Part.from_text("Which movie is better, The Avengers or Frozen?"), + ], + role="user", + ) + ] + assert_prompt_contents_equal(assembled_prompt_content, expected_content) + + def test_partstype_prompt_assemble_contents(self): + image1 = create_image() + image2 = create_image() + prompt_data = [ + "Compare the movie posters for The Avengers and {movie2}: ", + image1, + "{movie2_poster}", + ] + prompt = Prompt( + prompt_data=prompt_data, + variables=[ + { + "movie2": "Frozen", + "movie2_poster": [Part.from_image(image=image2)], + } + ], + ) + + # Check assembled prompt content + assembled_prompt_content = prompt.assemble_contents(**prompt.variables[0]) + expected_content = [ + Content( + parts=[ + Part.from_text( + "Compare the movie posters for The Avengers and Frozen: " + ), + Part.from_image(image=image1), + Part.from_image(image=image2), + ], + role="user", + ) + ] + assert_prompt_contents_equal(assembled_prompt_content, expected_content) + + def test_string_prompt_partial_assemble_contents(self): + prompt = Prompt( + prompt_data="Which movie is better, {movie1} or {movie2}?", + variables=[ + { + "movie1": "The Avengers", + } + ], + ) + + # Check partially assembled prompt content + assembled1_prompt_content = prompt.assemble_contents(**prompt.variables[0]) + expected1_content = [ + Content( + parts=[ + Part.from_text("Which movie is better, The Avengers or {movie2}?"), + ], + role="user", + ) + ] + assert_prompt_contents_equal(assembled1_prompt_content, expected1_content) + + # Check fully assembled prompt + assembled2_prompt_content = prompt.assemble_contents( + movie1="Inception", movie2="Frozen" + ) + expected2_content = [ + Content( + parts=[ + Part.from_text("Which movie is better, Inception or Frozen?"), + ], + role="user", + ) + ] + assert_prompt_contents_equal(assembled2_prompt_content, expected2_content) + + def test_string_prompt_assemble_unused_variables(self): + # Variables must present in prompt_data if specified + prompt = Prompt(prompt_data="Rate the movie {movie1}") + with pytest.raises(ValueError): + prompt.assemble_contents(day="Tuesday") diff --git a/vertexai/extensions/_extensions.py b/vertexai/extensions/_extensions.py index 42c0544a17..fa4b8a060d 100644 --- a/vertexai/extensions/_extensions.py +++ b/vertexai/extensions/_extensions.py @@ -14,14 +14,14 @@ # limitations under the License. # import json -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Union from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils as aip_utils from google.cloud.aiplatform_v1beta1 import types +from vertexai.generative_models import _generative_models from vertexai.reasoning_engines import _utils - from google.protobuf import struct_pb2 _LOGGER = base.Logger(__name__) @@ -248,6 +248,36 @@ def execute( response = self.execution_api_client.execute_extension(request) return _try_parse_execution_response(response) + def query( + self, + contents: _generative_models.ContentsType, + ) -> "QueryExtensionResponse": + """Queries an extension with the specified contents. + + Args: + contents (ContentsType): + Required. The content of the current + conversation with the model. + For single-turn queries, this is a single + instance. For multi-turn queries, this is a + repeated field that contains conversation + history + latest request. + + Returns: + The result of querying the extension. + + Raises: + RuntimeError: If the response contains an error. + """ + request = types.QueryExtensionRequest( + name=self.resource_name, + contents=_generative_models._content_types_to_gapic_contents(contents), + ) + response = self.execution_api_client.query_extension(request) + if response.failure_message: + raise RuntimeError(response.failure_message) + return QueryExtensionResponse._from_gapic(response) + @classmethod def from_hub( cls, @@ -317,6 +347,29 @@ def from_hub( ) +class QueryExtensionResponse: + """A class representing the response from querying an extension.""" + + def __init__(self, steps: List[_generative_models.Content]): + """Initializes the QueryExtensionResponse with the given steps.""" + self.steps = steps + + @classmethod + def _from_gapic( + cls, response: types.QueryExtensionResponse + ) -> "QueryExtensionResponse": + """Creates a QueryExtensionResponse from a gapic response.""" + return cls( + steps=[ + _generative_models.Content( + parts=[_generative_models.Part._from_gapic(p) for p in c.parts], + role=c.role, + ) + for c in response.steps + ] + ) + + def _try_parse_execution_response( response: types.ExecuteExtensionResponse, ) -> Union[_utils.JsonDict, str]: diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index b74226e6f5..5041433869 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -126,6 +126,8 @@ def _get_resource_name_from_model_name( ) -> str: """Returns the full resource name starting with projects/ given a model name.""" if model_name.startswith("publishers/"): + if not project: + return model_name return f"projects/{project}/locations/{location}/{model_name}" elif model_name.startswith("projects/"): return model_name @@ -337,7 +339,7 @@ def __init__( location = aiplatform_utils.extract_project_and_location_from_parent( prediction_resource_name - )["location"] + ).get("location") self._model_name = model_name self._prediction_resource_name = prediction_resource_name @@ -823,7 +825,7 @@ async def count_tokens_async( def compute_tokens( self, contents: ContentsType ) -> gapic_llm_utility_service_types.ComputeTokensResponse: - """Counts tokens. + """Computes tokens. Args: contents: Contents to send to the model. @@ -835,9 +837,13 @@ def compute_tokens( * List[Content] Returns: - A CountTokensResponse object that has the following attributes: - total_tokens: The total number of tokens counted across all instances from the request. - total_billable_characters: The total number of billable characters counted across all instances from the request. + A ComputeTokensResponse object that has the following attributes: + tokens_info: Lists of tokens_info from the input. + The input `contents: ContentsType` could have + multiple string instances and each tokens_info + item represents each string instance. Each token + info consists tokens list, token_ids list and + a role. """ return self._llm_utility_client.compute_tokens( request=gapic_llm_utility_service_types.ComputeTokensRequest( @@ -850,7 +856,7 @@ def compute_tokens( async def compute_tokens_async( self, contents: ContentsType ) -> gapic_llm_utility_service_types.ComputeTokensResponse: - """Counts tokens asynchronously. + """Computes tokens asynchronously. Args: contents: Contents to send to the model. @@ -862,9 +868,13 @@ async def compute_tokens_async( * List[Content] Returns: - And awaitable for a CountTokensResponse object that has the following attributes: - total_tokens: The total number of tokens counted across all instances from the request. - total_billable_characters: The total number of billable characters counted across all instances from the request. + And awaitable for a ComputeTokensResponse object that has the following attributes: + tokens_info: Lists of tokens_info from the input. + The input `contents: ContentsType` could have + multiple string instances and each tokens_info + item represents each string instance. Each token + info consists tokens list, token_ids list and + a role. """ return await self._llm_utility_async_client.compute_tokens( request=gapic_llm_utility_service_types.ComputeTokensRequest( @@ -2300,20 +2310,22 @@ class GoogleSearchRetrieval: def __init__( self, - disable_attribution: Optional[bool] = None, + disable_attribution: Optional[ + bool + ] = None, # pylint: disable=unused-argument ): """Initializes a Google Search Retrieval tool. Args: disable_attribution (bool): - Optional. Disable using the result from this - tool in detecting grounding attribution. This + Optional. This field is Deprecated. Disable using the result + from this tool in detecting grounding attribution. This does not affect how the result is given to the model for generation. """ - self._raw_google_search_retrieval = gapic_tool_types.GoogleSearchRetrieval( - disable_attribution=disable_attribution, - ) + if disable_attribution is not None: + warnings.warn("disable_attribution is deprecated.") + self._raw_google_search_retrieval = gapic_tool_types.GoogleSearchRetrieval() def _to_content( diff --git a/vertexai/generative_models/_prompts.py b/vertexai/generative_models/_prompts.py new file mode 100644 index 0000000000..69b55f52a1 --- /dev/null +++ b/vertexai/generative_models/_prompts.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform import base +from vertexai.generative_models import ( + Content, + Image, + Part, +) +from vertexai.generative_models._generative_models import ( + _to_content, + PartsType, +) + +import re +from typing import ( + Any, + Dict, + List, + Optional, + Union, +) + +_LOGGER = base.Logger(__name__) + +VARIABLE_NAME_REGEX = r"(\{[^\W0-9]\w*\})" + + +class Prompt: + """A prompt which may be a template with variables. + + The `Prompt` class allows users to define a template string with + variables represented in curly braces `{variable}`. The variable + name must be a valid Python variable name (no spaces, must start with a + letter). These placeholders can be replaced with specific values using the + `assemble_contents` method, providing flexibility in generating dynamic prompts. + + Usage: + Generate content from a single set of variables: + ``` + prompt = Prompt( + prompt_data="Hello, {name}! Today is {day}. How are you?", + variables=[{"name": "Alice", "day": "Monday"}] + ) + + # Generate content using the assembled prompt. + model.generate_content(contents=prompt.assemble_contents(**prompt.variables[0])) + ``` + """ + + def __init__( + self, + prompt_data: PartsType, + variables: Optional[List[Dict[str, PartsType]]] = None, + ): + """Initializes the Prompt with a given prompt, and variables. + + Args: + prompt: A PartsType prompt which may be a template with variables or a prompt with no variables. + variables: A list of dictionaries containing the variable names and values. + """ + self._prompt_data = None + self._variables = None + + self.prompt_data = prompt_data + self.variables = variables if variables else [{}] + + @property + def prompt_data(self) -> PartsType: + return self._prompt_data + + @property + def variables(self) -> Optional[List[Dict[str, PartsType]]]: + return self._variables + + @prompt_data.setter + def prompt_data(self, prompt_data: PartsType) -> None: + """Overwrites the existing saved local prompt_data. + + Args: + prompt_data: A PartsType prompt. + """ + Prompt._validate_prompt_data(prompt_data) + self._prompt_data = prompt_data + + @variables.setter + def variables(self, variables: List[Dict[str, PartsType]]) -> None: + """Overwrites the existing saved local variables. + + Args: + variables: A list of dictionaries containing the variable names and values. + """ + if isinstance(variables, list): + for i in range(len(variables)): + variables[i] = variables[i].copy() + Prompt._format_variable_value_to_parts(variables[i]) + self._variables = variables + else: + raise TypeError( + f"Variables must be a list of dictionaries, not {type(variables)}" + ) + + def _format_variable_value_to_parts(variables_dict: Dict[str, PartsType]) -> None: + """Formats the variables values to be List[Part]. + + Args: + variables_dict: A single dictionary containing the variable names and values. + + Raises: + TypeError: If a variable value is not a PartsType Object. + """ + for key in variables_dict.keys(): + # Disallow Content as variable value. + if isinstance(variables_dict[key], Content): + raise TypeError( + "Variable values must be a PartsType object, not Content" + ) + + # Rely on type checks in _to_content for validation. + content = Content._from_gapic(_to_content(value=variables_dict[key])) + variables_dict[key] = content.parts + + def _validate_prompt_data(prompt_data: Any) -> None: + """ + Args: + prompt_data: The prompt input to validate + + Raises: + TypeError: If prompt_data is not a PartsType Object. + """ + # Disallow Content as prompt_data. + if isinstance(prompt_data, Content): + raise TypeError("Prompt data must be a PartsType object, not Content") + + # Rely on type checks in _to_content. + _to_content(value=prompt_data) + + def assemble_contents(self, **variables_dict: PartsType) -> List[Content]: + """Returns the prompt data, as a List[Content], assembled with variables if applicable. + Can be ingested into model.generate_content to make API calls. + + Returns: + A List[Content] prompt. + Usage: + ``` + prompt = Prompt( + prompt_data="Hello, {name}! Today is {day}. How are you?", + ) + + model.generate_content( + contents=prompt.assemble_contents(name="Alice", day="Monday") + ) + ``` + """ + variables_dict = variables_dict.copy() + + # If there are no variables, return the prompt_data as a Content object. + if not variables_dict: + return [Content._from_gapic(_to_content(value=self.prompt_data))] + + # Step 1) Convert the variables values to List[Part]. + Prompt._format_variable_value_to_parts(variables_dict) + + # Step 2) Assemble the prompt. + # prompt_data must have been previously validated using _validate_prompt_data. + assembled_prompt = [] + assembled_variables_cnt = {} + if isinstance(self.prompt_data, list): + # User inputted a List of Parts as prompt_data. + for part in self.prompt_data: + assembled_prompt.extend( + self._assemble_singular_part( + part, variables_dict, assembled_variables_cnt + ) + ) + else: + # User inputted a single str, Image, or Part as prompt_data. + assembled_prompt.extend( + self._assemble_singular_part( + self.prompt_data, variables_dict, assembled_variables_cnt + ) + ) + + # Step 3) Simplify adjacent string Parts + simplified_assembled_prompt = [assembled_prompt[0]] + for i in range(1, len(assembled_prompt)): + # If the previous Part and current Part is a string, concatenate them. + try: + prev_text = simplified_assembled_prompt[-1].text + curr_text = assembled_prompt[i].text + if isinstance(prev_text, str) and isinstance(curr_text, str): + simplified_assembled_prompt[-1] = Part.from_text( + prev_text + curr_text + ) + else: + simplified_assembled_prompt.append(assembled_prompt[i]) + except AttributeError: + simplified_assembled_prompt.append(assembled_prompt[i]) + continue + + # Step 4) Validate that all variables were used, if specified. + for key in variables_dict: + if key not in assembled_variables_cnt: + raise ValueError(f"Variable {key} is not present in prompt_data.") + + assemble_cnt_msg = "Assembled prompt replacing: " + for key in assembled_variables_cnt: + assemble_cnt_msg += ( + f"{assembled_variables_cnt[key]} instances of variable {key}, " + ) + if assemble_cnt_msg[-2:] == ", ": + assemble_cnt_msg = assemble_cnt_msg[:-2] + _LOGGER.info(assemble_cnt_msg) + + # Step 5) Wrap List[Part] as a single Content object. + return [ + Content( + parts=simplified_assembled_prompt, + role="user", + ) + ] + + def _assemble_singular_part( + self, + prompt_data_part: Union[str, Image, Part], + formatted_variables_set: Dict[str, List[Part]], + assembled_variables_cnt: Dict[str, int], + ) -> List[Part]: + """Assemble a str, Image, or Part.""" + if isinstance(prompt_data_part, Image): + # Templating is not supported for Image prompt_data. + return [Part.from_image(prompt_data_part)] + elif isinstance(prompt_data_part, str): + # Assemble a single string + return self._assemble_single_str( + prompt_data_part, formatted_variables_set, assembled_variables_cnt + ) + elif isinstance(prompt_data_part, Part): + # If the Part is a text Part, assemble it. + try: + text = prompt_data_part.text + except AttributeError: + return [prompt_data_part] + return self._assemble_single_str( + text, formatted_variables_set, assembled_variables_cnt + ) + + def _assemble_single_str( + self, + prompt_data_str: str, + formatted_variables_set: Dict[str, List[Part]], + assembled_variables_cnt: Dict[str, int], + ) -> List[Part]: + """Assemble a single string with 0 or more variables within the string.""" + # Step 1) Find and isolate variables as their own string. + prompt_data_str_split = re.split(VARIABLE_NAME_REGEX, prompt_data_str) + + assembled_data = [] + # Step 2) Assemble variables with their values, creating a list of Parts. + for s in prompt_data_str_split: + if not s: + continue + variable_name = s[1:-1] + if ( + re.match(VARIABLE_NAME_REGEX, s) + and variable_name in formatted_variables_set + ): + assembled_data.extend(formatted_variables_set[variable_name]) + assembled_variables_cnt[variable_name] = ( + assembled_variables_cnt.get(variable_name, 0) + 1 + ) + else: + assembled_data.append(Part.from_text(s)) + + return assembled_data + + def get_unassembled_prompt_data(self) -> PartsType: + """Returns the prompt data, without any variables replaced.""" + return self.prompt_data + + def __str__(self) -> str: + """Returns the prompt data as a string, without any variables replaced.""" + return str(self.prompt_data) + + def __repr__(self) -> str: + """Returns a string representation of the unassembled prompt.""" + return f"Prompt(prompt_data='{self.prompt_data}', variables={self.variables})" diff --git a/vertexai/preview/evaluation/_base.py b/vertexai/preview/evaluation/_base.py index 337dd16ab8..30202ca263 100644 --- a/vertexai/preview/evaluation/_base.py +++ b/vertexai/preview/evaluation/_base.py @@ -37,7 +37,7 @@ class EvaluationRunConfig: Attributes: dataset: The dataset to evaluate. - metrics: The list of metric names, or metric bundle names, or Metric instances to evaluate. + metrics: The list of metrics, or Metric instances to evaluate. column_map: The dictionary of column name overrides in the dataset. client: The evaluation service client. evaluation_service_qps: The custom QPS limit for the evaluation service. @@ -73,10 +73,11 @@ class EvalResult: """Evaluation result. Attributes: - summary_metrics: The summary evaluation metrics for an evaluation run. - metrics_table: A table containing eval inputs, ground truth, and metrics per - row. + summary_metrics: The summary evaluation metrics for the evaluation run. + metrics_table: A table containing evaluation dataset, and metric results. + metadata: The metadata for the evaluation run. """ summary_metrics: Dict[str, float] metrics_table: Optional["pd.DataFrame"] = None + metadata: Optional[Dict[str, str]] = None diff --git a/vertexai/preview/evaluation/_eval_tasks.py b/vertexai/preview/evaluation/_eval_tasks.py index 96a11e835d..32a5a059dc 100644 --- a/vertexai/preview/evaluation/_eval_tasks.py +++ b/vertexai/preview/evaluation/_eval_tasks.py @@ -298,6 +298,10 @@ def _evaluate_with_experiment( k: ("NaN" if isinstance(v, float) and np.isnan(v) else v) for k, v in eval_result.summary_metrics.items() } + eval_result.metadata = { + "experiment": self.experiment, + "experiment_run": experiment_run_name, + } try: vertexai.preview.log_metrics(eval_result.summary_metrics) except (TypeError, exceptions.InvalidArgument) as e: diff --git a/vertexai/preview/evaluation/_evaluation.py b/vertexai/preview/evaluation/_evaluation.py index 0a83c6173e..9085800c8c 100644 --- a/vertexai/preview/evaluation/_evaluation.py +++ b/vertexai/preview/evaluation/_evaluation.py @@ -19,6 +19,7 @@ from concurrent import futures import time from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union +import warnings from google.cloud.aiplatform import base from google.cloud.aiplatform_v1beta1.types import ( @@ -113,6 +114,20 @@ def _validate_metrics(metrics: List[Union[str, metrics_base._Metric]]) -> None: if isinstance(metric, str): if metric in seen_strings: raise ValueError(f"Duplicate string metric name found: '{metric}'") + if metric in constants.Metric.MODEL_BASED_METRIC_LIST: + warnings.warn( + f"After google-cloud-aiplatform>1.61.0, using " + f"metric name `{metric}` will result in an error. Please" + f" use metric constant as Pointwise.{metric.upper()} or" + " define a PointwiseMetric instead." + ) + if metric in constants.Metric.PAIRWISE_METRIC_LIST: + warnings.warn( + f"After google-cloud-aiplatform>1.61.0, using " + f"metric name `{metric}` will result in an error. Please" + f" use metric constant as Pairwise.{metric.upper()} or" + " define a PairwiseMetric instead." + ) seen_strings.add(metric) elif isinstance(metric, metrics_base._Metric): if metric.metric_name in seen_metric_names: diff --git a/vertexai/resources/preview/feature_store/feature.py b/vertexai/resources/preview/feature_store/feature.py index 8ff43d415b..5ad5fe9129 100644 --- a/vertexai/resources/preview/feature_store/feature.py +++ b/vertexai/resources/preview/feature_store/feature.py @@ -80,7 +80,7 @@ def __init__( ): if feature_group_id: raise ValueError( - "Since feature is provided as a path, feature_group_id should not be specified." + f"Since feature '{name}' is provided as a path, feature_group_id should not be specified." ) feature = name else: @@ -90,8 +90,7 @@ def __init__( # feature group ID is provided. if not feature_group_id: raise ValueError( - "Since feature is not provided as a path, please specify" - + " feature_group_id." + f"Since feature '{name}' is not provided as a path, please specify feature_group_id." ) feature_group_path = utils.full_resource_name( diff --git a/vertexai/resources/preview/feature_store/feature_online_store.py b/vertexai/resources/preview/feature_store/feature_online_store.py index 86f625174e..205706b049 100644 --- a/vertexai/resources/preview/feature_store/feature_online_store.py +++ b/vertexai/resources/preview/feature_store/feature_online_store.py @@ -31,6 +31,7 @@ ) from google.cloud.aiplatform.compat.types import ( feature_online_store as gca_feature_online_store, + service_networking as gca_service_networking, feature_view as gca_feature_view, ) from vertexai.resources.preview.feature_store.feature_view import ( @@ -245,18 +246,30 @@ def create_optimized_store( Example Usage: + ``` + # Create optimized store with public endpoint. my_fos = vertexai.preview.FeatureOnlineStore.create_optimized_store('my_fos') + ``` + + ``` + # Create optimized online store with private service connect. + my_fos = vertexai.preview.FeatureOnlineStore.create_optimized_store( + 'my_fos', + enable_private_service_connect=True, + project_allowlist=['my-project'], + ) + ``` Args: name: The name of the feature online store. - enable_private_service_connect (bool): + enable_private_service_connect: Optional. If true, expose the optimized online store via private service connect. Otherwise the optimized online - store will be accessible through public endpoint - project_allowlist (MutableSequence[str]): + store will be accessible through public endpoint. + project_allowlist: A list of Projects from which the forwarding rule will target the service attachment. Only needed when - enable_private_service_connect is set to true. + `enable_private_service_connect` is set to true. labels: The labels with user-defined metadata to organize your feature online store. Label keys and values can be no longer than 64 @@ -290,7 +303,17 @@ def create_optimized_store( FeatureOnlineStore - the FeatureOnlineStore resource object. """ if enable_private_service_connect: - raise ValueError("private_service_connect is not supported") + if not project_allowlist: + raise ValueError( + "`project_allowlist` cannot be empty when `enable_private_service_connect` is set to true." + ) + + dedicated_serving_endpoint = gca_feature_online_store.FeatureOnlineStore.DedicatedServingEndpoint( + private_service_connect_config=gca_service_networking.PrivateServiceConnectConfig( + enable_private_service_connect=True, + project_allowlist=project_allowlist, + ), + ) else: dedicated_serving_endpoint = ( gca_feature_online_store.FeatureOnlineStore.DedicatedServingEndpoint()