SecureValueRecovery2/host/auth/auth.go
2026-02-19 17:25:44 -05:00

102 lines
2.6 KiB
Go

// Copyright 2023 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
// Package auth provides for the ability to authenticate clients using
// basic auth credentials they get from Signal chat servers'
// ExternalServiceCredentialsGenerator.
package auth
import (
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"fmt"
"strconv"
"strings"
"time"
"github.com/signalapp/svr2/util"
)
const (
DefaultAuthenticationTokenMaxAge = 120 * 24 * time.Hour
)
// Auth allows us to check a username and password, or generate a password for a user.
type Auth interface {
// Check returns nil if this user/pass combination is legitimate.
// Otherwise, it returns an error describing the reason it's invalid.
Check(user, pass string) error
// PassFor returns a valid password for a given user at the current time.
PassFor(user string) string
}
// New returns a new production Auth based on the given secret and max token age.
func New(secret []byte, authenticationTokenMaxAge time.Duration) Auth {
return &auth{secret: secret, clock: util.RealClock, expiration: authenticationTokenMaxAge}
}
type alwaysAllow struct{}
func (a alwaysAllow) Check(user, pass string) error {
return nil
}
func (a alwaysAllow) PassFor(user string) string {
return "wheee"
}
// AlwaysAllow provides an Auth that will always allow clients to connect.
var AlwaysAllow = Auth(alwaysAllow{})
type auth struct {
secret []byte
clock util.Clock
expiration time.Duration
}
func (a *auth) Check(user, pass string) error {
ts, sig, err := a.parsePass(pass)
if err != nil {
return err
}
return a.valid(user, ts, sig)
}
func (a *auth) parsePass(pass string) (ts time.Time, sig []byte, _ error) {
i := strings.Index(pass, ":")
if i < 0 {
return time.Time{}, nil, fmt.Errorf("no separator")
}
unixSecs, err := strconv.ParseInt(pass[:i], 10, 64)
if err != nil {
return time.Time{}, nil, fmt.Errorf("parsing timestamp: %v", err)
}
ts = time.Unix(unixSecs, 0)
sig, err = hex.DecodeString(pass[i+1:])
return ts, sig, err
}
func (a *auth) valid(user string, ts time.Time, sig []byte) error {
diff := a.clock.Now().Sub(ts)
if diff > a.expiration || diff < -a.expiration {
return fmt.Errorf("expired")
}
mac := hmac.New(sha256.New, a.secret)
fmt.Fprintf(mac, "%s:%d", user, ts.Unix())
var sum [sha256.Size]byte
mac.Sum(sum[:0])
if subtle.ConstantTimeCompare(sum[:10], sig) != 1 {
return fmt.Errorf("mac failure")
}
return nil
}
func (a *auth) PassFor(user string) string {
ts := a.clock.Now()
mac := hmac.New(sha256.New, a.secret)
fmt.Fprintf(mac, "%s:%d", user, ts.Unix())
key := mac.Sum(nil)[:10]
return fmt.Sprintf("%d:%x", ts.Unix(), key)
}