From c7593e3d5250ee34fc1ed49146c23e5c65f59a6f Mon Sep 17 00:00:00 2001 From: Sean Goedecke Date: Wed, 11 Jun 2025 22:55:08 +0000 Subject: [PATCH] Support multiple variables in run --- cmd/run/run.go | 72 +++++++++++++++++++++++--- cmd/run/run_test.go | 72 ++++++++++++++++++++++++++ examples/advanced_template_prompt.yml | 27 ++++++++++ examples/template_variables_prompt.yml | 12 +++++ 4 files changed, 175 insertions(+), 8 deletions(-) create mode 100644 examples/advanced_template_prompt.yml create mode 100644 examples/template_variables_prompt.yml diff --git a/cmd/run/run.go b/cmd/run/run.go index 989017b9..418f4da7 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -204,10 +204,16 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { If you know which model you want to run inference with, you can run the request in a single command as %[1]sgh models run [model] [prompt]%[1]s + When using prompt files, you can pass template variables using the %[1]s--var%[1]s flag: + %[1]sgh models run --file prompt.yml --var name=Alice --var topic=AI%[1]s + The return value will be the response to your prompt from the selected model. `, "`"), - Example: "gh models run openai/gpt-4o-mini \"how many types of hyena are there?\"", - Args: cobra.ArbitraryArgs, + Example: heredoc.Doc(` + gh models run openai/gpt-4o-mini "how many types of hyena are there?" + gh models run --file prompt.yml --var name=Alice --var topic="machine learning" + `), + Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { filePath, _ := cmd.Flags().GetString("file") var pf *prompt.File @@ -223,6 +229,12 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { } } + // Parse template variables from flags + templateVars, err := parseTemplateVariables(cmd.Flags()) + if err != nil { + return err + } + cmdHandler := newRunCommandHandler(cmd, cfg, args) if cmdHandler == nil { return nil @@ -270,16 +282,22 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { } // If there is no prompt file, add the initialPrompt to the conversation. - // If a prompt file is passed, load the messages from the file, templating {{input}} - // using the initialPrompt. + // If a prompt file is passed, load the messages from the file, templating variables + // using the provided template variables and initialPrompt. if pf == nil { conversation.AddMessage(azuremodels.ChatMessageRoleUser, initialPrompt) } else { interactiveMode = false - // Template the messages with the input - templateData := map[string]interface{}{ - "input": initialPrompt, + // Template the messages with the variables + templateData := make(map[string]interface{}) + + // Add the input variable (backward compatibility) + templateData["input"] = initialPrompt + + // Add custom variables + for key, value := range templateVars { + templateData[key] = value } for _, m := range pf.Messages { @@ -385,6 +403,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { } cmd.Flags().String("file", "", "Path to a .prompt.yml file.") + cmd.Flags().StringSlice("var", []string{}, "Template variables for prompt files (can be used multiple times: --var name=value)") cmd.Flags().String("max-tokens", "", "Limit the maximum tokens for the model response.") cmd.Flags().String("temperature", "", "Controls randomness in the response, use lower to be more deterministic.") cmd.Flags().String("top-p", "", "Controls text diversity by selecting the most probable words until a set probability is reached.") @@ -393,6 +412,43 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { return cmd } +// parseTemplateVariables parses template variables from the --var flags +func parseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) { + varFlags, err := flags.GetStringSlice("var") + if err != nil { + return nil, err + } + + templateVars := make(map[string]string) + for _, varFlag := range varFlags { + // Handle empty strings + if strings.TrimSpace(varFlag) == "" { + continue + } + + parts := strings.SplitN(varFlag, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag) + } + + key := strings.TrimSpace(parts[0]) + value := parts[1] // Don't trim value to preserve intentional whitespace + + if key == "" { + return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag) + } + + // Check for duplicate keys + if _, exists := templateVars[key]; exists { + return nil, fmt.Errorf("duplicate variable key '%s'", key) + } + + templateVars[key] = value + } + + return templateVars, nil +} + type runCommandHandler struct { ctx context.Context cfg *command.Config @@ -445,7 +501,7 @@ func (h *runCommandHandler) getModelNameFromArgs(models []*azuremodels.ModelSumm } func validateModelName(modelName string, models []*azuremodels.ModelSummary) (string, error) { - noMatchErrorMessage := "The specified model name is not found. Run 'gh models list' to see available models or 'gh models run' to select interactively." + noMatchErrorMessage := fmt.Sprintf("The specified model '%s' is not found. Run 'gh models list' to see available models or 'gh models run' to select interactively.", modelName) if modelName == "" { return "", errors.New(noMatchErrorMessage) diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 7395e7cd..c0a5a48b 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -11,6 +11,7 @@ import ( "github.com/github/gh-models/internal/sse" "github.com/github/gh-models/pkg/command" "github.com/github/gh-models/pkg/util" + "github.com/spf13/pflag" "github.com/stretchr/testify/require" ) @@ -331,3 +332,74 @@ messages: require.Equal(t, "User message", *capturedReq.Messages[1].Content) }) } + +func TestParseTemplateVariables(t *testing.T) { + tests := []struct { + name string + varFlags []string + expected map[string]string + expectErr bool + }{ + { + name: "empty vars", + varFlags: []string{}, + expected: map[string]string{}, + }, + { + name: "single var", + varFlags: []string{"name=John"}, + expected: map[string]string{"name": "John"}, + }, + { + name: "multiple vars", + varFlags: []string{"name=John", "age=25", "city=New York"}, + expected: map[string]string{"name": "John", "age": "25", "city": "New York"}, + }, + { + name: "multi-word values", + varFlags: []string{"full_name=John Smith", "description=A senior developer"}, + expected: map[string]string{"full_name": "John Smith", "description": "A senior developer"}, + }, + { + name: "value with equals sign", + varFlags: []string{"equation=x = y + 2"}, + expected: map[string]string{"equation": "x = y + 2"}, + }, + { + name: "empty strings are skipped", + varFlags: []string{"", "name=John", " "}, + expected: map[string]string{"name": "John"}, + }, + { + name: "invalid format - no equals", + varFlags: []string{"invalid"}, + expectErr: true, + }, + { + name: "invalid format - empty key", + varFlags: []string{"=value"}, + expectErr: true, + }, + { + name: "duplicate keys", + varFlags: []string{"name=John", "name=Jane"}, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flags := pflag.NewFlagSet("test", pflag.ContinueOnError) + flags.StringSlice("var", tt.varFlags, "test flag") + + result, err := parseTemplateVariables(flags) + + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, result) + } + }) + } +} diff --git a/examples/advanced_template_prompt.yml b/examples/advanced_template_prompt.yml new file mode 100644 index 00000000..2dd971eb --- /dev/null +++ b/examples/advanced_template_prompt.yml @@ -0,0 +1,27 @@ +# Advanced Template Variables Example +name: Advanced Template Example +description: Demonstrates advanced usage of template variables +model: openai/gpt-4o-mini +modelParameters: + temperature: 0.7 + maxTokens: 300 +messages: + - role: system + content: | + You are {{assistant_persona}}, a {{expertise_level}} {{domain}} specialist. + Your communication style should be {{tone}} and {{formality_level}}. + + Context: You are helping {{user_name}} who works as a {{user_role}} at {{company}}. + + - role: user + content: | + Hello! I'm {{user_name}} from {{company}}. + + Background: {{background_info}} + + Question: {{input}} + + Please provide your response considering my role as {{user_role}} and + make it appropriate for a {{formality_level}} setting. + + Additional context: {{additional_context}} diff --git a/examples/template_variables_prompt.yml b/examples/template_variables_prompt.yml new file mode 100644 index 00000000..bd0d6c2a --- /dev/null +++ b/examples/template_variables_prompt.yml @@ -0,0 +1,12 @@ +# Example demonstrating arbitrary template variables +name: Template Variables Example +description: Shows how to use custom template variables in prompt files +model: openai/gpt-4o +modelParameters: + temperature: 0.3 + maxTokens: 200 +messages: + - role: system + content: You are {{persona}}, a helpful assistant specializing in {{domain}}. + - role: user + content: Hello {{name}}! I need help with {{topic}}. {{input}}