diff --git a/auth/auth.go b/auth/auth.go index 36ff259e..40fa259f 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -8,6 +8,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "slices" "strings" @@ -73,8 +74,17 @@ func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) tokenInfo, errmsg, code := verify(r, verifier, opts) if code != 0 { if code == http.StatusUnauthorized || code == http.StatusForbidden { - if opts != nil && opts.ResourceMetadataURL != "" { - w.Header().Add("WWW-Authenticate", "Bearer resource_metadata="+opts.ResourceMetadataURL) + if opts != nil { + var params []string + if opts.ResourceMetadataURL != "" { + params = append(params, fmt.Sprintf("resource_metadata=%q", opts.ResourceMetadataURL)) + } + if len(opts.Scopes) > 0 { + params = append(params, fmt.Sprintf("scope=%q", strings.Join(opts.Scopes, " "))) + } + if len(params) > 0 { + w.Header().Add("WWW-Authenticate", "Bearer "+strings.Join(params, ", ")) + } } } http.Error(w, errmsg, code) diff --git a/auth/auth_test.go b/auth/auth_test.go index a943404c..4028c907 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -188,3 +188,99 @@ func TestProtectedResourceMetadataHandler(t *testing.T) { }) } } + +func TestRequireBearerToken(t *testing.T) { + verifier := func(_ context.Context, token string, _ *http.Request) (*TokenInfo, error) { + if token == "valid" { + return &TokenInfo{Expiration: time.Now().Add(time.Hour), Scopes: []string{"read"}}, nil + } + return nil, ErrInvalidToken + } + + tests := []struct { + name string + opts *RequireBearerTokenOptions + authHeader string + wantHeader string + wantStatus int + }{ + { + name: "no middleware options", + opts: nil, + authHeader: "Bearer invalid", + wantHeader: "", + wantStatus: http.StatusUnauthorized, + }, + { + name: "metadata only", + opts: &RequireBearerTokenOptions{ + ResourceMetadataURL: "https://example.com/resource-metadata", + }, + authHeader: "Bearer invalid", + wantHeader: "Bearer resource_metadata=\"https://example.com/resource-metadata\"", + wantStatus: http.StatusUnauthorized, + }, + { + name: "scopes only", + opts: &RequireBearerTokenOptions{ + Scopes: []string{"read", "write"}, + }, + authHeader: "Bearer invalid", + wantHeader: "Bearer scope=\"read write\"", + wantStatus: http.StatusUnauthorized, + }, + { + name: "metadata and scopes", + opts: &RequireBearerTokenOptions{ + ResourceMetadataURL: "https://example.com/resource-metadata", + Scopes: []string{"read", "write"}, + }, + authHeader: "Bearer invalid", + wantHeader: "Bearer resource_metadata=\"https://example.com/resource-metadata\", scope=\"read write\"", + wantStatus: http.StatusUnauthorized, + }, + { + name: "insufficient scope", + opts: &RequireBearerTokenOptions{ + Scopes: []string{"admin"}, + }, + authHeader: "Bearer valid", // Has "read", needs "admin" -> 403 + wantHeader: "Bearer scope=\"admin\"", + wantStatus: http.StatusForbidden, + }, + { + name: "success", + opts: &RequireBearerTokenOptions{ + Scopes: []string{"read"}, + }, + authHeader: "Bearer valid", + wantHeader: "", + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := RequireBearerToken(verifier, tt.opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/", nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("status = %d, want %d", rec.Code, tt.wantStatus) + } + + got := rec.Header().Get("WWW-Authenticate") + if got != tt.wantHeader { + t.Errorf("WWW-Authenticate = %q, want %q", got, tt.wantHeader) + } + }) + } +}