Skip to content

Commit c22220e

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Create PipelineJobSchedule in same project and location as associated PipelineJob by default
PiperOrigin-RevId: 568706789
1 parent 0c1c129 commit c22220e

File tree

2 files changed

+133
-2
lines changed

2 files changed

+133
-2
lines changed

google/cloud/aiplatform/pipeline_job_schedules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,17 @@ def __init__(
7474
Overrides credentials set in aiplatform.init.
7575
project (str):
7676
Optional. The project that you want to run this PipelineJobSchedule in.
77-
If not set, the project set in aiplatform.init will be used.
77+
If not set, the project used for the PipelineJob will be used.
7878
location (str):
7979
Optional. Location to create PipelineJobSchedule. If not set,
80-
location set in aiplatform.init will be used.
80+
location used for the PipelineJob will be used.
8181
"""
8282
if not display_name:
8383
display_name = self.__class__._generate_display_name()
8484
utils.validate_display_name(display_name)
8585

86+
project = project or pipeline_job.project
87+
location = location or pipeline_job.location
8688
super().__init__(credentials=credentials, project=project, location=location)
8789

8890
self._parent = initializer.global_config.common_location_path(

tests/unit/aiplatform/test_pipeline_job_schedules.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,91 @@ def test_call_schedule_service_create(
597597
gca_schedule.Schedule.State.COMPLETED
598598
)
599599

600+
@pytest.mark.parametrize(
601+
"job_spec",
602+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
603+
)
604+
def test_call_schedule_service_create_uses_pipeline_job_project_location(
605+
self,
606+
mock_schedule_service_create,
607+
mock_schedule_service_get,
608+
mock_schedule_bucket_exists,
609+
job_spec,
610+
mock_load_yaml_and_json,
611+
):
612+
"""Creates a PipelineJobSchedule.
613+
614+
Tests that the PipelineJobSchedule is created in the same project and location as the PipelineJob.
615+
"""
616+
aiplatform.init(
617+
project=_TEST_PROJECT,
618+
staging_bucket=_TEST_GCS_BUCKET_NAME,
619+
location=_TEST_LOCATION,
620+
credentials=_TEST_CREDENTIALS,
621+
)
622+
623+
job = pipeline_jobs.PipelineJob(
624+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
625+
template_path=_TEST_TEMPLATE_PATH,
626+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
627+
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
628+
enable_caching=True,
629+
project="managed-pipeline-test",
630+
location="europe-west4",
631+
)
632+
633+
pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
634+
pipeline_job=job,
635+
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
636+
)
637+
638+
assert pipeline_job_schedule.project == "managed-pipeline-test"
639+
assert pipeline_job_schedule.location == "europe-west4"
640+
641+
@pytest.mark.parametrize(
642+
"job_spec",
643+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
644+
)
645+
def test_call_schedule_service_create_uses_specified_project_location(
646+
self,
647+
mock_schedule_service_create,
648+
mock_schedule_service_get,
649+
mock_schedule_bucket_exists,
650+
job_spec,
651+
mock_load_yaml_and_json,
652+
):
653+
"""Creates a PipelineJobSchedule.
654+
655+
Tests that PipelineJobSchedule is created in the specified project and location over the PipelineJob's.
656+
"""
657+
aiplatform.init(
658+
project=_TEST_PROJECT,
659+
staging_bucket=_TEST_GCS_BUCKET_NAME,
660+
location=_TEST_LOCATION,
661+
credentials=_TEST_CREDENTIALS,
662+
)
663+
664+
job = pipeline_jobs.PipelineJob(
665+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
666+
template_path=_TEST_TEMPLATE_PATH,
667+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
668+
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
669+
enable_caching=True,
670+
)
671+
672+
pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
673+
pipeline_job=job,
674+
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
675+
project="managed-pipeline-test",
676+
location="europe-west4",
677+
)
678+
679+
assert job.project == _TEST_PROJECT
680+
assert job.location == _TEST_LOCATION
681+
682+
assert pipeline_job_schedule.project == "managed-pipeline-test"
683+
assert pipeline_job_schedule.location == "europe-west4"
684+
600685
@pytest.mark.parametrize(
601686
"job_spec",
602687
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
@@ -1148,6 +1233,50 @@ def test_call_pipeline_job_create_schedule(
11481233
gca_schedule.Schedule.State.COMPLETED
11491234
)
11501235

1236+
@pytest.mark.parametrize(
1237+
"job_spec",
1238+
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
1239+
)
1240+
def test_call_pipeline_job_create_schedule_uses_pipeline_job_project_location(
1241+
self,
1242+
mock_schedule_service_create,
1243+
mock_schedule_service_get,
1244+
job_spec,
1245+
mock_load_yaml_and_json,
1246+
):
1247+
"""Creates a PipelineJobSchedule via PipelineJob.create_schedule().
1248+
1249+
Tests that the PipelineJobSchedule is created in the same project and location as the PipelineJob.
1250+
"""
1251+
aiplatform.init(
1252+
project=_TEST_PROJECT,
1253+
staging_bucket=_TEST_GCS_BUCKET_NAME,
1254+
location=_TEST_LOCATION,
1255+
credentials=_TEST_CREDENTIALS,
1256+
)
1257+
1258+
job = pipeline_jobs.PipelineJob(
1259+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
1260+
template_path=_TEST_TEMPLATE_PATH,
1261+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
1262+
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
1263+
enable_caching=True,
1264+
project="managed-pipeline-test",
1265+
location="europe-west4",
1266+
)
1267+
1268+
pipeline_job_schedule = job.create_schedule(
1269+
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
1270+
cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
1271+
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
1272+
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
1273+
service_account=_TEST_SERVICE_ACCOUNT,
1274+
network=_TEST_NETWORK,
1275+
)
1276+
1277+
assert pipeline_job_schedule.project == "managed-pipeline-test"
1278+
assert pipeline_job_schedule.location == "europe-west4"
1279+
11511280
@pytest.mark.usefixtures("mock_schedule_service_get")
11521281
def test_get_schedule(self, mock_schedule_service_get):
11531282
aiplatform.init(project=_TEST_PROJECT)

0 commit comments

Comments
 (0)