Skip to content

Commit 45c198f

Browse files
authored
Merge branch 'main' into dapr-state-store-clickhouse
2 parents 51f4274 + 3816706 commit 45c198f

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

common/component/kafka/auth_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,20 @@ import (
1818
"crypto/rsa"
1919
"crypto/x509"
2020
"crypto/x509/pkix"
21+
"encoding/base64"
22+
"encoding/json"
2123
"encoding/pem"
2224
"fmt"
2325
"math/big"
26+
"net/http"
27+
"net/http/httptest"
28+
"strings"
2429
"testing"
2530
"time"
2631

2732
"github.com/IBM/sarama"
2833
"github.com/stretchr/testify/require"
34+
"golang.org/x/oauth2"
2935
)
3036

3137
func getAuthBaseMetadata() map[string]string {
@@ -122,4 +128,53 @@ func TestAuth(t *testing.T) {
122128
require.False(t, mockConfig.Net.TLS.Enable)
123129
require.Nil(t, mockConfig.Net.TLS.Config)
124130
})
131+
132+
t.Run("oidc private key jwt uses flattened audience", func(t *testing.T) {
133+
key, err := rsa.GenerateKey(rand.Reader, 2048)
134+
require.NoError(t, err)
135+
keyPEM := pem.EncodeToMemory(&pem.Block{
136+
Type: "RSA PRIVATE KEY",
137+
Bytes: x509.MarshalPKCS1PrivateKey(key),
138+
})
139+
certPEM, _, err := createTestCert()
140+
require.NoError(t, err)
141+
142+
var receivedAssertion string
143+
144+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
145+
require.NoError(t, r.ParseForm())
146+
receivedAssertion = r.FormValue("client_assertion")
147+
148+
w.Header().Set("Content-Type", "application/json")
149+
json.NewEncoder(w).Encode(map[string]interface{}{
150+
"access_token": "test-token",
151+
"expires_in": 3600,
152+
})
153+
}))
154+
defer server.Close()
155+
156+
ts := &OAuthTokenSourcePrivateKeyJWT{
157+
TokenEndpoint: oauth2.Endpoint{TokenURL: server.URL},
158+
ClientID: "test-client",
159+
ClientAssertionCert: string(certPEM),
160+
ClientAssertionKey: string(keyPEM),
161+
}
162+
163+
_, err = ts.Token()
164+
require.NoError(t, err)
165+
require.NotEmpty(t, receivedAssertion)
166+
167+
parts := strings.Split(receivedAssertion, ".")
168+
require.Len(t, parts, 3, "JWT should have 3 parts")
169+
170+
decodedPayload, err := base64.RawURLEncoding.DecodeString(parts[1])
171+
require.NoError(t, err)
172+
173+
var rawClaims map[string]interface{}
174+
err = json.Unmarshal(decodedPayload, &rawClaims)
175+
require.NoError(t, err)
176+
177+
audValue := rawClaims["aud"]
178+
require.IsType(t, "", audValue)
179+
})
125180
}

common/component/kafka/sasl_oauthbearer_private_key_jwt.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ func (ts *OAuthTokenSourcePrivateKeyJWT) Token() (*sarama.AccessToken, error) {
169169
return nil, fmt.Errorf("failed to build token: %w", err)
170170
}
171171

172+
// Some IdPs require the audience to be set as a single string
173+
token.Options().Enable(jwt.FlattenAudience)
174+
172175
var signOptions []jwt.Option
173176
if ts.Kid != "" {
174177
headers := jws.NewHeaders()

0 commit comments

Comments
 (0)