1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616#
17+
18+ import threading
19+ import time
1720from unittest import mock
1821
1922from google .cloud import aiplatform
3235from vertexai .preview .evaluation import _evaluation
3336from vertexai .preview .evaluation import utils
3437from vertexai .preview .evaluation .metrics import (
35- _summarization_quality ,
38+ _pairwise_question_answering_quality ,
3639)
3740from vertexai .preview .evaluation .metrics import (
3841 _pairwise_summarization_quality ,
3942)
43+ from vertexai .preview .evaluation .metrics import _rouge
4044from vertexai .preview .evaluation .metrics import (
41- _pairwise_question_answering_quality ,
42- )
43- from vertexai .preview .evaluation .metrics import (
44- _rouge ,
45+ _summarization_quality ,
4546)
4647import numpy as np
4748import pandas as pd
4849import pytest
4950
51+
5052_TEST_PROJECT = "test-project"
5153_TEST_LOCATION = "us-central1"
5254_TEST_METRICS = (
221223)
222224
223225
224- @pytest .fixture
225- def mock_async_event_loop ():
226- with mock .patch ("asyncio.get_event_loop" ) as mock_async_event_loop :
227- yield mock_async_event_loop
228-
229-
230226@pytest .fixture
231227def mock_experiment_tracker ():
232228 with mock .patch .object (
@@ -267,32 +263,6 @@ def test_create_eval_task(self):
267263 assert test_eval_task .reference_column_name == test_reference_column_name
268264 assert test_eval_task .response_column_name == test_response_column_name
269265
270- def test_evaluate_saved_response (self , mock_async_event_loop ):
271- eval_dataset = _TEST_EVAL_DATASET
272- test_metrics = _TEST_METRICS
273- mock_summary_metrics = {
274- "row_count" : 2 ,
275- "mock_metric/mean" : 0.5 ,
276- "mock_metric/std" : 0.5 ,
277- }
278- mock_metrics_table = pd .DataFrame (
279- {
280- "response" : ["test" , "text" ],
281- "reference" : ["test" , "ref" ],
282- "mock_metric" : [1.0 , 0.0 ],
283- }
284- )
285- mock_async_event_loop .return_value .run_until_complete .return_value = (
286- mock_summary_metrics ,
287- mock_metrics_table ,
288- )
289-
290- test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
291- test_result = test_eval_task .evaluate ()
292-
293- assert test_result .summary_metrics == mock_summary_metrics
294- assert test_result .metrics_table .equals (mock_metrics_table )
295-
296266 @pytest .mark .parametrize ("api_transport" , ["grpc" , "rest" ])
297267 def test_compute_automatic_metrics (self , api_transport ):
298268 aiplatform .init (
@@ -310,7 +280,7 @@ def test_compute_automatic_metrics(self, api_transport):
310280 test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
311281 mock_metric_results = _MOCK_EXACT_MATCH_RESULT
312282 with mock .patch .object (
313- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
283+ target = gapic_evaluation_services .EvaluationServiceClient ,
314284 attribute = "evaluate_instances" ,
315285 side_effect = mock_metric_results ,
316286 ):
@@ -343,7 +313,7 @@ def test_compute_pointwise_metrics(self, api_transport):
343313 test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
344314 mock_metric_results = _MOCK_FLUENCY_RESULT
345315 with mock .patch .object (
346- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
316+ target = gapic_evaluation_services .EvaluationServiceClient ,
347317 attribute = "evaluate_instances" ,
348318 side_effect = mock_metric_results ,
349319 ):
@@ -398,7 +368,7 @@ def test_compute_pointwise_metrics_with_custom_metric_spec(self, api_transport):
398368 test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
399369 mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT
400370 with mock .patch .object (
401- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
371+ target = gapic_evaluation_services .EvaluationServiceClient ,
402372 attribute = "evaluate_instances" ,
403373 side_effect = mock_metric_results ,
404374 ):
@@ -465,7 +435,7 @@ def test_compute_automatic_metrics_with_custom_metric_spec(self, api_transport):
465435 ]
466436 test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
467437 with mock .patch .object (
468- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
438+ target = gapic_evaluation_services .EvaluationServiceClient ,
469439 attribute = "evaluate_instances" ,
470440 side_effect = _MOCK_ROUGE_RESULT ,
471441 ) as mock_evaluate_instances :
@@ -527,7 +497,7 @@ def test_compute_pairwise_metrics_with_model_inference(self, api_transport):
527497 test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
528498 mock_metric_results = _MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT
529499 with mock .patch .object (
530- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
500+ target = gapic_evaluation_services .EvaluationServiceClient ,
531501 attribute = "evaluate_instances" ,
532502 side_effect = mock_metric_results ,
533503 ):
@@ -613,7 +583,7 @@ def test_compute_pairwise_metrics_without_inference(self, api_transport):
613583 test_eval_task = evaluation .EvalTask (dataset = eval_dataset , metrics = test_metrics )
614584 mock_metric_results = _MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT
615585 with mock .patch .object (
616- target = gapic_evaluation_services .EvaluationServiceAsyncClient ,
586+ target = gapic_evaluation_services .EvaluationServiceClient ,
617587 attribute = "evaluate_instances" ,
618588 side_effect = mock_metric_results ,
619589 ):
@@ -869,9 +839,9 @@ def setup_method(self):
869839 def teardown_method (self ):
870840 initializer .global_pool .shutdown (wait = True )
871841
872- def test_create_evaluation_service_async_client (self ):
873- client = utils .create_evaluation_service_async_client ()
874- assert isinstance (client , utils ._EvaluationServiceAsyncClientWithOverride )
842+ def test_create_evaluation_service_client (self ):
843+ client = utils .create_evaluation_service_client ()
844+ assert isinstance (client , utils ._EvaluationServiceClientWithOverride )
875845
876846 def test_load_dataset_from_dataframe (self ):
877847 data = {"col1" : [1 , 2 ], "col2" : ["a" , "b" ]}
@@ -924,6 +894,57 @@ def test_load_dataset_from_bigquery(self):
924894 assert isinstance (loaded_df , pd .DataFrame )
925895 assert loaded_df .equals (_TEST_EVAL_DATASET )
926896
897+ def test_initialization (self ):
898+ limiter = utils .RateLimiter (rate = 2 )
899+ assert limiter .seconds_per_event == 0.5
900+
901+ with pytest .raises (ValueError , match = "Rate must be a positive number" ):
902+ utils .RateLimiter (- 1 )
903+ with pytest .raises (ValueError , match = "Rate must be a positive number" ):
904+ utils .RateLimiter (0 )
905+
906+ def test_admit (self ):
907+ rate_limiter = utils .RateLimiter (rate = 2 )
908+
909+ assert rate_limiter ._admit () == 0
910+
911+ time .sleep (0.1 )
912+ delay = rate_limiter ._admit ()
913+ assert delay == pytest .approx (0.4 , 0.01 )
914+
915+ time .sleep (0.5 )
916+ delay = rate_limiter ._admit ()
917+ assert delay == 0
918+
919+ def test_sleep_and_advance (self ):
920+ rate_limiter = utils .RateLimiter (rate = 2 )
921+
922+ start_time = time .time ()
923+ rate_limiter .sleep_and_advance ()
924+ assert (time .time () - start_time ) < 0.1
925+
926+ start_time = time .time ()
927+ rate_limiter .sleep_and_advance ()
928+ assert (time .time () - start_time ) >= 0.5
929+
930+ def test_thread_safety (self ):
931+ rate_limiter = utils .RateLimiter (rate = 2 )
932+ start_time = time .time ()
933+
934+ def target ():
935+ rate_limiter .sleep_and_advance ()
936+
937+ threads = [threading .Thread (target = target ) for _ in range (10 )]
938+ for thread in threads :
939+ thread .start ()
940+ for thread in threads :
941+ thread .join ()
942+
943+ # Verify that the total minimum time should be 4.5 seconds
944+ # (9 intervals of 0.5 seconds each).
945+ total_time = time .time () - start_time
946+ assert total_time >= 4.5
947+
927948
928949class TestPromptTemplate :
929950 def test_init (self ):
0 commit comments