Skip to content

Commit 8b5a3f4

Browse files
committed
Fixed mocks
1 parent 83b3268 commit 8b5a3f4

6 files changed

+39
-31
lines changed

samples/model-builder/conftest.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,18 @@ def mock_custom_training_job():
143143
yield mock
144144

145145

146+
@pytest.fixture
147+
def mock_custom_container_training_job():
148+
mock = MagicMock(aiplatform.training_jobs.CustomContainerTrainingJob)
149+
yield mock
150+
151+
152+
@pytest.fixture
153+
def mock_custom_package_training_job():
154+
mock = MagicMock(aiplatform.training_jobs.CustomPythonPackageTrainingJob)
155+
yield mock
156+
157+
146158
@pytest.fixture
147159
def mock_image_training_job():
148160
mock = MagicMock(aiplatform.training_jobs.AutoMLImageTrainingJob)
@@ -194,47 +206,41 @@ def mock_run_automl_image_training_job(mock_image_training_job):
194206

195207

196208
@pytest.fixture
197-
def mock_init_custom_training_job():
198-
with patch.object(aiplatform.CustomTrainingJob, "__init__") as mock:
199-
mock.return_value = None
209+
def mock_get_custom_training_job(mock_custom_training_job):
210+
with patch.object(aiplatform, "CustomTrainingJob") as mock:
211+
mock.return_value = mock_custom_training_job
200212
yield mock
201213

202214

203215
@pytest.fixture
204-
def mock_run_custom_training_job():
205-
with patch.object(aiplatform.CustomTrainingJob, "run") as mock:
216+
def mock_get_custom_container_training_job(mock_custom_container_training_job):
217+
with patch.object(aiplatform, "CustomContainerTrainingJob") as mock:
218+
mock.return_value = mock_custom_container_training_job
206219
yield mock
207220

208221

209222
@pytest.fixture
210-
def mock_init_custom_container_training_job():
211-
with patch.object(
212-
aiplatform.training_jobs.CustomContainerTrainingJob, "__init__"
213-
) as mock:
214-
mock.return_value = None
223+
def mock_get_custom_package_training_job(mock_custom_package_training_job):
224+
with patch.object(aiplatform, "CustomPythonPackageTrainingJob") as mock:
225+
mock.return_value = mock_custom_package_training_job
215226
yield mock
216227

217228

218229
@pytest.fixture
219-
def mock_run_custom_container_training_job():
220-
with patch.object(aiplatform.CustomContainerTrainingJob, "run") as mock:
230+
def mock_run_custom_training_job(mock_custom_training_job):
231+
with patch.object(mock_custom_training_job, "run") as mock:
221232
yield mock
222233

223234

224235
@pytest.fixture
225-
def mock_init_custom_package_training_job():
226-
with patch.object(
227-
aiplatform.training_jobs.CustomPythonPackageTrainingJob, "__init__"
228-
) as mock:
229-
mock.return_value = None
236+
def mock_run_custom_container_training_job(mock_custom_container_training_job):
237+
with patch.object(mock_custom_container_training_job, "run") as mock:
230238
yield mock
231239

232240

233241
@pytest.fixture
234-
def mock_run_custom_package_training_job():
235-
with patch.object(
236-
aiplatform.training_jobs.CustomPythonPackageTrainingJob, "run"
237-
) as mock:
242+
def mock_run_custom_package_training_job(mock_custom_package_training_job):
243+
with patch.object(mock_custom_package_training_job, "run") as mock:
238244
yield mock
239245

240246

samples/model-builder/create_training_pipeline_custom_container_job_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_create_training_pipeline_custom_container_job_sample(
2121
mock_sdk_init,
2222
mock_image_dataset,
2323
mock_get_image_dataset,
24-
mock_init_custom_container_training_job,
24+
mock_get_custom_container_training_job,
2525
mock_run_custom_container_training_job,
2626
):
2727

@@ -50,7 +50,7 @@ def test_create_training_pipeline_custom_container_job_sample(
5050
staging_bucket=constants.STAGING_BUCKET,
5151
)
5252

