blob: 53f6752a3c3c8a8095fb052d95a8e702c414d621 [file] [log] [blame]
// Copyright (C) 2016 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package auth provides simple token based, stream authorization functions.
package auth
import (
"bytes"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"android.googlesource.com/platform/tools/gpu/framework/binary/endian"
"android.googlesource.com/platform/tools/gpu/framework/device"
)
var (
header = []byte{'A', 'U', 'T', 'H'}
// ErrInvalidHeader is returned by Check when the auth-token header was
// invalid.
ErrInvalidHeader = fmt.Errorf("Invalid auth-token header")
// ErrInvalidToken is returned by Check when the auth-token was not as
// expected.
ErrInvalidToken = fmt.Errorf("Invalid auth-token code")
// NoAuth is the token used for authenticationless connections.
NoAuth = Token("")
)
// Token is a secret password that must be sent on connection.
type Token string
// Write writes the authorization token to s.
func Write(s io.Writer, token Token) error {
if token == NoAuth {
return nil // Non-authenticated connection
}
w := endian.Writer(s, device.LittleEndian)
w.Data(header)
w.String(string(token))
return w.Error()
}
// Check reads the authorization token from s, returning an
// error and closing the reader if the token doesn't match.
func Check(s io.ReadCloser, token Token) (err error) {
defer func() {
if err != nil {
s.Close()
}
}()
if token == NoAuth {
return nil // Non-authenticated connection
}
r := endian.Reader(s, device.LittleEndian)
gotHeader := make([]byte, 4)
r.Data(gotHeader)
if err := r.Error(); err != nil {
return err
}
if bytes.Compare(gotHeader, header) != 0 {
return ErrInvalidHeader
}
gotToken := Token(r.String())
if err := r.Error(); err != nil {
return err
}
if gotToken != token {
return ErrInvalidToken
}
return nil
}
// GenToken returns a 8 character random token.
func GenToken() Token {
tok := [6]byte{}
_, err := rand.Read(tok[:])
if err != nil {
panic(fmt.Errorf("rand.Read returned error: %v", err))
}
return Token(base64.StdEncoding.EncodeToString(tok[:]))
}