Skip to content

Commit 049d3f9

Browse files
Implement the saslStepper interface in SCRAM
1 parent cdbd809 commit 049d3f9

File tree

2 files changed

+25
-54
lines changed

2 files changed

+25
-54
lines changed

auth.go

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,10 @@ func (socket *mongoSocket) loginSASL(cred Credential) error {
277277
// SCRAM is handled with github.com/xdg-go/scram.
278278
var method *scram.Method
279279
method, err = scram.NewMethod(cred.Mechanism)
280-
sasl = saslNewScram(method, cred)
280+
if err != nil {
281+
return err
282+
}
283+
sasl, err = scram.NewClient(method, cred.Username, cred.Password)
281284
} else if len(cred.ServiceHost) > 0 {
282285
sasl, err = saslNew(cred, cred.ServiceHost)
283286
} else {
@@ -350,25 +353,6 @@ func (socket *mongoSocket) loginSASL(cred Credential) error {
350353
return nil
351354
}
352355

353-
func saslNewScram(method *scram.Method, cred Credential) *saslScram {
354-
credsum := md5.New()
355-
credsum.Write([]byte(cred.Username + ":mongo:" + cred.Password))
356-
client := scram.NewClient(method, cred.Username, hex.EncodeToString(credsum.Sum(nil)))
357-
return &saslScram{cred: cred, client: client}
358-
}
359-
360-
type saslScram struct {
361-
cred Credential
362-
client *scram.Client
363-
}
364-
365-
func (s *saslScram) Close() {}
366-
367-
func (s *saslScram) Step(serverData []byte) (clientData []byte, done bool, err error) {
368-
more := s.client.Step(serverData)
369-
return s.client.Out(), !more, s.client.Err()
370-
}
371-
372356
func (socket *mongoSocket) loginRun(db string, query, result interface{}, f func() error) error {
373357
var mutex sync.Mutex
374358
var replyErr error

internal/scram/scram.go

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
package scram
3232

3333
import (
34-
"bytes"
3534
"errors"
3635

3736
xdg "github.com/xdg-go/scram"
@@ -59,8 +58,6 @@ import (
5958
// }
6059
//
6160
type Client struct {
62-
out bytes.Buffer
63-
err error
6461
conv *xdg.ClientConversation
6562
}
6663

@@ -97,48 +94,38 @@ func NewMethod(methodString string) (*Method, error) {
9794
//
9895
// method, _ := scram.NewMethod("SCRAM-SHA-1")
9996
//
100-
// client := scram.NewClient(method, user, pass)
97+
// client, _ := scram.NewClient(method, user, pass)
10198
//
102-
func NewClient(method *Method, user, pass string) *Client {
103-
var client *xdg.Client
104-
var err error
99+
func NewClient(method *Method, user, pass string) (client *Client, err error) {
100+
var internalClient *xdg.Client
105101

106102
switch method.method {
107103
case ScramSha1:
108-
client, err = xdg.SHA1.NewClient(user, pass, "")
104+
internalClient, err = xdg.SHA1.NewClient(user, pass, "")
109105
case ScramSha256:
110-
client, err = xdg.SHA256.NewClient(user, pass, "")
106+
internalClient, err = xdg.SHA256.NewClient(user, pass, "")
111107
}
112108

113-
c := &Client{
114-
conv: client.NewConversation(),
115-
err: err,
109+
client = &Client{
110+
conv: internalClient.NewConversation(),
116111
}
117-
c.out.Grow(256)
118-
return c
112+
return
119113
}
120114

121-
// Out returns the data to be sent to the server in the current step.
122-
func (c *Client) Out() []byte {
123-
if c.out.Len() == 0 {
124-
return []byte{}
125-
}
126-
return c.out.Bytes()
127-
}
128-
129-
// Err returns the error that occurred, or nil if there were no errors.
130-
func (c *Client) Err() error {
131-
return c.err
115+
// Implement saslStepper (auth.go)
116+
type saslStepper interface {
117+
Step(serverData []byte) (clientData []byte, done bool, err error)
118+
Close()
132119
}
133120

134-
// Step processes the incoming data from the server and makes the
135-
// next round of data for the server available via Client.Out.
136-
// Step returns false if there are no errors and more data is
137-
// still expected.
138-
func (c *Client) Step(in []byte) bool {
121+
// Step progresses the underlying SASL SCRAM process
122+
func (c *Client) Step(serverData []byte) (clientData []byte, done bool, err error) {
139123
var resp string
140-
c.out.Reset()
141-
resp, c.err = c.conv.Step(string(in))
142-
_, c.err = c.out.Write([]byte(resp))
143-
return c.conv.Valid() || c.err != nil
124+
resp, err = c.conv.Step(string(serverData))
125+
clientData = []byte(resp)
126+
done = c.conv.Done()
127+
return
144128
}
129+
130+
// Close is a no opp to fit the saslStepper interface
131+
func (c *Client) Close() {}

0 commit comments

Comments
 (0)