From 55978945eb50549300d8f21994faf67afd96f308 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 30 May 2025 15:45:29 +0000 Subject: [PATCH 1/9] Bump github.com/cli/go-gh/v2 in the go_modules group across 1 directory Bumps the go_modules group with 1 update in the / directory: [github.com/cli/go-gh/v2](https://github.com/cli/go-gh). Updates `github.com/cli/go-gh/v2` from 2.11.2 to 2.12.1 - [Release notes](https://github.com/cli/go-gh/releases) - [Commits](https://github.com/cli/go-gh/compare/v2.11.2...v2.12.1) --- updated-dependencies: - dependency-name: github.com/cli/go-gh/v2 dependency-version: 2.12.1 dependency-type: direct:production dependency-group: go_modules ... Signed-off-by: dependabot[bot] --- go.mod | 22 +++++++++++++--------- go.sum | 55 +++++++++++++++++++++++++++++++++++-------------------- 2 files changed, 48 insertions(+), 29 deletions(-) diff --git a/go.mod b/go.mod index 56dae7eb..f4058ea0 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/MakeNowJust/heredoc v1.0.0 github.com/briandowns/spinner v1.23.1 github.com/cli/cli/v2 v2.67.0 - github.com/cli/go-gh/v2 v2.11.2 + github.com/cli/go-gh/v2 v2.12.1 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 @@ -22,9 +22,12 @@ require ( github.com/alecthomas/chroma/v2 v2.14.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect - github.com/charmbracelet/glamour v0.8.0 // indirect - github.com/charmbracelet/lipgloss v0.12.1 // indirect - github.com/charmbracelet/x/ansi v0.1.4 // indirect + github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect + github.com/charmbracelet/glamour v0.9.2-0.20250319212134-549f544650e3 // indirect + github.com/charmbracelet/lipgloss v1.1.1-0.20250319133953-166f707985bc // indirect + github.com/charmbracelet/x/ansi v0.8.0 // indirect + github.com/charmbracelet/x/cellbuf v0.0.13 // indirect + github.com/charmbracelet/x/term v0.2.1 // indirect github.com/cli/safeexec v1.0.1 // indirect github.com/cli/shurcooL-graphql v0.0.4 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect @@ -34,19 +37,20 @@ require ( github.com/henvic/httpretty v0.1.4 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/kr/text v0.2.0 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect github.com/muesli/reflow v0.3.0 // indirect - github.com/muesli/termenv v0.15.3-0.20240618155329-98d742f6907a // indirect + github.com/muesli/termenv v0.16.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e // indirect - github.com/yuin/goldmark v1.7.4 // indirect - github.com/yuin/goldmark-emoji v1.0.3 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/yuin/goldmark v1.7.8 // indirect + github.com/yuin/goldmark-emoji v1.0.5 // indirect golang.org/x/net v0.38.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/term v0.30.0 // indirect diff --git a/go.sum b/go.sum index 47e61b9c..baa469a4 100644 --- a/go.sum +++ b/go.sum @@ -18,18 +18,24 @@ github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuP github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/briandowns/spinner v1.23.1 h1:t5fDPmScwUjozhDj4FA46p5acZWIPXYE30qW2Ptu650= github.com/briandowns/spinner v1.23.1/go.mod h1:LaZeM4wm2Ywy6vO571mvhQNRcWfRUnXOs0RcKV0wYKM= -github.com/charmbracelet/glamour v0.8.0 h1:tPrjL3aRcQbn++7t18wOpgLyl8wrOHUEDS7IZ68QtZs= -github.com/charmbracelet/glamour v0.8.0/go.mod h1:ViRgmKkf3u5S7uakt2czJ272WSg2ZenlYEZXT2x7Bjw= -github.com/charmbracelet/lipgloss v0.12.1 h1:/gmzszl+pedQpjCOH+wFkZr/N90Snz40J/NR7A0zQcs= -github.com/charmbracelet/lipgloss v0.12.1/go.mod h1:V2CiwIuhx9S1S1ZlADfOj9HmxeMAORuz5izHb0zGbB8= -github.com/charmbracelet/x/ansi v0.1.4 h1:IEU3D6+dWwPSgZ6HBH+v6oUuZ/nVawMiWj5831KfiLM= -github.com/charmbracelet/x/ansi v0.1.4/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw= -github.com/charmbracelet/x/exp/golden v0.0.0-20240715153702-9ba8adf781c4 h1:6KzMkQeAF56rggw2NZu1L+TH7j9+DM1/2Kmh7KUxg1I= -github.com/charmbracelet/x/exp/golden v0.0.0-20240715153702-9ba8adf781c4/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/charmbracelet/glamour v0.9.2-0.20250319212134-549f544650e3 h1:hx6E25SvI2WiZdt/gxINcYBnHD7PE2Vr9auqwg5B05g= +github.com/charmbracelet/glamour v0.9.2-0.20250319212134-549f544650e3/go.mod h1:ihVqv4/YOY5Fweu1cxajuQrwJFh3zU4Ukb4mHVNjq3s= +github.com/charmbracelet/lipgloss v1.1.1-0.20250319133953-166f707985bc h1:nFRtCfZu/zkltd2lsLUPlVNv3ej/Atod9hcdbRZtlys= +github.com/charmbracelet/lipgloss v1.1.1-0.20250319133953-166f707985bc/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA= +github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE= +github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q= +github.com/charmbracelet/x/cellbuf v0.0.13 h1:/KBBKHuVRbq1lYx5BzEHBAFBP8VcQzJejZ/IA3iR28k= +github.com/charmbracelet/x/cellbuf v0.0.13/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/exp/golden v0.0.0-20240806155701-69247e0abc2a h1:G99klV19u0QnhiizODirwVksQB91TJKV/UaTnACcG30= +github.com/charmbracelet/x/exp/golden v0.0.0-20240806155701-69247e0abc2a/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= +github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= +github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= github.com/cli/cli/v2 v2.67.0 h1:uV40wKPbtHPJH8coGSKZDqxw9fNeqlWqPwE7pdefQFI= github.com/cli/cli/v2 v2.67.0/go.mod h1:6VPo4p7DcIiFfJtn5iBPwAjNcfmI0zlZKwVtM7EtIig= -github.com/cli/go-gh/v2 v2.11.2 h1:oad1+sESTPNTiTvh3I3t8UmxuovNDxhwLzeMHk45Q9w= -github.com/cli/go-gh/v2 v2.11.2/go.mod h1:vVFhi3TfjseIW26ED9itAR8gQK0aVThTm8sYrsZ5QTI= +github.com/cli/go-gh/v2 v2.12.1 h1:SVt1/afj5FRAythyMV3WJKaUfDNsxXTIe7arZbwTWKA= +github.com/cli/go-gh/v2 v2.12.1/go.mod h1:+5aXmEOJsH9fc9mBHfincDwnS02j2AIA/DsTH0Bk5uw= github.com/cli/safeexec v1.0.1 h1:e/C79PbXF4yYTN/wauC4tviMxEV13BwljGj0N9j+N00= github.com/cli/safeexec v1.0.1/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q= github.com/cli/shurcooL-graphql v0.0.4 h1:6MogPnQJLjKkaXPyGqPRXOI2qCsQdqNfUY1QSJu2GuY= @@ -61,10 +67,12 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leaanthony/go-ansi-parser v1.6.1 h1:xd8bzARK3dErqkPFtoF9F3/HgN8UQk0ed1YDKpEz01A= +github.com/leaanthony/go-ansi-parser v1.6.1/go.mod h1:+vva/2y4alzVmmIEpk9QDhA7vLC5zKDTRwfZGOp3IWU= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= @@ -74,8 +82,8 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= @@ -83,8 +91,9 @@ github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwX github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= -github.com/muesli/termenv v0.15.3-0.20240618155329-98d742f6907a h1:2MaM6YC3mGu54x+RKAA6JiFFHlHDY1UbkxqppT7wYOg= -github.com/muesli/termenv v0.15.3-0.20240618155329-98d742f6907a/go.mod h1:hxSnBBYLK21Vtq/PHd0S2FYCxBXzBua8ov5s1RobyRQ= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -92,6 +101,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= @@ -103,14 +114,18 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e h1:BuzhfgfWQbX0dWzYzT1zsORLnHRv3bcRcsaUk0VmXA8= github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e/go.mod h1:/Tnicc6m/lsJE0irFMA0LfIwTBo4QP7A8IfyIv4zZKI= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= -github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -github.com/yuin/goldmark-emoji v1.0.3 h1:aLRkLHOuBR2czCY4R8olwMjID+tENfhyFDMCRhbIQY4= -github.com/yuin/goldmark-emoji v1.0.3/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U= +github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= +github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC4cIk= +github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw= +golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= From 16482d9413e91494b928f4e1c915bd6a3a23325b Mon Sep 17 00:00:00 2001 From: Daniel Garman Date: Thu, 12 Jun 2025 02:54:45 +0000 Subject: [PATCH 2/9] add --org flag to run and eval --- cmd/eval/eval.go | 16 +++++++++++++--- cmd/run/run.go | 8 +++++--- internal/azuremodels/azure_client.go | 11 +++++++++-- internal/azuremodels/azure_client_config.go | 9 ++++++--- internal/azuremodels/client.go | 2 +- internal/azuremodels/mock_client.go | 8 ++++---- internal/azuremodels/types.go | 13 +++++++------ internal/azuremodels/unauthenticated_client.go | 2 +- 8 files changed, 46 insertions(+), 23 deletions(-) diff --git a/cmd/eval/eval.go b/cmd/eval/eval.go index 149fad26..3baba905 100644 --- a/cmd/eval/eval.go +++ b/cmd/eval/eval.go @@ -48,6 +48,10 @@ type EvaluationResult struct { Details string `json:"details,omitempty"` } +type Organization struct { + Name string `json:"name"` +} + var FailedTests = errors.New("❌ Some tests failed.") // NewEvalCommand returns a new command to evaluate prompts against models @@ -66,7 +70,7 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command { Example prompt.yml structure: name: My Evaluation - model: gpt-4o + model: openai/gpt-4o testData: - input: "Hello world" expected: "Hello there" @@ -94,6 +98,9 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command { return err } + // Get the org flag + org, _ := cmd.Flags().GetString("org") + // Load the evaluation prompt file evalFile, err := loadEvaluationPromptFile(promptFilePath) if err != nil { @@ -106,6 +113,7 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command { client: cfg.Client, evalFile: evalFile, jsonOutput: jsonOutput, + org: org, } err = handler.runEvaluation(cmd.Context()) @@ -120,6 +128,7 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command { } cmd.Flags().Bool("json", false, "Output results in JSON format") + cmd.Flags().String("org", "", "Organization to attribute usage to (omitting will attribute usage to the current actor") return cmd } @@ -128,6 +137,7 @@ type evalCommandHandler struct { client azuremodels.Client evalFile *prompt.File jsonOutput bool + org string } func loadEvaluationPromptFile(filePath string) (*prompt.File, error) { @@ -321,7 +331,7 @@ func (h *evalCommandHandler) templateString(templateStr string, data map[string] func (h *evalCommandHandler) callModel(ctx context.Context, messages []azuremodels.ChatMessage) (string, error) { req := h.evalFile.BuildChatCompletionOptions(messages) - resp, err := h.client.GetChatCompletionStream(ctx, req) + resp, err := h.client.GetChatCompletionStream(ctx, req, h.org) if err != nil { return "", err } @@ -460,7 +470,7 @@ func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, e Stream: false, } - resp, err := h.client.GetChatCompletionStream(ctx, req) + resp, err := h.client.GetChatCompletionStream(ctx, req, h.org) if err != nil { return EvaluationResult{}, fmt.Errorf("failed to call evaluation model: %w", err) } diff --git a/cmd/run/run.go b/cmd/run/run.go index 418f4da7..5f87da7a 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -216,6 +216,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { filePath, _ := cmd.Flags().GetString("file") + org, _ := cmd.Flags().GetString("org") var pf *prompt.File if filePath != "" { var err error @@ -357,7 +358,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { //nolint:gocritic,revive // TODO defer sp.Stop() - reader, err := cmdHandler.getChatCompletionStreamReader(req) + reader, err := cmdHandler.getChatCompletionStreamReader(req, org) if err != nil { return err } @@ -408,6 +409,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { cmd.Flags().String("temperature", "", "Controls randomness in the response, use lower to be more deterministic.") cmd.Flags().String("top-p", "", "Controls text diversity by selecting the most probable words until a set probability is reached.") cmd.Flags().String("system-prompt", "", "Prompt the system.") + cmd.Flags().String("org", "", "Organization to attribute usage to (omitting will attribute usage to the current actor") return cmd } @@ -522,8 +524,8 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st return modelName, nil } -func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions) (sse.Reader[azuremodels.ChatCompletion], error) { - resp, err := h.client.GetChatCompletionStream(h.ctx, req) +func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions, org string) (sse.Reader[azuremodels.ChatCompletion], error) { + resp, err := h.client.GetChatCompletionStream(h.ctx, req, org) if err != nil { return nil, err } diff --git a/internal/azuremodels/azure_client.go b/internal/azuremodels/azure_client.go index a4a0c98b..bf747134 100644 --- a/internal/azuremodels/azure_client.go +++ b/internal/azuremodels/azure_client.go @@ -40,7 +40,7 @@ func NewAzureClient(httpClient *http.Client, authToken string, cfg *AzureClientC } // GetChatCompletionStream returns a stream of chat completions using the given options. -func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions) (*ChatCompletionResponse, error) { +func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions, org string) (*ChatCompletionResponse, error) { // Check for o1 models, which don't support streaming if req.Model == "o1-mini" || req.Model == "o1-preview" || req.Model == "o1" { req.Stream = false @@ -55,7 +55,14 @@ func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompl body := bytes.NewReader(bodyBytes) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.InferenceURL, body) + var inferenceURL string + if org != "" { + inferenceURL = fmt.Sprintf("%s/orgs/%s/%s", c.cfg.InferenceRoot, org, c.cfg.InferencePath) + } else { + inferenceURL = c.cfg.InferenceRoot + "/" + c.cfg.InferencePath + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, inferenceURL, body) if err != nil { return nil, err } diff --git a/internal/azuremodels/azure_client_config.go b/internal/azuremodels/azure_client_config.go index 58433e83..da8eae04 100644 --- a/internal/azuremodels/azure_client_config.go +++ b/internal/azuremodels/azure_client_config.go @@ -1,14 +1,16 @@ package azuremodels const ( - defaultInferenceURL = "https://models.github.ai/inference/chat/completions" + defaultInferenceRoot = "https://models.github.ai" + defaultInferencePath = "inference/chat/completions" defaultAzureAiStudioURL = "https://api.catalog.azureml.ms" defaultModelsURL = defaultAzureAiStudioURL + "/asset-gallery/v1.0/models" ) // AzureClientConfig represents configurable settings for the Azure client. type AzureClientConfig struct { - InferenceURL string + InferenceRoot string + InferencePath string AzureAiStudioURL string ModelsURL string } @@ -16,7 +18,8 @@ type AzureClientConfig struct { // NewDefaultAzureClientConfig returns a new AzureClientConfig with default values for API URLs. func NewDefaultAzureClientConfig() *AzureClientConfig { return &AzureClientConfig{ - InferenceURL: defaultInferenceURL, + InferenceRoot: defaultInferenceRoot, + InferencePath: defaultInferencePath, AzureAiStudioURL: defaultAzureAiStudioURL, ModelsURL: defaultModelsURL, } diff --git a/internal/azuremodels/client.go b/internal/azuremodels/client.go index 9681decd..a3f68ca3 100644 --- a/internal/azuremodels/client.go +++ b/internal/azuremodels/client.go @@ -5,7 +5,7 @@ import "context" // Client represents a client for interacting with an API about models. type Client interface { // GetChatCompletionStream returns a stream of chat completions using the given options. - GetChatCompletionStream(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) + GetChatCompletionStream(context.Context, ChatCompletionOptions, string) (*ChatCompletionResponse, error) // GetModelDetails returns the details of the specified model in a particular registry. GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) // ListModels returns a list of available models. diff --git a/internal/azuremodels/mock_client.go b/internal/azuremodels/mock_client.go index c15cfb6d..a926b297 100644 --- a/internal/azuremodels/mock_client.go +++ b/internal/azuremodels/mock_client.go @@ -7,7 +7,7 @@ import ( // MockClient provides a client for interacting with the Azure models API in tests. type MockClient struct { - MockGetChatCompletionStream func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) + MockGetChatCompletionStream func(context.Context, ChatCompletionOptions, string) (*ChatCompletionResponse, error) MockGetModelDetails func(context.Context, string, string, string) (*ModelDetails, error) MockListModels func(context.Context) ([]*ModelSummary, error) } @@ -15,7 +15,7 @@ type MockClient struct { // NewMockClient returns a new mock client for stubbing out interactions with the models API. func NewMockClient() *MockClient { return &MockClient{ - MockGetChatCompletionStream: func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) { + MockGetChatCompletionStream: func(context.Context, ChatCompletionOptions, string) (*ChatCompletionResponse, error) { return nil, errors.New("GetChatCompletionStream not implemented") }, MockGetModelDetails: func(context.Context, string, string, string) (*ModelDetails, error) { @@ -28,8 +28,8 @@ func NewMockClient() *MockClient { } // GetChatCompletionStream calls the mocked function for getting a stream of chat completions for the given request. -func (c *MockClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) { - return c.MockGetChatCompletionStream(ctx, opt) +func (c *MockClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions, org string) (*ChatCompletionResponse, error) { + return c.MockGetChatCompletionStream(ctx, opt, org) } // GetModelDetails calls the mocked function for getting the details of the specified model in a particular registry. diff --git a/internal/azuremodels/types.go b/internal/azuremodels/types.go index 29d4a7d1..ab7b43a2 100644 --- a/internal/azuremodels/types.go +++ b/internal/azuremodels/types.go @@ -26,12 +26,13 @@ type ChatMessage struct { // ChatCompletionOptions represents available options for a chat completion request. type ChatCompletionOptions struct { - MaxTokens *int `json:"max_tokens,omitempty"` - Messages []ChatMessage `json:"messages"` - Model string `json:"model"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Messages []ChatMessage `json:"messages"` + Model string `json:"model"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Organization *string `json:"organization,omitempty"` } // ChatChoiceMessage is a message from a choice in a chat conversation. diff --git a/internal/azuremodels/unauthenticated_client.go b/internal/azuremodels/unauthenticated_client.go index 2f35aa89..e755f0a8 100644 --- a/internal/azuremodels/unauthenticated_client.go +++ b/internal/azuremodels/unauthenticated_client.go @@ -15,7 +15,7 @@ func NewUnauthenticatedClient() *UnauthenticatedClient { } // GetChatCompletionStream returns an error because this functionality requires authentication. -func (c *UnauthenticatedClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) { +func (c *UnauthenticatedClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions, org string) (*ChatCompletionResponse, error) { return nil, errors.New("not authenticated") } From b12c4dffb4e2ce71fbeb4fea96fd999f7caad4ac Mon Sep 17 00:00:00 2001 From: Daniel Garman Date: Fri, 13 Jun 2025 01:09:46 +0000 Subject: [PATCH 3/9] update existing tests --- cmd/eval/eval_test.go | 12 ++++++------ cmd/run/run_test.go | 8 ++++---- internal/azuremodels/azure_client_test.go | 12 ++++++------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/cmd/eval/eval_test.go b/cmd/eval/eval_test.go index 123dcc2b..78b67439 100644 --- a/cmd/eval/eval_test.go +++ b/cmd/eval/eval_test.go @@ -162,7 +162,7 @@ evaluators: cfg := command.NewConfig(out, out, client, true, 100) // Mock a response that returns "4" for the LLM evaluator - client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{ { Choices: []azuremodels.ChatChoice{ @@ -228,7 +228,7 @@ evaluators: client := azuremodels.NewMockClient() // Mock a simple response - client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { // Create a mock reader that returns "test response" reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{ { @@ -284,7 +284,7 @@ evaluators: client := azuremodels.NewMockClient() // Mock a response that will fail the evaluator - client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{ { Choices: []azuremodels.ChatChoice{ @@ -346,7 +346,7 @@ evaluators: // Mock responses for both test cases callCount := 0 - client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { callCount++ var response string if callCount == 1 { @@ -444,7 +444,7 @@ evaluators: require.NoError(t, err) client := azuremodels.NewMockClient() - client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { response := "hello world" reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{ { @@ -526,7 +526,7 @@ evaluators: require.NoError(t, err) client := azuremodels.NewMockClient() - client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { response := "hello world" reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{ { diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index c0a5a48b..43ef6a1c 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -45,7 +45,7 @@ func TestRun(t *testing.T) { Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), } getChatCompletionCallCount := 0 - client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { getChatCompletionCallCount++ return chatResp, nil } @@ -122,7 +122,7 @@ messages: }, }}, } - client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { capturedReq = opt return &azuremodels.ChatCompletionResponse{ Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), @@ -188,7 +188,7 @@ messages: }, }}, } - client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { capturedReq = opt return &azuremodels.ChatCompletionResponse{ Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), @@ -278,7 +278,7 @@ messages: }}, } - client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { capturedReq = opt return &azuremodels.ChatCompletionResponse{ Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), diff --git a/internal/azuremodels/azure_client_test.go b/internal/azuremodels/azure_client_test.go index 17002da7..8d84e302 100644 --- a/internal/azuremodels/azure_client_test.go +++ b/internal/azuremodels/azure_client_test.go @@ -49,7 +49,7 @@ func TestAzureClient(t *testing.T) { require.NoError(t, err) })) defer testServer.Close() - cfg := &AzureClientConfig{InferenceURL: testServer.URL} + cfg := &AzureClientConfig{InferenceRoot: testServer.URL} httpClient := testServer.Client() client := NewAzureClient(httpClient, authToken, cfg) opts := ChatCompletionOptions{ @@ -63,7 +63,7 @@ func TestAzureClient(t *testing.T) { }, } - chatCompletionStreamResp, err := client.GetChatCompletionStream(ctx, opts) + chatCompletionStreamResp, err := client.GetChatCompletionStream(ctx, opts, "") require.NoError(t, err) require.NotNil(t, chatCompletionStreamResp) @@ -125,7 +125,7 @@ func TestAzureClient(t *testing.T) { require.NoError(t, err) })) defer testServer.Close() - cfg := &AzureClientConfig{InferenceURL: testServer.URL} + cfg := &AzureClientConfig{InferenceRoot: testServer.URL} httpClient := testServer.Client() client := NewAzureClient(httpClient, authToken, cfg) opts := ChatCompletionOptions{ @@ -139,7 +139,7 @@ func TestAzureClient(t *testing.T) { }, } - chatCompletionStreamResp, err := client.GetChatCompletionStream(ctx, opts) + chatCompletionStreamResp, err := client.GetChatCompletionStream(ctx, opts, "") require.NoError(t, err) require.NotNil(t, chatCompletionStreamResp) @@ -173,7 +173,7 @@ func TestAzureClient(t *testing.T) { require.NoError(t, err) })) defer testServer.Close() - cfg := &AzureClientConfig{InferenceURL: testServer.URL} + cfg := &AzureClientConfig{InferenceRoot: testServer.URL} httpClient := testServer.Client() client := NewAzureClient(httpClient, "fake-token-123abc", cfg) opts := ChatCompletionOptions{ @@ -181,7 +181,7 @@ func TestAzureClient(t *testing.T) { Messages: []ChatMessage{{Role: "user", Content: util.Ptr("Tell me a story, test model.")}}, } - chatCompletionResp, err := client.GetChatCompletionStream(ctx, opts) + chatCompletionResp, err := client.GetChatCompletionStream(ctx, opts, "") require.Error(t, err) require.Nil(t, chatCompletionResp) From 68ce62d1eff0a5fb10cd056649b0c18e516a51af Mon Sep 17 00:00:00 2001 From: Christopher Schleiden Date: Tue, 17 Jun 2025 11:15:00 +0200 Subject: [PATCH 4/9] Add examples --- cmd/eval/eval.go | 7 +++++-- cmd/run/run.go | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/cmd/eval/eval.go b/cmd/eval/eval.go index 3baba905..02877466 100644 --- a/cmd/eval/eval.go +++ b/cmd/eval/eval.go @@ -87,8 +87,11 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command { See https://docs.github.com/github-models/use-github-models/storing-prompts-in-github-repositories#supported-file-format for more information. `), - Example: "gh models eval my_prompt.prompt.yml", - Args: cobra.ExactArgs(1), + Example: heredoc.Doc(` + gh models eval my_prompt.prompt.yml + gh models eval --org my-org my_prompt.prompt.yml + `), + Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { promptFilePath := args[0] diff --git a/cmd/run/run.go b/cmd/run/run.go index 5f87da7a..e380de5b 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -207,10 +207,14 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { When using prompt files, you can pass template variables using the %[1]s--var%[1]s flag: %[1]sgh models run --file prompt.yml --var name=Alice --var topic=AI%[1]s + When running inference against an organization, pass the organization name using the %[1]s--org%[1]s flag: + %[1]sgh models run --org my-org openai/gpt-4o-mini "What is AI?"%[1]s + The return value will be the response to your prompt from the selected model. `, "`"), Example: heredoc.Doc(` gh models run openai/gpt-4o-mini "how many types of hyena are there?" + gh models run --org my-org openai/gpt-4o-mini "how many types of hyena are there?" gh models run --file prompt.yml --var name=Alice --var topic="machine learning" `), Args: cobra.ArbitraryArgs, From a78725160c51ddc5e20ee88e67d6b62e7580c6e1 Mon Sep 17 00:00:00 2001 From: Christopher Schleiden Date: Tue, 17 Jun 2025 11:23:28 +0200 Subject: [PATCH 5/9] Remove unused struct & field --- cmd/eval/eval.go | 4 ---- internal/azuremodels/types.go | 13 ++++++------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/cmd/eval/eval.go b/cmd/eval/eval.go index 02877466..5a6b39c2 100644 --- a/cmd/eval/eval.go +++ b/cmd/eval/eval.go @@ -48,10 +48,6 @@ type EvaluationResult struct { Details string `json:"details,omitempty"` } -type Organization struct { - Name string `json:"name"` -} - var FailedTests = errors.New("❌ Some tests failed.") // NewEvalCommand returns a new command to evaluate prompts against models diff --git a/internal/azuremodels/types.go b/internal/azuremodels/types.go index ab7b43a2..29d4a7d1 100644 --- a/internal/azuremodels/types.go +++ b/internal/azuremodels/types.go @@ -26,13 +26,12 @@ type ChatMessage struct { // ChatCompletionOptions represents available options for a chat completion request. type ChatCompletionOptions struct { - MaxTokens *int `json:"max_tokens,omitempty"` - Messages []ChatMessage `json:"messages"` - Model string `json:"model"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - Organization *string `json:"organization,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Messages []ChatMessage `json:"messages"` + Model string `json:"model"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` } // ChatChoiceMessage is a message from a choice in a chat conversation. From d84a1c22cc11d9c3b6d33400aaca5d504478daeb Mon Sep 17 00:00:00 2001 From: Christopher Schleiden Date: Tue, 17 Jun 2025 12:15:33 +0200 Subject: [PATCH 6/9] Add model key --- internal/modelkey/modelkey.go | 46 ++++++++++++ internal/modelkey/modelkey_test.go | 114 +++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+) create mode 100644 internal/modelkey/modelkey.go create mode 100644 internal/modelkey/modelkey_test.go diff --git a/internal/modelkey/modelkey.go b/internal/modelkey/modelkey.go new file mode 100644 index 00000000..9cec0eac --- /dev/null +++ b/internal/modelkey/modelkey.go @@ -0,0 +1,46 @@ +package modelkey + +import ( + "fmt" + "strings" +) + +type ModelKey struct { + Provider string + Publisher string + ModelName string +} + +func ParseModelKey(modelKey string) (*ModelKey, error) { + if modelKey == "" { + return nil, fmt.Errorf("invalid model key format: %s", modelKey) + } + + parts := strings.Split(modelKey, "/") + + // Check for empty parts + for _, part := range parts { + if part == "" { + return nil, fmt.Errorf("invalid model key format: %s", modelKey) + } + } + + switch len(parts) { + case 2: + // Format: publisher/model-name (provider defaults to "azureml") + return &ModelKey{ + Provider: "azureml", + Publisher: parts[0], + ModelName: parts[1], + }, nil + case 3: + // Format: provider/publisher/model-name + return &ModelKey{ + Provider: parts[0], + Publisher: parts[1], + ModelName: parts[2], + }, nil + default: + return nil, fmt.Errorf("invalid model key format: %s", modelKey) + } +} diff --git a/internal/modelkey/modelkey_test.go b/internal/modelkey/modelkey_test.go new file mode 100644 index 00000000..561447c7 --- /dev/null +++ b/internal/modelkey/modelkey_test.go @@ -0,0 +1,114 @@ +package modelkey + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseModelKey(t *testing.T) { + tests := []struct { + name string + input string + expected *ModelKey + expectError bool + }{ + { + name: "valid format with provider", + input: "custom/openai/gpt-4", + expected: &ModelKey{ + Provider: "custom", + Publisher: "openai", + ModelName: "gpt-4", + }, + expectError: false, + }, + { + name: "valid format without provider (defaults to azureml)", + input: "openai/gpt-4", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "openai", + ModelName: "gpt-4", + }, + expectError: false, + }, + { + name: "valid format with azureml provider explicitly", + input: "azureml/microsoft/phi-3", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "microsoft", + ModelName: "phi-3", + }, + expectError: false, + }, + { + name: "valid format with hyphens in model name", + input: "cohere/command-r-plus", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "cohere", + ModelName: "command-r-plus", + }, + expectError: false, + }, + { + name: "valid format with underscores in model name", + input: "ai21/jamba_instruct", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "ai21", + ModelName: "jamba_instruct", + }, + expectError: false, + }, + { + name: "invalid format with only one part", + input: "gpt-4", + expected: nil, + expectError: true, + }, + { + name: "invalid format with four parts", + input: "provider/publisher/model/extra", + expected: nil, + expectError: true, + }, + { + name: "invalid format with empty string", + input: "", + expected: nil, + expectError: true, + }, + { + name: "invalid format with only slashes", + input: "//", + expected: nil, + expectError: true, + }, + { + name: "invalid format with empty parts", + input: "provider//model", + expected: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseModelKey(tt.input) + + if tt.expectError { + require.Error(t, err) + require.Nil(t, result) + } else { + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, tt.expected.Provider, result.Provider) + require.Equal(t, tt.expected.Publisher, result.Publisher) + require.Equal(t, tt.expected.ModelName, result.ModelName) + } + }) + } +} From 9a0e37bbf94ccf20770ef1726ce102cabbc26386 Mon Sep 17 00:00:00 2001 From: Christopher Schleiden Date: Tue, 17 Jun 2025 12:18:40 +0200 Subject: [PATCH 7/9] Convert model key to string --- internal/modelkey/modelkey.go | 5 +++ internal/modelkey/modelkey_test.go | 61 ++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/internal/modelkey/modelkey.go b/internal/modelkey/modelkey.go index 9cec0eac..e58990a7 100644 --- a/internal/modelkey/modelkey.go +++ b/internal/modelkey/modelkey.go @@ -44,3 +44,8 @@ func ParseModelKey(modelKey string) (*ModelKey, error) { return nil, fmt.Errorf("invalid model key format: %s", modelKey) } } + +// String returns the string representation of the ModelKey in the format provider/publisher/model-name +func (mk *ModelKey) String() string { + return fmt.Sprintf("%s/%s/%s", mk.Provider, mk.Publisher, mk.ModelName) +} diff --git a/internal/modelkey/modelkey_test.go b/internal/modelkey/modelkey_test.go index 561447c7..ea4583fa 100644 --- a/internal/modelkey/modelkey_test.go +++ b/internal/modelkey/modelkey_test.go @@ -112,3 +112,64 @@ func TestParseModelKey(t *testing.T) { }) } } + +func TestModelKey_String(t *testing.T) { + tests := []struct { + name string + modelKey *ModelKey + expected string + }{ + { + name: "standard format with azureml provider", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "openai", + ModelName: "gpt-4", + }, + expected: "azureml/openai/gpt-4", + }, + { + name: "custom provider", + modelKey: &ModelKey{ + Provider: "custom", + Publisher: "microsoft", + ModelName: "phi-3", + }, + expected: "custom/microsoft/phi-3", + }, + { + name: "model name with hyphens", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "cohere", + ModelName: "command-r-plus", + }, + expected: "azureml/cohere/command-r-plus", + }, + { + name: "model name with underscores", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "ai21", + ModelName: "jamba_instruct", + }, + expected: "azureml/ai21/jamba_instruct", + }, + { + name: "long provider name", + modelKey: &ModelKey{ + Provider: "custom-provider", + Publisher: "test-publisher", + ModelName: "test-model", + }, + expected: "custom-provider/test-publisher/test-model", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.modelKey.String() + require.Equal(t, tt.expected, result) + }) + } +} From d9164f069b51283349db20299128cac32f0c2198 Mon Sep 17 00:00:00 2001 From: Christopher Schleiden Date: Tue, 17 Jun 2025 12:43:16 +0200 Subject: [PATCH 8/9] Do not validate models for the custom provider --- cmd/run/run.go | 17 +++++++++++++-- cmd/run/run_test.go | 53 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index e380de5b..1fe574b2 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -16,6 +16,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/briandowns/spinner" "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/modelkey" "github.com/github/gh-models/internal/sse" "github.com/github/gh-models/pkg/command" "github.com/github/gh-models/pkg/prompt" @@ -513,9 +514,21 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st return "", errors.New(noMatchErrorMessage) } + parsedModel, err := modelkey.ParseModelKey(modelName) + if err != nil { + return "", fmt.Errorf("invalid model format: %w", err) + } + + if parsedModel.Provider == "custom" { + // Skip validation for custom provider + return parsedModel.String(), nil + } + + // For non-custom providers, validate the model exists + expectedModelID := azuremodels.FormatIdentifier(parsedModel.Publisher, parsedModel.ModelName) foundMatch := false for _, model := range models { - if model.HasName(modelName) { + if model.HasName(expectedModelID) { foundMatch = true break } @@ -525,7 +538,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st return "", errors.New(noMatchErrorMessage) } - return modelName, nil + return expectedModelID, nil } func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions, org string) (sse.Reader[azuremodels.ChatCompletion], error) { diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 43ef6a1c..eb10649c 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -403,3 +403,56 @@ func TestParseTemplateVariables(t *testing.T) { }) } } + +func TestValidateModelName(t *testing.T) { + tests := []struct { + name string + modelName string + expectedModel string + expectError bool + }{ + { + name: "custom provider skips validation", + modelName: "custom/mycompany/custom-model", + expectedModel: "custom/mycompany/custom-model", + expectError: false, + }, + { + name: "azureml provider requires validation", + modelName: "openai/gpt-4", + expectedModel: "openai/gpt-4", + expectError: false, + }, + { + name: "invalid model format", + modelName: "invalid-format", + expectError: true, + }, + { + name: "nonexistent azureml model", + modelName: "nonexistent/model", + expectError: true, + }, + } + + // Create a mock model for testing + mockModel := &azuremodels.ModelSummary{ + Name: "gpt-4", + Publisher: "openai", + Task: "chat-completion", + } + models := []*azuremodels.ModelSummary{mockModel} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := validateModelName(tt.modelName, models) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedModel, result) + } + }) + } +} From bd103368cece54e331e91a08b3d9accb0b3d5a8a Mon Sep 17 00:00:00 2001 From: Christopher Schleiden Date: Tue, 17 Jun 2025 14:38:47 +0200 Subject: [PATCH 9/9] Refactor model key formatting to use centralized function and update tests for azureml provider behavior --- internal/azuremodels/model_details.go | 12 ++------ internal/modelkey/modelkey.go | 29 ++++++++++++++++-- internal/modelkey/modelkey_test.go | 43 ++++++++++++++++++++++----- 3 files changed, 65 insertions(+), 19 deletions(-) diff --git a/internal/azuremodels/model_details.go b/internal/azuremodels/model_details.go index ecd135ac..53289cf0 100644 --- a/internal/azuremodels/model_details.go +++ b/internal/azuremodels/model_details.go @@ -2,7 +2,8 @@ package azuremodels import ( "fmt" - "strings" + + "github.com/github/gh-models/internal/modelkey" ) // ModelDetails includes detailed information about a model. @@ -28,12 +29,5 @@ func (m *ModelDetails) ContextLimits() string { // FormatIdentifier formats the model identifier based on the publisher and model name. func FormatIdentifier(publisher, name string) string { - formatPart := func(s string) string { - // Replace spaces with dashes and convert to lowercase - result := strings.ToLower(s) - result = strings.ReplaceAll(result, " ", "-") - return result - } - - return fmt.Sprintf("%s/%s", formatPart(publisher), formatPart(name)) + return modelkey.FormatIdentifier("azureml", publisher, name) } diff --git a/internal/modelkey/modelkey.go b/internal/modelkey/modelkey.go index e58990a7..bd18562d 100644 --- a/internal/modelkey/modelkey.go +++ b/internal/modelkey/modelkey.go @@ -45,7 +45,32 @@ func ParseModelKey(modelKey string) (*ModelKey, error) { } } -// String returns the string representation of the ModelKey in the format provider/publisher/model-name +// String returns the string representation of the ModelKey. func (mk *ModelKey) String() string { - return fmt.Sprintf("%s/%s/%s", mk.Provider, mk.Publisher, mk.ModelName) + provider := formatPart(mk.Provider) + publisher := formatPart(mk.Publisher) + modelName := formatPart(mk.ModelName) + + if provider == "azureml" { + return fmt.Sprintf("%s/%s", publisher, modelName) + } + + return fmt.Sprintf("%s/%s/%s", provider, publisher, modelName) +} + +func formatPart(s string) string { + s = strings.ToLower(s) + s = strings.ReplaceAll(s, " ", "-") + + return s +} + +func FormatIdentifier(provider, publisher, name string) string { + mk := &ModelKey{ + Provider: provider, + Publisher: publisher, + ModelName: name, + } + + return mk.String() } diff --git a/internal/modelkey/modelkey_test.go b/internal/modelkey/modelkey_test.go index ea4583fa..f4d13410 100644 --- a/internal/modelkey/modelkey_test.go +++ b/internal/modelkey/modelkey_test.go @@ -120,16 +120,16 @@ func TestModelKey_String(t *testing.T) { expected string }{ { - name: "standard format with azureml provider", + name: "standard format with azureml provider - should omit provider", modelKey: &ModelKey{ Provider: "azureml", Publisher: "openai", ModelName: "gpt-4", }, - expected: "azureml/openai/gpt-4", + expected: "openai/gpt-4", }, { - name: "custom provider", + name: "custom provider - should include provider", modelKey: &ModelKey{ Provider: "custom", Publisher: "microsoft", @@ -138,25 +138,25 @@ func TestModelKey_String(t *testing.T) { expected: "custom/microsoft/phi-3", }, { - name: "model name with hyphens", + name: "azureml provider with hyphens - should omit provider", modelKey: &ModelKey{ Provider: "azureml", Publisher: "cohere", ModelName: "command-r-plus", }, - expected: "azureml/cohere/command-r-plus", + expected: "cohere/command-r-plus", }, { - name: "model name with underscores", + name: "azureml provider with underscores - should omit provider", modelKey: &ModelKey{ Provider: "azureml", Publisher: "ai21", ModelName: "jamba_instruct", }, - expected: "azureml/ai21/jamba_instruct", + expected: "ai21/jamba_instruct", }, { - name: "long provider name", + name: "non-azureml provider - should include provider", modelKey: &ModelKey{ Provider: "custom-provider", Publisher: "test-publisher", @@ -164,6 +164,33 @@ func TestModelKey_String(t *testing.T) { }, expected: "custom-provider/test-publisher/test-model", }, + { + name: "azureml provider with uppercase and spaces - should format and omit provider", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "Open AI", + ModelName: "GPT 4", + }, + expected: "open-ai/gpt-4", + }, + { + name: "non-azureml provider with uppercase and spaces - should format and include provider", + modelKey: &ModelKey{ + Provider: "Custom Provider", + Publisher: "Test Publisher", + ModelName: "Test Model Name", + }, + expected: "custom-provider/test-publisher/test-model-name", + }, + { + name: "mixed case with multiple spaces", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "Microsoft Corporation", + ModelName: "Phi 3 Mini Instruct", + }, + expected: "microsoft-corporation/phi-3-mini-instruct", + }, } for _, tt := range tests {