Skip to content

Commit 1549187

Browse files
committed
fix: use client authn on device auth request
According to https://datatracker.ietf.org/doc/html/rfc8628#section-3.1, the device auth request must include client authentication. Fixes #685
1 parent 22134a4 commit 1549187

File tree

5 files changed

+245
-49
lines changed

5 files changed

+245
-49
lines changed

deviceauth.go

+25-35
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@ import (
44
"context"
55
"encoding/json"
66
"errors"
7-
"fmt"
8-
"io"
9-
"net/http"
107
"net/url"
118
"strings"
129
"time"
@@ -93,47 +90,40 @@ func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*Devic
9390
return retrieveDeviceAuth(ctx, c, v)
9491
}
9592

96-
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
97-
if c.Endpoint.DeviceAuthURL == "" {
98-
return nil, errors.New("endpoint missing DeviceAuthURL")
93+
// deviceAuthFromInternal maps an *internal.DeviceAuthResponse struct into
94+
// a *DeviceAuthResponse struct.
95+
func deviceAuthFromInternal(da *internal.DeviceAuthResponse) *DeviceAuthResponse {
96+
if da == nil {
97+
return nil
9998
}
100-
101-
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
102-
if err != nil {
103-
return nil, err
99+
return &DeviceAuthResponse{
100+
DeviceCode: da.DeviceCode,
101+
UserCode: da.UserCode,
102+
VerificationURI: da.VerificationURI,
103+
VerificationURIComplete: da.VerificationURIComplete,
104+
Expiry: time.Now().UTC().Add(time.Second * time.Duration(da.Expiry)),
105+
Interval: da.Interval,
104106
}
105-
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
106-
req.Header.Set("Accept", "application/json")
107+
}
107108

108-
t := time.Now()
109-
r, err := internal.ContextClient(ctx).Do(req)
110-
if err != nil {
111-
return nil, err
109+
// retrieveDeviceAuth takes a *Config and uses that to retrieve an *internal.DeviceAuthResponse.
110+
// This response is then mapped from *internal.DeviceAuthResponse into an *oauth2.DeviceAuthResponse which is returned along
111+
// with an error.
112+
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
113+
if c.Endpoint.DeviceAuthURL == "" {
114+
return nil, errors.New("endpoint missing DeviceAuthURL")
112115
}
113116

114-
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
117+
da, err := internal.RetrieveDeviceAuth(ctx, c.ClientID, c.ClientSecret, c.Endpoint.DeviceAuthURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
115118
if err != nil {
116-
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
117-
}
118-
if code := r.StatusCode; code < 200 || code > 299 {
119-
return nil, &RetrieveError{
120-
Response: r,
121-
Body: body,
119+
if rErr, ok := err.(*internal.RetrieveError); ok {
120+
return nil, (*RetrieveError)(rErr)
122121
}
122+
return nil, err
123123
}
124+
dar := deviceAuthFromInternal(da)
124125

125-
da := &DeviceAuthResponse{}
126-
err = json.Unmarshal(body, &da)
127-
if err != nil {
128-
return nil, fmt.Errorf("unmarshal %s", err)
129-
}
130-
131-
if !da.Expiry.IsZero() {
132-
// Make a small adjustment to account for time taken by the request
133-
da.Expiry = da.Expiry.Add(-time.Since(t))
134-
}
135-
136-
return da, nil
126+
return dar, err
137127
}
138128

139129
// DeviceAccessToken polls the server to exchange a device code for a token.

internal/deviceauth.go

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright 2014 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package internal
6+
7+
import (
8+
"context"
9+
"encoding/json"
10+
"fmt"
11+
"io"
12+
"net/url"
13+
"time"
14+
)
15+
16+
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
17+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
18+
//
19+
// This type is a mirror of oauth2.DeviceAuthResponse, with the only difference
20+
// being that in this struct `expires_in` isn't mapped to a timestamp. It solely
21+
// exists to break an otherwise-circular dependency. Other internal packages should
22+
// convert this DeviceAuthResponse into an oauth2.DeviceAuthResponse before use.
23+
type DeviceAuthResponse struct {
24+
// DeviceCode
25+
DeviceCode string `json:"device_code"`
26+
// UserCode is the code the user should enter at the verification uri
27+
UserCode string `json:"user_code"`
28+
// VerificationURI is where user should enter the user code
29+
VerificationURI string `json:"verification_uri"`
30+
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
31+
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
32+
// Expiry is when the device code and user code expire
33+
Expiry int64 `json:"expires_in,omitempty"`
34+
// Interval is the duration in seconds that Poll should wait between requests
35+
Interval int64 `json:"interval,omitempty"`
36+
}
37+
38+
func RetrieveDeviceAuth(ctx context.Context, clientID, clientSecret, deviceAuthURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*DeviceAuthResponse, error) {
39+
needsAuthStyleProbe := authStyle == AuthStyleUnknown
40+
if needsAuthStyleProbe {
41+
if style, ok := styleCache.lookupAuthStyle(deviceAuthURL); ok {
42+
authStyle = style
43+
needsAuthStyleProbe = false
44+
} else {
45+
authStyle = AuthStyleInHeader // the first way we'll try
46+
}
47+
}
48+
49+
req, err := NewRequestWithClientAuthn("POST", deviceAuthURL, clientID, clientSecret, v, authStyle)
50+
if err != nil {
51+
return nil, err
52+
}
53+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
54+
req.Header.Set("Accept", "application/json")
55+
56+
t := time.Now()
57+
r, err := ContextClient(ctx).Do(req)
58+
59+
if err != nil && needsAuthStyleProbe {
60+
// If we get an error, assume the server wants the
61+
// clientID & clientSecret in a different form.
62+
authStyle = AuthStyleInParams // the second way we'll try
63+
req, _ := NewRequestWithClientAuthn("POST", deviceAuthURL, clientID, clientSecret, v, authStyle)
64+
r, err = ContextClient(ctx).Do(req)
65+
}
66+
if needsAuthStyleProbe && err == nil {
67+
styleCache.setAuthStyle(deviceAuthURL, authStyle)
68+
}
69+
70+
if err != nil {
71+
return nil, err
72+
}
73+
defer r.Body.Close()
74+
75+
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
76+
if err != nil {
77+
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
78+
}
79+
if code := r.StatusCode; code < 200 || code > 299 {
80+
return nil, &RetrieveError{
81+
Response: r,
82+
Body: body,
83+
}
84+
}
85+
86+
da := &DeviceAuthResponse{}
87+
err = json.Unmarshal(body, &da)
88+
if err != nil {
89+
return nil, fmt.Errorf("unmarshal %s", err)
90+
}
91+
92+
if da.Expiry != 0 {
93+
// Make a small adjustment to account for time taken by the request
94+
da.Expiry = da.Expiry + int64(t.Nanosecond())
95+
}
96+
return da, nil
97+
}

internal/deviceauth_test.go

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright 2014 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package internal
6+
7+
import (
8+
"context"
9+
"io"
10+
"net/http"
11+
"net/http/httptest"
12+
"net/url"
13+
"testing"
14+
)
15+
16+
func TestDeviceAuth_ClientAuthnInParams(t *testing.T) {
17+
styleCache := new(AuthStyleCache)
18+
const clientID = "client-id"
19+
const clientSecret = "client-secret"
20+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21+
if got, want := r.FormValue("client_id"), clientID; got != want {
22+
t.Errorf("client_id = %q; want %q", got, want)
23+
}
24+
if got, want := r.FormValue("client_secret"), clientSecret; got != want {
25+
t.Errorf("client_secret = %q; want %q", got, want)
26+
}
27+
io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`)
28+
}))
29+
defer ts.Close()
30+
_, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleInParams, styleCache)
31+
if err != nil {
32+
t.Errorf("RetrieveDeviceAuth = %v; want no error", err)
33+
}
34+
}
35+
36+
func TestDeviceAuth_ClientAuthnInHeader(t *testing.T) {
37+
styleCache := new(AuthStyleCache)
38+
const clientID = "client-id"
39+
const clientSecret = "client-secret"
40+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
41+
u, p, ok := r.BasicAuth()
42+
if !ok {
43+
io.WriteString(w, `{"error":"invalid_client"}`)
44+
w.WriteHeader(http.StatusBadRequest)
45+
}
46+
if got, want := u, clientID; got != want {
47+
io.WriteString(w, `{"error":"invalid_client"}`)
48+
w.WriteHeader(http.StatusBadRequest)
49+
}
50+
if got, want := p, clientSecret; got != want {
51+
io.WriteString(w, `{"error":"invalid_client"}`)
52+
w.WriteHeader(http.StatusBadRequest)
53+
}
54+
io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`)
55+
}))
56+
defer ts.Close()
57+
_, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleInHeader, styleCache)
58+
if err != nil {
59+
t.Errorf("RetrieveDeviceAuth = %v; want no error", err)
60+
}
61+
}
62+
63+
func TestDeviceAuth_ClientAuthnProbe(t *testing.T) {
64+
styleCache := new(AuthStyleCache)
65+
const clientID = "client-id"
66+
const clientSecret = "client-secret"
67+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
68+
u, p, ok := r.BasicAuth()
69+
if !ok {
70+
io.WriteString(w, `{"error":"invalid_client"}`)
71+
w.WriteHeader(http.StatusBadRequest)
72+
}
73+
if got, want := u, clientID; got != want {
74+
io.WriteString(w, `{"error":"invalid_client"}`)
75+
w.WriteHeader(http.StatusBadRequest)
76+
}
77+
if got, want := p, clientSecret; got != want {
78+
io.WriteString(w, `{"error":"invalid_client"}`)
79+
w.WriteHeader(http.StatusBadRequest)
80+
}
81+
io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`)
82+
}))
83+
defer ts.Close()
84+
_, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
85+
if err != nil {
86+
t.Errorf("RetrieveDeviceAuth = %v; want no error", err)
87+
}
88+
}

internal/oauth2.go

+34
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ import (
1010
"encoding/pem"
1111
"errors"
1212
"fmt"
13+
"net/http"
14+
"net/url"
15+
"strings"
1316
)
1417

1518
// ParseKey converts the binary contents of a private key file
@@ -35,3 +38,34 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) {
3538
}
3639
return parsed, nil
3740
}
41+
42+
// addClientAuthnRequestParams adds client_secret_post client authentication
43+
func addClientAuthnRequestParams(clientID, clientSecret string, v url.Values, authStyle AuthStyle) url.Values {
44+
if authStyle == AuthStyleInParams {
45+
v = cloneURLValues(v)
46+
if clientID != "" {
47+
v.Set("client_id", clientID)
48+
}
49+
if clientSecret != "" {
50+
v.Set("client_secret", clientSecret)
51+
}
52+
}
53+
return v
54+
}
55+
56+
// addClientAuthnRequestHeaders adds client_secret_basic client authentication
57+
func addClientAuthnRequestHeaders(clientID, clientSecret string, req *http.Request, authStyle AuthStyle) {
58+
if authStyle == AuthStyleInHeader {
59+
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
60+
}
61+
}
62+
63+
func NewRequestWithClientAuthn(httpMethod string, endpointURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
64+
v = addClientAuthnRequestParams(clientID, clientSecret, v, authStyle)
65+
req, err := http.NewRequest(httpMethod, endpointURL, strings.NewReader(v.Encode()))
66+
if err != nil {
67+
return nil, err
68+
}
69+
addClientAuthnRequestHeaders(clientID, clientSecret, req, authStyle)
70+
return req, nil
71+
}

internal/token.go

+1-14
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
"net/http"
1717
"net/url"
1818
"strconv"
19-
"strings"
2019
"sync"
2120
"sync/atomic"
2221
"time"
@@ -181,23 +180,11 @@ func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
181180
// the POST body (along with any values in v); false means to send it
182181
// in the Authorization header.
183182
func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
184-
if authStyle == AuthStyleInParams {
185-
v = cloneURLValues(v)
186-
if clientID != "" {
187-
v.Set("client_id", clientID)
188-
}
189-
if clientSecret != "" {
190-
v.Set("client_secret", clientSecret)
191-
}
192-
}
193-
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
183+
req, err := NewRequestWithClientAuthn("POST", tokenURL, clientID, clientSecret, v, authStyle)
194184
if err != nil {
195185
return nil, err
196186
}
197187
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
198-
if authStyle == AuthStyleInHeader {
199-
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
200-
}
201188
return req, nil
202189
}
203190

0 commit comments

Comments
 (0)