Add further tests
This commit is contained in:
parent
78a02529ab
commit
2306bbaad7
3 changed files with 154 additions and 27 deletions
|
@ -1,17 +1,17 @@
|
|||
package app
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"path/filepath"
|
||||
"os"
|
||||
"time"
|
||||
"strconv"
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
|
@ -25,13 +25,29 @@ func TestParseConfig(t *testing.T) {
|
|||
name string
|
||||
want *Config
|
||||
}{
|
||||
{"default", cfg(t, tmpDir, `
|
||||
{"default", cfg(t, tmpDir)},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
if got := ParseConfig(); !reflect.DeepEqual(got, tt.want) {
|
||||
gotJSON, _ := json.Marshal(got)
|
||||
wantJSON, _ := json.Marshal(tt.want)
|
||||
t.Errorf("ParseConfig() = %s, want %s", gotJSON, wantJSON)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func cfg(t *testing.T, tmpDir string) *Config {
|
||||
viper.SetConfigType("yaml")
|
||||
var yamlCfg = []byte(`
|
||||
address: 1.2.3.4
|
||||
port: 42
|
||||
prefix: /oh-de-lally
|
||||
tls:
|
||||
keyFile: `+ tmpDir+ `/robin.pem
|
||||
certFile: `+ tmpDir+ `/tuck.pem
|
||||
keyFile: ` + tmpDir + `/robin.pem
|
||||
certFile: ` + tmpDir + `/tuck.pem
|
||||
dir: /sherwood/forest
|
||||
realm: uk
|
||||
users:
|
||||
|
@ -43,23 +59,7 @@ users:
|
|||
subdir: /sheriff
|
||||
log:
|
||||
error: true
|
||||
`)},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
if got := ParseConfig(); !reflect.DeepEqual(got, tt.want) {
|
||||
gotJson, _ := json.Marshal(got)
|
||||
wantJson, _ := json.Marshal(tt.want)
|
||||
t.Errorf("ParseConfig() = %s, want %s", gotJson, wantJson)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func cfg(t *testing.T, tmpDir string, content string) *Config {
|
||||
viper.SetConfigType("yaml")
|
||||
var yamlCfg = []byte(content)
|
||||
`)
|
||||
|
||||
err := ioutil.WriteFile(filepath.Join(tmpDir, "config.yaml"), yamlCfg, 0600)
|
||||
if err != nil {
|
||||
|
|
|
@ -94,6 +94,7 @@ func writeUnauthorized(w http.ResponseWriter, realm string) {
|
|||
w.Write([]byte(fmt.Sprintf("%d %s", http.StatusUnauthorized, "Unauthorized")))
|
||||
}
|
||||
|
||||
// GenHash generates a bcrypt hashed password string
|
||||
func GenHash(password []byte) string {
|
||||
pw, err := bcrypt.GenerateFromPassword(password, 10)
|
||||
if err != nil {
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"golang.org/x/net/webdav"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
@ -106,3 +110,125 @@ func TestAuthenticate(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthFromContext(t *testing.T) {
|
||||
type fakeKey int
|
||||
var fakeKeyValue fakeKey
|
||||
|
||||
baseCtx := context.Background()
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *AuthInfo
|
||||
}{
|
||||
{
|
||||
"success",
|
||||
args{
|
||||
ctx: context.WithValue(baseCtx, authInfoKey, &AuthInfo{"username", true}),
|
||||
},
|
||||
&AuthInfo{"username", true},
|
||||
},
|
||||
{
|
||||
"failure",
|
||||
args{
|
||||
ctx: context.WithValue(baseCtx, fakeKeyValue, &AuthInfo{"username", true}),
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := AuthFromContext(tt.args.ctx); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("AuthFromContext() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandle(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
w *httptest.ResponseRecorder
|
||||
r *http.Request
|
||||
username []byte
|
||||
password []byte
|
||||
a *App
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
"basic auth error",
|
||||
args{
|
||||
context.Background(),
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("PROPFIND", "/", nil),
|
||||
nil,
|
||||
nil,
|
||||
&App{Config: &Config{Users: map[string]*UserInfo{
|
||||
"foo": {
|
||||
Password: GenHash([]byte("password")),
|
||||
},
|
||||
}}},
|
||||
},
|
||||
401,
|
||||
},
|
||||
{
|
||||
"unauthorized error",
|
||||
args{
|
||||
context.Background(),
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("PROPFIND", "/", nil),
|
||||
[]byte("u"),
|
||||
[]byte("p"),
|
||||
&App{Config: &Config{Users: map[string]*UserInfo{
|
||||
"foo": {
|
||||
Password: GenHash([]byte("password")),
|
||||
},
|
||||
}}},
|
||||
},
|
||||
401,
|
||||
},
|
||||
{
|
||||
"ok",
|
||||
args{
|
||||
context.Background(),
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("PROPFIND", "/", nil),
|
||||
[]byte("foo"),
|
||||
[]byte("password"),
|
||||
&App{
|
||||
Config: &Config{Users: map[string]*UserInfo{
|
||||
"foo": {
|
||||
Password: GenHash([]byte("password")),
|
||||
},
|
||||
}},
|
||||
Handler: &webdav.Handler{
|
||||
FileSystem: webdav.NewMemFS(),
|
||||
LockSystem: webdav.NewMemLS(),
|
||||
},
|
||||
},
|
||||
},
|
||||
207,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.args.username != nil || tt.args.password != nil {
|
||||
tt.args.r.SetBasicAuth(string(tt.args.username), string(tt.args.password))
|
||||
}
|
||||
|
||||
handle(tt.args.ctx, tt.args.w, tt.args.r, tt.args.a)
|
||||
resp := tt.args.w.Result()
|
||||
|
||||
if resp.StatusCode != tt.statusCode {
|
||||
t.Errorf("TestHandle() = %v, want %v", resp.StatusCode, tt.statusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue