192 lines
4.5 KiB
Go
192 lines
4.5 KiB
Go
|
package sasl
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
// The OAUTHBEARER mechanism name.
|
||
|
const OAuthBearer = "OAUTHBEARER"
|
||
|
|
||
|
type OAuthBearerError struct {
|
||
|
Status string `json:"status"`
|
||
|
Schemes string `json:"schemes"`
|
||
|
Scope string `json:"scope"`
|
||
|
}
|
||
|
|
||
|
type OAuthBearerOptions struct {
|
||
|
Username string
|
||
|
Token string
|
||
|
Host string
|
||
|
Port int
|
||
|
}
|
||
|
|
||
|
// Implements error
|
||
|
func (err *OAuthBearerError) Error() string {
|
||
|
return fmt.Sprintf("OAUTHBEARER authentication error (%v)", err.Status)
|
||
|
}
|
||
|
|
||
|
type oauthBearerClient struct {
|
||
|
OAuthBearerOptions
|
||
|
}
|
||
|
|
||
|
func (a *oauthBearerClient) Start() (mech string, ir []byte, err error) {
|
||
|
mech = OAuthBearer
|
||
|
var str = "n,a=" + a.Username + ","
|
||
|
|
||
|
if a.Host != "" {
|
||
|
str += "\x01host=" + a.Host
|
||
|
}
|
||
|
|
||
|
if a.Port != 0 {
|
||
|
str += "\x01port=" + strconv.Itoa(a.Port)
|
||
|
}
|
||
|
str += "\x01auth=Bearer " + a.Token + "\x01\x01"
|
||
|
ir = []byte(str)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (a *oauthBearerClient) Next(challenge []byte) ([]byte, error) {
|
||
|
authBearerErr := &OAuthBearerError{}
|
||
|
if err := json.Unmarshal(challenge, authBearerErr); err != nil {
|
||
|
return nil, err
|
||
|
} else {
|
||
|
return nil, authBearerErr
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// An implementation of the OAUTHBEARER authentication mechanism, as
|
||
|
// described in RFC 7628.
|
||
|
func NewOAuthBearerClient(opt *OAuthBearerOptions) Client {
|
||
|
return &oauthBearerClient{*opt}
|
||
|
}
|
||
|
|
||
|
type OAuthBearerAuthenticator func(opts OAuthBearerOptions) *OAuthBearerError
|
||
|
|
||
|
type oauthBearerServer struct {
|
||
|
done bool
|
||
|
failErr error
|
||
|
authenticate OAuthBearerAuthenticator
|
||
|
}
|
||
|
|
||
|
func (a *oauthBearerServer) fail(descr string) ([]byte, bool, error) {
|
||
|
blob, err := json.Marshal(OAuthBearerError{
|
||
|
Status: "invalid_request",
|
||
|
Schemes: "bearer",
|
||
|
})
|
||
|
if err != nil {
|
||
|
panic(err) // wtf
|
||
|
}
|
||
|
a.failErr = errors.New(descr)
|
||
|
return blob, false, nil
|
||
|
}
|
||
|
|
||
|
func (a *oauthBearerServer) Next(response []byte) (challenge []byte, done bool, err error) {
|
||
|
// Per RFC, we cannot just send an error, we need to return JSON-structured
|
||
|
// value as a challenge and then after getting dummy response from the
|
||
|
// client stop the exchange.
|
||
|
if a.failErr != nil {
|
||
|
// Server libraries (go-smtp, go-imap) will not call Next on
|
||
|
// protocol-specific SASL cancel response ('*'). However, GS2 (and
|
||
|
// indirectly OAUTHBEARER) defines a protocol-independent way to do so
|
||
|
// using 0x01.
|
||
|
if len(response) != 1 && response[0] != 0x01 {
|
||
|
return nil, true, errors.New("unexpected response")
|
||
|
}
|
||
|
return nil, true, a.failErr
|
||
|
}
|
||
|
|
||
|
if a.done {
|
||
|
err = ErrUnexpectedClientResponse
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Generate empty challenge.
|
||
|
if response == nil {
|
||
|
return []byte{}, false, nil
|
||
|
}
|
||
|
|
||
|
a.done = true
|
||
|
|
||
|
// Cut n,a=username,\x01host=...\x01auth=...
|
||
|
// into
|
||
|
// n
|
||
|
// a=username
|
||
|
// \x01host=...\x01auth=...\x01\x01
|
||
|
parts := bytes.SplitN(response, []byte{','}, 3)
|
||
|
if len(parts) != 3 {
|
||
|
return a.fail("Invalid response")
|
||
|
}
|
||
|
if !bytes.Equal(parts[0], []byte{'n'}) {
|
||
|
return a.fail("Invalid response, missing 'n'")
|
||
|
}
|
||
|
opts := OAuthBearerOptions{}
|
||
|
if !bytes.HasPrefix(parts[1], []byte("a=")) {
|
||
|
return a.fail("Invalid response, missing 'a'")
|
||
|
}
|
||
|
opts.Username = string(bytes.TrimPrefix(parts[1], []byte("a=")))
|
||
|
|
||
|
// Cut \x01host=...\x01auth=...\x01\x01
|
||
|
// into
|
||
|
// *empty*
|
||
|
// host=...
|
||
|
// auth=...
|
||
|
// *empty*
|
||
|
//
|
||
|
// Note that this code does not do a lot of checks to make sure the input
|
||
|
// follows the exact format specified by RFC.
|
||
|
params := bytes.Split(parts[2], []byte{0x01})
|
||
|
for _, p := range params {
|
||
|
// Skip empty fields (one at start and end).
|
||
|
if len(p) == 0 {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
pParts := bytes.SplitN(p, []byte{'='}, 2)
|
||
|
if len(pParts) != 2 {
|
||
|
return a.fail("Invalid response, missing '='")
|
||
|
}
|
||
|
|
||
|
switch string(pParts[0]) {
|
||
|
case "host":
|
||
|
opts.Host = string(pParts[1])
|
||
|
case "port":
|
||
|
port, err := strconv.ParseUint(string(pParts[1]), 10, 16)
|
||
|
if err != nil {
|
||
|
return a.fail("Invalid response, malformed 'port' value")
|
||
|
}
|
||
|
opts.Port = int(port)
|
||
|
case "auth":
|
||
|
const prefix = "bearer "
|
||
|
strValue := string(pParts[1])
|
||
|
// Token type is case-insensitive.
|
||
|
if !strings.HasPrefix(strings.ToLower(strValue), prefix) {
|
||
|
return a.fail("Unsupported token type")
|
||
|
}
|
||
|
opts.Token = strValue[len(prefix):]
|
||
|
default:
|
||
|
return a.fail("Invalid response, unknown parameter: " + string(pParts[0]))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
authzErr := a.authenticate(opts)
|
||
|
if authzErr != nil {
|
||
|
blob, err := json.Marshal(authzErr)
|
||
|
if err != nil {
|
||
|
panic(err) // wtf
|
||
|
}
|
||
|
a.failErr = authzErr
|
||
|
return blob, false, nil
|
||
|
}
|
||
|
|
||
|
return nil, true, nil
|
||
|
}
|
||
|
|
||
|
func NewOAuthBearerServer(auth OAuthBearerAuthenticator) Server {
|
||
|
return &oauthBearerServer{authenticate: auth}
|
||
|
}
|