53-
mock_init_custom_container_training_job.assert_called_once_with(
53+
mock_get_custom_container_training_job.assert_called_once_with(
5454
display_name=constants.DISPLAY_NAME,
5555
container_uri=constants.CONTAINER_URI,
5656
model_serving_container_image_uri=constants.CONTAINER_URI,

samples/model-builder/create_training_pipeline_custom_job_sample.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
def create_training_pipeline_custom_job_sample(
2222
project: str,
2323
location: str,
24+
staging_bucket: str,
2425
display_name: str,
2526
script_path: str,
2627
container_uri: str,
@@ -37,7 +38,7 @@ def create_training_pipeline_custom_job_sample(
3738
test_fraction_split: float = 0.1,
3839
sync: bool = True,
3940
):
40-
aiplatform.init(project=project, location=location)
41+
aiplatform.init(project=project, location=location, staging_bucket=staging_bucket)
4142

4243
job = aiplatform.CustomTrainingJob(
4344
display_name=display_name,

samples/model-builder/create_training_pipeline_custom_job_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ def test_create_training_pipeline_custom_job_sample(
2121
mock_sdk_init,
2222
mock_image_dataset,
2323
mock_get_image_dataset,
24-
mock_init_custom_training_job,
24+
mock_get_custom_training_job,
2525
mock_run_custom_training_job,
2626
):
2727

2828
create_training_pipeline_custom_job_sample.create_training_pipeline_custom_job_sample(
2929
project=constants.PROJECT,
3030
location=constants.LOCATION,
31+
staging_bucket=constants.STAGING_BUCKET,
3132
display_name=constants.DISPLAY_NAME,
3233
args=constants.ARGS,
3334
script_path=constants.SCRIPT_PATH,
@@ -45,9 +46,11 @@ def test_create_training_pipeline_custom_job_sample(
4546
)
4647

4748
mock_sdk_init.assert_called_once_with(
48-
project=constants.PROJECT, location=constants.LOCATION
49+
project=constants.PROJECT,
50+
location=constants.LOCATION,
51+
staging_bucket=constants.STAGING_BUCKET,
4952
)
50-
mock_init_custom_training_job.assert_called_once_with(
53+
mock_get_custom_training_job.assert_called_once_with(
5154
display_name=constants.DISPLAY_NAME,
5255
script_path=constants.SCRIPT_PATH,
5356
container_uri=constants.CONTAINER_URI,

samples/model-builder/create_training_pipeline_custom_package_job_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_create_training_pipeline_custom_package_job_sample(
2121
mock_sdk_init,
2222
mock_image_dataset,
2323
mock_get_image_dataset,
24-
mock_init_custom_package_training_job,
24+
mock_get_custom_package_training_job,
2525
mock_run_custom_package_training_job,
2626
):
2727

@@ -52,7 +52,7 @@ def test_create_training_pipeline_custom_package_job_sample(
5252
staging_bucket=constants.STAGING_BUCKET,
5353
)
5454

55-
mock_init_custom_package_training_job.assert_called_once_with(
55+
mock_get_custom_package_training_job.assert_called_once_with(
5656
display_name=constants.DISPLAY_NAME,
5757
python_package_gcs_uri=constants.PYTHON_PACKAGE_GCS_URI,
5858
python_module_name=constants.PYTHON_MODULE_NAME,

samples/model-builder/test_constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,6 @@
159159
"--file_system_poll_wait_seconds=31540000",
160160
],
161161
)
162-
MODEL_SERVING_CONTAINER_PREDICT_ROUTE = (f"/v1/models/{MODEL_NAME}:predict",)
163-
MODEL_SERVING_CONTAINER_HEALTH_ROUTE = f"/v1/models/{MODEL_NAME}"
164162
PYTHON_PACKAGE_GCS_URI = (
165163
"gs://bucket3/custom-training-python-package/my_app/trainer-0.1.tar.gz"
166164
)

0 commit comments

Comments
 (0)