Skip to content

Commit 674227d

Browse files
authored
feat: parse project location when passed full resource name to get apis (#297)
1 parent 10b89e2 commit 674227d

File tree

7 files changed

+156
-9
lines changed

7 files changed

+156
-9
lines changed

google/cloud/aiplatform/base.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import functools
2121
import inspect
2222
import threading
23-
from typing import Any, Callable, Dict, Optional, Sequence, Type, Union
23+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
2424

2525
import proto
2626

@@ -266,6 +266,7 @@ def __init__(
266266
project: Optional[str] = None,
267267
location: Optional[str] = None,
268268
credentials: Optional[auth_credentials.Credentials] = None,
269+
resource_name: Optional[str] = None,
269270
):
270271
"""Initializes class with project, location, and api_client.
271272
@@ -274,8 +275,14 @@ def __init__(
274275
location(str): The location of the resource noun.
275276
credentials(google.auth.crendentials.Crendentials): Optional custom
276277
credentials to use when accessing interacting with resource noun.
278+
resource_name(str): A fully-qualified resource name or ID.
277279
"""
278280

281+
if resource_name:
282+
project, location = self._get_and_validate_project_location(
283+
resource_name=resource_name, project=project, location=location
284+
)
285+
279286
self.project = project or initializer.global_config.project
280287
self.location = location or initializer.global_config.location
281288
self.credentials = credentials or initializer.global_config.credentials
@@ -306,6 +313,41 @@ def _instantiate_client(
306313
prediction_client=cls._is_client_prediction_client,
307314
)
308315

316+
def _get_and_validate_project_location(
317+
self,
318+
resource_name: str,
319+
project: Optional[str] = None,
320+
location: Optional[str] = None,
321+
) -> Tuple:
322+
323+
"""Validate the project and location for the resource.
324+
325+
Args:
326+
resource_name(str): Required. A fully-qualified resource name or ID.
327+
project(str): Project of the resource noun.
328+
location(str): The location of the resource noun.
329+
330+
Raises:
331+
RuntimeError if location is different from resource location
332+
"""
333+
334+
if not project and not location:
335+
return project, location
336+
337+
fields = utils.extract_fields_from_resource_name(
338+
resource_name, self._resource_noun
339+
)
340+
if not fields:
341+
return project, location
342+
343+
if location and fields.location != location:
344+
raise RuntimeError(
345+
f"location {location} is provided, but different from "
346+
f"the resource location {fields.location}"
347+
)
348+
349+
return fields.project, fields.location
350+
309351
def _get_gca_resource(self, resource_name: str) -> proto.Message:
310352
"""Returns GAPIC service representation of client class resource."""
311353
"""
@@ -493,6 +535,7 @@ def __init__(
493535
project: Optional[str] = None,
494536
location: Optional[str] = None,
495537
credentials: Optional[auth_credentials.Credentials] = None,
538+
resource_name: Optional[str] = None,
496539
):
497540
"""Initializes class with project, location, and api_client.
498541
@@ -502,9 +545,14 @@ def __init__(
502545
credentials(google.auth.crendentials.Crendentials):
503546
Optional. custom credentials to use when accessing interacting with
504547
resource noun.
548+
resource_name(str): A fully-qualified resource name or ID.
505549
"""
506550
AiPlatformResourceNoun.__init__(
507-
self, project=project, location=location, credentials=credentials
551+
self,
552+
project=project,
553+
location=location,
554+
credentials=credentials,
555+
resource_name=resource_name,
508556
)
509557
FutureManager.__init__(self)
510558

@@ -514,6 +562,7 @@ def _empty_constructor(
514562
project: Optional[str] = None,
515563
location: Optional[str] = None,
516564
credentials: Optional[auth_credentials.Credentials] = None,
565+
resource_name: Optional[str] = None,
517566
) -> "AiPlatformResourceNounWithFutureManager":
518567
"""Initializes with all attributes set to None.
519568
@@ -526,11 +575,18 @@ def _empty_constructor(
526575
credentials(google.auth.crendentials.Crendentials):
527576
Optional. custom credentials to use when accessing interacting with
528577
resource noun.
578+
resource_name(str): A fully-qualified resource name or ID.
529579
Returns:
530580
An instance of this class with attributes set to None.
531581
"""
532582
self = cls.__new__(cls)
533-
AiPlatformResourceNoun.__init__(self, project, location, credentials)
583+
AiPlatformResourceNoun.__init__(
584+
self,
585+
project=project,
586+
location=location,
587+
credentials=credentials,
588+
resource_name=resource_name,
589+
)
534590
FutureManager.__init__(self)
535591
self._gca_resource = None
536592
return self

google/cloud/aiplatform/datasets/dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ def __init__(
7171
"""
7272

7373
super().__init__(
74-
project=project, location=location, credentials=credentials,
74+
project=project,
75+
location=location,
76+
credentials=credentials,
77+
resource_name=dataset_name,
7578
)
7679
self._gca_resource = self._get_gca_resource(resource_name=dataset_name)
7780
self._validate_metadata_schema_uri()

google/cloud/aiplatform/jobs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ def __init__(
107107
Custom credentials to use. If not set, credentials set in
108108
aiplatform.init will be used.
109109
"""
110-
super().__init__(project=project, location=location, credentials=credentials)
110+
super().__init__(
111+
project=project,
112+
location=location,
113+
credentials=credentials,
114+
resource_name=job_name,
115+
)
111116
self._gca_resource = self._get_gca_resource(resource_name=job_name)
112117

113118
@property

google/cloud/aiplatform/models.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,12 @@ def __init__(
9999
credentials set in aiplatform.init.
100100
"""
101101

102-
super().__init__(project=project, location=location, credentials=credentials)
102+
super().__init__(
103+
project=project,
104+
location=location,
105+
credentials=credentials,
106+
resource_name=endpoint_name,
107+
)
103108
self._gca_resource = self._get_gca_resource(resource_name=endpoint_name)
104109
self._prediction_client = self._instantiate_prediction_client(
105110
location=location or initializer.global_config.location,
@@ -1144,7 +1149,12 @@ def __init__(
11441149
credentials set in aiplatform.init will be used.
11451150
"""
11461151

1147-
super().__init__(project=project, location=location, credentials=credentials)
1152+
super().__init__(
1153+
project=project,
1154+
location=location,
1155+
credentials=credentials,
1156+
resource_name=model_name,
1157+
)
11481158
self._gca_resource = self._get_gca_resource(resource_name=model_name)
11491159

11501160
# TODO(b/170979552) Add support for predict schemata

google/cloud/aiplatform/training_jobs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,10 @@ def get(
180180
# These parameters won't be used as user can not run the job again.
181181
# If they try, an exception will be raised.
182182
self = cls._empty_constructor(
183-
project=project, location=location, credentials=credentials
183+
project=project,
184+
location=location,
185+
credentials=credentials,
186+
resource_name=resource_name,
184187
)
185188

186189
self._gca_resource = self._get_gca_resource(resource_name=resource_name)

tests/unit/aiplatform/test_datasets.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
_TEST_PROJECT = "test-project"
4949
_TEST_LOCATION = "us-central1"
5050
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
51+
_TEST_ALT_PROJECT = "test-project_alt"
5152

5253
_TEST_ALT_LOCATION = "europe-west4"
5354
_TEST_INVALID_LOCATION = "us-central2"
@@ -259,6 +260,38 @@ def test_init_dataset(self, get_dataset_mock):
259260
datasets.Dataset(dataset_name=_TEST_NAME)
260261
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
261262

263+
def test_init_dataset_with_id_only_with_project_and_location(
264+
self, get_dataset_mock
265+
):
266+
aiplatform.init(project=_TEST_PROJECT)
267+
datasets.Dataset(
268+
dataset_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION
269+
)
270+
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
271+
272+
def test_init_dataset_with_project_and_location(self, get_dataset_mock):
273+
aiplatform.init(project=_TEST_PROJECT)
274+
datasets.Dataset(
275+
dataset_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION
276+
)
277+
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
278+
279+
def test_init_dataset_with_alt_project_and_location(self, get_dataset_mock):
280+
aiplatform.init(project=_TEST_PROJECT)
281+
datasets.Dataset(
282+
dataset_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION
283+
)
284+
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)
285+
286+
def test_init_dataset_with_project_and_alt_location(self):
287+
aiplatform.init(project=_TEST_PROJECT)
288+
with pytest.raises(RuntimeError):
289+
datasets.Dataset(
290+
dataset_name=_TEST_NAME,
291+
project=_TEST_PROJECT,
292+
location=_TEST_ALT_LOCATION,
293+
)
294+
262295
def test_init_dataset_with_id_only(self, get_dataset_mock):
263296
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
264297
datasets.Dataset(dataset_name=_TEST_ID)

tests/unit/aiplatform/test_training_jobs.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
_TEST_GCS_PATH_WITH_TRAILING_SLASH = f"{_TEST_GCS_PATH}/"
6666
_TEST_LOCAL_SCRIPT_FILE_NAME = "____test____script.py"
6767
_TEST_LOCAL_SCRIPT_FILE_PATH = f"path/to/{_TEST_LOCAL_SCRIPT_FILE_NAME}"
68-
_TEST_PROJECT = "test-project"
6968
_TEST_PYTHON_SOURCE = """
7069
print('hello world')
7170
"""
@@ -107,6 +106,8 @@
107106
_TEST_NAME = (
108107
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/trainingPipelines/{_TEST_ID}"
109108
)
109+
_TEST_ALT_PROJECT = "test-project-alt"
110+
_TEST_ALT_LOCATION = "europe-west4"
110111

111112
_TEST_MODEL_INSTANCE_SCHEMA_URI = "instance_schema_uri.yaml"
112113
_TEST_MODEL_PARAMETERS_SCHEMA_URI = "parameters_schema_uri.yaml"
@@ -1381,6 +1382,42 @@ def test_get_training_job_with_id_only(self, get_training_job_custom_mock):
13811382
training_jobs.CustomTrainingJob.get(resource_name=_TEST_ID)
13821383
get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME)
13831384

1385+
def test_get_training_job_with_id_only_with_project_and_location(
1386+
self, get_training_job_custom_mock
1387+
):
1388+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
1389+
training_jobs.CustomTrainingJob.get(
1390+
resource_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION
1391+
)
1392+
get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME)
1393+
1394+
def test_get_training_job_with_project_and_location(
1395+
self, get_training_job_custom_mock
1396+
):
1397+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
1398+
training_jobs.CustomTrainingJob.get(
1399+
resource_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION
1400+
)
1401+
get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME)
1402+
1403+
def test_get_training_job_with_alt_project_and_location(
1404+
self, get_training_job_custom_mock
1405+
):
1406+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
1407+
training_jobs.CustomTrainingJob.get(
1408+
resource_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION
1409+
)
1410+
get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME)
1411+
1412+
def test_get_training_job_with_project_and_alt_location(self):
1413+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
1414+
with pytest.raises(RuntimeError):
1415+
training_jobs.CustomTrainingJob.get(
1416+
resource_name=_TEST_NAME,
1417+
project=_TEST_PROJECT,
1418+
location=_TEST_ALT_LOCATION,
1419+
)
1420+
13841421
@pytest.mark.parametrize("sync", [True, False])
13851422
def test_run_call_pipeline_service_create_with_nontabular_dataset(
13861423
self,

0 commit comments

Comments
 (0)