Skip to content

Commit 223c75e

Browse files
Use external library for SCRAM authentication
Removes custom SCRAM implementation replacing it with a wrapper for the existing xdg-go/scram library. Changes the saslNewScram interface to take a new type *scram.Method argument replacing the func () hash.Hash type. Adds a scram.NewMethod function that validates and returns a supported method.
1 parent cee0d26 commit 223c75e

File tree

4 files changed

+77
-242
lines changed

4 files changed

+77
-242
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ install:
3131
- go get gopkg.in/yaml.v2
3232
- go get gopkg.in/tomb.v2
3333
- go get github.com/golang/lint
34+
- go get github.com/xdg-go/scram
3435

3536
before_script:
3637
- golint ./... | grep -v 'ID' | cat

auth.go

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,9 @@ package mgo
2828

2929
import (
3030
"crypto/md5"
31-
"crypto/sha1"
32-
"crypto/sha256"
3331
"encoding/hex"
3432
"errors"
3533
"fmt"
36-
"hash"
3734
"sync"
3835

3936
"github.com/globalsign/mgo/bson"
@@ -276,11 +273,11 @@ func (socket *mongoSocket) loginPlain(cred Credential) error {
276273
func (socket *mongoSocket) loginSASL(cred Credential) error {
277274
var sasl saslStepper
278275
var err error
279-
if cred.Mechanism == "SCRAM-SHA-1" {
280-
// SCRAM is handled without external libraries.
281-
sasl = saslNewScram(sha1.New, cred)
282-
} else if cred.Mechanism == "SCRAM-SHA-256" {
283-
sasl = saslNewScram(sha256.New, cred)
276+
if cred.Mechanism == "SCRAM-SHA-1" || cred.Mechanism == "SCRAM-SHA-256" {
277+
// SCRAM is handled with github.com/xdg-go/scram.
278+
var method *scram.Method
279+
method, err = scram.NewMethod(cred.Mechanism)
280+
sasl = saslNewScram(method, cred)
284281
} else if len(cred.ServiceHost) > 0 {
285282
sasl, err = saslNew(cred, cred.ServiceHost)
286283
} else {
@@ -357,10 +354,10 @@ func (socket *mongoSocket) loginSASL(cred Credential) error {
357354
return nil
358355
}
359356

360-
func saslNewScram(hash func() hash.Hash, cred Credential) *saslScram {
357+
func saslNewScram(method *scram.Method, cred Credential) *saslScram {
361358
credsum := md5.New()
362359
credsum.Write([]byte(cred.Username + ":mongo:" + cred.Password))
363-
client := scram.NewClient(hash, cred.Username, hex.EncodeToString(credsum.Sum(nil)))
360+
client := scram.NewClient(method, cred.Username, hex.EncodeToString(credsum.Sum(nil)))
364361
return &saslScram{cred: cred, client: client}
365362
}
366363

internal/scram/scram.go

Lines changed: 59 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,23 @@ package scram
3232

3333
import (
3434
"bytes"
35-
"crypto/hmac"
36-
"crypto/rand"
37-
"encoding/base64"
38-
"fmt"
39-
"hash"
40-
"strconv"
41-
"strings"
35+
"errors"
36+
37+
xdg "github.com/xdg-go/scram"
4238
)
4339

44-
// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc).
40+
// Client adapts a SCRAM client (SCRAM-SHA-1, SCRAM-SHA-256).
4541
//
4642
// A Client may be used within a SASL conversation with logic resembling:
4743
//
44+
// mechanism, err := scram.NewMethod("SCRAM-SHA-256")
45+
//
46+
// if err != nil {
47+
// log.Fatal(err)
48+
// }
49+
//
4850
// var in []byte
49-
// var client = scram.NewClient(sha1.New, user, pass)
51+
// var client = scram.NewClient(, user, pass)
5052
// for client.Step(in) {
5153
// out := client.Out()
5254
// // send out to server
@@ -57,34 +59,62 @@ import (
5759
// }
5860
//
5961
type Client struct {
60-
newHash func() hash.Hash
61-
62-
user string
63-
pass string
64-
step int
6562
out bytes.Buffer
6663
err error
64+
conv *xdg.ClientConversation
65+
}
66+
67+
// Method defines the variant of SCRAM to use
68+
type Method struct {
69+
method string
70+
}
71+
72+
const (
73+
// ScramSha1 use the SCRAM-SHA-1 variant
74+
ScramSha1 = "SCRAM-SHA-1"
75+
76+
// ScramSha256 use the SCRAM-SHA-256 variant
77+
ScramSha256 = "SCRAM-SHA-256"
78+
)
6779

68-
clientNonce []byte
69-
serverNonce []byte
70-
saltedPass []byte
71-
authMsg bytes.Buffer
80+
// NewMethod returns a Method if the input method string is supported
81+
// otherwise it returns an error.
82+
// Supported method strings:
83+
// - "SCRAM-SHA-1"
84+
// - "SCRAM-SHA-256"
85+
func NewMethod(methodString string) (*Method, error) {
86+
switch methodString {
87+
case ScramSha1, ScramSha256:
88+
return &Method{method: methodString}, nil
89+
default:
90+
return nil, errors.New("invalid SCRAM mechanism")
91+
}
7292
}
7393

74-
// NewClient returns a new SCRAM-* client with the provided hash algorithm.
94+
// NewClient returns a new SCRAM client with the provided hash algorithm.
7595
//
7696
// For SCRAM-SHA-1, for example, use:
7797
//
78-
// client := scram.NewClient(sha1.New, user, pass)
98+
// method, _ := scram.NewMethod("SCRAM-SHA-1")
99+
//
100+
// client := scram.NewClient(method, user, pass)
79101
//
80-
func NewClient(newHash func() hash.Hash, user, pass string) *Client {
102+
func NewClient(method *Method, user, pass string) *Client {
103+
var client *xdg.Client
104+
var err error
105+
106+
switch method.method {
107+
case ScramSha1:
108+
client, err = xdg.SHA1.NewClient(user, pass, "")
109+
case ScramSha256:
110+
client, err = xdg.SHA256.NewClient(user, pass, "")
111+
}
112+
81113
c := &Client{
82-
newHash: newHash,
83-
user: user,
84-
pass: pass,
114+
conv: client.NewConversation(),
115+
err: err,
85116
}
86117
c.out.Grow(256)
87-
c.authMsg.Grow(256)
88118
return c
89119
}
90120

@@ -101,166 +131,14 @@ func (c *Client) Err() error {
101131
return c.err
102132
}
103133

104-
// SetNonce sets the client nonce to the provided value.
105-
// If not set, the nonce is generated automatically out of crypto/rand on the first step.
106-
func (c *Client) SetNonce(nonce []byte) {
107-
c.clientNonce = nonce
108-
}
109-
110-
var escaper = strings.NewReplacer("=", "=3D", ",", "=2C")
111-
112134
// Step processes the incoming data from the server and makes the
113135
// next round of data for the server available via Client.Out.
114136
// Step returns false if there are no errors and more data is
115137
// still expected.
116138
func (c *Client) Step(in []byte) bool {
139+
var resp string
117140
c.out.Reset()
118-
if c.step > 2 || c.err != nil {
119-
return false
120-
}
121-
c.step++
122-
switch c.step {
123-
case 1:
124-
c.err = c.step1(in)
125-
case 2:
126-
c.err = c.step2(in)
127-
case 3:
128-
c.err = c.step3(in)
129-
}
130-
return c.step > 2 || c.err != nil
131-
}
132-
133-
func (c *Client) step1(in []byte) error {
134-
if len(c.clientNonce) == 0 {
135-
const nonceLen = 6
136-
buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen))
137-
if _, err := rand.Read(buf[:nonceLen]); err != nil {
138-
return fmt.Errorf("cannot read random SCRAM-SHA-1 nonce from operating system: %v", err)
139-
}
140-
c.clientNonce = buf[nonceLen:]
141-
b64.Encode(c.clientNonce, buf[:nonceLen])
142-
}
143-
c.authMsg.WriteString("n=")
144-
escaper.WriteString(&c.authMsg, c.user)
145-
c.authMsg.WriteString(",r=")
146-
c.authMsg.Write(c.clientNonce)
147-
148-
c.out.WriteString("n,,")
149-
c.out.Write(c.authMsg.Bytes())
150-
return nil
151-
}
152-
153-
var b64 = base64.StdEncoding
154-
155-
func (c *Client) step2(in []byte) error {
156-
c.authMsg.WriteByte(',')
157-
c.authMsg.Write(in)
158-
159-
fields := bytes.Split(in, []byte(","))
160-
if len(fields) != 3 {
161-
return fmt.Errorf("expected 3 fields in first SCRAM-SHA-1 server message, got %d: %q", len(fields), in)
162-
}
163-
if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 {
164-
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 nonce: %q", fields[0])
165-
}
166-
if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 {
167-
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 salt: %q", fields[1])
168-
}
169-
if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 {
170-
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2])
171-
}
172-
173-
c.serverNonce = fields[0][2:]
174-
if !bytes.HasPrefix(c.serverNonce, c.clientNonce) {
175-
return fmt.Errorf("server SCRAM-SHA-1 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce)
176-
}
177-
178-
salt := make([]byte, b64.DecodedLen(len(fields[1][2:])))
179-
n, err := b64.Decode(salt, fields[1][2:])
180-
if err != nil {
181-
return fmt.Errorf("cannot decode SCRAM-SHA-1 salt sent by server: %q", fields[1])
182-
}
183-
salt = salt[:n]
184-
iterCount, err := strconv.Atoi(string(fields[2][2:]))
185-
if err != nil {
186-
return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2])
187-
}
188-
c.saltPassword(salt, iterCount)
189-
190-
c.authMsg.WriteString(",c=biws,r=")
191-
c.authMsg.Write(c.serverNonce)
192-
193-
c.out.WriteString("c=biws,r=")
194-
c.out.Write(c.serverNonce)
195-
c.out.WriteString(",p=")
196-
c.out.Write(c.clientProof())
197-
return nil
198-
}
199-
200-
func (c *Client) step3(in []byte) error {
201-
var isv, ise bool
202-
var fields = bytes.Split(in, []byte(","))
203-
if len(fields) == 1 {
204-
isv = bytes.HasPrefix(fields[0], []byte("v="))
205-
ise = bytes.HasPrefix(fields[0], []byte("e="))
206-
}
207-
if ise {
208-
return fmt.Errorf("SCRAM-SHA-1 authentication error: %s", fields[0][2:])
209-
} else if !isv {
210-
return fmt.Errorf("unsupported SCRAM-SHA-1 final message from server: %q", in)
211-
}
212-
if !bytes.Equal(c.serverSignature(), fields[0][2:]) {
213-
return fmt.Errorf("cannot authenticate SCRAM-SHA-1 server signature: %q", fields[0][2:])
214-
}
215-
return nil
216-
}
217-
218-
func (c *Client) saltPassword(salt []byte, iterCount int) {
219-
mac := hmac.New(c.newHash, []byte(c.pass))
220-
mac.Write(salt)
221-
mac.Write([]byte{0, 0, 0, 1})
222-
ui := mac.Sum(nil)
223-
hi := make([]byte, len(ui))
224-
copy(hi, ui)
225-
for i := 1; i < iterCount; i++ {
226-
mac.Reset()
227-
mac.Write(ui)
228-
mac.Sum(ui[:0])
229-
for j, b := range ui {
230-
hi[j] ^= b
231-
}
232-
}
233-
c.saltedPass = hi
234-
}
235-
236-
func (c *Client) clientProof() []byte {
237-
mac := hmac.New(c.newHash, c.saltedPass)
238-
mac.Write([]byte("Client Key"))
239-
clientKey := mac.Sum(nil)
240-
hash := c.newHash()
241-
hash.Write(clientKey)
242-
storedKey := hash.Sum(nil)
243-
mac = hmac.New(c.newHash, storedKey)
244-
mac.Write(c.authMsg.Bytes())
245-
clientProof := mac.Sum(nil)
246-
for i, b := range clientKey {
247-
clientProof[i] ^= b
248-
}
249-
clientProof64 := make([]byte, b64.EncodedLen(len(clientProof)))
250-
b64.Encode(clientProof64, clientProof)
251-
return clientProof64
252-
}
253-
254-
func (c *Client) serverSignature() []byte {
255-
mac := hmac.New(c.newHash, c.saltedPass)
256-
mac.Write([]byte("Server Key"))
257-
serverKey := mac.Sum(nil)
258-
259-
mac = hmac.New(c.newHash, serverKey)
260-
mac.Write(c.authMsg.Bytes())
261-
serverSignature := mac.Sum(nil)
262-
263-
encoded := make([]byte, b64.EncodedLen(len(serverSignature)))
264-
b64.Encode(encoded, serverSignature)
265-
return encoded
141+
resp, c.err = c.conv.Step(string(in))
142+
_, c.err = c.out.Write([]byte(resp))
143+
return c.conv.Done() || c.err != nil
266144
}

0 commit comments

Comments
 (0)