Add further tests
This commit is contained in:
parent
78a02529ab
commit
2306bbaad7
3 changed files with 154 additions and 27 deletions
|
@ -1,17 +1,17 @@
|
||||||
package app
|
package app
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"path/filepath"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
"strconv"
|
|
||||||
"bytes"
|
|
||||||
"io/ioutil"
|
|
||||||
"encoding/json"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseConfig(t *testing.T) {
|
func TestParseConfig(t *testing.T) {
|
||||||
|
@ -25,13 +25,29 @@ func TestParseConfig(t *testing.T) {
|
||||||
name string
|
name string
|
||||||
want *Config
|
want *Config
|
||||||
}{
|
}{
|
||||||
{"default", cfg(t, tmpDir, `
|
{"default", cfg(t, tmpDir)},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
if got := ParseConfig(); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
gotJSON, _ := json.Marshal(got)
|
||||||
|
wantJSON, _ := json.Marshal(tt.want)
|
||||||
|
t.Errorf("ParseConfig() = %s, want %s", gotJSON, wantJSON)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cfg(t *testing.T, tmpDir string) *Config {
|
||||||
|
viper.SetConfigType("yaml")
|
||||||
|
var yamlCfg = []byte(`
|
||||||
address: 1.2.3.4
|
address: 1.2.3.4
|
||||||
port: 42
|
port: 42
|
||||||
prefix: /oh-de-lally
|
prefix: /oh-de-lally
|
||||||
tls:
|
tls:
|
||||||
keyFile: `+ tmpDir+ `/robin.pem
|
keyFile: ` + tmpDir + `/robin.pem
|
||||||
certFile: `+ tmpDir+ `/tuck.pem
|
certFile: ` + tmpDir + `/tuck.pem
|
||||||
dir: /sherwood/forest
|
dir: /sherwood/forest
|
||||||
realm: uk
|
realm: uk
|
||||||
users:
|
users:
|
||||||
|
@ -43,23 +59,7 @@ users:
|
||||||
subdir: /sheriff
|
subdir: /sheriff
|
||||||
log:
|
log:
|
||||||
error: true
|
error: true
|
||||||
`)},
|
`)
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
|
|
||||||
if got := ParseConfig(); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
gotJson, _ := json.Marshal(got)
|
|
||||||
wantJson, _ := json.Marshal(tt.want)
|
|
||||||
t.Errorf("ParseConfig() = %s, want %s", gotJson, wantJson)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func cfg(t *testing.T, tmpDir string, content string) *Config {
|
|
||||||
viper.SetConfigType("yaml")
|
|
||||||
var yamlCfg = []byte(content)
|
|
||||||
|
|
||||||
err := ioutil.WriteFile(filepath.Join(tmpDir, "config.yaml"), yamlCfg, 0600)
|
err := ioutil.WriteFile(filepath.Join(tmpDir, "config.yaml"), yamlCfg, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -94,6 +94,7 @@ func writeUnauthorized(w http.ResponseWriter, realm string) {
|
||||||
w.Write([]byte(fmt.Sprintf("%d %s", http.StatusUnauthorized, "Unauthorized")))
|
w.Write([]byte(fmt.Sprintf("%d %s", http.StatusUnauthorized, "Unauthorized")))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenHash generates a bcrypt hashed password string
|
||||||
func GenHash(password []byte) string {
|
func GenHash(password []byte) string {
|
||||||
pw, err := bcrypt.GenerateFromPassword(password, 10)
|
pw, err := bcrypt.GenerateFromPassword(password, 10)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
package app
|
package app
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"golang.org/x/net/webdav"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -106,3 +110,125 @@ func TestAuthenticate(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthFromContext(t *testing.T) {
|
||||||
|
type fakeKey int
|
||||||
|
var fakeKeyValue fakeKey
|
||||||
|
|
||||||
|
baseCtx := context.Background()
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *AuthInfo
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"success",
|
||||||
|
args{
|
||||||
|
ctx: context.WithValue(baseCtx, authInfoKey, &AuthInfo{"username", true}),
|
||||||
|
},
|
||||||
|
&AuthInfo{"username", true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"failure",
|
||||||
|
args{
|
||||||
|
ctx: context.WithValue(baseCtx, fakeKeyValue, &AuthInfo{"username", true}),
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := AuthFromContext(tt.args.ctx); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("AuthFromContext() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandle(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
w *httptest.ResponseRecorder
|
||||||
|
r *http.Request
|
||||||
|
username []byte
|
||||||
|
password []byte
|
||||||
|
a *App
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"basic auth error",
|
||||||
|
args{
|
||||||
|
context.Background(),
|
||||||
|
httptest.NewRecorder(),
|
||||||
|
httptest.NewRequest("PROPFIND", "/", nil),
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
&App{Config: &Config{Users: map[string]*UserInfo{
|
||||||
|
"foo": {
|
||||||
|
Password: GenHash([]byte("password")),
|
||||||
|
},
|
||||||
|
}}},
|
||||||
|
},
|
||||||
|
401,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"unauthorized error",
|
||||||
|
args{
|
||||||
|
context.Background(),
|
||||||
|
httptest.NewRecorder(),
|
||||||
|
httptest.NewRequest("PROPFIND", "/", nil),
|
||||||
|
[]byte("u"),
|
||||||
|
[]byte("p"),
|
||||||
|
&App{Config: &Config{Users: map[string]*UserInfo{
|
||||||
|
"foo": {
|
||||||
|
Password: GenHash([]byte("password")),
|
||||||
|
},
|
||||||
|
}}},
|
||||||
|
},
|
||||||
|
401,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ok",
|
||||||
|
args{
|
||||||
|
context.Background(),
|
||||||
|
httptest.NewRecorder(),
|
||||||
|
httptest.NewRequest("PROPFIND", "/", nil),
|
||||||
|
[]byte("foo"),
|
||||||
|
[]byte("password"),
|
||||||
|
&App{
|
||||||
|
Config: &Config{Users: map[string]*UserInfo{
|
||||||
|
"foo": {
|
||||||
|
Password: GenHash([]byte("password")),
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
Handler: &webdav.Handler{
|
||||||
|
FileSystem: webdav.NewMemFS(),
|
||||||
|
LockSystem: webdav.NewMemLS(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
207,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.args.username != nil || tt.args.password != nil {
|
||||||
|
tt.args.r.SetBasicAuth(string(tt.args.username), string(tt.args.password))
|
||||||
|
}
|
||||||
|
|
||||||
|
handle(tt.args.ctx, tt.args.w, tt.args.r, tt.args.a)
|
||||||
|
resp := tt.args.w.Result()
|
||||||
|
|
||||||
|
if resp.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("TestHandle() = %v, want %v", resp.StatusCode, tt.statusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue