|
49 | 49 | env_var as gca_env_var, |
50 | 50 | explanation as gca_explanation, |
51 | 51 | machine_resources as gca_machine_resources, |
| 52 | + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat, |
52 | 53 | model_service as gca_model_service, |
53 | 54 | model_evaluation as gca_model_evaluation, |
54 | 55 | endpoint_service as gca_endpoint_service, |
|
86 | 87 | _TEST_STARTING_REPLICA_COUNT = 2 |
87 | 88 | _TEST_MAX_REPLICA_COUNT = 12 |
88 | 89 |
|
| 90 | +_TEST_BATCH_SIZE = 16 |
| 91 | + |
89 | 92 | _TEST_PIPELINE_RESOURCE_NAME = ( |
90 | 93 | "projects/my-project/locations/us-central1/trainingPipeline/12345" |
91 | 94 | ) |
@@ -1402,47 +1405,47 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn |
1402 | 1405 | encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, |
1403 | 1406 | sync=sync, |
1404 | 1407 | create_request_timeout=None, |
| 1408 | + batch_size=_TEST_BATCH_SIZE, |
1405 | 1409 | ) |
1406 | 1410 |
|
1407 | 1411 | if not sync: |
1408 | 1412 | batch_prediction_job.wait() |
1409 | 1413 |
|
1410 | 1414 | # Construct expected request |
1411 | | - expected_gapic_batch_prediction_job = ( |
1412 | | - gca_batch_prediction_job.BatchPredictionJob( |
1413 | | - display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, |
1414 | | - model=model_service_client.ModelServiceClient.model_path( |
1415 | | - _TEST_PROJECT, _TEST_LOCATION, _TEST_ID |
1416 | | - ), |
1417 | | - input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( |
1418 | | - instances_format="jsonl", |
1419 | | - gcs_source=gca_io.GcsSource( |
1420 | | - uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] |
1421 | | - ), |
1422 | | - ), |
1423 | | - output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( |
1424 | | - gcs_destination=gca_io.GcsDestination( |
1425 | | - output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX |
1426 | | - ), |
1427 | | - predictions_format="csv", |
1428 | | - ), |
1429 | | - dedicated_resources=gca_machine_resources.BatchDedicatedResources( |
1430 | | - machine_spec=gca_machine_resources.MachineSpec( |
1431 | | - machine_type=_TEST_MACHINE_TYPE, |
1432 | | - accelerator_type=_TEST_ACCELERATOR_TYPE, |
1433 | | - accelerator_count=_TEST_ACCELERATOR_COUNT, |
1434 | | - ), |
1435 | | - starting_replica_count=_TEST_STARTING_REPLICA_COUNT, |
1436 | | - max_replica_count=_TEST_MAX_REPLICA_COUNT, |
| 1415 | + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( |
| 1416 | + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, |
| 1417 | + model=model_service_client.ModelServiceClient.model_path( |
| 1418 | + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID |
| 1419 | + ), |
| 1420 | + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( |
| 1421 | + instances_format="jsonl", |
| 1422 | + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), |
| 1423 | + ), |
| 1424 | + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( |
| 1425 | + gcs_destination=gca_io.GcsDestination( |
| 1426 | + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX |
1437 | 1427 | ), |
1438 | | - generate_explanation=True, |
1439 | | - explanation_spec=gca_explanation.ExplanationSpec( |
1440 | | - metadata=_TEST_EXPLANATION_METADATA, |
1441 | | - parameters=_TEST_EXPLANATION_PARAMETERS, |
| 1428 | + predictions_format="csv", |
| 1429 | + ), |
| 1430 | + dedicated_resources=gca_machine_resources.BatchDedicatedResources( |
| 1431 | + machine_spec=gca_machine_resources.MachineSpec( |
| 1432 | + machine_type=_TEST_MACHINE_TYPE, |
| 1433 | + accelerator_type=_TEST_ACCELERATOR_TYPE, |
| 1434 | + accelerator_count=_TEST_ACCELERATOR_COUNT, |
1442 | 1435 | ), |
1443 | | - labels=_TEST_LABEL, |
1444 | | - encryption_spec=_TEST_ENCRYPTION_SPEC, |
1445 | | - ) |
| 1436 | + starting_replica_count=_TEST_STARTING_REPLICA_COUNT, |
| 1437 | + max_replica_count=_TEST_MAX_REPLICA_COUNT, |
| 1438 | + ), |
| 1439 | + manual_batch_tuning_parameters=gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters( |
| 1440 | + batch_size=_TEST_BATCH_SIZE |
| 1441 | + ), |
| 1442 | + generate_explanation=True, |
| 1443 | + explanation_spec=gca_explanation.ExplanationSpec( |
| 1444 | + metadata=_TEST_EXPLANATION_METADATA, |
| 1445 | + parameters=_TEST_EXPLANATION_PARAMETERS, |
| 1446 | + ), |
| 1447 | + labels=_TEST_LABEL, |
| 1448 | + encryption_spec=_TEST_ENCRYPTION_SPEC, |
1446 | 1449 | ) |
1447 | 1450 |
|
1448 | 1451 | create_batch_prediction_job_mock.assert_called_once_with( |
|
0 commit comments