Skip to content

Commit 7432c2c

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: GenAI - Release the Prompt Management feature to Public Preview
PiperOrigin-RevId: 702592033
1 parent e220312 commit 7432c2c

File tree

4 files changed

+1063
-8
lines changed

4 files changed

+1063
-8
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
# pylint: disable=protected-access, g-multiple-import
18+
"""System tests for GenAI prompts."""
19+
20+
from google.cloud import aiplatform
21+
from vertexai import generative_models
22+
from vertexai.generative_models import (
23+
GenerationConfig,
24+
SafetySetting,
25+
ToolConfig,
26+
)
27+
from vertexai.preview import prompts
28+
from vertexai.preview.prompts import Prompt
29+
30+
from tests.system.aiplatform import e2e_base
31+
from google import auth
32+
33+
_REQUEST_FUNCTION_PARAMETER_SCHEMA_STRUCT = {
34+
"type": "object",
35+
"properties": {
36+
"location": {
37+
"type": "string",
38+
"description": "The city and state, e.g. San Francisco, CA",
39+
},
40+
"unit": {
41+
"type": "string",
42+
"enum": [
43+
"celsius",
44+
"fahrenheit",
45+
],
46+
},
47+
},
48+
"required": ["location"],
49+
}
50+
51+
52+
class TestPrompts(e2e_base.TestEndToEnd):
53+
"""System tests for prompts."""
54+
55+
_temp_prefix = "temp_prompts_test_"
56+
57+
def setup_method(self):
58+
super().setup_method()
59+
credentials, _ = auth.default(
60+
scopes=["https://www.googleapis.com/auth/cloud-platform"]
61+
)
62+
aiplatform.init(
63+
project=e2e_base._PROJECT,
64+
location=e2e_base._LOCATION,
65+
credentials=credentials,
66+
)
67+
68+
def test_create_prompt_with_variables(self):
69+
# Create local Prompt
70+
prompt = Prompt(
71+
prompt_data="Hello, {name}! Today is {day}. How are you?",
72+
variables=[
73+
{"name": "Alice", "day": "Monday"},
74+
{"name": "Bob", "day": "Tuesday"},
75+
],
76+
generation_config=GenerationConfig(temperature=0.1),
77+
model_name="gemini-1.0-pro-002",
78+
safety_settings=[
79+
SafetySetting(
80+
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
81+
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
82+
method=SafetySetting.HarmBlockMethod.SEVERITY,
83+
)
84+
],
85+
system_instruction="Please answer in a short sentence.",
86+
)
87+
88+
# Generate content using the assembled prompt for each variable set.
89+
for i in range(len(prompt.variables)):
90+
prompt.generate_content(
91+
contents=prompt.assemble_contents(**prompt.variables[i])
92+
)
93+
94+
# Save Prompt to online resource. Returns a new Prompt object associated with the online resource
95+
prompt1 = prompts.create_version(prompt=prompt)
96+
97+
# Only new prompt should be associated with a prompt resource
98+
assert prompt1.prompt_id
99+
assert not prompt.prompt_id
100+
101+
# Update prompt and save a new version
102+
prompt1.prompt_data = "Hi, {name}! How are you? Today is {day}."
103+
prompt2 = prompts.create_version(prompt=prompt1, version_name="v2")
104+
assert prompt2.prompt_id == prompt1.prompt_id
105+
assert prompt2.version_id != prompt1.version_id
106+
107+
# Restore previous version
108+
metadata = prompts.restore_version(
109+
prompt_id=prompt2.prompt_id, version_id=prompt1.version_id
110+
)
111+
assert metadata.prompt_id == prompt2.prompt_id
112+
assert metadata.version_id != prompt2.version_id
113+
114+
# List prompt versions
115+
versions_metadata = prompts.list_versions(prompt_id=metadata.prompt_id)
116+
assert len(versions_metadata) == 3
117+
118+
# Delete the prompt resource
119+
prompts.delete(prompt_id=prompt2.prompt_id)
120+
121+
def test_create_prompt_with_function_calling(self):
122+
# Create local Prompt
123+
get_current_weather_func = generative_models.FunctionDeclaration(
124+
name="get_current_weather",
125+
description="Get the current weather in a given location",
126+
parameters=_REQUEST_FUNCTION_PARAMETER_SCHEMA_STRUCT,
127+
)
128+
weather_tool = generative_models.Tool(
129+
function_declarations=[get_current_weather_func],
130+
)
131+
132+
tool_config = ToolConfig(
133+
function_calling_config=ToolConfig.FunctionCallingConfig(
134+
mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
135+
allowed_function_names=["get_current_weather"],
136+
)
137+
)
138+
139+
prompt = Prompt(
140+
prompt_data="What is the weather like in Boston?",
141+
tools=[weather_tool],
142+
tool_config=tool_config,
143+
model_name="gemini-1.0-pro-002",
144+
)
145+
146+
# (Optional) Create a separate prompt resource to save the version to
147+
prompt_temp = Prompt(model_name="gemini-1.0-pro-002")
148+
prompt_temp1 = prompts.create_version(prompt=prompt_temp, version_name="empty")
149+
150+
# Create a new version to an existing prompt
151+
prompt1 = prompts.create_version(
152+
prompt=prompt, prompt_id=prompt_temp1.prompt_id, version_name="fc"
153+
)
154+
155+
# Delete the prompt resource
156+
prompts.delete(prompt_id=prompt1.prompt_id)
157+
158+
def test_get_prompt_with_variables(self):
159+
# List prompts
160+
prompts_list = prompts.list()
161+
assert prompts_list
162+
163+
# Get prompt created in UI
164+
prompt_id = "3217694940163211264"
165+
prompt = prompts.get(prompt_id=prompt_id)
166+
assert prompt.prompt_id == prompt_id
167+
assert prompt.prompt_data
168+
assert prompt.generation_config
169+
assert prompt.system_instruction
170+
# UI has a bug where safety settings are not saved
171+
# assert prompt.safety_settings
172+
173+
# Generate content using the assembled prompt for each variable set.
174+
for i in range(len(prompt.variables)):
175+
response = prompt.generate_content(
176+
contents=prompt.assemble_contents(**prompt.variables[i])
177+
)
178+
assert response.text
179+
180+
def test_get_prompt_with_function_calling(self):
181+
# List prompts
182+
prompts_list = prompts.list()
183+
assert prompts_list
184+
185+
# Get prompt created in UI
186+
prompt_id = "1173060709337006080"
187+
prompt = prompts.get(prompt_id=prompt_id)
188+
assert prompt.prompt_id == prompt_id
189+
assert prompt.tools
190+
191+
# Generate content using the prompt
192+
response = prompt.generate_content(contents=prompt.assemble_contents())
193+
assert response

vertexai/preview/prompts.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,21 @@
1616
from vertexai.prompts._prompts import (
1717
Prompt,
1818
)
19+
from vertexai.prompts._prompt_management import (
20+
create_version,
21+
delete,
22+
get,
23+
list_prompts as list,
24+
list_versions,
25+
restore_version,
26+
)
1927

2028
__all__ = [
2129
"Prompt",
30+
"delete",
31+
"create_version",
32+
"get",
33+
"list",
34+
"list_versions",
35+
"restore_version",
2236
]

0 commit comments

Comments
 (0)