Skip to content

Commit d84a1c2

Browse files
committed
Add model key
1 parent a787251 commit d84a1c2

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

internal/modelkey/modelkey.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package modelkey
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
)
7+
8+
type ModelKey struct {
9+
Provider string
10+
Publisher string
11+
ModelName string
12+
}
13+
14+
func ParseModelKey(modelKey string) (*ModelKey, error) {
15+
if modelKey == "" {
16+
return nil, fmt.Errorf("invalid model key format: %s", modelKey)
17+
}
18+
19+
parts := strings.Split(modelKey, "/")
20+
21+
// Check for empty parts
22+
for _, part := range parts {
23+
if part == "" {
24+
return nil, fmt.Errorf("invalid model key format: %s", modelKey)
25+
}
26+
}
27+
28+
switch len(parts) {
29+
case 2:
30+
// Format: publisher/model-name (provider defaults to "azureml")
31+
return &ModelKey{
32+
Provider: "azureml",
33+
Publisher: parts[0],
34+
ModelName: parts[1],
35+
}, nil
36+
case 3:
37+
// Format: provider/publisher/model-name
38+
return &ModelKey{
39+
Provider: parts[0],
40+
Publisher: parts[1],
41+
ModelName: parts[2],
42+
}, nil
43+
default:
44+
return nil, fmt.Errorf("invalid model key format: %s", modelKey)
45+
}
46+
}

internal/modelkey/modelkey_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package modelkey
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestParseModelKey(t *testing.T) {
10+
tests := []struct {
11+
name string
12+
input string
13+
expected *ModelKey
14+
expectError bool
15+
}{
16+
{
17+
name: "valid format with provider",
18+
input: "custom/openai/gpt-4",
19+
expected: &ModelKey{
20+
Provider: "custom",
21+
Publisher: "openai",
22+
ModelName: "gpt-4",
23+
},
24+
expectError: false,
25+
},
26+
{
27+
name: "valid format without provider (defaults to azureml)",
28+
input: "openai/gpt-4",
29+
expected: &ModelKey{
30+
Provider: "azureml",
31+
Publisher: "openai",
32+
ModelName: "gpt-4",
33+
},
34+
expectError: false,
35+
},
36+
{
37+
name: "valid format with azureml provider explicitly",
38+
input: "azureml/microsoft/phi-3",
39+
expected: &ModelKey{
40+
Provider: "azureml",
41+
Publisher: "microsoft",
42+
ModelName: "phi-3",
43+
},
44+
expectError: false,
45+
},
46+
{
47+
name: "valid format with hyphens in model name",
48+
input: "cohere/command-r-plus",
49+
expected: &ModelKey{
50+
Provider: "azureml",
51+
Publisher: "cohere",
52+
ModelName: "command-r-plus",
53+
},
54+
expectError: false,
55+
},
56+
{
57+
name: "valid format with underscores in model name",
58+
input: "ai21/jamba_instruct",
59+
expected: &ModelKey{
60+
Provider: "azureml",
61+
Publisher: "ai21",
62+
ModelName: "jamba_instruct",
63+
},
64+
expectError: false,
65+
},
66+
{
67+
name: "invalid format with only one part",
68+
input: "gpt-4",
69+
expected: nil,
70+
expectError: true,
71+
},
72+
{
73+
name: "invalid format with four parts",
74+
input: "provider/publisher/model/extra",
75+
expected: nil,
76+
expectError: true,
77+
},
78+
{
79+
name: "invalid format with empty string",
80+
input: "",
81+
expected: nil,
82+
expectError: true,
83+
},
84+
{
85+
name: "invalid format with only slashes",
86+
input: "//",
87+
expected: nil,
88+
expectError: true,
89+
},
90+
{
91+
name: "invalid format with empty parts",
92+
input: "provider//model",
93+
expected: nil,
94+
expectError: true,
95+
},
96+
}
97+
98+
for _, tt := range tests {
99+
t.Run(tt.name, func(t *testing.T) {
100+
result, err := ParseModelKey(tt.input)
101+
102+
if tt.expectError {
103+
require.Error(t, err)
104+
require.Nil(t, result)
105+
} else {
106+
require.NoError(t, err)
107+
require.NotNil(t, result)
108+
require.Equal(t, tt.expected.Provider, result.Provider)
109+
require.Equal(t, tt.expected.Publisher, result.Publisher)
110+
require.Equal(t, tt.expected.ModelName, result.ModelName)
111+
}
112+
})
113+
}
114+
}

0 commit comments

Comments
 (0)