Skip to content

Commit df2c83f

Browse files
authored
Add --var template variable support to generate command with command-specific reserved key validation (#83)
1 parent cb8a394 commit df2c83f

File tree

7 files changed

+322
-53
lines changed

7 files changed

+322
-53
lines changed

cmd/generate/generate.go

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ import (
1313
)
1414

1515
type generateCommandHandler struct {
16-
ctx context.Context
17-
cfg *command.Config
18-
client azuremodels.Client
19-
options *PromptPexOptions
20-
promptFile string
21-
org string
22-
sessionFile *string
16+
ctx context.Context
17+
cfg *command.Config
18+
client azuremodels.Client
19+
options *PromptPexOptions
20+
promptFile string
21+
org string
22+
sessionFile *string
23+
templateVars map[string]string
2324
}
2425

2526
// NewGenerateCommand returns a new command to generate tests using PromptPex.
@@ -37,6 +38,7 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command {
3738
gh models generate prompt.yml
3839
gh models generate --org my-org --groundtruth-model "openai/gpt-4.1" prompt.yml
3940
gh models generate --session-file prompt.session.json prompt.yml
41+
gh models generate --var name=Alice --var topic="machine learning" prompt.yml
4042
`),
4143
Args: cobra.ExactArgs(1),
4244
RunE: func(cmd *cobra.Command, args []string) error {
@@ -50,6 +52,17 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command {
5052
return fmt.Errorf("failed to parse flags: %w", err)
5153
}
5254

55+
// Parse template variables from flags
56+
templateVars, err := util.ParseTemplateVariables(cmd.Flags())
57+
if err != nil {
58+
return err
59+
}
60+
61+
// Check for reserved keys specific to generate command
62+
if _, exists := templateVars["input"]; exists {
63+
return fmt.Errorf("'input' is a reserved variable name and cannot be used with --var")
64+
}
65+
5366
// Get organization
5467
org, _ := cmd.Flags().GetString("org")
5568

@@ -67,13 +80,14 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command {
6780

6881
// Create the command handler
6982
handler := &generateCommandHandler{
70-
ctx: ctx,
71-
cfg: cfg,
72-
client: cfg.Client,
73-
options: options,
74-
promptFile: promptFile,
75-
org: org,
76-
sessionFile: util.Ptr(sessionFile),
83+
ctx: ctx,
84+
cfg: cfg,
85+
client: cfg.Client,
86+
options: options,
87+
promptFile: promptFile,
88+
org: org,
89+
sessionFile: util.Ptr(sessionFile),
90+
templateVars: templateVars,
7791
}
7892

7993
// Create prompt context
@@ -105,6 +119,7 @@ func AddCommandLineFlags(cmd *cobra.Command) {
105119
flags.String("effort", "", "Effort level (low, medium, high)")
106120
flags.String("groundtruth-model", "", "Model to use for generating groundtruth outputs. Defaults to openai/gpt-4o. Use 'none' to disable groundtruth generation.")
107121
flags.String("session-file", "", "Session file to load existing context from")
122+
flags.StringSlice("var", []string{}, "Template variables for prompt files (can be used multiple times: --var name=value)")
108123

109124
// Custom instruction flags for each phase
110125
flags.String("instruction-intent", "", "Custom system instruction for intent generation phase")

cmd/generate/generate_test.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import (
1111
"testing"
1212

1313
"github.com/github/gh-models/internal/azuremodels"
14+
"github.com/github/gh-models/internal/sse"
1415
"github.com/github/gh-models/pkg/command"
16+
"github.com/github/gh-models/pkg/util"
1517
"github.com/stretchr/testify/require"
1618
)
1719

@@ -393,3 +395,128 @@ messages:
393395
require.Contains(t, err.Error(), "failed to load prompt file")
394396
})
395397
}
398+
399+
func TestGenerateCommandWithTemplateVariables(t *testing.T) {
400+
t.Run("parse template variables in command handler", func(t *testing.T) {
401+
client := azuremodels.NewMockClient()
402+
cfg := command.NewConfig(new(bytes.Buffer), new(bytes.Buffer), client, true, 100)
403+
404+
cmd := NewGenerateCommand(cfg)
405+
args := []string{
406+
"--var", "name=Bob",
407+
"--var", "location=Seattle",
408+
"dummy.yml",
409+
}
410+
411+
// Parse flags without executing
412+
err := cmd.ParseFlags(args[:len(args)-1]) // Exclude positional arg
413+
require.NoError(t, err)
414+
415+
// Test that the util.ParseTemplateVariables function works correctly
416+
templateVars, err := util.ParseTemplateVariables(cmd.Flags())
417+
require.NoError(t, err)
418+
require.Equal(t, map[string]string{
419+
"name": "Bob",
420+
"location": "Seattle",
421+
}, templateVars)
422+
})
423+
424+
t.Run("runSingleTestWithContext applies template variables", func(t *testing.T) {
425+
// Create test prompt file with template variables
426+
const yamlBody = `
427+
name: Template Variable Test
428+
description: Test prompt with template variables
429+
model: openai/gpt-4o-mini
430+
messages:
431+
- role: system
432+
content: "You are a helpful assistant for {{name}}."
433+
- role: user
434+
content: "Tell me about {{topic}} in {{style}} style."
435+
`
436+
437+
tmpDir := t.TempDir()
438+
promptFile := filepath.Join(tmpDir, "test.prompt.yml")
439+
err := os.WriteFile(promptFile, []byte(yamlBody), 0644)
440+
require.NoError(t, err)
441+
442+
// Setup mock client to capture template-rendered messages
443+
var capturedOptions azuremodels.ChatCompletionOptions
444+
client := azuremodels.NewMockClient()
445+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
446+
capturedOptions = opt
447+
448+
// Create a proper mock response with reader
449+
mockResponse := "test response"
450+
mockCompletion := azuremodels.ChatCompletion{
451+
Choices: []azuremodels.ChatChoice{
452+
{
453+
Message: &azuremodels.ChatChoiceMessage{
454+
Content: &mockResponse,
455+
},
456+
},
457+
},
458+
}
459+
460+
return &azuremodels.ChatCompletionResponse{
461+
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{mockCompletion}),
462+
}, nil
463+
}
464+
465+
out := new(bytes.Buffer)
466+
cfg := command.NewConfig(out, out, client, true, 100)
467+
468+
// Create handler with template variables
469+
templateVars := map[string]string{
470+
"name": "Alice",
471+
"topic": "machine learning",
472+
"style": "academic",
473+
}
474+
475+
handler := &generateCommandHandler{
476+
ctx: context.Background(),
477+
cfg: cfg,
478+
client: client,
479+
options: GetDefaultOptions(),
480+
promptFile: promptFile,
481+
org: "",
482+
templateVars: templateVars,
483+
}
484+
485+
// Create context from prompt
486+
promptCtx, err := handler.CreateContextFromPrompt()
487+
require.NoError(t, err)
488+
489+
// Call runSingleTestWithContext directly
490+
_, err = handler.runSingleTestWithContext("test input", "openai/gpt-4o-mini", promptCtx)
491+
require.NoError(t, err)
492+
493+
// Verify that template variables were applied correctly
494+
require.NotNil(t, capturedOptions.Messages)
495+
require.Len(t, capturedOptions.Messages, 2)
496+
497+
// Check system message
498+
systemMsg := capturedOptions.Messages[0]
499+
require.Equal(t, azuremodels.ChatMessageRoleSystem, systemMsg.Role)
500+
require.NotNil(t, systemMsg.Content)
501+
require.Contains(t, *systemMsg.Content, "helpful assistant for Alice")
502+
503+
// Check user message
504+
userMsg := capturedOptions.Messages[1]
505+
require.Equal(t, azuremodels.ChatMessageRoleUser, userMsg.Role)
506+
require.NotNil(t, userMsg.Content)
507+
require.Contains(t, *userMsg.Content, "about machine learning")
508+
require.Contains(t, *userMsg.Content, "academic style")
509+
})
510+
511+
t.Run("rejects input as template variable", func(t *testing.T) {
512+
client := azuremodels.NewMockClient()
513+
cfg := command.NewConfig(new(bytes.Buffer), new(bytes.Buffer), client, true, 100)
514+
515+
cmd := NewGenerateCommand(cfg)
516+
cmd.SetArgs([]string{"--var", "input=test", "dummy.yml"})
517+
518+
err := cmd.Execute()
519+
require.Error(t, err)
520+
require.Contains(t, err.Error(), "'input' is a reserved variable name and cannot be used with --var")
521+
})
522+
}

cmd/generate/pipeline.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,15 @@ func (h *generateCommandHandler) runSingleTestWithContext(input string, modelNam
460460
openaiMessages := []azuremodels.ChatMessage{}
461461
for _, msg := range messages {
462462
templateData := make(map[string]interface{})
463+
464+
// Add the input variable (backward compatibility)
463465
templateData["input"] = input
466+
467+
// Add custom variables
468+
for key, value := range h.templateVars {
469+
templateData[key] = value
470+
}
471+
464472
// Replace template variables in content
465473
content, err := prompt.TemplateString(msg.Content, templateData)
466474
if err != nil {

cmd/run/run.go

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
236236
}
237237

238238
// Parse template variables from flags
239-
templateVars, err := parseTemplateVariables(cmd.Flags())
239+
templateVars, err := util.ParseTemplateVariables(cmd.Flags())
240240
if err != nil {
241241
return err
242242
}
@@ -427,43 +427,6 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
427427
return cmd
428428
}
429429

430-
// parseTemplateVariables parses template variables from the --var flags
431-
func parseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) {
432-
varFlags, err := flags.GetStringSlice("var")
433-
if err != nil {
434-
return nil, err
435-
}
436-
437-
templateVars := make(map[string]string)
438-
for _, varFlag := range varFlags {
439-
// Handle empty strings
440-
if strings.TrimSpace(varFlag) == "" {
441-
continue
442-
}
443-
444-
parts := strings.SplitN(varFlag, "=", 2)
445-
if len(parts) != 2 {
446-
return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag)
447-
}
448-
449-
key := strings.TrimSpace(parts[0])
450-
value := parts[1] // Don't trim value to preserve intentional whitespace
451-
452-
if key == "" {
453-
return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag)
454-
}
455-
456-
// Check for duplicate keys
457-
if _, exists := templateVars[key]; exists {
458-
return nil, fmt.Errorf("duplicate variable key '%s'", key)
459-
}
460-
461-
templateVars[key] = value
462-
}
463-
464-
return templateVars, nil
465-
}
466-
467430
type runCommandHandler struct {
468431
ctx context.Context
469432
cfg *command.Config

cmd/run/run_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,14 +470,19 @@ func TestParseTemplateVariables(t *testing.T) {
470470
varFlags: []string{"name=John", "name=Jane"},
471471
expectErr: true,
472472
},
473+
{
474+
name: "input variable is allowed in run command",
475+
varFlags: []string{"input=test value"},
476+
expected: map[string]string{"input": "test value"},
477+
},
473478
}
474479

475480
for _, tt := range tests {
476481
t.Run(tt.name, func(t *testing.T) {
477482
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
478483
flags.StringSlice("var", tt.varFlags, "test flag")
479484

480-
result, err := parseTemplateVariables(flags)
485+
result, err := util.ParseTemplateVariables(flags)
481486

482487
if tt.expectErr {
483488
require.Error(t, err)

pkg/util/util.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ package util
44
import (
55
"fmt"
66
"io"
7+
"strings"
8+
9+
"github.com/spf13/pflag"
710
)
811

912
// WriteToOut writes a message to the given io.Writer.
@@ -18,3 +21,40 @@ func WriteToOut(out io.Writer, message string) {
1821
func Ptr[T any](value T) *T {
1922
return &value
2023
}
24+
25+
// ParseTemplateVariables parses template variables from the --var flags
26+
func ParseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) {
27+
varFlags, err := flags.GetStringSlice("var")
28+
if err != nil {
29+
return nil, err
30+
}
31+
32+
templateVars := make(map[string]string)
33+
for _, varFlag := range varFlags {
34+
// Handle empty strings
35+
if strings.TrimSpace(varFlag) == "" {
36+
continue
37+
}
38+
39+
parts := strings.SplitN(varFlag, "=", 2)
40+
if len(parts) != 2 {
41+
return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag)
42+
}
43+
44+
key := strings.TrimSpace(parts[0])
45+
value := parts[1] // Don't trim value to preserve intentional whitespace
46+
47+
if key == "" {
48+
return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag)
49+
}
50+
51+
// Check for duplicate keys
52+
if _, exists := templateVars[key]; exists {
53+
return nil, fmt.Errorf("duplicate variable key '%s'", key)
54+
}
55+
56+
templateVars[key] = value
57+
}
58+
59+
return templateVars, nil
60+
}

0 commit comments

Comments
 (0)