4545 artifact as gca_artifact ,
4646 prediction_service as gca_prediction_service ,
4747 context as gca_context ,
48- endpoint as gca_endpoint ,
48+ endpoint_v1 as gca_endpoint ,
4949 pipeline_job as gca_pipeline_job ,
5050 pipeline_state as gca_pipeline_state ,
5151 deployed_model_ref_v1 ,
@@ -1030,6 +1030,11 @@ def get_endpoint_mock():
10301030 get_endpoint_mock .return_value = gca_endpoint .Endpoint (
10311031 display_name = "test-display-name" ,
10321032 name = test_constants .EndpointConstants ._TEST_ENDPOINT_NAME ,
1033+ deployed_models = [
1034+ gca_endpoint .DeployedModel (
1035+ model = test_constants .ModelConstants ._TEST_MODEL_RESOURCE_NAME
1036+ ),
1037+ ],
10331038 )
10341039 yield get_endpoint_mock
10351040
@@ -2420,7 +2425,10 @@ def test_text_embedding_ga(self):
24202425 assert len (vector ) == _TEXT_EMBEDDING_VECTOR_LENGTH
24212426 assert vector == _TEST_TEXT_EMBEDDING_PREDICTION ["embeddings" ]["values" ]
24222427
2423- def test_batch_prediction (self ):
2428+ def test_batch_prediction (
2429+ self ,
2430+ get_endpoint_mock ,
2431+ ):
24242432 """Tests batch prediction."""
24252433 aiplatform .init (
24262434 project = _TEST_PROJECT ,
@@ -2447,7 +2455,29 @@ def test_batch_prediction(self):
24472455 model_parameters = {"temperature" : 0.1 },
24482456 )
24492457 mock_create .assert_called_once_with (
2450- model_name = "publishers/google/models/text-bison@001" ,
2458+ model_name = f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /publishers/google/models/text-bison@001" ,
2459+ job_display_name = None ,
2460+ gcs_source = "gs://test-bucket/test_table.jsonl" ,
2461+ gcs_destination_prefix = "gs://test-bucket/results/" ,
2462+ model_parameters = {"temperature" : 0.1 },
2463+ )
2464+
2465+ # Testing tuned model batch prediction
2466+ tuned_model = language_models .TextGenerationModel (
2467+ model_id = model ._model_id ,
2468+ endpoint_name = test_constants .EndpointConstants ._TEST_ENDPOINT_NAME ,
2469+ )
2470+ with mock .patch .object (
2471+ target = aiplatform .BatchPredictionJob ,
2472+ attribute = "create" ,
2473+ ) as mock_create :
2474+ tuned_model .batch_predict (
2475+ dataset = "gs://test-bucket/test_table.jsonl" ,
2476+ destination_uri_prefix = "gs://test-bucket/results/" ,
2477+ model_parameters = {"temperature" : 0.1 },
2478+ )
2479+ mock_create .assert_called_once_with (
2480+ model_name = test_constants .ModelConstants ._TEST_MODEL_RESOURCE_NAME ,
24512481 job_display_name = None ,
24522482 gcs_source = "gs://test-bucket/test_table.jsonl" ,
24532483 gcs_destination_prefix = "gs://test-bucket/results/" ,
@@ -2481,7 +2511,7 @@ def test_batch_prediction_for_text_embedding(self):
24812511 model_parameters = {},
24822512 )
24832513 mock_create .assert_called_once_with (
2484- model_name = " publishers/google/models/textembedding-gecko@001" ,
2514+ model_name = f"projects/ { _TEST_PROJECT } /locations/ { _TEST_LOCATION } / publishers/google/models/textembedding-gecko@001" ,
24852515 job_display_name = None ,
24862516 gcs_source = "gs://test-bucket/test_table.jsonl" ,
24872517 gcs_destination_prefix = "gs://test-bucket/results/" ,
0 commit comments