Skip to content

Commit 37875b5

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Enable Ray Job submission without VPC peering
PiperOrigin-RevId: 641037130
1 parent 6592042 commit 37875b5

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

google/cloud/aiplatform/vertex_ray/dashboard_sdk.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,27 @@ def get_job_submission_client_cluster_info(
6969
"RAY_HEAD_NODE_INTERNAL_IP", None
7070
)
7171
if head_address is None:
72-
raise RuntimeError(
73-
"[Ray on Vertex AI]: Unable to obtain a response from the backend."
72+
# No peering. Try to get the dashboard address.
73+
dashboard_address = response.resource_runtime.access_uris.get(
74+
"RAY_DASHBOARD_URI", None
7475
)
75-
76+
if dashboard_address is None:
77+
raise RuntimeError(
78+
"[Ray on Vertex AI]: Unable to obtain a response from the backend."
79+
)
80+
if _validation_utils.valid_dashboard_address(dashboard_address):
81+
bearer_token = _validation_utils.get_bearer_token()
82+
if kwargs.get("headers", None) is None:
83+
kwargs["headers"] = {
84+
"Content-Type": "application/json",
85+
"Authorization": "Bearer {}".format(bearer_token),
86+
}
87+
return oss_dashboard_sdk.get_job_submission_client_cluster_info(
88+
address=dashboard_address,
89+
_use_tls=True,
90+
*args,
91+
**kwargs,
92+
)
7693
# Assume that head node internal IP in a form of xxx.xxx.xxx.xxx:10001.
7794
# Ray-on-Vertex cluster serves the Dashboard at port 8888 instead of
7895
# the default 8251.

tests/unit/vertex_ray/test_dashboard_sdk.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def get_persistent_resource_status_running_mock():
4444
yield get_persistent_resource
4545

4646

47+
@pytest.fixture
48+
def get_persistent_resource_status_running_byosa_public_mock():
49+
# Cluster with BYOSA and no peering
50+
with mock.patch.object(
51+
vertex_ray.util._gapic_utils, "get_persistent_resource"
52+
) as get_persistent_resource:
53+
get_persistent_resource.return_value = (
54+
tc.ClusterConstants.TEST_RESPONSE_RUNNING_1_POOL_BYOSA
55+
)
56+
yield get_persistent_resource
57+
58+
4759
@pytest.fixture
4860
def get_bearer_token_mock():
4961
with mock.patch.object(
@@ -112,3 +124,27 @@ def test_job_submission_client_cluster_info_with_dashboard_address(
112124
_use_tls=True,
113125
headers=tc.ClusterConstants.TEST_HEADERS,
114126
)
127+
128+
@pytest.mark.usefixtures(
129+
"get_persistent_resource_status_running_byosa_public_mock", "google_auth_mock"
130+
)
131+
def test_job_submission_client_cluster_info_with_cluster_name_byosa_public(
132+
self,
133+
ray_get_job_submission_client_cluster_info_mock,
134+
get_bearer_token_mock,
135+
get_project_number_mock,
136+
):
137+
aiplatform.init(project=tc.ProjectConstants.TEST_GCP_PROJECT_ID)
138+
139+
vertex_ray.get_job_submission_client_cluster_info(
140+
tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID
141+
)
142+
get_project_number_mock.assert_called_once_with(
143+
name="projects/{}".format(tc.ProjectConstants.TEST_GCP_PROJECT_ID)
144+
)
145+
get_bearer_token_mock.assert_called_once_with()
146+
ray_get_job_submission_client_cluster_info_mock.assert_called_once_with(
147+
address=tc.ClusterConstants.TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
148+
_use_tls=True,
149+
headers=tc.ClusterConstants.TEST_HEADERS,
150+
)

0 commit comments

Comments
 (0)