Skip to content

Commit 3bc7b92

Browse files
authored
Use the GitHub models catalog endpoint for listing models (#72)
2 parents 8050279 + f560a7f commit 3bc7b92

15 files changed

+125
-134
lines changed

.vscode/launch.json

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
{
2-
// Use IntelliSense to learn about possible attributes.
3-
// Hover to view descriptions of existing attributes.
4-
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5-
"version": "0.2.0",
6-
"configurations": [
7-
{
8-
"name": "Run models list",
9-
"type": "go",
10-
"request": "launch",
11-
"mode": "auto",
12-
"program": "${workspaceFolder}/main.go",
13-
"args": ["list"]
14-
}
15-
]
16-
}
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"name": "Run models list",
9+
"type": "go",
10+
"request": "launch",
11+
"mode": "auto",
12+
"program": "${workspaceFolder}/main.go",
13+
"args": ["list"]
14+
},
15+
{
16+
"name": "Run models view",
17+
"type": "go",
18+
"request": "launch",
19+
"mode": "auto",
20+
"program": "${workspaceFolder}/main.go",
21+
"args": ["view"],
22+
"console": "integratedTerminal"
23+
}
24+
]
25+
}

cmd/list/list.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func NewListCommand(cfg *command.Config) *cobra.Command {
5353
printer.EndRow()
5454

5555
for _, model := range models {
56-
printer.AddField(azuremodels.FormatIdentifier(model.Publisher, model.Name))
56+
printer.AddField(model.ID)
5757
printer.AddField(model.FriendlyName)
5858
printer.EndRow()
5959
}

cmd/list/list_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@ func TestList(t *testing.T) {
1414
t.Run("NewListCommand happy path", func(t *testing.T) {
1515
client := azuremodels.NewMockClient()
1616
modelSummary := &azuremodels.ModelSummary{
17-
ID: "test-id-1",
17+
ID: "openai/test-id-1",
1818
Name: "test-model-1",
1919
FriendlyName: "Test Model 1",
2020
Task: "chat-completion",
2121
Publisher: "OpenAI",
2222
Summary: "This is a test model",
2323
Version: "1.0",
24-
RegistryName: "azure-openai",
2524
}
2625
listModelsCallCount := 0
2726
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
@@ -41,7 +40,7 @@ func TestList(t *testing.T) {
4140
require.Contains(t, output, "DISPLAY NAME")
4241
require.Contains(t, output, "ID")
4342
require.Contains(t, output, modelSummary.FriendlyName)
44-
require.Contains(t, output, azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name))
43+
require.Contains(t, output, modelSummary.ID)
4544
})
4645

4746
t.Run("--help prints usage info", func(t *testing.T) {

cmd/run/run.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,8 @@ func (h *runCommandHandler) getModelNameFromArgs(models []*azuremodels.ModelSumm
500500
if !model.IsChatModel() {
501501
continue
502502
}
503-
prompt.Options = append(prompt.Options, azuremodels.FormatIdentifier(model.Publisher, model.Name))
503+
504+
prompt.Options = append(prompt.Options, model.ID)
504505
}
505506

506507
err := survey.AskOne(prompt, &modelName, survey.WithPageSize(10))
@@ -533,7 +534,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
533534
}
534535

535536
// For non-custom providers, validate the model exists
536-
expectedModelID := azuremodels.FormatIdentifier(parsedModel.Publisher, parsedModel.ModelName)
537+
expectedModelID := parsedModel.String()
537538
foundMatch := false
538539
for _, model := range models {
539540
if model.HasName(expectedModelID) {

cmd/run/run_test.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ func TestRun(t *testing.T) {
1919
t.Run("NewRunCommand happy path", func(t *testing.T) {
2020
client := azuremodels.NewMockClient()
2121
modelSummary := &azuremodels.ModelSummary{
22-
ID: "test-id-1",
22+
ID: "openai/test-model-1",
2323
Name: "test-model-1",
2424
FriendlyName: "Test Model 1",
2525
Task: "chat-completion",
2626
Publisher: "OpenAI",
2727
Summary: "This is a test model",
2828
Version: "1.0",
29-
RegistryName: "azure-openai",
3029
}
3130
listModelsCallCount := 0
3231
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
@@ -52,7 +51,7 @@ func TestRun(t *testing.T) {
5251
buf := new(bytes.Buffer)
5352
cfg := command.NewConfig(buf, buf, client, true, 80)
5453
runCmd := NewRunCommand(cfg)
55-
runCmd.SetArgs([]string{azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name), "this is my prompt"})
54+
runCmd.SetArgs([]string{modelSummary.ID, "this is my prompt"})
5655

5756
_, err := runCmd.ExecuteC()
5857

@@ -104,6 +103,7 @@ messages:
104103

105104
client := azuremodels.NewMockClient()
106105
modelSummary := &azuremodels.ModelSummary{
106+
ID: "openai/test-model",
107107
Name: "test-model",
108108
Publisher: "openai",
109109
Task: "chat-completion",
@@ -134,7 +134,7 @@ messages:
134134
runCmd := NewRunCommand(cfg)
135135
runCmd.SetArgs([]string{
136136
"--file", tmp.Name(),
137-
azuremodels.FormatIdentifier("openai", "test-model"),
137+
"openai/test-model",
138138
})
139139

140140
_, err = runCmd.ExecuteC()
@@ -170,6 +170,7 @@ messages:
170170

171171
client := azuremodels.NewMockClient()
172172
modelSummary := &azuremodels.ModelSummary{
173+
ID: "openai/test-model",
173174
Name: "test-model",
174175
Publisher: "openai",
175176
Task: "chat-completion",
@@ -214,7 +215,7 @@ messages:
214215
runCmd := NewRunCommand(cfg)
215216
runCmd.SetArgs([]string{
216217
"--file", tmp.Name(),
217-
azuremodels.FormatIdentifier("openai", "test-model"),
218+
"openai/test-model",
218219
initialPrompt,
219220
})
220221

@@ -252,11 +253,13 @@ messages:
252253

253254
client := azuremodels.NewMockClient()
254255
modelSummary := &azuremodels.ModelSummary{
256+
ID: "openai/example-model",
255257
Name: "example-model",
256258
Publisher: "openai",
257259
Task: "chat-completion",
258260
}
259261
modelSummary2 := &azuremodels.ModelSummary{
262+
ID: "openai/example-model-4o-mini-plus",
260263
Name: "example-model-4o-mini-plus",
261264
Publisher: "openai",
262265
Task: "chat-completion",
@@ -369,6 +372,7 @@ messages:
369372

370373
client := azuremodels.NewMockClient()
371374
modelSummary := &azuremodels.ModelSummary{
375+
ID: "openai/test-model",
372376
Name: "test-model",
373377
Publisher: "openai",
374378
Task: "chat-completion",
@@ -533,6 +537,7 @@ func TestValidateModelName(t *testing.T) {
533537

534538
// Create a mock model for testing
535539
mockModel := &azuremodels.ModelSummary{
540+
ID: "openai/gpt-4",
536541
Name: "gpt-4",
537542
Publisher: "openai",
538543
Task: "chat-completion",

cmd/view/view.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func NewViewCommand(cfg *command.Config) *cobra.Command {
5050
if !model.IsChatModel() {
5151
continue
5252
}
53-
prompt.Options = append(prompt.Options, azuremodels.FormatIdentifier(model.Publisher, model.Name))
53+
prompt.Options = append(prompt.Options, model.ID)
5454
}
5555

5656
err = survey.AskOne(prompt, &modelName, survey.WithPageSize(10))
@@ -61,13 +61,12 @@ func NewViewCommand(cfg *command.Config) *cobra.Command {
6161
case len(args) >= 1:
6262
modelName = args[0]
6363
}
64-
6564
modelSummary, err := getModelByName(modelName, models)
6665
if err != nil {
6766
return err
6867
}
6968

70-
modelDetails, err := client.GetModelDetails(ctx, modelSummary.RegistryName, modelSummary.Name, modelSummary.Version)
69+
modelDetails, err := client.GetModelDetails(ctx, modelSummary.Registry, modelSummary.Name, modelSummary.Version)
7170
if err != nil {
7271
return err
7372
}

cmd/view/view_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@ func TestView(t *testing.T) {
1414
t.Run("NewViewCommand happy path", func(t *testing.T) {
1515
client := azuremodels.NewMockClient()
1616
modelSummary := &azuremodels.ModelSummary{
17-
ID: "test-id-1",
17+
ID: "openai/test-model-1",
1818
Name: "test-model-1",
1919
FriendlyName: "Test Model 1",
2020
Task: "chat-completion",
2121
Publisher: "OpenAI",
2222
Summary: "This is a test model",
2323
Version: "1.0",
24-
RegistryName: "azure-openai",
2524
}
2625
listModelsCallCount := 0
2726
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
@@ -49,7 +48,7 @@ func TestView(t *testing.T) {
4948
buf := new(bytes.Buffer)
5049
cfg := command.NewConfig(buf, buf, client, true, 80)
5150
viewCmd := NewViewCommand(cfg)
52-
viewCmd.SetArgs([]string{azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name)})
51+
viewCmd.SetArgs([]string{modelSummary.ID})
5352

5453
_, err := viewCmd.ExecuteC()
5554

internal/azuremodels/azure_client.go

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ import (
99
"fmt"
1010
"io"
1111
"net/http"
12+
"slices"
1213
"strings"
1314

1415
"github.com/cli/go-gh/v2/pkg/api"
16+
"github.com/github/gh-models/internal/modelkey"
1517
"github.com/github/gh-models/internal/sse"
1618
"golang.org/x/text/language"
1719
"golang.org/x/text/language/display"
@@ -185,19 +187,7 @@ func lowercaseStrings(input []string) []string {
185187

186188
// ListModels returns a list of available models.
187189
func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) {
188-
body := bytes.NewReader([]byte(`
189-
{
190-
"filters": [
191-
{ "field": "freePlayground", "values": ["true"], "operator": "eq"},
192-
{ "field": "labels", "values": ["latest"], "operator": "eq"}
193-
],
194-
"order": [
195-
{ "field": "displayName", "direction": "asc" }
196-
]
197-
}
198-
`))
199-
200-
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.ModelsURL, body)
190+
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.ModelsURL, nil)
201191
if err != nil {
202192
return nil, err
203193
}
@@ -218,28 +208,34 @@ func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) {
218208
decoder := json.NewDecoder(resp.Body)
219209
decoder.UseNumber()
220210

221-
var searchResponse modelCatalogSearchResponse
222-
err = decoder.Decode(&searchResponse)
211+
var catalog githubModelCatalogResponse
212+
err = decoder.Decode(&catalog)
223213
if err != nil {
224214
return nil, err
225215
}
226216

227-
models := make([]*ModelSummary, 0, len(searchResponse.Summaries))
228-
for _, summary := range searchResponse.Summaries {
217+
models := make([]*ModelSummary, 0, len(catalog))
218+
for _, catalogModel := range catalog {
219+
// Determine task from supported modalities - if it supports text input/output, it's likely a chat model
229220
inferenceTask := ""
230-
if len(summary.InferenceTasks) > 0 {
231-
inferenceTask = summary.InferenceTasks[0]
221+
if slices.Contains(catalogModel.SupportedInputModalities, "text") && slices.Contains(catalogModel.SupportedOutputModalities, "text") {
222+
inferenceTask = "chat-completion"
223+
}
224+
225+
modelKey, err := modelkey.ParseModelKey(catalogModel.ID)
226+
if err != nil {
227+
return nil, fmt.Errorf("parsing model key %q: %w", catalogModel.ID, err)
232228
}
233229

234230
models = append(models, &ModelSummary{
235-
ID: summary.AssetID,
236-
Name: summary.Name,
237-
FriendlyName: summary.DisplayName,
231+
ID: catalogModel.ID,
232+
Name: modelKey.ModelName,
233+
Registry: catalogModel.Registry,
234+
FriendlyName: catalogModel.Name,
238235
Task: inferenceTask,
239-
Publisher: summary.Publisher,
240-
Summary: summary.Summary,
241-
Version: summary.Version,
242-
RegistryName: summary.RegistryName,
236+
Publisher: catalogModel.Publisher,
237+
Summary: catalogModel.Summary,
238+
Version: catalogModel.Version,
243239
})
244240
}
245241

internal/azuremodels/azure_client_config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ const (
44
defaultInferenceRoot = "https://models.github.ai"
55
defaultInferencePath = "inference/chat/completions"
66
defaultAzureAiStudioURL = "https://api.catalog.azureml.ms"
7-
defaultModelsURL = defaultAzureAiStudioURL + "/asset-gallery/v1.0/models"
7+
defaultModelsURL = "https://models.github.ai/catalog/models"
88
)
99

1010
// AzureClientConfig represents configurable settings for the Azure client.

0 commit comments

Comments
 (0)