@@ -18,14 +18,20 @@ import (
1818 "crypto/rsa"
1919 "crypto/x509"
2020 "crypto/x509/pkix"
21+ "encoding/base64"
22+ "encoding/json"
2123 "encoding/pem"
2224 "fmt"
2325 "math/big"
26+ "net/http"
27+ "net/http/httptest"
28+ "strings"
2429 "testing"
2530 "time"
2631
2732 "github.com/IBM/sarama"
2833 "github.com/stretchr/testify/require"
34+ "golang.org/x/oauth2"
2935)
3036
3137func getAuthBaseMetadata () map [string ]string {
@@ -122,4 +128,53 @@ func TestAuth(t *testing.T) {
122128 require .False (t , mockConfig .Net .TLS .Enable )
123129 require .Nil (t , mockConfig .Net .TLS .Config )
124130 })
131+
132+ t .Run ("oidc private key jwt uses flattened audience" , func (t * testing.T ) {
133+ key , err := rsa .GenerateKey (rand .Reader , 2048 )
134+ require .NoError (t , err )
135+ keyPEM := pem .EncodeToMemory (& pem.Block {
136+ Type : "RSA PRIVATE KEY" ,
137+ Bytes : x509 .MarshalPKCS1PrivateKey (key ),
138+ })
139+ certPEM , _ , err := createTestCert ()
140+ require .NoError (t , err )
141+
142+ var receivedAssertion string
143+
144+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
145+ require .NoError (t , r .ParseForm ())
146+ receivedAssertion = r .FormValue ("client_assertion" )
147+
148+ w .Header ().Set ("Content-Type" , "application/json" )
149+ json .NewEncoder (w ).Encode (map [string ]interface {}{
150+ "access_token" : "test-token" ,
151+ "expires_in" : 3600 ,
152+ })
153+ }))
154+ defer server .Close ()
155+
156+ ts := & OAuthTokenSourcePrivateKeyJWT {
157+ TokenEndpoint : oauth2.Endpoint {TokenURL : server .URL },
158+ ClientID : "test-client" ,
159+ ClientAssertionCert : string (certPEM ),
160+ ClientAssertionKey : string (keyPEM ),
161+ }
162+
163+ _ , err = ts .Token ()
164+ require .NoError (t , err )
165+ require .NotEmpty (t , receivedAssertion )
166+
167+ parts := strings .Split (receivedAssertion , "." )
168+ require .Len (t , parts , 3 , "JWT should have 3 parts" )
169+
170+ decodedPayload , err := base64 .RawURLEncoding .DecodeString (parts [1 ])
171+ require .NoError (t , err )
172+
173+ var rawClaims map [string ]interface {}
174+ err = json .Unmarshal (decodedPayload , & rawClaims )
175+ require .NoError (t , err )
176+
177+ audValue := rawClaims ["aud" ]
178+ require .IsType (t , "" , audValue )
179+ })
125180}
0 commit comments