@@ -691,6 +691,35 @@ def raw_prediction_response(self) -> aiplatform.models.Prediction:
691691 return self ._prediction_response
692692
693693
694+ @dataclasses .dataclass
695+ class MultiCandidateTextGenerationResponse (TextGenerationResponse ):
696+ """Represents a multi-candidate response of a language model.
697+
698+ Attributes:
699+ text: The generated text for the first candidate.
700+ is_blocked: Whether the first candidate response was blocked.
701+ safety_attributes: Scores for safety attributes for the first candidate.
702+ Learn more about the safety attributes here:
703+ https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
704+ candidates: The candidate responses.
705+ Usually contains a single candidate unless `candidate_count` is used.
706+ """
707+
708+ __module__ = "vertexai.language_models"
709+
710+ candidates : List [TextGenerationResponse ] = dataclasses .field (default_factory = list )
711+
712+ def _repr_pretty_ (self , p , cycle ):
713+ """Pretty prints self in IPython environments."""
714+ if cycle :
715+ p .text (f"{ self .__class__ .__name__ } (...)" )
716+ else :
717+ if len (self .candidates ) == 1 :
718+ p .text (repr (self .candidates [0 ]))
719+ else :
720+ p .text (repr (self ))
721+
722+
694723class _TextGenerationModel (_LanguageModel ):
695724 """TextGenerationModel represents a general language model.
696725
@@ -716,7 +745,8 @@ def predict(
716745 top_k : Optional [int ] = None ,
717746 top_p : Optional [float ] = None ,
718747 stop_sequences : Optional [List [str ]] = None ,
719- ) -> "TextGenerationResponse" :
748+ candidate_count : Optional [int ] = None ,
749+ ) -> "MultiCandidateTextGenerationResponse" :
720750 """Gets model response for a single prompt.
721751
722752 Args:
@@ -726,9 +756,10 @@ def predict(
726756 top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
727757 top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
728758 stop_sequences: Customized stop sequences to stop the decoding process.
759+ candidate_count: Number of response candidates to return.
729760
730761 Returns:
731- A `TextGenerationResponse ` object that contains the text produced by the model.
762+ A `MultiCandidateTextGenerationResponse ` object that contains the text produced by the model.
732763 """
733764 prediction_request = _create_text_generation_prediction_request (
734765 prompt = prompt ,
@@ -737,14 +768,15 @@ def predict(
737768 top_k = top_k ,
738769 top_p = top_p ,
739770 stop_sequences = stop_sequences ,
771+ candidate_count = candidate_count ,
740772 )
741773
742774 prediction_response = self ._endpoint .predict (
743775 instances = [prediction_request .instance ],
744776 parameters = prediction_request .parameters ,
745777 )
746778
747- return _parse_text_generation_model_response (prediction_response )
779+ return _parse_text_generation_model_multi_candidate_response (prediction_response )
748780
749781 async def predict_async (
750782 self ,
@@ -755,7 +787,8 @@ async def predict_async(
755787 top_k : Optional [int ] = None ,
756788 top_p : Optional [float ] = None ,
757789 stop_sequences : Optional [List [str ]] = None ,
758- ) -> "TextGenerationResponse" :
790+ candidate_count : Optional [int ] = None ,
791+ ) -> "MultiCandidateTextGenerationResponse" :
759792 """Asynchronously gets model response for a single prompt.
760793
761794 Args:
@@ -765,9 +798,10 @@ async def predict_async(
765798 top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
766799 top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
767800 stop_sequences: Customized stop sequences to stop the decoding process.
801+ candidate_count: Number of response candidates to return.
768802
769803 Returns:
770- A `TextGenerationResponse ` object that contains the text produced by the model.
804+ A `MultiCandidateTextGenerationResponse ` object that contains the text produced by the model.
771805 """
772806 prediction_request = _create_text_generation_prediction_request (
773807 prompt = prompt ,
@@ -776,14 +810,15 @@ async def predict_async(
776810 top_k = top_k ,
777811 top_p = top_p ,
778812 stop_sequences = stop_sequences ,
813+ candidate_count = candidate_count ,
779814 )
780815
781816 prediction_response = await self ._endpoint .predict_async (
782817 instances = [prediction_request .instance ],
783818 parameters = prediction_request .parameters ,
784819 )
785820
786- return _parse_text_generation_model_response (prediction_response )
821+ return _parse_text_generation_model_multi_candidate_response (prediction_response )
787822
788823 def predict_streaming (
789824 self ,
@@ -844,6 +879,7 @@ def _create_text_generation_prediction_request(
844879 top_k : Optional [int ] = None ,
845880 top_p : Optional [float ] = None ,
846881 stop_sequences : Optional [List [str ]] = None ,
882+ candidate_count : Optional [int ] = None ,
847883) -> "_PredictionRequest" :
848884 """Prepares the text generation request for a single prompt.
849885
@@ -854,6 +890,7 @@ def _create_text_generation_prediction_request(
854890 top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
855891 top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
856892 stop_sequences: Customized stop sequences to stop the decoding process.
893+ candidate_count: Number of candidates to return.
857894
858895 Returns:
859896 A `_PredictionRequest` object that contains prediction instance and parameters.
@@ -880,6 +917,9 @@ def _create_text_generation_prediction_request(
880917 if stop_sequences :
881918 prediction_parameters ["stopSequences" ] = stop_sequences
882919
920+ if candidate_count is not None :
921+ prediction_parameters ["candidateCount" ] = candidate_count
922+
883923 return _PredictionRequest (
884924 instance = instance ,
885925 parameters = prediction_parameters ,
@@ -906,6 +946,32 @@ def _parse_text_generation_model_response(
906946 )
907947
908948
949+ def _parse_text_generation_model_multi_candidate_response (
950+ prediction_response : aiplatform .models .Prediction ,
951+ ) -> MultiCandidateTextGenerationResponse :
952+ """Converts the raw text_generation model response to `MultiCandidateTextGenerationResponse`."""
953+ # The contract for the PredictionService is that there is a 1:1 mapping
954+ # between request `instances` and response `predictions`.
955+ # Unfortunetely, the text-bison models violate this contract.
956+
957+ prediction_count = len (prediction_response .predictions )
958+ candidates = []
959+ for prediction_idx in range (prediction_count ):
960+ candidate = _parse_text_generation_model_response (
961+ prediction_response = prediction_response ,
962+ prediction_idx = prediction_idx ,
963+ )
964+ candidates .append (candidate )
965+
966+ return MultiCandidateTextGenerationResponse (
967+ text = candidates [0 ].text ,
968+ _prediction_response = prediction_response ,
969+ is_blocked = candidates [0 ].is_blocked ,
970+ safety_attributes = candidates [0 ].safety_attributes ,
971+ candidates = candidates ,
972+ )
973+
974+
909975class _ModelWithBatchPredict (_LanguageModel ):
910976 """Model that supports batch prediction."""
911977
0 commit comments