2121from unittest import mock
2222from importlib import reload
2323from unittest .mock import patch
24+ from datetime import datetime
2425
2526from google .auth import credentials as auth_credentials
26-
2727from google .cloud import aiplatform
2828from google .cloud import storage
29-
3029from google .cloud .aiplatform import pipeline_jobs
3130from google .cloud .aiplatform import initializer
3231from google .protobuf import json_format
7271_TEST_PIPELINE_RESOURCE_NAME = (
7372 f"{ _TEST_PARENT } /fakePipelineJobs/{ _TEST_PIPELINE_JOB_ID } "
7473)
74+ _TEST_PIPELINE_CREATE_TIME = datetime .now ()
7575
7676
7777@pytest .fixture
@@ -82,13 +82,16 @@ def mock_pipeline_service_create():
8282 mock_create_pipeline_job .return_value = gca_pipeline_job_v1beta1 .PipelineJob (
8383 name = _TEST_PIPELINE_JOB_NAME ,
8484 state = gca_pipeline_state_v1beta1 .PipelineState .PIPELINE_STATE_SUCCEEDED ,
85+ create_time = _TEST_PIPELINE_CREATE_TIME ,
8586 )
8687 yield mock_create_pipeline_job
8788
8889
8990def make_pipeline_job (state ):
9091 return gca_pipeline_job_v1beta1 .PipelineJob (
91- name = _TEST_PIPELINE_JOB_NAME , state = state ,
92+ name = _TEST_PIPELINE_JOB_NAME ,
93+ state = state ,
94+ create_time = _TEST_PIPELINE_CREATE_TIME ,
9295 )
9396
9497
@@ -130,6 +133,14 @@ def mock_pipeline_service_get():
130133 yield mock_get_pipeline_job
131134
132135
136+ @pytest .fixture
137+ def mock_pipeline_service_cancel ():
138+ with mock .patch .object (
139+ pipeline_service_client_v1beta1 .PipelineServiceClient , "cancel_pipeline_job"
140+ ) as mock_cancel_pipeline_job :
141+ yield mock_cancel_pipeline_job
142+
143+
133144@pytest .fixture
134145def mock_load_json ():
135146 with patch .object (storage .Blob , "download_as_bytes" ) as mock_load_json :
@@ -155,13 +166,10 @@ def setup_method(self):
155166 def teardown_method (self ):
156167 initializer .global_pool .shutdown (wait = True )
157168
169+ @pytest .mark .usefixtures ("mock_load_json" )
158170 @pytest .mark .parametrize ("sync" , [True , False ])
159171 def test_run_call_pipeline_service_create (
160- self ,
161- mock_pipeline_service_create ,
162- mock_pipeline_service_get ,
163- mock_load_json ,
164- sync ,
172+ self , mock_pipeline_service_create , mock_pipeline_service_get , sync ,
165173 ):
166174 aiplatform .init (
167175 project = _TEST_PROJECT ,
@@ -213,3 +221,51 @@ def test_run_call_pipeline_service_create(
213221 assert job ._gca_resource == make_pipeline_job (
214222 gca_pipeline_state_v1beta1 .PipelineState .PIPELINE_STATE_SUCCEEDED
215223 )
224+
225+ @pytest .mark .usefixtures (
226+ "mock_pipeline_service_create" , "mock_pipeline_service_get" , "mock_load_json" ,
227+ )
228+ def test_cancel_pipeline_job (
229+ self , mock_pipeline_service_cancel ,
230+ ):
231+ aiplatform .init (
232+ project = _TEST_PROJECT ,
233+ staging_bucket = _TEST_GCS_BUCKET_NAME ,
234+ credentials = _TEST_CREDENTIALS ,
235+ )
236+
237+ job = pipeline_jobs .PipelineJob (
238+ display_name = _TEST_PIPELINE_JOB_ID ,
239+ template_path = _TEST_TEMPLATE_PATH ,
240+ job_id = _TEST_PIPELINE_JOB_ID ,
241+ )
242+
243+ job .run ()
244+ job .cancel ()
245+
246+ mock_pipeline_service_cancel .assert_called_once_with (
247+ name = _TEST_PIPELINE_JOB_NAME
248+ )
249+
250+ @pytest .mark .usefixtures (
251+ "mock_pipeline_service_create" , "mock_pipeline_service_get" , "mock_load_json" ,
252+ )
253+ def test_cancel_pipeline_job_without_running (
254+ self , mock_pipeline_service_cancel ,
255+ ):
256+ aiplatform .init (
257+ project = _TEST_PROJECT ,
258+ staging_bucket = _TEST_GCS_BUCKET_NAME ,
259+ credentials = _TEST_CREDENTIALS ,
260+ )
261+
262+ job = pipeline_jobs .PipelineJob (
263+ display_name = _TEST_PIPELINE_JOB_ID ,
264+ template_path = _TEST_TEMPLATE_PATH ,
265+ job_id = _TEST_PIPELINE_JOB_ID ,
266+ )
267+
268+ with pytest .raises (RuntimeError ) as e :
269+ job .cancel ()
270+
271+ assert e .match (regexp = r"PipelineJob has not been launched" )
0 commit comments