Add further tests

This commit is contained in:
Christian Claus 2018-05-24 21:37:42 +02:00
parent 78a02529ab
commit 2306bbaad7
3 changed files with 154 additions and 27 deletions

View file

@ -1,17 +1,17 @@
package app
import (
"bytes"
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"strconv"
"testing"
"time"
"github.com/spf13/viper"
"path/filepath"
"os"
"time"
"strconv"
"bytes"
"io/ioutil"
"encoding/json"
)
func TestParseConfig(t *testing.T) {
@ -25,13 +25,29 @@ func TestParseConfig(t *testing.T) {
name string
want *Config
}{
{"default", cfg(t, tmpDir, `
{"default", cfg(t, tmpDir)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ParseConfig(); !reflect.DeepEqual(got, tt.want) {
gotJSON, _ := json.Marshal(got)
wantJSON, _ := json.Marshal(tt.want)
t.Errorf("ParseConfig() = %s, want %s", gotJSON, wantJSON)
}
})
}
}
func cfg(t *testing.T, tmpDir string) *Config {
viper.SetConfigType("yaml")
var yamlCfg = []byte(`
address: 1.2.3.4
port: 42
prefix: /oh-de-lally
tls:
keyFile: `+ tmpDir+ `/robin.pem
certFile: `+ tmpDir+ `/tuck.pem
keyFile: ` + tmpDir + `/robin.pem
certFile: ` + tmpDir + `/tuck.pem
dir: /sherwood/forest
realm: uk
users:
@ -43,23 +59,7 @@ users:
subdir: /sheriff
log:
error: true
`)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ParseConfig(); !reflect.DeepEqual(got, tt.want) {
gotJson, _ := json.Marshal(got)
wantJson, _ := json.Marshal(tt.want)
t.Errorf("ParseConfig() = %s, want %s", gotJson, wantJson)
}
})
}
}
func cfg(t *testing.T, tmpDir string, content string) *Config {
viper.SetConfigType("yaml")
var yamlCfg = []byte(content)
`)
err := ioutil.WriteFile(filepath.Join(tmpDir, "config.yaml"), yamlCfg, 0600)
if err != nil {

View file

@ -94,6 +94,7 @@ func writeUnauthorized(w http.ResponseWriter, realm string) {
w.Write([]byte(fmt.Sprintf("%d %s", http.StatusUnauthorized, "Unauthorized")))
}
// GenHash generates a bcrypt hashed password string
func GenHash(password []byte) string {
pw, err := bcrypt.GenerateFromPassword(password, 10)
if err != nil {

View file

@ -1,6 +1,10 @@
package app
import (
"context"
"golang.org/x/net/webdav"
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
@ -106,3 +110,125 @@ func TestAuthenticate(t *testing.T) {
})
}
}
func TestAuthFromContext(t *testing.T) {
type fakeKey int
var fakeKeyValue fakeKey
baseCtx := context.Background()
type args struct {
ctx context.Context
}
tests := []struct {
name string
args args
want *AuthInfo
}{
{
"success",
args{
ctx: context.WithValue(baseCtx, authInfoKey, &AuthInfo{"username", true}),
},
&AuthInfo{"username", true},
},
{
"failure",
args{
ctx: context.WithValue(baseCtx, fakeKeyValue, &AuthInfo{"username", true}),
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := AuthFromContext(tt.args.ctx); !reflect.DeepEqual(got, tt.want) {
t.Errorf("AuthFromContext() = %v, want %v", got, tt.want)
}
})
}
}
func TestHandle(t *testing.T) {
type args struct {
ctx context.Context
w *httptest.ResponseRecorder
r *http.Request
username []byte
password []byte
a *App
}
tests := []struct {
name string
args args
statusCode int
}{
{
"basic auth error",
args{
context.Background(),
httptest.NewRecorder(),
httptest.NewRequest("PROPFIND", "/", nil),
nil,
nil,
&App{Config: &Config{Users: map[string]*UserInfo{
"foo": {
Password: GenHash([]byte("password")),
},
}}},
},
401,
},
{
"unauthorized error",
args{
context.Background(),
httptest.NewRecorder(),
httptest.NewRequest("PROPFIND", "/", nil),
[]byte("u"),
[]byte("p"),
&App{Config: &Config{Users: map[string]*UserInfo{
"foo": {
Password: GenHash([]byte("password")),
},
}}},
},
401,
},
{
"ok",
args{
context.Background(),
httptest.NewRecorder(),
httptest.NewRequest("PROPFIND", "/", nil),
[]byte("foo"),
[]byte("password"),
&App{
Config: &Config{Users: map[string]*UserInfo{
"foo": {
Password: GenHash([]byte("password")),
},
}},
Handler: &webdav.Handler{
FileSystem: webdav.NewMemFS(),
LockSystem: webdav.NewMemLS(),
},
},
},
207,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.args.username != nil || tt.args.password != nil {
tt.args.r.SetBasicAuth(string(tt.args.username), string(tt.args.password))
}
handle(tt.args.ctx, tt.args.w, tt.args.r, tt.args.a)
resp := tt.args.w.Result()
if resp.StatusCode != tt.statusCode {
t.Errorf("TestHandle() = %v, want %v", resp.StatusCode, tt.statusCode)
}
})
}
}