diff --git a/app/config_test.go b/app/config_test.go index f5ba383..db654b5 100644 --- a/app/config_test.go +++ b/app/config_test.go @@ -1,17 +1,17 @@ package app import ( + "bytes" + "encoding/json" + "io/ioutil" + "os" + "path/filepath" "reflect" + "strconv" "testing" + "time" "github.com/spf13/viper" - "path/filepath" - "os" - "time" - "strconv" - "bytes" - "io/ioutil" - "encoding/json" ) func TestParseConfig(t *testing.T) { @@ -25,13 +25,29 @@ func TestParseConfig(t *testing.T) { name string want *Config }{ - {"default", cfg(t, tmpDir, ` + {"default", cfg(t, tmpDir)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + if got := ParseConfig(); !reflect.DeepEqual(got, tt.want) { + gotJSON, _ := json.Marshal(got) + wantJSON, _ := json.Marshal(tt.want) + t.Errorf("ParseConfig() = %s, want %s", gotJSON, wantJSON) + } + }) + } +} + +func cfg(t *testing.T, tmpDir string) *Config { + viper.SetConfigType("yaml") + var yamlCfg = []byte(` address: 1.2.3.4 port: 42 prefix: /oh-de-lally tls: - keyFile: `+ tmpDir+ `/robin.pem - certFile: `+ tmpDir+ `/tuck.pem + keyFile: ` + tmpDir + `/robin.pem + certFile: ` + tmpDir + `/tuck.pem dir: /sherwood/forest realm: uk users: @@ -43,23 +59,7 @@ users: subdir: /sheriff log: error: true -`)}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - if got := ParseConfig(); !reflect.DeepEqual(got, tt.want) { - gotJson, _ := json.Marshal(got) - wantJson, _ := json.Marshal(tt.want) - t.Errorf("ParseConfig() = %s, want %s", gotJson, wantJson) - } - }) - } -} - -func cfg(t *testing.T, tmpDir string, content string) *Config { - viper.SetConfigType("yaml") - var yamlCfg = []byte(content) +`) err := ioutil.WriteFile(filepath.Join(tmpDir, "config.yaml"), yamlCfg, 0600) if err != nil { diff --git a/app/security.go b/app/security.go index c934e11..2e39aab 100644 --- a/app/security.go +++ b/app/security.go @@ -94,6 +94,7 @@ func writeUnauthorized(w http.ResponseWriter, realm string) { w.Write([]byte(fmt.Sprintf("%d %s", http.StatusUnauthorized, "Unauthorized"))) } +// GenHash generates a bcrypt hashed password string func GenHash(password []byte) string { pw, err := bcrypt.GenerateFromPassword(password, 10) if err != nil { diff --git a/app/security_test.go b/app/security_test.go index c21863b..c6765ab 100644 --- a/app/security_test.go +++ b/app/security_test.go @@ -1,6 +1,10 @@ package app import ( + "context" + "golang.org/x/net/webdav" + "net/http" + "net/http/httptest" "reflect" "testing" ) @@ -106,3 +110,125 @@ func TestAuthenticate(t *testing.T) { }) } } + +func TestAuthFromContext(t *testing.T) { + type fakeKey int + var fakeKeyValue fakeKey + + baseCtx := context.Background() + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + want *AuthInfo + }{ + { + "success", + args{ + ctx: context.WithValue(baseCtx, authInfoKey, &AuthInfo{"username", true}), + }, + &AuthInfo{"username", true}, + }, + { + "failure", + args{ + ctx: context.WithValue(baseCtx, fakeKeyValue, &AuthInfo{"username", true}), + }, + nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := AuthFromContext(tt.args.ctx); !reflect.DeepEqual(got, tt.want) { + t.Errorf("AuthFromContext() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHandle(t *testing.T) { + type args struct { + ctx context.Context + w *httptest.ResponseRecorder + r *http.Request + username []byte + password []byte + a *App + } + tests := []struct { + name string + args args + statusCode int + }{ + { + "basic auth error", + args{ + context.Background(), + httptest.NewRecorder(), + httptest.NewRequest("PROPFIND", "/", nil), + nil, + nil, + &App{Config: &Config{Users: map[string]*UserInfo{ + "foo": { + Password: GenHash([]byte("password")), + }, + }}}, + }, + 401, + }, + { + "unauthorized error", + args{ + context.Background(), + httptest.NewRecorder(), + httptest.NewRequest("PROPFIND", "/", nil), + []byte("u"), + []byte("p"), + &App{Config: &Config{Users: map[string]*UserInfo{ + "foo": { + Password: GenHash([]byte("password")), + }, + }}}, + }, + 401, + }, + { + "ok", + args{ + context.Background(), + httptest.NewRecorder(), + httptest.NewRequest("PROPFIND", "/", nil), + []byte("foo"), + []byte("password"), + &App{ + Config: &Config{Users: map[string]*UserInfo{ + "foo": { + Password: GenHash([]byte("password")), + }, + }}, + Handler: &webdav.Handler{ + FileSystem: webdav.NewMemFS(), + LockSystem: webdav.NewMemLS(), + }, + }, + }, + 207, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.args.username != nil || tt.args.password != nil { + tt.args.r.SetBasicAuth(string(tt.args.username), string(tt.args.password)) + } + + handle(tt.args.ctx, tt.args.w, tt.args.r, tt.args.a) + resp := tt.args.w.Result() + + if resp.StatusCode != tt.statusCode { + t.Errorf("TestHandle() = %v, want %v", resp.StatusCode, tt.statusCode) + } + }) + } +}