diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 4b3678d63d..5fb5a9ad37 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.36.0" + ".": "1.36.1" } diff --git a/CHANGELOG.md b/CHANGELOG.md index bb3dd1de90..168673b4d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,36 @@ # Changelog +## [1.36.1](https://github.com/googleapis/python-aiplatform/compare/v1.36.0...v1.36.1) (2023-11-07) + + +### Features + +* Add `per_crowding_attribute_neighbor_count`, `approx_num_neighbors`, `fraction_leaf_nodes_to_search_override`, and `return_full_datapoint` to MatchingEngineIndexEndpoint `find_neighbors` ([33c551e](https://github.com/googleapis/python-aiplatform/commit/33c551efca38688c8c62ef5847dfcef0221e848c)) +* Add profiler support to tensorboard uploader sdk ([be1df7f](https://github.com/googleapis/python-aiplatform/commit/be1df7f4823f7b40022d31f529204dfe27fdb4d7)) +* Add support for `per_crowding_attribute_num_neighbors` `approx_num_neighbors`to MatchingEngineIndexEndpoint `match()` ([e5c20c3](https://github.com/googleapis/python-aiplatform/commit/e5c20c3b5c0078c9dfc70e2d1d13513a4dcefa63)) +* Add support for `per_crowding_attribute_num_neighbors` `approx_num_neighbors`to MatchingEngineIndexEndpoint `match()` ([53d31b5](https://github.com/googleapis/python-aiplatform/commit/53d31b5b6ec477e6f2b4391aaeadc8ae349800b8)) +* Add support for `per_crowding_attribute_num_neighbors` `approx_num_neighbors`to MatchingEngineIndexEndpoint `match()` ([4e357d5](https://github.com/googleapis/python-aiplatform/commit/4e357d5121d053dc313f3a3f180131e1850bebe2)) +* Enable grounding to ChatModel send_message and send_message_async methods ([d4667f2](https://github.com/googleapis/python-aiplatform/commit/d4667f25a7c95bd16511beaed85edf45307176b5)) +* Enable grounding to TextGenerationModel predict and predict_async methods ([b0b4e6b](https://github.com/googleapis/python-aiplatform/commit/b0b4e6b8243cbdb829288e3fc204d94005f1e8b4)) +* LLM - Added support for the `enable_checkpoint_selection` tuning evaluation parameter ([eaf4420](https://github.com/googleapis/python-aiplatform/commit/eaf4420479b64740cdd464afb64b8780f57c8199)) +* LLM - Added tuning support for the `*-bison-32k` models ([9eba18f](https://github.com/googleapis/python-aiplatform/commit/9eba18f70d36ac3901ba8b580cde6dde04413bc3)) +* LLM - Released `CodeChatModel` tuning to GA ([621af52](https://github.com/googleapis/python-aiplatform/commit/621af5244797a0e218195c72d9781cbd86b24fa0)) + + +### Bug Fixes + +* Correct class name in system test ([b822b57](https://github.com/googleapis/python-aiplatform/commit/b822b57fa490c8d89802ee5fbf0f3736e0811208)) + + +### Documentation + +* Clean up RoV create_ray_cluster docstring ([1473e19](https://github.com/googleapis/python-aiplatform/commit/1473e19c9b05c89ba2229f42a8d72588fa267d17)) + + +### Miscellaneous Chores + +* Release 1.36.1 ([1cde170](https://github.com/googleapis/python-aiplatform/commit/1cde1708fd26357995f3ee86194aa92aa7de5519)) + ## [1.36.0](https://github.com/googleapis/python-aiplatform/compare/v1.35.0...v1.36.0) (2023-10-31) diff --git a/docs/aiplatform_v1/services_.rst b/docs/aiplatform_v1/services_.rst new file mode 100644 index 0000000000..93afd80841 --- /dev/null +++ b/docs/aiplatform_v1/services_.rst @@ -0,0 +1,23 @@ +Services for Google Cloud Aiplatform v1 API +=========================================== +.. toctree:: + :maxdepth: 2 + + dataset_service + endpoint_service + featurestore_online_serving_service + featurestore_service + index_endpoint_service + index_service + job_service + match_service + metadata_service + migration_service + model_garden_service + model_service + pipeline_service + prediction_service + schedule_service + specialist_pool_service + tensorboard_service + vizier_service diff --git a/docs/aiplatform_v1/types_.rst b/docs/aiplatform_v1/types_.rst new file mode 100644 index 0000000000..da19f0e39e --- /dev/null +++ b/docs/aiplatform_v1/types_.rst @@ -0,0 +1,6 @@ +Types for Google Cloud Aiplatform v1 API +======================================== + +.. automodule:: google.cloud.aiplatform_v1.types + :members: + :show-inheritance: diff --git a/docs/aiplatform_v1beta1/services_.rst b/docs/aiplatform_v1beta1/services_.rst new file mode 100644 index 0000000000..809b60c29c --- /dev/null +++ b/docs/aiplatform_v1beta1/services_.rst @@ -0,0 +1,28 @@ +Services for Google Cloud Aiplatform v1beta1 API +================================================ +.. toctree:: + :maxdepth: 2 + + dataset_service + deployment_resource_pool_service + endpoint_service + feature_online_store_admin_service + feature_online_store_service + feature_registry_service + featurestore_online_serving_service + featurestore_service + index_endpoint_service + index_service + job_service + match_service + metadata_service + migration_service + model_garden_service + model_service + persistent_resource_service + pipeline_service + prediction_service + schedule_service + specialist_pool_service + tensorboard_service + vizier_service diff --git a/docs/aiplatform_v1beta1/types_.rst b/docs/aiplatform_v1beta1/types_.rst new file mode 100644 index 0000000000..19bab68ada --- /dev/null +++ b/docs/aiplatform_v1beta1/types_.rst @@ -0,0 +1,6 @@ +Types for Google Cloud Aiplatform v1beta1 API +============================================= + +.. automodule:: google.cloud.aiplatform_v1beta1.types + :members: + :show-inheritance: diff --git a/docs/definition_v1/services_.rst b/docs/definition_v1/services_.rst new file mode 100644 index 0000000000..ba6b1940e8 --- /dev/null +++ b/docs/definition_v1/services_.rst @@ -0,0 +1,4 @@ +Services for Google Cloud Aiplatform V1 Schema Trainingjob Definition v1 API +============================================================================ +.. toctree:: + :maxdepth: 2 diff --git a/docs/definition_v1/types_.rst b/docs/definition_v1/types_.rst new file mode 100644 index 0000000000..0add260eee --- /dev/null +++ b/docs/definition_v1/types_.rst @@ -0,0 +1,6 @@ +Types for Google Cloud Aiplatform V1 Schema Trainingjob Definition v1 API +========================================================================= + +.. automodule:: google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types + :members: + :show-inheritance: diff --git a/docs/definition_v1beta1/services_.rst b/docs/definition_v1beta1/services_.rst new file mode 100644 index 0000000000..5f1ed5f2b7 --- /dev/null +++ b/docs/definition_v1beta1/services_.rst @@ -0,0 +1,4 @@ +Services for Google Cloud Aiplatform V1beta1 Schema Trainingjob Definition v1beta1 API +====================================================================================== +.. toctree:: + :maxdepth: 2 diff --git a/docs/definition_v1beta1/types_.rst b/docs/definition_v1beta1/types_.rst new file mode 100644 index 0000000000..3f351d03fc --- /dev/null +++ b/docs/definition_v1beta1/types_.rst @@ -0,0 +1,6 @@ +Types for Google Cloud Aiplatform V1beta1 Schema Trainingjob Definition v1beta1 API +=================================================================================== + +.. automodule:: google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types + :members: + :show-inheritance: diff --git a/docs/instance_v1/services_.rst b/docs/instance_v1/services_.rst new file mode 100644 index 0000000000..50c011c69a --- /dev/null +++ b/docs/instance_v1/services_.rst @@ -0,0 +1,4 @@ +Services for Google Cloud Aiplatform V1 Schema Predict Instance v1 API +====================================================================== +.. toctree:: + :maxdepth: 2 diff --git a/docs/instance_v1/types_.rst b/docs/instance_v1/types_.rst new file mode 100644 index 0000000000..81597999f2 --- /dev/null +++ b/docs/instance_v1/types_.rst @@ -0,0 +1,6 @@ +Types for Google Cloud Aiplatform V1 Schema Predict Instance v1 API +=================================================================== + +.. automodule:: google.cloud.aiplatform.v1.schema.predict.instance_v1.types + :members: + :show-inheritance: diff --git a/docs/instance_v1beta1/services_.rst b/docs/instance_v1beta1/services_.rst new file mode 100644 index 0000000000..941dbcca59 --- /dev/null +++ b/docs/instance_v1beta1/services_.rst @@ -0,0 +1,4 @@ +Services for Google Cloud Aiplatform V1beta1 Schema Predict Instance v1beta1 API +================================================================================ +.. toctree:: + :maxdepth: 2 diff --git a/docs/instance_v1beta1/types_.rst b/docs/instance_v1beta1/types_.rst new file mode 100644 index 0000000000..c52ae4800c --- /dev/null +++ b/docs/instance_v1beta1/types_.rst @@ -0,0 +1,6 @@ +Types for Google Cloud Aiplatform V1beta1 Schema Predict Instance v1beta1 API +============================================================================= + +.. automodule:: google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types + :members: + :show-inheritance: diff --git a/docs/params_v1/services_.rst b/docs/params_v1/services_.rst new file mode 100644 index 0000000000..bf08ea6e98 --- /dev/null +++ b/docs/params_v1/services_.rst @@ -0,0 +1,4 @@ +Services for Google Cloud Aiplatform V1 Schema Predict Params v1 API +==================================================================== +.. toctree:: + :maxdepth: 2 diff --git a/docs/params_v1/types_.rst b/docs/params_v1/types_.rst new file mode 100644 index 0000000000..afc962c218 --- /dev/null +++ b/docs/params_v1/types_.rst @@ -0,0 +1,6 @@ +Types for Google Cloud Aiplatform V1 Schema Predict Params v1 API +================================================================= + +.. automodule:: google.cloud.aiplatform.v1.schema.predict.params_v1.types + :members: + :show-inheritance: diff --git a/docs/params_v1beta1/services_.rst b/docs/params_v1beta1/services_.rst new file mode 100644 index 0000000000..b3b897a0f4 --- /dev/null +++ b/docs/params_v1beta1/services_.rst @@ -0,0 +1,4 @@ +Services for Google Cloud Aiplatform V1beta1 Schema Predict Params v1beta1 API +============================================================================== +.. toctree:: + :maxdepth: 2 diff --git a/docs/params_v1beta1/types_.rst b/docs/params_v1beta1/types_.rst new file mode 100644 index 0000000000..ce7a29cb01 --- /dev/null +++ b/docs/params_v1beta1/types_.rst @@ -0,0 +1,6 @@ +Types for Google Cloud Aiplatform V1beta1 Schema Predict Params v1beta1 API +=========================================================================== + +.. automodule:: google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types + :members: + :show-inheritance: diff --git a/docs/prediction_v1/services_.rst b/docs/prediction_v1/services_.rst new file mode 100644 index 0000000000..ad6f034387 --- /dev/null +++ b/docs/prediction_v1/services_.rst @@ -0,0 +1,4 @@ +Services for Google Cloud Aiplatform V1 Schema Predict Prediction v1 API +======================================================================== +.. toctree:: + :maxdepth: 2 diff --git a/docs/prediction_v1/types_.rst b/docs/prediction_v1/types_.rst new file mode 100644 index 0000000000..739ca93799 --- /dev/null +++ b/docs/prediction_v1/types_.rst @@ -0,0 +1,6 @@ +Types for Google Cloud Aiplatform V1 Schema Predict Prediction v1 API +===================================================================== + +.. automodule:: google.cloud.aiplatform.v1.schema.predict.prediction_v1.types + :members: + :show-inheritance: diff --git a/docs/prediction_v1beta1/services_.rst b/docs/prediction_v1beta1/services_.rst new file mode 100644 index 0000000000..6de5e17520 --- /dev/null +++ b/docs/prediction_v1beta1/services_.rst @@ -0,0 +1,4 @@ +Services for Google Cloud Aiplatform V1beta1 Schema Predict Prediction v1beta1 API +================================================================================== +.. toctree:: + :maxdepth: 2 diff --git a/docs/prediction_v1beta1/types_.rst b/docs/prediction_v1beta1/types_.rst new file mode 100644 index 0000000000..cdbe7f2842 --- /dev/null +++ b/docs/prediction_v1beta1/types_.rst @@ -0,0 +1,6 @@ +Types for Google Cloud Aiplatform V1beta1 Schema Predict Prediction v1beta1 API +=============================================================================== + +.. automodule:: google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types + :members: + :show-inheritance: diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/gapic_version.py +++ b/google/cloud/aiplatform/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index ab6ad877e1..3ccfc06fb6 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -956,6 +956,10 @@ def find_neighbors( queries: List[List[float]], num_neighbors: int = 10, filter: Optional[List[Namespace]] = [], + per_crowding_attribute_neighbor_count: Optional[int] = None, + approx_num_neighbors: Optional[int] = None, + fraction_leaf_nodes_to_search_override: Optional[float] = None, + return_full_datapoint: bool = False, ) -> List[List[MatchNeighbor]]: """Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint. @@ -979,25 +983,58 @@ def find_neighbors( For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints that satisfy "red color" but not include datapoints with "squared shape". Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail. + + per_crowding_attribute_neighbor_count (int): + Optional. Crowding is a constraint on a neighbor list produced + by nearest neighbor search requiring that no more than some + value k' of the k neighbors returned have the same value of + crowding_attribute. It's used for improving result diversity. + This field is the maximum number of matches with the same crowding tag. + + approx_num_neighbors (int): + Optional. The number of neighbors to find via approximate search + before exact reordering is performed. If not set, the default + value from scam config is used; if set, this value must be > 0. + + fraction_leaf_nodes_to_search_override (float): + Optional. The fraction of the number of leaves to search, set at + query time allows user to tune search performance. This value + increase result in both search accuracy and latency increase. + The value should be between 0.0 and 1.0. + + return_full_datapoint (bool): + Optional. If set to true, the full datapoints (including all + vector values and of the nearest neighbors are returned. + Note that returning full datapoint will significantly increase the + latency and cost of the query. + Returns: List[List[MatchNeighbor]] - A list of nearest neighbors for each query. """ if not self._public_match_client: raise ValueError( - "Please make sure index has been deployed to public endpoint, and follow the example usage to call this method." + "Please make sure index has been deployed to public endpoint,and follow the example usage to call this method." ) # Create the FindNeighbors request find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest() find_neighbors_request.index_endpoint = self.resource_name find_neighbors_request.deployed_index_id = deployed_index_id + find_neighbors_request.return_full_datapoint = return_full_datapoint for query in queries: find_neighbors_query = ( gca_match_service_v1beta1.FindNeighborsRequest.Query() ) find_neighbors_query.neighbor_count = num_neighbors + find_neighbors_query.per_crowding_attribute_neighbor_count = ( + per_crowding_attribute_neighbor_count + ) + find_neighbors_query.approximate_neighbor_count = approx_num_neighbors + find_neighbors_query.fraction_leaf_nodes_to_search_override = ( + fraction_leaf_nodes_to_search_override + ) datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query) for namespace in filter: restrict = gca_index_v1beta1.IndexDatapoint.Restriction() @@ -1073,6 +1110,8 @@ def match( queries: List[List[float]], num_neighbors: int = 1, filter: Optional[List[Namespace]] = [], + per_crowding_attribute_num_neighbors: Optional[int] = None, + approx_num_neighbors: Optional[int] = None, ) -> List[List[MatchNeighbor]]: """Retrieves nearest neighbors for the given embedding queries on the specified deployed index. @@ -1089,6 +1128,15 @@ def match( For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints that satisfy "red color" but not include datapoints with "squared shape". Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail. + per_crowding_attribute_num_neighbors (int): + Optional. Crowding is a constraint on a neighbor list produced by nearest neighbor + search requiring that no more than some value k' of the k neighbors + returned have the same value of crowding_attribute. + It's used for improving result diversity. + This field is the maximum number of matches with the same crowding tag. + approx_num_neighbors (int): + The number of neighbors to find via approximate search before exact reordering is performed. + If not set, the default value from scam config is used; if set, this value must be > 0. Returns: List[List[MatchNeighbor]] - A list of nearest neighbors for each query. @@ -1123,6 +1171,8 @@ def match( num_neighbors=num_neighbors, deployed_index_id=deployed_index_id, float_val=query, + per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors, + approx_num_neighbors=approx_num_neighbors, ) for namespace in filter: restrict = match_service_pb2.Namespace() diff --git a/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py b/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py index 740bc95a16..6438812b82 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py +++ b/google/cloud/aiplatform/preview/vertex_ray/cluster_init.py @@ -16,7 +16,8 @@ # import copy -from typing import List, Optional +import logging +from typing import Dict, List, Optional from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils @@ -46,6 +47,7 @@ def create_ray_cluster( network: Optional[str] = None, cluster_name: Optional[str] = None, worker_node_types: Optional[List[resources.Resources]] = None, + labels: Optional[Dict[str, str]] = None, ) -> str: """Create a ray cluster on the Vertex AI. @@ -68,17 +70,17 @@ def create_ray_cluster( )] cluster_resource_name = vertex_ray.create_ray_cluster( - head_node_type=head_node_type, - network="my-vpc", - worker_node_types=worker_node_types, + head_node_type=head_node_type, + network="projects/my-project-number/global/networks/my-vpc-name", + worker_node_types=worker_node_types, ) After a ray cluster is set up, you can call - `ray.init(vertex_ray://{cluster_resource_name}, runtime_env=...)` without + `ray.init(f"vertex_ray://{cluster_resource_name}", runtime_env=...)` without specifying ray cluster address to connect to the cluster. To shut down the cluster you can call `ray.delete_ray_cluster()`. - Note: If the active ray cluster haven't shut down, you cannot create a new ray - cluster with the same cluster_name. + Note: If the active ray cluster has not finished shutting down, you cannot + create a new ray cluster with the same cluster_name. Args: head_node_type: The head node resource. Resources.node_count must be 1. @@ -95,14 +97,22 @@ def create_ray_cluster( or hyphen. worker_node_types: The list of Resources of the worker nodes. The same Resources object should not appear multiple times in the list. + labels: + The labels with user-defined metadata to organize Ray cluster. + + Label keys and values can be no longer than 64 characters (Unicode + codepoints), can only contain lowercase letters, numeric characters, + underscores and dashes. International characters are allowed. + + See https://goo.gl/xmQnxf for more information and examples of labels. Returns: The cluster_resource_name of the initiated Ray cluster on Vertex. """ if network is None: - raise ValueError( - "[Ray on Vertex]: VPC network is required for client connection." + logging.info( + "[Ray on Vertex]: No VPC network configured. It is required for client connection." ) if cluster_name is None: @@ -176,6 +186,7 @@ def create_ray_cluster( persistent_resource = PersistentResource( resource_pools=resource_pools, network=network, + labels=labels, resource_runtime_spec=resource_runtime_spec, ) diff --git a/google/cloud/aiplatform/preview/vertex_ray/dashboard_sdk.py b/google/cloud/aiplatform/preview/vertex_ray/dashboard_sdk.py index 2328299893..bc4da844c7 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/dashboard_sdk.py +++ b/google/cloud/aiplatform/preview/vertex_ray/dashboard_sdk.py @@ -46,6 +46,19 @@ def get_job_submission_client_cluster_info( Raises: RuntimeError if head_address is None. """ + # If passing the dashboard uri, programmatically get headers + if _validation_utils.valid_dashboard_address(address): + bearer_token = _validation_utils.get_bearer_token() + if kwargs.get("headers", None) is None: + kwargs["headers"] = { + "Content-Type": "application/json", + "Authorization": "Bearer {}".format(bearer_token), + } + return oss_dashboard_sdk.get_job_submission_client_cluster_info( + address=address, + *args, + **kwargs, + ) address = _validation_utils.maybe_reconstruct_resource_name(address) _validation_utils.valid_resource_name(address) diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py b/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py index 1cb283b264..fc61fb794a 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py +++ b/google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py @@ -142,6 +142,7 @@ def persistent_resource_to_cluster( cluster_resource_name=persistent_resource.name, network=persistent_resource.network, state=persistent_resource.state.name, + labels=persistent_resource.labels, ) if not persistent_resource.resource_runtime_spec.ray_spec: # skip PersistentResource without RaySpec diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py b/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py index 9aceb0872c..0b6ef1847e 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py +++ b/google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py @@ -15,6 +15,8 @@ # limitations under the License. # +import google.auth +import google.auth.transport.requests import logging import re @@ -29,6 +31,7 @@ _PERSISTENT_RESOURCE_NAME_PATTERN = "projects/{}/locations/{}/persistentResources/{}" _VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}" +_DASHBOARD_URI_SUFFIX = "aiplatform-training.googleusercontent.com" def valid_resource_name(resource_name): @@ -91,3 +94,19 @@ def get_versions_from_image_uri(image_uri): py_version = image_label[-3] + "_" + image_label[-2:] ray_version = image_label.split(".")[1].replace("-", "_") return py_version, ray_version + + +def valid_dashboard_address(address): + """Check if address is a valid dashboard uri.""" + return address.endswith(_DASHBOARD_URI_SUFFIX) + + +def get_bearer_token(): + """Get bearer token through Application Default Credentials.""" + creds, _ = google.auth.default() + + # creds.valid is False, and creds.token is None + # Need to refresh credentials to populate those + auth_req = google.auth.transport.requests.Request() + creds.refresh(auth_req) + return creds.token diff --git a/google/cloud/aiplatform/preview/vertex_ray/util/resources.py b/google/cloud/aiplatform/preview/vertex_ray/util/resources.py index e7a0e58eaf..7dbffe23d7 100644 --- a/google/cloud/aiplatform/preview/vertex_ray/util/resources.py +++ b/google/cloud/aiplatform/preview/vertex_ray/util/resources.py @@ -15,7 +15,7 @@ # limitations under the License. # import dataclasses -from typing import List, Optional +from typing import Dict, List, Optional from google.cloud.aiplatform_v1beta1.types import PersistentResource @@ -82,6 +82,16 @@ class Cluster: If not set, by default it is a CPU node with machine_type of n1-standard-4. worker_node_types: The list of Resources of the worker nodes. Should not duplicate the elements in the list. + dashboard_address: For Ray Job API (JobSubmissionClient), with this + cluster connection doesn't require VPC peering. + labels: + The labels with user-defined metadata to organize Ray cluster. + + Label keys and values can be no longer than 64 characters (Unicode + codepoints), can only contain lowercase letters, numeric characters, + underscores and dashes. International characters are allowed. + + See https://goo.gl/xmQnxf for more information and examples of labels. """ cluster_resource_name: str = None @@ -91,6 +101,8 @@ class Cluster: ray_version: str = None head_node_type: Resources = None worker_node_types: List[Resources] = None + dashboard_address: str = None + labels: Dict[str, str] = None def _check_machine_spec_identical( diff --git a/google/cloud/aiplatform/releases.txt b/google/cloud/aiplatform/releases.txt index 8aad820a3f..7db3982039 100644 --- a/google/cloud/aiplatform/releases.txt +++ b/google/cloud/aiplatform/releases.txt @@ -1,4 +1,4 @@ Use this file when you need to force a patch release with release-please. Edit line 4 below with the version for the release. -1.32.0 \ No newline at end of file +1.36.1 \ No newline at end of file diff --git a/google/cloud/aiplatform/tensorboard/uploader_tracker.py b/google/cloud/aiplatform/tensorboard/uploader_tracker.py index 3929d04712..4e15e471d4 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_tracker.py +++ b/google/cloud/aiplatform/tensorboard/uploader_tracker.py @@ -17,7 +17,7 @@ """Launches Tensorboard Uploader for SDK.""" import threading -from typing import Optional +from typing import FrozenSet, Optional from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer @@ -48,6 +48,7 @@ def upload_tb_log( run_name_prefix: Optional[str] = None, description: Optional[str] = None, verbosity: Optional[int] = 1, + allowed_plugins: Optional[FrozenSet[str]] = None, ): """upload only the existing data in the logdir and then return immediately @@ -74,6 +75,7 @@ def upload_tb_log( verbosity (str): Optional. Level of verbosity, an integer. Supported value: 0 - No upload statistics is printed. 1 - Print upload statistics while uploading data (default). + allowed_plugins (FrozenSet[str]): Optional. List of additional allowed plugin names. """ self._create_uploader( tensorboard_id=tensorboard_id, @@ -85,6 +87,8 @@ def upload_tb_log( experiment_display_name=experiment_display_name, run_name_prefix=run_name_prefix, description=description, + verbosity=verbosity, + allowed_plugins=allowed_plugins, ).start_uploading() _LOGGER.info("One time TensorBoard log upload completed.") @@ -98,6 +102,7 @@ def start_upload_tb_log( experiment_display_name: Optional[str] = None, run_name_prefix: Optional[str] = None, description: Optional[str] = None, + allowed_plugins: Optional[FrozenSet[str]] = None, ): """continues to listen for new data in the logdir and uploads when it appears. @@ -121,6 +126,7 @@ def start_upload_tb_log( invocation will have their name prefixed by this value. description (str): Optional. String description to assign to the experiment. + allowed_plugins (FrozenSet[str]): Optional. List of additional allowed plugin names. """ if self._tensorboard_uploader: _LOGGER.info( @@ -141,6 +147,7 @@ def start_upload_tb_log( run_name_prefix=run_name_prefix, description=description, verbosity=0, + allowed_plugins=allowed_plugins, ) threading.Thread(target=self._tensorboard_uploader.start_uploading).start() @@ -174,6 +181,7 @@ def _create_uploader( run_name_prefix: Optional[str] = None, description: Optional[str] = None, verbosity: Optional[int] = 1, + allowed_plugins: Optional[FrozenSet[str]] = None, ) -> "TensorBoardUploader": # noqa: F821 """Create a TensorBoardUploader and a TensorBoard Experiment @@ -188,6 +196,7 @@ def _create_uploader( run_name_prefix (str): Optional. If present, all runs created by this invocation will have their name prefixed by this value. description (str): Optional. String description to assign to the experiment. verbosity (int)): Optional. Level of verbosity. Supported value: 0 - No upload statistics is printed. 1 - Print upload statistics while uploading data (default). + allowed_plugins (FrozenSet[str]): Optional. List of additional allowed plugin names. Returns: An instance of TensorBoardUploader. @@ -244,13 +253,21 @@ def _create_uploader( ) = uploader_utils.get_blob_storage_bucket_and_folder( api_client, tensorboard_resource_name, project ) + + plugins = uploader_constants.ALLOWED_PLUGINS + plugins += [ + plugin + for plugin in allowed_plugins + if plugin not in uploader_constants.ALLOWED_PLUGINS + ] + tensorboard_uploader = TensorBoardUploader( experiment_name=tensorboard_experiment_name, tensorboard_resource_name=tensorboard_resource_name, experiment_display_name=experiment_display_name, blob_storage_bucket=blob_storage_bucket, blob_storage_folder=blob_storage_folder, - allowed_plugins=uploader_constants.ALLOWED_PLUGINS, + allowed_plugins=plugins, writer_client=api_client, logdir=logdir, one_shot=one_shot, diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py index d23ea8a705..ddcf4bffd2 100644 --- a/google/cloud/aiplatform/version.py +++ b/google/cloud/aiplatform/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.36.0" +__version__ = "1.36.1" diff --git a/google/cloud/aiplatform_v1/gapic_metadata.json b/google/cloud/aiplatform_v1/gapic_metadata.json index 481256e086..7137ea8899 100644 --- a/google/cloud/aiplatform_v1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1/gapic_metadata.json @@ -1975,6 +1975,90 @@ } } }, + "ScheduleService": { + "clients": { + "grpc": { + "libraryClient": "ScheduleServiceClient", + "rpcs": { + "CreateSchedule": { + "methods": [ + "create_schedule" + ] + }, + "DeleteSchedule": { + "methods": [ + "delete_schedule" + ] + }, + "GetSchedule": { + "methods": [ + "get_schedule" + ] + }, + "ListSchedules": { + "methods": [ + "list_schedules" + ] + }, + "PauseSchedule": { + "methods": [ + "pause_schedule" + ] + }, + "ResumeSchedule": { + "methods": [ + "resume_schedule" + ] + }, + "UpdateSchedule": { + "methods": [ + "update_schedule" + ] + } + } + }, + "grpc-async": { + "libraryClient": "ScheduleServiceAsyncClient", + "rpcs": { + "CreateSchedule": { + "methods": [ + "create_schedule" + ] + }, + "DeleteSchedule": { + "methods": [ + "delete_schedule" + ] + }, + "GetSchedule": { + "methods": [ + "get_schedule" + ] + }, + "ListSchedules": { + "methods": [ + "list_schedules" + ] + }, + "PauseSchedule": { + "methods": [ + "pause_schedule" + ] + }, + "ResumeSchedule": { + "methods": [ + "resume_schedule" + ] + }, + "UpdateSchedule": { + "methods": [ + "update_schedule" + ] + } + } + } + } + }, "SpecialistPoolService": { "clients": { "grpc": { diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform_v1/gapic_version.py +++ b/google/cloud/aiplatform_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index ec924555fe..6f737a3810 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -230,40 +230,40 @@ def parse_dataset_path(path: str) -> Dict[str, str]: @staticmethod def dataset_path( project: str, + location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( + return "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod def dataset_path( project: str, - location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( + return "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py index 14603a1447..7709740465 100644 --- a/google/cloud/aiplatform_v1/types/__init__.py +++ b/google/cloud/aiplatform_v1/types/__init__.py @@ -559,6 +559,30 @@ from .service_networking import ( PrivateServiceConnectConfig, ) + +from .publisher_model import ( + PublisherModel, +) +from .saved_query import ( + SavedQuery, +) +from .schedule import ( + Schedule, +) +from .schedule_service import ( + CreateScheduleRequest, + DeleteScheduleRequest, + GetScheduleRequest, + ListSchedulesRequest, + ListSchedulesResponse, + PauseScheduleRequest, + ResumeScheduleRequest, + UpdateScheduleRequest, +) +from .service_networking import ( + PrivateServiceConnectConfig, +) + from .specialist_pool import ( SpecialistPool, ) diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py index 0b15726445..b41d58b0dd 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.36.0" # {x-release-please-version} +__version__ = "1.36.1" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py index f64a8e0510..00135fb2f7 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/async_client.py @@ -1035,7 +1035,7 @@ async def sample_create_feature(): response, self._client._transport.operations_client, gca_feature.Feature, - metadata_type=feature_registry_service.CreateRegistryFeatureOperationMetadata, + metadata_type=featurestore_service.CreateFeatureOperationMetadata, ) # Done; return the response. diff --git a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py index b7a71b4f6d..c1ef0d14c3 100644 --- a/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/feature_registry_service/client.py @@ -1281,7 +1281,7 @@ def sample_create_feature(): response, self._transport.operations_client, gca_feature.Feature, - metadata_type=feature_registry_service.CreateRegistryFeatureOperationMetadata, + metadata_type=featurestore_service.CreateFeatureOperationMetadata, ) # Done; return the response. diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index 55a32d1c1b..31dce2992e 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -227,6 +227,23 @@ def parse_dataset_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def dataset_path( + project: str, + dataset: str, + ) -> str: + """Returns a fully-qualified dataset string.""" + return "projects/{project}/datasets/{dataset}".format( + project=project, + dataset=dataset, + ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str, str]: + """Parses a dataset path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + return m.groupdict() if m else {} + @staticmethod def dataset_path( project: str, @@ -252,18 +269,23 @@ def parse_dataset_path(path: str) -> Dict[str, str]: @staticmethod def dataset_path( project: str, + location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( + return "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod diff --git a/google/cloud/aiplatform_v1beta1/types/feature_online_store.py b/google/cloud/aiplatform_v1beta1/types/feature_online_store.py index e15989a3c5..9f35aa29d6 100644 --- a/google/cloud/aiplatform_v1beta1/types/feature_online_store.py +++ b/google/cloud/aiplatform_v1beta1/types/feature_online_store.py @@ -19,6 +19,7 @@ import proto # type: ignore +from google.cloud.aiplatform_v1beta1.types import service_networking from google.protobuf import timestamp_pb2 # type: ignore @@ -35,6 +36,10 @@ class FeatureOnlineStore(proto.Message): repository for serving ML features and embedding indexes at low latency. The Feature Online Store is a top-level container. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields @@ -45,6 +50,16 @@ class FeatureOnlineStore(proto.Message): featureValues for all FeatureViews under this FeatureOnlineStore. + This field is a member of `oneof`_ ``storage_type``. + optimized (google.cloud.aiplatform_v1beta1.types.FeatureOnlineStore.Optimized): + Contains settings for the Optimized store that will be + created to serve featureValues for all FeatureViews under + this FeatureOnlineStore. When choose Optimized storage type, + need to set + [PrivateServiceConnectConfig.enable_private_service_connect][google.cloud.aiplatform.v1beta1.PrivateServiceConnectConfig.enable_private_service_connect] + to use private endpoint. Otherwise will use public endpoint + by default. + This field is a member of `oneof`_ ``storage_type``. name (str): Output only. Name of the FeatureOnlineStore. Format: @@ -161,20 +176,45 @@ class AutoScaling(proto.Message): message="FeatureOnlineStore.Bigtable.AutoScaling", ) + class Optimized(proto.Message): + r"""Optimized storage type to replace lightning""" + class DedicatedServingEndpoint(proto.Message): r"""The dedicated serving endpoint for this FeatureOnlineStore. + Only need to set when you choose Optimized storage type or + enable EmbeddingManagement. Will use public endpoint by default. Attributes: public_endpoint_domain_name (str): Output only. This field will be populated with the domain name to use for this FeatureOnlineStore + private_service_connect_config (google.cloud.aiplatform_v1beta1.types.PrivateServiceConnectConfig): + Optional. Private service connect config. If + [PrivateServiceConnectConfig.enable_private_service_connect][google.cloud.aiplatform.v1beta1.PrivateServiceConnectConfig.enable_private_service_connect] + set to true, customers will use private service connection + to send request. Otherwise, the connection will set to + public endpoint. + service_attachment (str): + Output only. The name of the service + attachment resource. Populated if private + service connect is enabled and after + FeatureViewSync is created. """ public_endpoint_domain_name: str = proto.Field( proto.STRING, number=2, ) + private_service_connect_config: service_networking.PrivateServiceConnectConfig = proto.Field( + proto.MESSAGE, + number=3, + message=service_networking.PrivateServiceConnectConfig, + ) + service_attachment: str = proto.Field( + proto.STRING, + number=4, + ) class EmbeddingManagement(proto.Message): r"""Contains settings for embedding management. @@ -198,6 +238,12 @@ class EmbeddingManagement(proto.Message): oneof="storage_type", message=Bigtable, ) + optimized: Optimized = proto.Field( + proto.MESSAGE, + number=12, + oneof="storage_type", + message=Optimized, + ) name: str = proto.Field( proto.STRING, number=1, diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index fd6b807e14..7b9e12444d 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -290,6 +290,14 @@ class ExplainRequest(proto.Message): methods to reduce approximate errors; - Using different baselines for explaining the prediction results. + concurrent_explanation_spec_override (MutableMapping[str, google.cloud.aiplatform_v1beta1.types.ExplanationSpecOverride]): + Optional. This field is the same as the one above, but + supports multiple explanations to occur in parallel. The key + can be any string. Each override will be run against the + model, then its explanations will be grouped together. + + Note - these explanations are run **In Addition** to the + default Explanation in the deployed model. deployed_model_id (str): If specified, this ExplainRequest will be served by the chosen DeployedModel, overriding @@ -315,6 +323,14 @@ class ExplainRequest(proto.Message): number=5, message=explanation.ExplanationSpecOverride, ) + concurrent_explanation_spec_override: MutableMapping[ + str, explanation.ExplanationSpecOverride + ] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=6, + message=explanation.ExplanationSpecOverride, + ) deployed_model_id: str = proto.Field( proto.STRING, number=3, @@ -333,6 +349,10 @@ class ExplainResponse(proto.Message): It has the same number of elements as [instances][google.cloud.aiplatform.v1beta1.ExplainRequest.instances] to be explained. + concurrent_explanations (MutableMapping[str, google.cloud.aiplatform_v1beta1.types.ExplainResponse.ConcurrentExplanation]): + This field stores the results of the + explanations run in parallel with the default + explanation strategy/method. deployed_model_id (str): ID of the Endpoint's DeployedModel that served this explanation. @@ -342,11 +362,38 @@ class ExplainResponse(proto.Message): [PredictResponse.predictions][google.cloud.aiplatform.v1beta1.PredictResponse.predictions]. """ + class ConcurrentExplanation(proto.Message): + r"""This message is a wrapper grouping Concurrent Explanations. + + Attributes: + explanations (MutableSequence[google.cloud.aiplatform_v1beta1.types.Explanation]): + The explanations of the Model's + [PredictResponse.predictions][google.cloud.aiplatform.v1beta1.PredictResponse.predictions]. + + It has the same number of elements as + [instances][google.cloud.aiplatform.v1beta1.ExplainRequest.instances] + to be explained. + """ + + explanations: MutableSequence[explanation.Explanation] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=explanation.Explanation, + ) + explanations: MutableSequence[explanation.Explanation] = proto.RepeatedField( proto.MESSAGE, number=1, message=explanation.Explanation, ) + concurrent_explanations: MutableMapping[ + str, ConcurrentExplanation + ] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=4, + message=ConcurrentExplanation, + ) deployed_model_id: str = proto.Field( proto.STRING, number=2, diff --git a/noxfile.py b/noxfile.py index 4635c0c558..c34817f137 100644 --- a/noxfile.py +++ b/noxfile.py @@ -303,7 +303,7 @@ def docs(session): ) -@nox.session(python="3.9") +@nox.session(python="3.10") def docfx(session): """Build the docfx yaml files for this library.""" diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index 92f6b8d44a..18b7cb8c22 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.36.0" + "version": "1.36.1" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index 24118e342b..ea699a3921 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.36.0" + "version": "1.36.1" }, "snippets": [ { diff --git a/samples/model-builder/experiment_tracking/autologging_with_auto_run_creation_sample.py b/samples/model-builder/experiment_tracking/autologging_with_auto_run_creation_sample.py index e4f1c896af..2b19745f1d 100644 --- a/samples/model-builder/experiment_tracking/autologging_with_auto_run_creation_sample.py +++ b/samples/model-builder/experiment_tracking/autologging_with_auto_run_creation_sample.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Optional, Union from google.cloud import aiplatform @@ -20,9 +20,9 @@ # [START aiplatform_sdk_autologging_with_auto_run_creation_sample] def autologging_with_auto_run_creation_sample( experiment_name: str, - experiment_tensorboard: Union[str, aiplatform.Tensorboard], project: str, location: str, + experiment_tensorboard: Optional[Union[str, aiplatform.Tensorboard]] = None, ): aiplatform.init( experiment=experiment_name, diff --git a/samples/model-builder/experiment_tracking/autologging_with_manual_run_creation_sample.py b/samples/model-builder/experiment_tracking/autologging_with_manual_run_creation_sample.py index c41b1eb285..5d1398850a 100644 --- a/samples/model-builder/experiment_tracking/autologging_with_manual_run_creation_sample.py +++ b/samples/model-builder/experiment_tracking/autologging_with_manual_run_creation_sample.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Optional, Union from google.cloud import aiplatform @@ -21,9 +21,9 @@ def autologging_with_manual_run_creation_sample( experiment_name: str, run_name: str, - experiment_tensorboard: Union[str, aiplatform.Tensorboard], project: str, location: str, + experiment_tensorboard: Optional[Union[str, aiplatform.Tensorboard]] = None, ): aiplatform.init( experiment=experiment_name, diff --git a/samples/model-builder/experiment_tracking/create_experiment_default_tensorboard_sample.py b/samples/model-builder/experiment_tracking/create_experiment_default_tensorboard_sample.py new file mode 100644 index 0000000000..c8e2ab53de --- /dev/null +++ b/samples/model-builder/experiment_tracking/create_experiment_default_tensorboard_sample.py @@ -0,0 +1,33 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_experiment_default_tensorboard_sample] +def create_experiment_default_tensorboard_sample( + experiment_name: str, + experiment_description: str, + project: str, + location: str, +): + aiplatform.init( + experiment=experiment_name, + experiment_description=experiment_description, + project=project, + location=location, + ) + + +# [END aiplatform_sdk_create_experiment_default_tensorboard_sample] diff --git a/samples/model-builder/experiment_tracking/create_experiment_default_tensorboard_sample_test.py b/samples/model-builder/experiment_tracking/create_experiment_default_tensorboard_sample_test.py new file mode 100644 index 0000000000..98589f7a54 --- /dev/null +++ b/samples/model-builder/experiment_tracking/create_experiment_default_tensorboard_sample_test.py @@ -0,0 +1,33 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from experiment_tracking import create_experiment_default_tensorboard_sample +import test_constants as constants + + +def test_create_experiment_default_tensorboard_sample(mock_sdk_init): + + create_experiment_default_tensorboard_sample.create_experiment_default_tensorboard_sample( + experiment_name=constants.EXPERIMENT_NAME, + experiment_description=constants.DESCRIPTION, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + mock_sdk_init.assert_called_with( + experiment=constants.EXPERIMENT_NAME, + experiment_description=constants.DESCRIPTION, + project=constants.PROJECT, + location=constants.LOCATION, + ) diff --git a/samples/model-builder/experiment_tracking/upload_tensorboard_log_to_experiment_sample.py b/samples/model-builder/experiment_tracking/upload_tensorboard_log_to_experiment_sample.py new file mode 100644 index 0000000000..bcc22971a5 --- /dev/null +++ b/samples/model-builder/experiment_tracking/upload_tensorboard_log_to_experiment_sample.py @@ -0,0 +1,39 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_upload_tensorboard_to_experiment_sample] +def upload_tensorboard_log_to_experiment_sample( + experiment_name: str, + logdir: str, + project: str, + location: str, + run_name_prefix: Optional[str] = None, +) -> None: + + aiplatform.init(project=project, location=location, experiment=experiment_name) + + # one time upload + aiplatform.upload_tb_log( + tensorboard_experiment_name=experiment_name, + logdir=logdir, + run_name_prefix=run_name_prefix, + ) + + +# [END aiplatform_sdk_upload_tensorboard_to_experiment_sample] diff --git a/samples/model-builder/experiment_tracking/upload_tensorboard_log_to_experiment_sample_test.py b/samples/model-builder/experiment_tracking/upload_tensorboard_log_to_experiment_sample_test.py new file mode 100644 index 0000000000..98fc1be10e --- /dev/null +++ b/samples/model-builder/experiment_tracking/upload_tensorboard_log_to_experiment_sample_test.py @@ -0,0 +1,42 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from experiment_tracking import upload_tensorboard_log_to_experiment_sample +import test_constants as constants + + +def test_upload_tensorboard_to_experiment_sample( + mock_sdk_init, + mock_tensorboard_uploader_onetime, +): + upload_tensorboard_log_to_experiment_sample.upload_tensorboard_log_to_experiment_sample( + project=constants.PROJECT, + location=constants.LOCATION, + logdir=constants.TENSORBOARD_LOG_DIR, + experiment_name=constants.EXPERIMENT_NAME, + run_name_prefix=constants.EXPERIMENT_RUN_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, + location=constants.LOCATION, + experiment=constants.EXPERIMENT_NAME, + ) + + mock_tensorboard_uploader_onetime.assert_called_once_with( + logdir=constants.TENSORBOARD_LOG_DIR, + tensorboard_experiment_name=constants.EXPERIMENT_NAME, + run_name_prefix=constants.EXPERIMENT_RUN_NAME, + ) diff --git a/testing/constraints-3.10.txt b/testing/constraints-3.10.txt index f1932a360e..844dc4cfba 100644 --- a/testing/constraints-3.10.txt +++ b/testing/constraints-3.10.txt @@ -8,3 +8,5 @@ mock==4.0.2 google-cloud-storage==2.0.0 packaging==20.0 # Increased for compatibility with MLFlow grpcio-testing==1.34.0 +mlflow==1.30.1 # Pinned to speed up installation + diff --git a/tests/system/aiplatform/test_experiments.py b/tests/system/aiplatform/test_experiments.py index 657e706bd5..e9632fa078 100644 --- a/tests/system/aiplatform/test_experiments.py +++ b/tests/system/aiplatform/test_experiments.py @@ -63,6 +63,7 @@ class TestExperiments(e2e_base.TestEndToEnd): def setup_class(cls): cls._experiment_name = cls._make_display_name("")[:64] + cls._experiment_model_name = cls._make_display_name("sklearn-model")[:64] cls._dataset_artifact_name = cls._make_display_name("")[:64] cls._dataset_artifact_uri = cls._make_display_name("ds-uri") cls._pipeline_job_id = cls._make_display_name("job-id") @@ -199,45 +200,31 @@ def test_log_model(self, shared_state): model = LinearRegression() model.fit(train_x, train_y) - try: - model_artifact = aiplatform.log_model( - model=model, - artifact_id="sklearn-model", - uri=f"gs://{shared_state['staging_bucket_name']}/sklearn-model", - input_example=train_x, - ) - shared_state["resources"].append(model_artifact) + model_artifact = aiplatform.log_model( + model=model, + artifact_id=self._experiment_model_name, + uri=f"gs://{shared_state['staging_bucket_name']}/sklearn-model", + input_example=train_x, + ) + shared_state["resources"].append(model_artifact) - run = aiplatform.ExperimentRun( - run_name=_RUN, experiment=self._experiment_name - ) - experiment_model = run.get_experiment_models()[0] - assert experiment_model.name == "sklearn-model" - assert ( - experiment_model.uri - == f"gs://{shared_state['staging_bucket_name']}/sklearn-model" - ) - assert experiment_model.get_model_info() == { - "model_class": "sklearn.linear_model._base.LinearRegression", - "framework_name": "sklearn", - "framework_version": sklearn.__version__, - "input_example": { - "type": "numpy.ndarray", - "data": train_x.tolist(), - }, - } - experiment_model.delete() - finally: - # Make sure that, if the model resources already exists but the call - # aiplatform.log_model fails, we clean up the model resource. - run = aiplatform.ExperimentRun( - run_name=_RUN, experiment=self._experiment_name - ) - experiment_models = run.get_experiment_models() - if experiment_models: - experiment_model = experiment_models[0] - experiment_model.delete() - assert False, "log_model() call failed and assertions are not run." + run = aiplatform.ExperimentRun(run_name=_RUN, experiment=self._experiment_name) + experiment_model = run.get_experiment_models()[0] + assert "sklearn-model" in experiment_model.name + assert ( + experiment_model.uri + == f"gs://{shared_state['staging_bucket_name']}/sklearn-model" + ) + assert experiment_model.get_model_info() == { + "model_class": "sklearn.linear_model._base.LinearRegression", + "framework_name": "sklearn", + "framework_version": sklearn.__version__, + "input_example": { + "type": "numpy.ndarray", + "data": train_x.tolist(), + }, + } + experiment_model.delete() def test_create_artifact(self, shared_state): ds = aiplatform.Artifact.create( diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py index d522f0f09a..a9533d735b 100644 --- a/tests/system/aiplatform/test_language_models.py +++ b/tests/system/aiplatform/test_language_models.py @@ -50,7 +50,7 @@ def test_text_generation(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) model = TextGenerationModel.from_pretrained("google/text-bison@001") - + grounding_source = language_models.GroundingSource.WebSearch() assert model.predict( "What is the best recipe for banana bread? Recipe:", max_output_tokens=128, @@ -58,6 +58,7 @@ def test_text_generation(self): top_p=1.0, top_k=5, stop_sequences=["# %%"], + grounding_source=grounding_source, ).text def test_text_generation_preview_count_tokens(self): @@ -77,7 +78,7 @@ async def test_text_generation_model_predict_async(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) model = TextGenerationModel.from_pretrained("google/text-bison@001") - + grounding_source = language_models.GroundingSource.WebSearch() response = await model.predict_async( "What is the best recipe for banana bread? Recipe:", max_output_tokens=128, @@ -85,6 +86,7 @@ async def test_text_generation_model_predict_async(self): top_p=1.0, top_k=5, stop_sequences=["# %%"], + grounding_source=grounding_source, ) assert response.text @@ -122,8 +124,8 @@ def test_preview_text_embedding_top_level_from_pretrained(self): def test_chat_on_chat_model(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) - chat_model = ChatModel.from_pretrained("google/chat-bison@001") + grounding_source = language_models.GroundingSource.WebSearch() chat = chat_model.start_chat( context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.", examples=[ @@ -141,8 +143,12 @@ def test_chat_on_chat_model(self): ) message1 = "Are my favorite movies based on a book series?" - response1 = chat.send_message(message1) + response1 = chat.send_message( + message1, + grounding_source=grounding_source, + ) assert response1.text + assert response1.grounding_metadata assert len(chat.message_history) == 2 assert chat.message_history[0].author == chat.USER_AUTHOR assert chat.message_history[0].content == message1 @@ -150,10 +156,10 @@ def test_chat_on_chat_model(self): message2 = "When were these books published?" response2 = chat.send_message( - message2, - temperature=0.1, + message2, temperature=0.1, grounding_source=grounding_source ) assert response2.text + assert response2.grounding_metadata assert len(chat.message_history) == 4 assert chat.message_history[2].author == chat.USER_AUTHOR assert chat.message_history[2].content == message2 @@ -187,6 +193,7 @@ async def test_chat_model_async(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) chat_model = ChatModel.from_pretrained("google/chat-bison@001") + grounding_source = language_models.GroundingSource.WebSearch() chat = chat_model.start_chat( context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.", examples=[ @@ -204,8 +211,12 @@ async def test_chat_model_async(self): ) message1 = "Are my favorite movies based on a book series?" - response1 = await chat.send_message_async(message1) + response1 = await chat.send_message_async( + message1, + grounding_source=grounding_source, + ) assert response1.text + assert response1.grounding_metadata assert len(chat.message_history) == 2 assert chat.message_history[0].author == chat.USER_AUTHOR assert chat.message_history[0].content == message1 @@ -215,8 +226,10 @@ async def test_chat_model_async(self): response2 = await chat.send_message_async( message2, temperature=0.1, + grounding_source=grounding_source, ) assert response2.text + assert response2.grounding_metadata assert len(chat.message_history) == 4 assert chat.message_history[2].author == chat.USER_AUTHOR assert chat.message_history[2].content == message2 diff --git a/tests/system/vertex_ray/test_cluster_management.py b/tests/system/vertex_ray/test_cluster_management.py new file mode 100644 index 0000000000..0e853ec612 --- /dev/null +++ b/tests/system/vertex_ray/test_cluster_management.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud import aiplatform +from google.cloud.aiplatform.preview import vertex_ray +from tests.system.aiplatform import e2e_base +import ray + + +class TestClusterManagement(e2e_base.TestEndToEnd): + _temp_prefix = "temp-rov-cluster-management" + + def test_cluster_management(self): + assert ray.__version__ == "2.4.0" + aiplatform.init(project="ucaip-sample-tests", location="us-central1") + + clusters = vertex_ray.list_ray_clusters() + assert clusters[0].ray_version == "2_4" diff --git a/tests/system/vertexai/test_bigframes_sklearn.py b/tests/system/vertexai/test_bigframes_sklearn.py index c9b7b1313f..addc4f8363 100644 --- a/tests/system/vertexai/test_bigframes_sklearn.py +++ b/tests/system/vertexai/test_bigframes_sklearn.py @@ -61,7 +61,7 @@ else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@main", ) # To avoid flaky test due to autolog enabled in parallel tests -@mock.patch.object(vertexai.preview.global_config, "autolog", False) +@mock.patch.object(vertexai.preview.initializer._Config, "autolog", False) @pytest.mark.usefixtures( "prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources" ) diff --git a/tests/system/vertexai/test_bigframes_tensorflow.py b/tests/system/vertexai/test_bigframes_tensorflow.py index 1e15fd8b87..da64e6abab 100644 --- a/tests/system/vertexai/test_bigframes_tensorflow.py +++ b/tests/system/vertexai/test_bigframes_tensorflow.py @@ -57,7 +57,7 @@ else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@main", ) # To avoid flaky test due to autolog enabled in parallel tests -@mock.patch.object(vertexai.preview.global_config, "autolog", False) +@mock.patch.object(vertexai.preview.initializer._Config, "autolog", False) @pytest.mark.usefixtures( "prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources" ) diff --git a/tests/system/vertexai/test_pytorch.py b/tests/system/vertexai/test_pytorch.py index 27611881a1..9c6ab8b606 100644 --- a/tests/system/vertexai/test_pytorch.py +++ b/tests/system/vertexai/test_pytorch.py @@ -49,7 +49,7 @@ else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@main", ) # To avoid flaky test due to autolog enabled in parallel tests -@mock.patch.object(vertexai.preview.global_config, "autolog", False) +@mock.patch.object(vertexai.preview.initializer._Config, "autolog", False) @pytest.mark.usefixtures( "prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources" ) diff --git a/tests/system/vertexai/test_sklearn.py b/tests/system/vertexai/test_sklearn.py index b3fe9a1788..c458737126 100644 --- a/tests/system/vertexai/test_sklearn.py +++ b/tests/system/vertexai/test_sklearn.py @@ -55,7 +55,7 @@ else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@main", ) # To avoid flaky test due to autolog enabled in parallel tests -@mock.patch.object(vertexai.preview.global_config, "autolog", False) +@mock.patch.object(vertexai.preview.initializer._Config, "autolog", False) @pytest.mark.usefixtures( "prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources" ) diff --git a/tests/system/vertexai/test_tensorflow.py b/tests/system/vertexai/test_tensorflow.py index 37b1a3fd71..af53fd2c90 100644 --- a/tests/system/vertexai/test_tensorflow.py +++ b/tests/system/vertexai/test_tensorflow.py @@ -54,7 +54,7 @@ else "google-cloud-aiplatform[preview,autologging] @ git+https://github.com/googleapis/python-aiplatform.git@main", ) # To avoid flaky test due to autolog enabled in parallel tests -@mock.patch.object(vertexai.preview.global_config, "autolog", False) +@mock.patch.object(vertexai.preview.initializer._Config, "autolog", False) @pytest.mark.usefixtures( "prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources" ) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index e96735bac6..c866953e1c 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -16,7 +16,7 @@ # # pylint: disable=protected-access,bad-continuation - +import dataclasses import json import pytest from importlib import reload @@ -74,6 +74,7 @@ from vertexai.language_models import ( _evaluatable_language_models, ) +from vertexai.language_models import GroundingSource from google.cloud.aiplatform_v1 import Execution as GapicExecution from google.cloud.aiplatform.compat.types import ( encryption_spec as gca_encryption_spec, @@ -166,6 +167,53 @@ }, } +_TEST_GROUNDING_WEB_SEARCH = GroundingSource.WebSearch() + +_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE = GroundingSource.VertexAISearch( + data_store_id="test_datastore", location="global" +) + +_TEST_TEXT_GENERATION_PREDICTION_GROUNDING = { + "safetyAttributes": { + "categories": ["Violent"], + "blocked": False, + "scores": [0.10000000149011612], + }, + "groundingMetadata": { + "citations": [ + {"url": "url1", "startIndex": 1, "endIndex": 2}, + {"url": "url2", "startIndex": 3, "endIndex": 4}, + ] + }, + "content": """ +Ingredients: +* 3 cups all-purpose flour + +Instructions: +1. Preheat oven to 350 degrees F (175 degrees C).""", +} + +_EXPECTED_PARSED_GROUNDING_METADATA = { + "citations": [ + { + "url": "url1", + "start_index": 1, + "end_index": 2, + "title": None, + "license": None, + "publication_date": None, + }, + { + "url": "url2", + "start_index": 3, + "end_index": 4, + "title": None, + "license": None, + "publication_date": None, + }, + ] +} + _TEST_TEXT_GENERATION_PREDICTION = { "safetyAttributes": { "categories": ["Violent"], @@ -263,6 +311,97 @@ ], } +_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING = { + "safetyAttributes": [ + { + "scores": [], + "categories": [], + "blocked": False, + }, + { + "scores": [0.1], + "categories": ["Finance"], + "blocked": True, + }, + ], + "groundingMetadata": [ + { + "citations": [ + { + "startIndex": 1, + "endIndex": 2, + "url": "url1", + } + ] + }, + { + "citations": [ + { + "startIndex": 3, + "endIndex": 4, + "url": "url2", + } + ] + }, + ], + "candidates": [ + { + "author": "1", + "content": "Chat response 2", + }, + { + "author": "1", + "content": "", + }, + ], +} + +_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE = { + "safetyAttributes": [ + { + "scores": [], + "categories": [], + "blocked": False, + }, + { + "scores": [0.1], + "categories": ["Finance"], + "blocked": True, + }, + ], + "groundingMetadata": [ + None, + None, + ], + "candidates": [ + { + "author": "1", + "content": "Chat response 2", + }, + { + "author": "1", + "content": "", + }, + ], +} + +_EXPECTED_PARSED_GROUNDING_METADATA_CHAT = { + "citations": [ + { + "url": "url1", + "start_index": 1, + "end_index": 2, + "title": None, + "license": None, + "publication_date": None, + }, + ], +} + +_EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE = { + "citations": [], +} + _TEST_CHAT_PREDICTION_STREAMING = [ { "candidates": [ @@ -342,7 +481,6 @@ def reverse_string_2(s):""", "total_billable_characters": 25, } - _TEST_TEXT_BISON_TRAINING_DF = pd.DataFrame( { "input_text": [ @@ -390,6 +528,11 @@ def reverse_string_2(s):""", "isOptional": True, "parameterType": "STRING", }, + "enable_checkpoint_selection": { + "defaultValue": "default", + "isOptional": True, + "parameterType": "STRING", + }, "enable_early_stopping": { "defaultValue": True, "isOptional": True, @@ -1392,6 +1535,78 @@ def test_text_generation_multiple_candidates(self): response.candidates[0].text == _TEST_TEXT_GENERATION_PREDICTION["content"] ) + def test_text_generation_multiple_candidates_grounding(self): + """Tests the text generation model with multiple candidates with web grounding.""" + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = language_models.TextGenerationModel.from_pretrained( + "text-bison@001" + ) + + gca_predict_response = gca_prediction_service.PredictResponse() + # Discrepancy between the number of `instances` and the number of `predictions` + # is a violation of the prediction service invariant, but the service does this. + gca_predict_response.predictions.append( + _TEST_TEXT_GENERATION_PREDICTION_GROUNDING + ) + gca_predict_response.predictions.append( + _TEST_TEXT_GENERATION_PREDICTION_GROUNDING + ) + + test_grounding_sources = [ + _TEST_GROUNDING_WEB_SEARCH, + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, + ] + datastore_path = ( + "projects/test-project/locations/global/" + "collections/default_collection/dataStores/test_datastore" + ) + expected_grounding_sources = [ + {"sources": [{"type": "WEB"}]}, + { + "sources": [ + { + "type": "ENTERPRISE", + "enterpriseDatastore": datastore_path, + } + ] + }, + ] + + for test_grounding_source, expected_grounding_source in zip( + test_grounding_sources, expected_grounding_sources + ): + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ) as mock_predict: + response = model.predict( + "What is the best recipe for banana bread? Recipe:", + candidate_count=2, + grounding_source=test_grounding_source, + ) + prediction_parameters = mock_predict.call_args[1]["parameters"] + assert prediction_parameters["candidateCount"] == 2 + assert prediction_parameters["groundingConfig"] == expected_grounding_source + assert ( + response.text == _TEST_TEXT_GENERATION_PREDICTION_GROUNDING["content"] + ) + assert len(response.candidates) == 2 + assert ( + response.candidates[0].text + == _TEST_TEXT_GENERATION_PREDICTION_GROUNDING["content"] + ) + assert ( + dataclasses.asdict(response.candidates[0].grounding_metadata) + == _EXPECTED_PARSED_GROUNDING_METADATA + ) + @pytest.mark.asyncio async def test_text_generation_async(self): """Tests the text generation model.""" @@ -1435,6 +1650,79 @@ async def test_text_generation_async(self): assert prediction_parameters["stopSequences"] == ["\n"] assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"] + @pytest.mark.asyncio + async def test_text_generation_multiple_candidates_grounding_async(self): + """Tests the text generation model with multiple candidates async with web grounding.""" + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = language_models.TextGenerationModel.from_pretrained( + "text-bison@001" + ) + + gca_predict_response = gca_prediction_service.PredictResponse() + # Discrepancy between the number of `instances` and the number of `predictions` + # is a violation of the prediction service invariant, but the service does this. + gca_predict_response.predictions.append( + _TEST_TEXT_GENERATION_PREDICTION_GROUNDING + ) + + test_grounding_sources = [ + _TEST_GROUNDING_WEB_SEARCH, + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, + ] + datastore_path = ( + "projects/test-project/locations/global/" + "collections/default_collection/dataStores/test_datastore" + ) + expected_grounding_sources = [ + {"sources": [{"type": "WEB"}]}, + { + "sources": [ + { + "type": "ENTERPRISE", + "enterpriseDatastore": datastore_path, + } + ] + }, + ] + + for test_grounding_source, expected_grounding_source in zip( + test_grounding_sources, expected_grounding_sources + ): + with mock.patch.object( + target=prediction_service_async_client.PredictionServiceAsyncClient, + attribute="predict", + return_value=gca_predict_response, + ) as mock_predict: + response = await model.predict_async( + "What is the best recipe for banana bread? Recipe:", + max_output_tokens=128, + temperature=0.0, + top_p=1.0, + top_k=5, + stop_sequences=["\n"], + grounding_source=test_grounding_source, + ) + prediction_parameters = mock_predict.call_args[1]["parameters"] + assert prediction_parameters["maxDecodeSteps"] == 128 + assert prediction_parameters["temperature"] == 0.0 + assert prediction_parameters["topP"] == 1.0 + assert prediction_parameters["topK"] == 5 + assert prediction_parameters["stopSequences"] == ["\n"] + assert prediction_parameters["groundingConfig"] == expected_grounding_source + assert ( + response.text == _TEST_TEXT_GENERATION_PREDICTION_GROUNDING["content"] + ) + assert ( + dataclasses.asdict(response.grounding_metadata) + == _EXPECTED_PARSED_GROUNDING_METADATA + ) + def test_text_generation_model_predict_streaming(self): """Tests the TextGenerationModel.predict_streaming method.""" with mock.patch.object( @@ -1645,6 +1933,7 @@ def test_tune_text_generation_model_ga( evaluation_data_uri = "gs://bucket/eval.jsonl" evaluation_interval = 37 enable_early_stopping = True + enable_checkpoint_selection = True tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123" tuning_job = model.tune_model( @@ -1657,6 +1946,7 @@ def test_tune_text_generation_model_ga( evaluation_data=evaluation_data_uri, evaluation_interval=evaluation_interval, enable_early_stopping=enable_early_stopping, + enable_checkpoint_selection=enable_checkpoint_selection, tensorboard=tensorboard_name, ), accelerator_type="TPU", @@ -1670,6 +1960,10 @@ def test_tune_text_generation_model_ga( assert pipeline_arguments["evaluation_data_uri"] == evaluation_data_uri assert pipeline_arguments["evaluation_interval"] == evaluation_interval assert pipeline_arguments["enable_early_stopping"] == enable_early_stopping + assert ( + pipeline_arguments["enable_checkpoint_selection"] + == enable_checkpoint_selection + ) assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name assert pipeline_arguments["large_model_reference"] == "text-bison@001" assert pipeline_arguments["accelerator_type"] == "TPU" @@ -1867,7 +2161,7 @@ def test_tune_code_generation_model( _CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT ), ): - model = language_models.CodeGenerationModel.from_pretrained( + model = preview_language_models.CodeGenerationModel.from_pretrained( "code-bison@001" ) # The tune_model call needs to be inside the PublisherModel mock @@ -1915,9 +2209,7 @@ def test_tune_code_chat_model( _CODECHAT_BISON_PUBLISHER_MODEL_DICT ), ): - model = preview_language_models.CodeChatModel.from_pretrained( - "codechat-bison@001" - ) + model = language_models.CodeChatModel.from_pretrained("codechat-bison@001") # The tune_model call needs to be inside the PublisherModel mock # since it gets a new PublisherModel when tuning completes. @@ -2111,6 +2403,221 @@ def test_chat(self): assert prediction_parameters["topK"] == message_top_k assert prediction_parameters["topP"] == message_top_p + gca_predict_response4 = gca_prediction_service.PredictResponse() + gca_predict_response4.predictions.append( + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING + ) + test_grounding_sources = [ + _TEST_GROUNDING_WEB_SEARCH, + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, + ] + datastore_path = ( + "projects/test-project/locations/global/" + "collections/default_collection/dataStores/test_datastore" + ) + expected_grounding_sources = [ + {"sources": [{"type": "WEB"}]}, + { + "sources": [ + { + "type": "ENTERPRISE", + "enterpriseDatastore": datastore_path, + } + ] + }, + ] + for test_grounding_source, expected_grounding_source in zip( + test_grounding_sources, expected_grounding_sources + ): + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response4, + ) as mock_predict4: + response = chat2.send_message( + "Are my favorite movies based on a book series?", + grounding_source=test_grounding_source, + ) + prediction_parameters = mock_predict4.call_args[1]["parameters"] + assert ( + prediction_parameters["groundingConfig"] + == expected_grounding_source + ) + assert ( + dataclasses.asdict(response.grounding_metadata) + == _EXPECTED_PARSED_GROUNDING_METADATA_CHAT + ) + + gca_predict_response5 = gca_prediction_service.PredictResponse() + gca_predict_response5.predictions.append( + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE + ) + test_grounding_sources = [ + _TEST_GROUNDING_WEB_SEARCH, + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, + ] + datastore_path = ( + "projects/test-project/locations/global/" + "collections/default_collection/dataStores/test_datastore" + ) + expected_grounding_sources = [ + {"sources": [{"type": "WEB"}]}, + { + "sources": [ + { + "type": "ENTERPRISE", + "enterpriseDatastore": datastore_path, + } + ] + }, + ] + for test_grounding_source, expected_grounding_source in zip( + test_grounding_sources, expected_grounding_sources + ): + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response5, + ) as mock_predict5: + response = chat2.send_message( + "Are my favorite movies based on a book series?", + grounding_source=test_grounding_source, + ) + prediction_parameters = mock_predict5.call_args[1]["parameters"] + assert ( + prediction_parameters["groundingConfig"] + == expected_grounding_source + ) + assert ( + dataclasses.asdict(response.grounding_metadata) + == _EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE + ) + + @pytest.mark.asyncio + async def test_chat_async(self): + """Test the chat generation model async api.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CHAT_BISON_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = preview_language_models.ChatModel.from_pretrained("chat-bison@001") + + mock_get_publisher_model.assert_called_once_with( + name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY + ) + chat_temperature = 0.1 + chat_max_output_tokens = 100 + chat_top_k = 1 + chat_top_p = 0.1 + + chat = model.start_chat( + temperature=chat_temperature, + max_output_tokens=chat_max_output_tokens, + top_k=chat_top_k, + top_p=chat_top_p, + ) + + gca_predict_response6 = gca_prediction_service.PredictResponse() + gca_predict_response6.predictions.append( + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING + ) + test_grounding_sources = [ + _TEST_GROUNDING_WEB_SEARCH, + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, + ] + datastore_path = ( + "projects/test-project/locations/global/" + "collections/default_collection/dataStores/test_datastore" + ) + expected_grounding_sources = [ + {"sources": [{"type": "WEB"}]}, + { + "sources": [ + { + "type": "ENTERPRISE", + "enterpriseDatastore": datastore_path, + } + ] + }, + ] + for test_grounding_source, expected_grounding_source in zip( + test_grounding_sources, expected_grounding_sources + ): + with mock.patch.object( + target=prediction_service_async_client.PredictionServiceAsyncClient, + attribute="predict", + return_value=gca_predict_response6, + ) as mock_predict6: + response = await chat.send_message_async( + "Are my favorite movies based on a book series?", + grounding_source=test_grounding_source, + ) + prediction_parameters = mock_predict6.call_args[1]["parameters"] + assert prediction_parameters["temperature"] == chat_temperature + assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens + assert prediction_parameters["topK"] == chat_top_k + assert prediction_parameters["topP"] == chat_top_p + assert ( + prediction_parameters["groundingConfig"] + == expected_grounding_source + ) + assert ( + dataclasses.asdict(response.grounding_metadata) + == _EXPECTED_PARSED_GROUNDING_METADATA_CHAT + ) + + gca_predict_response7 = gca_prediction_service.PredictResponse() + gca_predict_response7.predictions.append( + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE + ) + test_grounding_sources = [ + _TEST_GROUNDING_WEB_SEARCH, + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, + ] + datastore_path = ( + "projects/test-project/locations/global/" + "collections/default_collection/dataStores/test_datastore" + ) + expected_grounding_sources = [ + {"sources": [{"type": "WEB"}]}, + { + "sources": [ + { + "type": "ENTERPRISE", + "enterpriseDatastore": datastore_path, + } + ] + }, + ] + for test_grounding_source, expected_grounding_source in zip( + test_grounding_sources, expected_grounding_sources + ): + with mock.patch.object( + target=prediction_service_async_client.PredictionServiceAsyncClient, + attribute="predict", + return_value=gca_predict_response7, + ) as mock_predict7: + response = await chat.send_message_async( + "Are my favorite movies based on a book series?", + grounding_source=test_grounding_source, + ) + prediction_parameters = mock_predict7.call_args[1]["parameters"] + assert ( + prediction_parameters["groundingConfig"] + == expected_grounding_source + ) + assert ( + dataclasses.asdict(response.grounding_metadata) + == _EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE + ) + def test_chat_ga(self): """Tests the chat generation model.""" aiplatform.init( @@ -3060,6 +3567,10 @@ def test_text_embedding_ga(self): assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"] + # Validating that a single string is not accepted. + with pytest.raises(TypeError): + model.get_embeddings("What is life?") + def test_batch_prediction( self, get_endpoint_mock, diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 87ab5a9c5f..48e7c3c506 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -232,6 +232,10 @@ Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"]) ] _TEST_IDS = ["123", "456", "789"] +_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3 +_TEST_APPROX_NUM_NEIGHBORS = 2 +_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE = 0.8 +_TEST_RETURN_FULL_DATAPOINT = True def uuid_mock(): @@ -853,6 +857,47 @@ def test_delete_index_endpoint_with_force( name=_TEST_INDEX_ENDPOINT_NAME ) + @pytest.mark.usefixtures("get_index_endpoint_mock") + def test_index_endpoint_match_queries_backward_compatibility( + self, index_endpoint_match_queries_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_index_endpoint.match( + _TEST_DEPLOYED_INDEX_ID, + _TEST_QUERIES, + _TEST_NUM_NEIGHBOURS, + _TEST_FILTER, + ) + + batch_request = match_service_pb2.BatchMatchRequest( + requests=[ + match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + requests=[ + match_service_pb2.MatchRequest( + num_neighbors=_TEST_NUM_NEIGHBOURS, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + float_val=_TEST_QUERIES[0], + restricts=[ + match_service_pb2.Namespace( + name="class", + allow_tokens=["token_1"], + deny_tokens=["token_2"], + ) + ], + ) + ], + ) + ] + ) + + index_endpoint_match_queries_mock.assert_called_with(batch_request) + @pytest.mark.usefixtures("get_index_endpoint_mock") def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock): aiplatform.init(project=_TEST_PROJECT) @@ -866,6 +911,8 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock): queries=_TEST_QUERIES, num_neighbors=_TEST_NUM_NEIGHBOURS, filter=_TEST_FILTER, + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, ) batch_request = match_service_pb2.BatchMatchRequest( @@ -884,6 +931,8 @@ def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock): deny_tokens=["token_2"], ) ], + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, ) ], ) @@ -907,6 +956,10 @@ def test_index_public_endpoint_match_queries( queries=_TEST_QUERIES, num_neighbors=_TEST_NUM_NEIGHBOURS, filter=_TEST_FILTER, + per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT, ) find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest( @@ -925,8 +978,12 @@ def test_index_public_endpoint_match_queries( ) ], ), + per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approximate_neighbor_count=_TEST_APPROX_NUM_NEIGHBORS, + fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, ) ], + return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT, ) index_public_endpoint_match_queries_mock.assert_called_with( diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index ed1cd95f39..c85deda21c 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -2032,19 +2032,22 @@ def test_parse_dataset_path(): def test_dataset_path(): project = "squid" - dataset = "clam" - expected = "projects/{project}/datasets/{dataset}".format( + location = "clam" + dataset = "whelk" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", - "dataset": "octopus", + "project": "octopus", + "location": "oyster", + "dataset": "nudibranch", } path = MigrationServiceClient.dataset_path(**expected) @@ -2054,22 +2057,19 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "oyster" - location = "nudibranch" - dataset = "cuttlefish" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project = "cuttlefish" + dataset = "mussel" + expected = "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", + "project": "winkle", "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py index 60647315dd..b11bba9ffa 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_feature_online_store_admin_service.py @@ -62,6 +62,7 @@ from google.cloud.aiplatform_v1beta1.types import feature_view as gca_feature_view from google.cloud.aiplatform_v1beta1.types import feature_view_sync from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import service_networking from google.cloud.location import locations_pb2 from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index b287e8fe7f..c56b99ac1b 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -2034,8 +2034,31 @@ def test_parse_dataset_path(): def test_dataset_path(): project = "squid" - location = "clam" - dataset = "whelk" + dataset = "clam" + expected = "projects/{project}/datasets/{dataset}".format( + project=project, + dataset=dataset, + ) + actual = MigrationServiceClient.dataset_path(project, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "whelk", + "dataset": "octopus", + } + path = MigrationServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "oyster" + location = "nudibranch" + dataset = "cuttlefish" expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, location=location, @@ -2047,9 +2070,9 @@ def test_dataset_path(): def test_parse_dataset_path(): expected = { - "project": "octopus", - "location": "oyster", - "dataset": "nudibranch", + "project": "mussel", + "location": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -2059,19 +2082,48 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "cuttlefish" - dataset = "mussel" - expected = "projects/{project}/datasets/{dataset}".format( + project = "scallop" + location = "abalone" + dataset = "squid" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", + "project": "clam", + "location": "whelk", + "dataset": "octopus", + } + path = MigrationServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "oyster" + location = "nudibranch" + dataset = "cuttlefish" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, + location=location, + dataset=dataset, + ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "mussel", + "location": "winkle", "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/vertex_ray/test_cluster_init.py b/tests/unit/vertex_ray/test_cluster_init.py index 1556c6e86d..f4cee9380f 100644 --- a/tests/unit/vertex_ray/test_cluster_init.py +++ b/tests/unit/vertex_ray/test_cluster_init.py @@ -213,6 +213,31 @@ def test_create_ray_cluster_1_pool_gpu_success( request, ) + @pytest.mark.usefixtures("get_persistent_resource_1_pool_mock") + def test_create_ray_cluster_1_pool_gpu_with_labels_success( + self, create_persistent_resource_1_pool_mock + ): + """If head and worker nodes are duplicate, merge to head pool.""" + cluster_name = vertex_ray.create_ray_cluster( + head_node_type=tc.ClusterConstants._TEST_HEAD_NODE_TYPE_1_POOL, + worker_node_types=tc.ClusterConstants._TEST_WORKER_NODE_TYPES_1_POOL, + network=tc.ProjectConstants._TEST_VPC_NETWORK, + cluster_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID, + labels=tc.ClusterConstants._TEST_LABELS, + ) + + assert tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS == cluster_name + + request = persistent_resource_service.CreatePersistentResourceRequest( + parent=tc.ProjectConstants._TEST_PARENT, + persistent_resource=tc.ClusterConstants._TEST_REQUEST_RUNNING_1_POOL_WITH_LABELS, + persistent_resource_id=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID, + ) + + create_persistent_resource_1_pool_mock.assert_called_with( + request, + ) + @pytest.mark.usefixtures("get_persistent_resource_2_pools_mock") def test_create_ray_cluster_2_pools_success( self, create_persistent_resource_2_pools_mock diff --git a/tests/unit/vertex_ray/test_constants.py b/tests/unit/vertex_ray/test_constants.py index 8d972c08ea..97dc92260d 100644 --- a/tests/unit/vertex_ray/test_constants.py +++ b/tests/unit/vertex_ray/test_constants.py @@ -63,9 +63,10 @@ class ProjectConstants: class ClusterConstants: """Defines cluster constants used by tests.""" + _TEST_LABELS = {"my_key": "my_value"} _TEST_VERTEX_RAY_HEAD_NODE_IP = "1.2.3.4:10001" _TEST_VERTEX_RAY_JOB_CLIENT_IP = "1.2.3.4:8888" - _TEST_VERTEX_RAY_DASHBOARD_URL = ( + _TEST_VERTEX_RAY_DASHBOARD_ADDRESS = ( "48b400ad90b8dd3c-dot-us-central1.aiplatform-training.googleusercontent.com" ) _TEST_VERTEX_RAY_PR_ID = "user-persistent-resource-1234567890" @@ -111,6 +112,14 @@ class ClusterConstants: ), network=ProjectConstants._TEST_VPC_NETWORK, ) + _TEST_REQUEST_RUNNING_1_POOL_WITH_LABELS = PersistentResource( + resource_pools=[_TEST_RESOURCE_POOL_0], + resource_runtime_spec=ResourceRuntimeSpec( + ray_spec=RaySpec(resource_pool_images={"head-node": _TEST_GPU_IMAGE}), + ), + network=ProjectConstants._TEST_VPC_NETWORK, + labels=_TEST_LABELS, + ) # Get response has generated name, and URIs _TEST_RESPONSE_RUNNING_1_POOL = PersistentResource( name=_TEST_VERTEX_RAY_PR_ADDRESS, @@ -121,7 +130,7 @@ class ClusterConstants: network=ProjectConstants._TEST_VPC_NETWORK, resource_runtime=ResourceRuntime( access_uris={ - "RAY_DASHBOARD_URI": _TEST_VERTEX_RAY_DASHBOARD_URL, + "RAY_DASHBOARD_URI": _TEST_VERTEX_RAY_DASHBOARD_ADDRESS, "RAY_HEAD_NODE_INTERNAL_IP": _TEST_VERTEX_RAY_HEAD_NODE_IP, } ), @@ -182,7 +191,7 @@ class ClusterConstants: network=ProjectConstants._TEST_VPC_NETWORK, resource_runtime=ResourceRuntime( access_uris={ - "RAY_DASHBOARD_URI": _TEST_VERTEX_RAY_DASHBOARD_URL, + "RAY_DASHBOARD_URI": _TEST_VERTEX_RAY_DASHBOARD_ADDRESS, "RAY_HEAD_NODE_INTERNAL_IP": _TEST_VERTEX_RAY_HEAD_NODE_IP, } ), @@ -196,6 +205,7 @@ class ClusterConstants: state="RUNNING", head_node_type=_TEST_HEAD_NODE_TYPE_1_POOL, worker_node_types=_TEST_WORKER_NODE_TYPES_1_POOL, + dashboard_address=_TEST_VERTEX_RAY_DASHBOARD_ADDRESS, ) _TEST_CLUSTER_2 = Cluster( cluster_resource_name=_TEST_VERTEX_RAY_PR_ADDRESS, @@ -205,4 +215,10 @@ class ClusterConstants: state="RUNNING", head_node_type=_TEST_HEAD_NODE_TYPE_2_POOLS, worker_node_types=_TEST_WORKER_NODE_TYPES_2_POOLS, + dashboard_address=_TEST_VERTEX_RAY_DASHBOARD_ADDRESS, ) + _TEST_BEARER_TOKEN = "test-bearer-token" + _TEST_HEADERS = { + "Content-Type": "application/json", + "Authorization": "Bearer {}".format(_TEST_BEARER_TOKEN), + } diff --git a/tests/unit/vertex_ray/test_dashboard_sdk.py b/tests/unit/vertex_ray/test_dashboard_sdk.py index 05ffd35dbb..bebf78f01b 100644 --- a/tests/unit/vertex_ray/test_dashboard_sdk.py +++ b/tests/unit/vertex_ray/test_dashboard_sdk.py @@ -44,6 +44,15 @@ def get_persistent_resource_status_running_mock(): yield get_persistent_resource +@pytest.fixture +def get_bearer_token_mock(): + with mock.patch.object( + vertex_ray.util._validation_utils, "get_bearer_token" + ) as get_bearer_token_mock: + get_bearer_token_mock.return_value = tc.ClusterConstants._TEST_BEARER_TOKEN + yield get_bearer_token_mock + + class TestGetJobSubmissionClientClusterInfo: def setup_method(self): importlib.reload(aiplatform.initializer) @@ -83,3 +92,22 @@ def test_job_submission_client_cluster_info_with_cluster_name( ray_get_job_submission_client_cluster_info_mock.assert_called_once_with( address=tc.ClusterConstants._TEST_VERTEX_RAY_JOB_CLIENT_IP ) + + @pytest.mark.usefixtures( + "get_persistent_resource_status_running_mock", "google_auth_mock" + ) + def test_job_submission_client_cluster_info_with_dashboard_address( + self, + ray_get_job_submission_client_cluster_info_mock, + get_bearer_token_mock, + ): + aiplatform.init(project=tc.ProjectConstants._TEST_GCP_PROJECT_ID) + + vertex_ray.get_job_submission_client_cluster_info( + tc.ClusterConstants._TEST_VERTEX_RAY_DASHBOARD_ADDRESS + ) + get_bearer_token_mock.assert_called_once_with() + ray_get_job_submission_client_cluster_info_mock.assert_called_once_with( + address=tc.ClusterConstants._TEST_VERTEX_RAY_DASHBOARD_ADDRESS, + headers=tc.ClusterConstants._TEST_HEADERS, + ) diff --git a/tests/unit/vertex_ray/test_vertex_ray_client.py b/tests/unit/vertex_ray/test_vertex_ray_client.py index 1ea55e35c5..a421f84ab2 100644 --- a/tests/unit/vertex_ray/test_vertex_ray_client.py +++ b/tests/unit/vertex_ray/test_vertex_ray_client.py @@ -25,7 +25,7 @@ # -*- coding: utf-8 -*- _TEST_CLIENT_CONTEXT = ray.client_builder.ClientContext( - dashboard_url=tc.ClusterConstants._TEST_VERTEX_RAY_DASHBOARD_URL, + dashboard_url=tc.ClusterConstants._TEST_VERTEX_RAY_DASHBOARD_ADDRESS, python_version="MOCK_PYTHON_VERSION", ray_version="MOCK_RAY_VERSION", ray_commit="MOCK_RAY_COMMIT", @@ -37,7 +37,7 @@ _TEST_VERTEX_RAY_CLIENT_CONTEXT = vertex_ray.client_builder._VertexRayClientContext( persistent_resource_id="MOCK_PERSISTENT_RESOURCE_ID", ray_head_uris={ - "RAY_DASHBOARD_URI": tc.ClusterConstants._TEST_VERTEX_RAY_DASHBOARD_URL, + "RAY_DASHBOARD_URI": tc.ClusterConstants._TEST_VERTEX_RAY_DASHBOARD_ADDRESS, "RAY_HEAD_NODE_INTERNAL_IP": tc.ClusterConstants._TEST_VERTEX_RAY_HEAD_NODE_IP, }, ray_client_context=_TEST_CLIENT_CONTEXT, diff --git a/tests/unit/vertexai/test_any_serializer.py b/tests/unit/vertexai/test_any_serializer.py index d991f047eb..9ea44e9895 100644 --- a/tests/unit/vertexai/test_any_serializer.py +++ b/tests/unit/vertexai/test_any_serializer.py @@ -18,6 +18,7 @@ import pytest import cloudpickle +import logging import json import os from typing import Any @@ -31,11 +32,36 @@ from vertexai.preview._workflow.shared import constants import pandas as pd +import sklearn from sklearn.linear_model import LogisticRegression import tensorflow as tf from tensorflow import keras import torch +try: + # pylint: disable=g-import-not-at-top + import lightning.pytorch as pl +except ImportError: + pl = None + +try: + import bigframes as bf +except ImportError: + bf = None + +# lightning trainer and bigframes dataframe are not in this scheme since +# the test environment may not have these packages. +_TEST_SERIALIZATION_SCHEME = { + object: serializers.CloudPickleSerializer, + sklearn.base.BaseEstimator: serializers.SklearnEstimatorSerializer, + keras.models.Model: serializers.KerasModelSerializer, + keras.callbacks.History: serializers.KerasHistoryCallbackSerializer, + tf.data.Dataset: serializers.TFDatasetSerializer, + torch.nn.Module: serializers.TorchModelSerializer, + torch.utils.data.DataLoader: serializers.TorchDataLoaderSerializer, + pd.DataFrame: serializers.PandasDataSerializer, +} + @pytest.fixture def any_serializer_instance(): @@ -262,6 +288,41 @@ def forward(self, x): class TestAnySerializer: """Tests that AnySerializer is acting as 'controller' and router.""" + def test_any_serializer_register_predefined_serializers(self, caplog): + with caplog.at_level( + level=logging.DEBUG, logger="vertexai.serialization_engine" + ): + serializers_base.Serializer._instances = {} + serializer_instance = any_serializer.AnySerializer() + + if pl: + _TEST_SERIALIZATION_SCHEME[ + pl.Trainer + ] = serializers.LightningTrainerSerializer + else: + # Lightning trainer is not registered. + # Check the logs to make sure we tried to register them. + assert ( + f"Failed to register {serializers.LightningTrainerSerializer} due to" + in caplog.text + ) + + if bf: + _TEST_SERIALIZATION_SCHEME[ + bf.dataframe.DataFrame + ] = serializers.BigframeSerializer + else: + # Bigframes dataframe is not registered. + # Check the logs to make sure we tried to register them. + assert ( + f"Failed to register {serializers.BigframeSerializer} due to" + in caplog.text + ) + + assert ( + serializer_instance._serialization_scheme == _TEST_SERIALIZATION_SCHEME + ) + def test_any_serializer_global_metadata_created( self, mock_cloudpickle_serialize, any_serializer_instance, tmp_path ): diff --git a/vertexai/_model_garden/_model_garden_models.py b/vertexai/_model_garden/_model_garden_models.py index 0dc249e1a6..d1e650e0f3 100644 --- a/vertexai/_model_garden/_model_garden_models.py +++ b/vertexai/_model_garden/_model_garden_models.py @@ -35,9 +35,13 @@ _SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP = { "text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0", + "text-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0", "code-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0", + "code-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0", "chat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0", + "chat-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0", "codechat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0", + "codechat-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0", } _SDK_PRIVATE_PREVIEW_LAUNCH_STAGE = frozenset( diff --git a/vertexai/language_models/__init__.py b/vertexai/language_models/__init__.py index 8d16584ecb..c1991017e8 100644 --- a/vertexai/language_models/__init__.py +++ b/vertexai/language_models/__init__.py @@ -27,6 +27,7 @@ TextEmbeddingModel, TextGenerationModel, TextGenerationResponse, + GroundingSource, ) __all__ = [ @@ -42,4 +43,5 @@ "TextEmbeddingModel", "TextGenerationModel", "TextGenerationResponse", + "GroundingSource", ] diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 2316108cef..04f7bcb548 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -14,6 +14,7 @@ # """Classes for working with language models.""" +import abc import dataclasses from typing import ( Any, @@ -244,6 +245,10 @@ def tune_model( tuning_parameters[ "enable_early_stopping" ] = eval_spec.enable_early_stopping + if eval_spec.enable_checkpoint_selection is not None: + tuning_parameters[ + "enable_checkpoint_selection" + ] = eval_spec.enable_checkpoint_selection if eval_spec.tensorboard is not None: if isinstance(eval_spec.tensorboard, aiplatform.Tensorboard): if eval_spec.tensorboard.location != tuning_job_location: @@ -676,6 +681,10 @@ class TuningEvaluationSpec: evaluation_interval tuning steps. Default: 20. enable_early_stopping: If True, the tuning may stop early before completing all the tuning steps. Requires evaluation_data. + enable_checkpoint_selection: If set to True, the tuning process returns + the best model checkpoint (based on model evaluation). + If set to False, the latest model checkpoint is returned. + If unset, the selection is only enabled for `*-bison@001` models. tensorboard: Vertex Tensorboard where to write the evaluation metrics. The Tensorboard must be in the same location as the tuning job. """ @@ -685,9 +694,133 @@ class TuningEvaluationSpec: evaluation_data: Optional[str] = None evaluation_interval: Optional[int] = None enable_early_stopping: Optional[bool] = None + enable_checkpoint_selection: Optional[bool] = None tensorboard: Optional[Union[aiplatform.Tensorboard, str]] = None +class _GroundingSourceBase(abc.ABC): + """Interface of grounding source dataclass for grounding.""" + + @abc.abstractmethod + def _to_grounding_source_dict(self) -> Dict[str, Any]: + """construct grounding source into dictionary""" + pass + + +@dataclasses.dataclass +class WebSearch(_GroundingSourceBase): + """WebSearch represents a grounding source using public web search.""" + + _type: str = dataclasses.field(default="WEB", init=False, repr=False) + + def _to_grounding_source_dict(self) -> Dict[str, Any]: + return {"type": self._type} + + +@dataclasses.dataclass +class VertexAISearch(_GroundingSourceBase): + """VertexAISearchDatastore represents a grounding source using Vertex AI Search datastore + Attributes: + data_store_id: Data store ID of the Vertex AI Search datastore. + location: GCP multi region where you have set up your Vertex AI Search data store. Possible values can be `global`, `us`, `eu`, etc. + Learn more about Vertex AI Search location here: + https://cloud.google.com/generative-ai-app-builder/docs/locations + project: The project where you have set up your Vertex AI Search. + If not specified, will assume that your Vertex AI Search is within your current project. + """ + + data_store_id: str + location: str + project: Optional[str] = None + _type: str = dataclasses.field(default="ENTERPRISE", init=False, repr=False) + + def _get_datastore_path(self) -> str: + _project = self.project or aiplatform_initializer.global_config.project + return ( + f"projects/{_project}/locations/{self.location}" + f"/collections/default_collection/dataStores/{self.data_store_id}" + ) + + def _to_grounding_source_dict(self) -> Dict[str, Any]: + return {"type": self._type, "enterpriseDatastore": self._get_datastore_path()} + + +@dataclasses.dataclass +class GroundingSource: + + WebSearch = WebSearch + VertexAISearch = VertexAISearch + + +@dataclasses.dataclass +class GroundingCitation: + """Citaion used from grounding. + Attributes: + start_index: Index in the prediction output where the citation starts + (inclusive). Must be >= 0 and < end_index. + end_index: Index in the prediction output where the citation ends + (exclusive). Must be > start_index and < len(output). + url: URL associated with this citation. If present, this URL links to the + webpage of the source of this citation. Possible URLs include news + websites, GitHub repos, etc. + title: Title associated with this citation. If present, it refers to the title + of the source of this citation. Possible titles include + news titles, book titles, etc. + license: License associated with this citation. If present, it refers to the + license of the source of this citation. Possible licenses include code + licenses, e.g., mit license. + publication_date: Publication date associated with this citation. If present, it refers to + the date at which the source of this citation was published. + Possible formats are YYYY, YYYY-MM, YYYY-MM-DD. + """ + + start_index: Optional[int] = None + end_index: Optional[int] = None + url: Optional[str] = None + title: Optional[str] = None + license: Optional[str] = None + publication_date: Optional[str] = None + + +@dataclasses.dataclass +class GroundingMetadata: + """Metadata for grounding. + Attributes: + citations: List of grounding citations. + """ + + citations: Optional[List[GroundingCitation]] = None + + def _parse_citation_from_dict( + self, citation_dict_camel: Dict[str, Any] + ) -> GroundingCitation: + _start_index = citation_dict_camel.get("startIndex") + _end_index = citation_dict_camel.get("endIndex") + if _start_index is not None: + _start_index = int(_start_index) + if _end_index is not None: + _end_index = int(_end_index) + _url = citation_dict_camel.get("url") + _title = citation_dict_camel.get("title") + _license = citation_dict_camel.get("license") + _publication_date = citation_dict_camel.get("publicationDate") + + return GroundingCitation( + start_index=_start_index, + end_index=_end_index, + url=_url, + title=_title, + license=_license, + publication_date=_publication_date, + ) + + def __init__(self, response: Optional[Dict[str, Any]] = {}): + self.citations = [ + self._parse_citation_from_dict(citation) + for citation in response.get("citations", []) + ] + + @dataclasses.dataclass class TextGenerationResponse: """TextGenerationResponse represents a response of a language model. @@ -697,6 +830,7 @@ class TextGenerationResponse: safety_attributes: Scores for safety attributes. Learn more about the safety attributes here: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions + grounding_metadata: Metadata for grounding. """ __module__ = "vertexai.language_models" @@ -705,12 +839,22 @@ class TextGenerationResponse: _prediction_response: Any is_blocked: bool = False safety_attributes: Dict[str, float] = dataclasses.field(default_factory=dict) + grounding_metadata: Optional[GroundingMetadata] = None def __repr__(self): if self.text: return self.text + # Falling back to the full representation + elif self.grounding_metadata is not None: + return ( + "TextGenerationResponse(" + f"text={self.text!r}" + f", is_blocked={self.is_blocked!r}" + f", safety_attributes={self.safety_attributes!r}" + f", grounding_metadata={self.grounding_metadata!r}" + ")" + ) else: - # Falling back to the full representation return ( "TextGenerationResponse(" f"text={self.text!r}" @@ -735,6 +879,7 @@ class MultiCandidateTextGenerationResponse(TextGenerationResponse): safety_attributes: Scores for safety attributes for the first candidate. Learn more about the safety attributes here: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions + grounding_metadata: Grounding metadata for the first candidate. candidates: The candidate responses. Usually contains a single candidate unless `candidate_count` is used. """ @@ -780,6 +925,9 @@ def predict( top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, + grounding_source: Optional[ + Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + ] = None, ) -> "MultiCandidateTextGenerationResponse": """Gets model response for a single prompt. @@ -791,6 +939,7 @@ def predict( top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. stop_sequences: Customized stop sequences to stop the decoding process. candidate_count: Number of response candidates to return. + grounding_source: If specified, grounding feature will be enabled using the grounding source. Default: None. Returns: A `MultiCandidateTextGenerationResponse` object that contains the text produced by the model. @@ -803,6 +952,7 @@ def predict( top_p=top_p, stop_sequences=stop_sequences, candidate_count=candidate_count, + grounding_source=grounding_source, ) prediction_response = self._endpoint.predict( @@ -824,6 +974,9 @@ async def predict_async( top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, + grounding_source: Optional[ + Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + ] = None, ) -> "MultiCandidateTextGenerationResponse": """Asynchronously gets model response for a single prompt. @@ -835,6 +988,7 @@ async def predict_async( top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. stop_sequences: Customized stop sequences to stop the decoding process. candidate_count: Number of response candidates to return. + grounding_source: If specified, grounding feature will be enabled using the grounding source. Default: None. Returns: A `MultiCandidateTextGenerationResponse` object that contains the text produced by the model. @@ -847,6 +1001,7 @@ async def predict_async( top_p=top_p, stop_sequences=stop_sequences, candidate_count=candidate_count, + grounding_source=grounding_source, ) prediction_response = await self._endpoint.predict_async( @@ -966,6 +1121,9 @@ def _create_text_generation_prediction_request( top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, + grounding_source: Optional[ + Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + ] = None, ) -> "_PredictionRequest": """Prepares the text generation request for a single prompt. @@ -977,6 +1135,8 @@ def _create_text_generation_prediction_request( top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. stop_sequences: Customized stop sequences to stop the decoding process. candidate_count: Number of candidates to return. + grounding_source: If specified, grounding feature will be enabled using the grounding source. Default: None. + Returns: A `_PredictionRequest` object that contains prediction instance and parameters. @@ -1006,6 +1166,10 @@ def _create_text_generation_prediction_request( if candidate_count is not None: prediction_parameters["candidateCount"] = candidate_count + if grounding_source is not None: + sources = [grounding_source._to_grounding_source_dict()] + prediction_parameters["groundingConfig"] = {"sources": sources} + return _PredictionRequest( instance=instance, parameters=prediction_parameters, @@ -1019,6 +1183,7 @@ def _parse_text_generation_model_response( """Converts the raw text_generation model response to `TextGenerationResponse`.""" prediction = prediction_response.predictions[prediction_idx] safety_attributes_dict = prediction.get("safetyAttributes", {}) + grounding_metadata_dict = prediction.get("groundingMetadata", {}) return TextGenerationResponse( text=prediction["content"], _prediction_response=prediction_response, @@ -1029,6 +1194,7 @@ def _parse_text_generation_model_response( safety_attributes_dict.get("scores") or [], ) ), + grounding_metadata=GroundingMetadata(grounding_metadata_dict), ) @@ -1054,6 +1220,7 @@ def _parse_text_generation_model_multi_candidate_response( _prediction_response=prediction_response, is_blocked=candidates[0].is_blocked, safety_attributes=candidates[0].safety_attributes, + grounding_metadata=candidates[0].grounding_metadata, candidates=candidates, ) @@ -1353,6 +1520,10 @@ def _prepare_text_embedding_request( Returns: A `_MultiInstancePredictionRequest` object. """ + if isinstance(texts, str) or not isinstance(texts, Sequence): + raise TypeError( + "The `texts` argument must be a list, not a single string." + ) instances = [] for text in texts: if isinstance(text, TextEmbeddingInput): @@ -1641,7 +1812,7 @@ def start_chat( ) -class CodeChatModel(_ChatModelBase): +class CodeChatModel(_ChatModelBase, _TunableChatModelMixin): """CodeChatModel represents a model that is capable of completing code. Examples: @@ -1693,7 +1864,7 @@ def start_chat( ) -class _PreviewCodeChatModel(CodeChatModel, _TunableChatModelMixin): +class _PreviewCodeChatModel(CodeChatModel): __name__ = "CodeChatModel" __module__ = "vertexai.preview.language_models" @@ -1778,6 +1949,9 @@ def _prepare_request( top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, + grounding_source: Optional[ + Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + ] = None, ) -> _PredictionRequest: """Prepares a request for the language model. @@ -1793,6 +1967,7 @@ def _prepare_request( Uses the value specified when calling `ChatModel.start_chat` by default. stop_sequences: Customized stop sequences to stop the decoding process. candidate_count: Number of candidates to return. + grounding_source: If specified, grounding feature will be enabled using the grounding source. Default: None. Returns: A `_PredictionRequest` object. @@ -1827,6 +2002,10 @@ def _prepare_request( if candidate_count is not None: prediction_parameters["candidateCount"] = candidate_count + if grounding_source is not None: + sources = [grounding_source._to_grounding_source_dict()] + prediction_parameters["groundingConfig"] = {"sources": sources} + message_structs = [] for past_message in self._message_history: message_structs.append( @@ -1878,8 +2057,12 @@ def _parse_chat_prediction_response( prediction = prediction_response.predictions[prediction_idx] candidate_count = len(prediction["candidates"]) candidates = [] + grounding_metadata_list = prediction.get("groundingMetadata") for candidate_idx in range(candidate_count): safety_attributes = prediction["safetyAttributes"][candidate_idx] + grounding_metadata_dict = {} + if grounding_metadata_list and grounding_metadata_list[candidate_idx]: + grounding_metadata_dict = grounding_metadata_list[candidate_idx] candidate_response = TextGenerationResponse( text=prediction["candidates"][candidate_idx]["content"], _prediction_response=prediction_response, @@ -1892,6 +2075,7 @@ def _parse_chat_prediction_response( safety_attributes.get("scores") or [], ) ), + grounding_metadata=GroundingMetadata(grounding_metadata_dict), ) candidates.append(candidate_response) return MultiCandidateTextGenerationResponse( @@ -1899,6 +2083,7 @@ def _parse_chat_prediction_response( _prediction_response=prediction_response, is_blocked=candidates[0].is_blocked, safety_attributes=candidates[0].safety_attributes, + grounding_metadata=candidates[0].grounding_metadata, candidates=candidates, ) @@ -1912,6 +2097,9 @@ def send_message( top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, + grounding_source: Optional[ + Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + ] = None, ) -> "MultiCandidateTextGenerationResponse": """Sends message to the language model and gets a response. @@ -1927,6 +2115,7 @@ def send_message( Uses the value specified when calling `ChatModel.start_chat` by default. stop_sequences: Customized stop sequences to stop the decoding process. candidate_count: Number of candidates to return. + grounding_source: If specified, grounding feature will be enabled using the grounding source. Default: None. Returns: A `MultiCandidateTextGenerationResponse` object that contains the @@ -1940,6 +2129,7 @@ def send_message( top_p=top_p, stop_sequences=stop_sequences, candidate_count=candidate_count, + grounding_source=grounding_source, ) prediction_response = self._model._endpoint.predict( @@ -1970,6 +2160,9 @@ async def send_message_async( top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, + grounding_source: Optional[ + Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + ] = None, ) -> "MultiCandidateTextGenerationResponse": """Asynchronously sends message to the language model and gets a response. @@ -1985,6 +2178,7 @@ async def send_message_async( Uses the value specified when calling `ChatModel.start_chat` by default. stop_sequences: Customized stop sequences to stop the decoding process. candidate_count: Number of candidates to return. + grounding_source: If specified, grounding feature will be enabled using the grounding source. Default: None. Returns: A `MultiCandidateTextGenerationResponse` object that contains @@ -1998,6 +2192,7 @@ async def send_message_async( top_p=top_p, stop_sequences=stop_sequences, candidate_count=candidate_count, + grounding_source=grounding_source, ) prediction_response = await self._model._endpoint.predict_async( @@ -2690,9 +2885,7 @@ class CodeGenerationModel(_CodeGenerationModel, _TunableTextModelMixin): pass -class _PreviewCodeGenerationModel( - CodeGenerationModel, _CountTokensCodeGenerationMixin -): +class _PreviewCodeGenerationModel(CodeGenerationModel, _CountTokensCodeGenerationMixin): __name__ = "CodeGenerationModel" __module__ = "vertexai.preview.language_models" diff --git a/vertexai/preview/_workflow/serialization_engine/any_serializer.py b/vertexai/preview/_workflow/serialization_engine/any_serializer.py index f0b855f030..64ed527e75 100644 --- a/vertexai/preview/_workflow/serialization_engine/any_serializer.py +++ b/vertexai/preview/_workflow/serialization_engine/any_serializer.py @@ -18,6 +18,7 @@ """Defines the Serializer classes.""" import collections import dataclasses +import importlib import json import os import sys @@ -36,79 +37,9 @@ from packaging import requirements -# TODO(b/272263750): use the centralized module and usage pattern to guard these -# imports -# pylint: disable=g-import-not-at-top -try: - import pandas as pd - import bigframes as bf - - PandasData = pd.DataFrame - BigframesData = bf.dataframe.DataFrame -except ImportError: - pd = None - bf = None - PandasData = Any - BigframesData = Any - -try: - import pandas as pd - - PandasData = pd.DataFrame -except ImportError: - pd = None - PandasData = Any - -try: - import sklearn - - SklearnEstimator = sklearn.base.BaseEstimator -except ImportError: - sklearn = None - SklearnEstimator = Any - -try: - from tensorflow import keras - import tensorflow as tf - - KerasModel = keras.models.Model - TFDataset = tf.data.Dataset -except ImportError: - keras = None - tf = None - KerasModel = Any - TFDataset = Any - -try: - import torch - - TorchModel = torch.nn.Module - TorchDataLoader = torch.utils.data.DataLoader -except ImportError: - torch = None - TorchModel = Any - TorchDataLoader = Any - -try: - import lightning.pytorch as pl - - LightningTrainer = pl.Trainer -except ImportError: - pl = None - LightningTrainer = Any - T = TypeVar("T") -Types = Union[ - PandasData, - BigframesData, - SklearnEstimator, - KerasModel, - TorchModel, - LightningTrainer, -] - _LOGGER = base.Logger("vertexai.serialization_engine") SERIALIZATION_METADATA_SERIALIZER_KEY = "serializer" @@ -119,6 +50,25 @@ _LIGHTNING_ROOT_DIR = "/vertex_lightning_root_dir/" _JSONABLE_TYPES = Union[int, float, bytes, bool, str, None] +# This is a collection of all the predefined serializers and the fully qualified +# class names that these serializers are intended to be used on. +_PREDEFINED_SERIALIZERS = frozenset( + [ + ("sklearn.base.BaseEstimator", serializers.SklearnEstimatorSerializer), + ("tensorflow.keras.models.Model", serializers.KerasModelSerializer), + ( + "tensorflow.keras.callbacks.History", + serializers.KerasHistoryCallbackSerializer, + ), + ("tensorflow.data.Dataset", serializers.TFDatasetSerializer), + ("torch.nn.Module", serializers.TorchModelSerializer), + ("torch.utils.data.DataLoader", serializers.TorchDataLoaderSerializer), + ("lightning.pytorch.Trainer", serializers.LightningTrainerSerializer), + ("bigframes.dataframe.DataFrame", serializers.BigframeSerializer), + ("pandas.DataFrame", serializers.PandasDataSerializer), + ] +) + def get_arg_path_from_file_gcs_uri(gcs_uri: str, arg_name: str) -> str: """Gets the argument gcs path from the to-be-serialized object's gcs uri.""" @@ -321,32 +271,9 @@ def __init__(self): super().__init__() # Register with default serializers AnySerializer._register(object, serializers.CloudPickleSerializer) - if sklearn: - AnySerializer._register( - sklearn.base.BaseEstimator, serializers.SklearnEstimatorSerializer - ) - if keras: - AnySerializer._register( - keras.models.Model, serializers.KerasModelSerializer - ) - AnySerializer._register( - keras.callbacks.History, serializers.KerasHistoryCallbackSerializer - ) - if tf: - AnySerializer._register(tf.data.Dataset, serializers.TFDatasetSerializer) - if torch: - AnySerializer._register(torch.nn.Module, serializers.TorchModelSerializer) - AnySerializer._register( - torch.utils.data.DataLoader, serializers.TorchDataLoaderSerializer - ) - if pl: - AnySerializer._register(pl.Trainer, serializers.LightningTrainerSerializer) - if bf: - AnySerializer._register( - bf.dataframe.DataFrame, serializers.BigframeSerializer - ) - if pd: - AnySerializer._register(pd.DataFrame, serializers.PandasDataSerializer) + + for args in _PREDEFINED_SERIALIZERS: + AnySerializer._register_predefined_serializer(*args) @classmethod def _get_custom_serializer(cls, type_cls): @@ -356,6 +283,24 @@ def _get_custom_serializer(cls, type_cls): def _get_predefined_serializer(cls, type_cls): return cls._serialization_scheme.get(type_cls) + @classmethod + def _register_predefined_serializer( + cls, + full_class_name: str, + serializer: serializers_base.Serializer, + ): + """Registers a predefined serializer to AnySerializer.""" + try: + module_name, class_name = full_class_name.rsplit(".", 1) + module = importlib.import_module(module_name) + to_serialize_class = getattr(module, class_name) + + AnySerializer._register(to_serialize_class, serializer) + _LOGGER.debug(f"Successfully registered {serializer}") + + except Exception as e: + _LOGGER.debug(f"Failed to register {serializer} due to: {e}") + def _gcs_path_in_metadata(self, obj) -> Optional[str]: """Checks if an object has been (de-)serialized before.""" for key, value in self._metadata.serialized.items(): @@ -631,9 +576,3 @@ def deserialize(self, gcs_path): any_serializer.register_custom( to_serialize_type=to_serialize_type, serializer_cls=serializer_cls ) - - -try: - _any_serializer = AnySerializer() -except ImportError: - pass diff --git a/vertexai/preview/_workflow/serialization_engine/serializers.py b/vertexai/preview/_workflow/serialization_engine/serializers.py index db3ff1dab1..04258ae893 100644 --- a/vertexai/preview/_workflow/serialization_engine/serializers.py +++ b/vertexai/preview/_workflow/serialization_engine/serializers.py @@ -41,6 +41,7 @@ from packaging import version try: + # pylint: disable=g-import-not-at-top import cloudpickle except ImportError: cloudpickle = None @@ -49,7 +50,6 @@ # TODO(b/272263750): use the centralized module and usage pattern to guard these # imports -# pylint: disable=g-import-not-at-top try: import pandas as pd import bigframes as bf @@ -74,14 +74,6 @@ pq = None PandasData = Any -try: - import sklearn - - SklearnEstimator = sklearn.base.BaseEstimator -except ImportError: - sklearn = None - SklearnEstimator = Any - try: from tensorflow import keras import tensorflow as tf @@ -106,23 +98,6 @@ TorchDataLoader = Any TorchTensor = Any -try: - import lightning.pytorch as pl - - LightningTrainer = pl.Trainer -except ImportError: - pl = None - LightningTrainer = Any - - -Types = Union[ - PandasData, - BigframesData, - SklearnEstimator, - KerasModel, - TorchModel, - LightningTrainer, -] _LIGHTNING_ROOT_DIR = "/vertex_lightning_root_dir/" SERIALIZATION_METADATA_FILENAME = "serialization_metadata" @@ -420,7 +395,12 @@ class SklearnEstimatorSerializer(serializers_base.Serializer): serializers_base.SerializationMetadata(serializer="SklearnEstimatorSerializer") ) - def serialize(self, to_serialize: SklearnEstimator, gcs_path: str, **kwargs) -> str: + def serialize( + self, + to_serialize: "sklearn.base.BaseEstimator", # noqa: F821 + gcs_path: str, + **kwargs, + ) -> str: """Serializes a sklearn estimator to a gcs path. Args: @@ -447,7 +427,9 @@ def serialize(self, to_serialize: SklearnEstimator, gcs_path: str, **kwargs) -> return gcs_path - def deserialize(self, serialized_gcs_path: str, **kwargs) -> SklearnEstimator: + def deserialize( + self, serialized_gcs_path: str, **kwargs + ) -> "sklearn.base.BaseEstimator": # noqa: F821 """Deserialize a sklearn estimator given the gcs file name. Args: @@ -572,7 +554,9 @@ class LightningTrainerSerializer(serializers_base.Serializer): serializers_base.SerializationMetadata(serializer="LightningTrainerSerializer") ) - def _serialize_to_local(self, to_serialize: LightningTrainer, path: str): + def _serialize_to_local( + self, to_serialize: "lightning.pytorch.Trainer", path: str # noqa: F821 + ): """Serializes a lightning.pytorch.Trainer to a local path. Args: @@ -626,7 +610,12 @@ def _serialize_to_local(self, to_serialize: LightningTrainer, path: str): dirs_exist_ok=True, ) - def serialize(self, to_serialize: LightningTrainer, gcs_path: str, **kwargs) -> str: + def serialize( + self, + to_serialize: "lightning.pytorch.Trainer", # noqa: F821 + gcs_path: str, + **kwargs, + ) -> str: """Serializes a lightning.pytorch.Trainer to a gcs path. Args: @@ -660,7 +649,9 @@ def serialize(self, to_serialize: LightningTrainer, gcs_path: str, **kwargs) -> return gcs_path - def _deserialize_from_local(self, path: str) -> LightningTrainer: + def _deserialize_from_local( + self, path: str + ) -> "lightning.pytorch.Trainer": # noqa: F821 """Deserialize a lightning.pytorch.Trainer given a local path. Args: @@ -734,7 +725,9 @@ def _deserialize_from_local(self, path: str) -> LightningTrainer: return trainer - def deserialize(self, serialized_gcs_path: str, **kwargs) -> LightningTrainer: + def deserialize( + self, serialized_gcs_path: str, **kwargs + ) -> "lightning.pytorch.Trainer": # noqa: F821 """Deserialize a lightning.pytorch.Trainer given the gcs path. Args: @@ -1154,10 +1147,14 @@ def serialize( elif detected_framework == "torch": # Install using custom_commands to avoid numpy dependency conflict BigframeSerializer._metadata.custom_commands.append("pip install torchdata") - BigframeSerializer._metadata.custom_commands.append("pip install torcharrow") + BigframeSerializer._metadata.custom_commands.append( + "pip install torcharrow" + ) elif detected_framework == "tensorflow": tensorflow_io_dep = "tensorflow-io==" + self._get_tfio_verison() - tensorflow_io_gcs_fs_dep = "tensorflow-io-gcs-filesystem==" + self._get_tfio_verison() + tensorflow_io_gcs_fs_dep = ( + "tensorflow-io-gcs-filesystem==" + self._get_tfio_verison() + ) BigframeSerializer._metadata.dependencies.append(tensorflow_io_dep) BigframeSerializer._metadata.dependencies.append(tensorflow_io_gcs_fs_dep)