Skip to content

Commit 3b19fff

Browse files
authored
feat: add cancel method to pipeline client (#488)
1 parent 74627ba commit 3b19fff

File tree

2 files changed

+81
-9
lines changed

2 files changed

+81
-9
lines changed

google/cloud/aiplatform/pipeline_jobs.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def state(self) -> Optional[gca_pipeline_state_v1beta1.PipelineState]:
264264
@property
265265
def _has_run(self) -> bool:
266266
"""Helper property to check if this pipeline job has been run."""
267-
return bool(self._gca_resource.name)
267+
return bool(self._gca_resource.create_time)
268268

269269
@property
270270
def has_failed(self) -> bool:
@@ -310,3 +310,19 @@ def _block_until_complete(self):
310310
log_wait = min(log_wait * multiplier, max_wait)
311311
previous_time = current_time
312312
time.sleep(wait)
313+
314+
def cancel(self) -> None:
315+
"""Starts asynchronous cancellation on the PipelineJob. The server
316+
makes a best effort to cancel the job, but success is not guaranteed.
317+
On successful cancellation, the PipelineJob is not deleted; instead it
318+
becomes a job with state set to `CANCELLED`.
319+
320+
Raises:
321+
RuntimeError: If this PipelineJob has not started running.
322+
"""
323+
if not self._has_run:
324+
raise RuntimeError(
325+
"This PipelineJob has not been launched, use the `run()` method "
326+
"to start. `cancel()` can only be called on a job that is running."
327+
)
328+
self.api_client.cancel_pipeline_job(name=self.resource_name)

tests/unit/aiplatform/test_pipeline_jobs.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
from unittest import mock
2222
from importlib import reload
2323
from unittest.mock import patch
24+
from datetime import datetime
2425

2526
from google.auth import credentials as auth_credentials
26-
2727
from google.cloud import aiplatform
2828
from google.cloud import storage
29-
3029
from google.cloud.aiplatform import pipeline_jobs
3130
from google.cloud.aiplatform import initializer
3231
from google.protobuf import json_format
@@ -72,6 +71,7 @@
7271
_TEST_PIPELINE_RESOURCE_NAME = (
7372
f"{_TEST_PARENT}/fakePipelineJobs/{_TEST_PIPELINE_JOB_ID}"
7473
)
74+
_TEST_PIPELINE_CREATE_TIME = datetime.now()
7575

7676

7777
@pytest.fixture
@@ -82,13 +82,16 @@ def mock_pipeline_service_create():
8282
mock_create_pipeline_job.return_value = gca_pipeline_job_v1beta1.PipelineJob(
8383
name=_TEST_PIPELINE_JOB_NAME,
8484
state=gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED,
85+
create_time=_TEST_PIPELINE_CREATE_TIME,
8586
)
8687
yield mock_create_pipeline_job
8788

8889

8990
def make_pipeline_job(state):
9091
return gca_pipeline_job_v1beta1.PipelineJob(
91-
name=_TEST_PIPELINE_JOB_NAME, state=state,
92+
name=_TEST_PIPELINE_JOB_NAME,
93+
state=state,
94+
create_time=_TEST_PIPELINE_CREATE_TIME,
9295
)
9396

9497

@@ -130,6 +133,14 @@ def mock_pipeline_service_get():
130133
yield mock_get_pipeline_job
131134

132135

136+
@pytest.fixture
137+
def mock_pipeline_service_cancel():
138+
with mock.patch.object(
139+
pipeline_service_client_v1beta1.PipelineServiceClient, "cancel_pipeline_job"
140+
) as mock_cancel_pipeline_job:
141+
yield mock_cancel_pipeline_job
142+
143+
133144
@pytest.fixture
134145
def mock_load_json():
135146
with patch.object(storage.Blob, "download_as_bytes") as mock_load_json:
@@ -155,13 +166,10 @@ def setup_method(self):
155166
def teardown_method(self):
156167
initializer.global_pool.shutdown(wait=True)
157168

169+
@pytest.mark.usefixtures("mock_load_json")
158170
@pytest.mark.parametrize("sync", [True, False])
159171
def test_run_call_pipeline_service_create(
160-
self,
161-
mock_pipeline_service_create,
162-
mock_pipeline_service_get,
163-
mock_load_json,
164-
sync,
172+
self, mock_pipeline_service_create, mock_pipeline_service_get, sync,
165173
):
166174
aiplatform.init(
167175
project=_TEST_PROJECT,
@@ -213,3 +221,51 @@ def test_run_call_pipeline_service_create(
213221
assert job._gca_resource == make_pipeline_job(
214222
gca_pipeline_state_v1beta1.PipelineState.PIPELINE_STATE_SUCCEEDED
215223
)
224+
225+
@pytest.mark.usefixtures(
226+
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
227+
)
228+
def test_cancel_pipeline_job(
229+
self, mock_pipeline_service_cancel,
230+
):
231+
aiplatform.init(
232+
project=_TEST_PROJECT,
233+
staging_bucket=_TEST_GCS_BUCKET_NAME,
234+
credentials=_TEST_CREDENTIALS,
235+
)
236+
237+
job = pipeline_jobs.PipelineJob(
238+
display_name=_TEST_PIPELINE_JOB_ID,
239+
template_path=_TEST_TEMPLATE_PATH,
240+
job_id=_TEST_PIPELINE_JOB_ID,
241+
)
242+
243+
job.run()
244+
job.cancel()
245+
246+
mock_pipeline_service_cancel.assert_called_once_with(
247+
name=_TEST_PIPELINE_JOB_NAME
248+
)
249+
250+
@pytest.mark.usefixtures(
251+
"mock_pipeline_service_create", "mock_pipeline_service_get", "mock_load_json",
252+
)
253+
def test_cancel_pipeline_job_without_running(
254+
self, mock_pipeline_service_cancel,
255+
):
256+
aiplatform.init(
257+
project=_TEST_PROJECT,
258+
staging_bucket=_TEST_GCS_BUCKET_NAME,
259+
credentials=_TEST_CREDENTIALS,
260+
)
261+
262+
job = pipeline_jobs.PipelineJob(
263+
display_name=_TEST_PIPELINE_JOB_ID,
264+
template_path=_TEST_TEMPLATE_PATH,
265+
job_id=_TEST_PIPELINE_JOB_ID,
266+
)
267+
268+
with pytest.raises(RuntimeError) as e:
269+
job.cancel()
270+
271+
assert e.match(regexp=r"PipelineJob has not been launched")

0 commit comments

Comments
 (0)