Skip to content

Commit

Permalink
move code around
Browse files Browse the repository at this point in the history
  • Loading branch information
kjk committed Oct 20, 2021
1 parent 79c7312 commit 7ee1848
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 56 deletions.
26 changes: 26 additions & 0 deletions httputil/httputil.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ import (
"mime/multipart"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"

"github.com/kjk/common/u"
)

// can be used for http.Get() requests with better timeouts. New one must be created
Expand Down Expand Up @@ -117,3 +120,26 @@ func JoinURL(s1, s2 string) string {
}
return s1 + "/" + s2
}

func MakeFullRedirectURL(path string, reqURL *url.URL) string {
// TODO: could verify that path is really a path
// and doesn't have query / fragment
if reqURL.RawQuery != "" {
path = path + "?" + reqURL.RawQuery
}
if reqURL.Fragment != "" {
path = path + "#" + reqURL.EscapedFragment()
}
return path
}

// SmartRedirect redirects to uri but also adds query / fragment from r.URL
func SmartRedirect(w http.ResponseWriter, r *http.Request, uri string, code int) {
u.PanicIf(code < 300 || code >= 400)
uri = MakeFullRedirectURL(uri, r.URL)
http.Redirect(w, r, uri, code)
}

func SmartPermanentRedirect(w http.ResponseWriter, r *http.Request, uri string) {
SmartRedirect(w, r, uri, http.StatusMovedPermanently) // 301
}
18 changes: 18 additions & 0 deletions httputil/httputil_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package httputil

import (
"net/url"
"testing"

"github.com/kjk/common/assert"
Expand All @@ -20,3 +21,20 @@ func TestJoinURL(t *testing.T) {
assert.Equal(t, exp, got)
}
}

func TestMakeFullRedirectURL(t *testing.T) {
tests := []string{
"/foo.html#me;him", "/bar", "/bar#me;him",
"/foo.html", "/bar", "/bar",
"/foo.html?me=him", "/bar", "/bar?me=him",
"/foo.html?me=him#me", "/bar", "/bar?me=him#me",
}
for i := 0; i < len(tests); i += 3 {
u, err := url.Parse(tests[i])
assert.NoError(t, err)
path := tests[i+1]
exp := tests[i+2]
got := MakeFullRedirectURL(path, u)
assert.Equal(t, exp, got)
}
}
41 changes: 5 additions & 36 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"io/ioutil"
"mime"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
Expand All @@ -19,6 +18,8 @@ import (
"time"

"github.com/andybalholm/brotli"
"github.com/kjk/common/httputil"
"github.com/kjk/common/u"
)

// Server represents all files known to the server
Expand Down Expand Up @@ -504,41 +505,9 @@ func Gen404Candidates(uri string) []string {
return res
}

// returns true if s ends with extension (e.g. ".html")
// case-insensitive
func hasExtFold(s string, ext string) bool {
e := filepath.Ext(s)
return strings.EqualFold(e, ext)
}

func trimExt(s string) string {
idx := strings.LastIndex(s, ".")
if idx == -1 {
return s
}
return s[:idx]
}

func MakeFullRedirectURL(path string, reqURL *url.URL) string {
// TODO: could verify that path is really a path
// and doesn't have query / fragment
if reqURL.RawQuery != "" {
path = path + "?" + reqURL.RawQuery
}
if reqURL.Fragment != "" {
path = path + "#" + reqURL.EscapedFragment()
}
return path
}

func permRedirectTo(w http.ResponseWriter, r *http.Request, uri string) {
uri = MakeFullRedirectURL(uri, r.URL)
http.Redirect(w, r, uri, http.StatusMovedPermanently) // 301
}

func makePermRedirect(uri string) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
permRedirectTo(w, r, uri)
httputil.SmartPermanentRedirect(w, r, uri)
}
}

Expand All @@ -548,8 +517,8 @@ func (s *Server) FindHandler(uri string) (h HandlerFunc, is404 bool) {
uri = path.Join(uri, "/index.html")
}
if h = s.FindHandlerExact(uri); h != nil {
if s.ForceCleanURLS && hasExtFold(uri, ".html") {
uri = trimExt(uri)
if s.ForceCleanURLS && u.ExtEqualFold(uri, ".html") {
uri = u.TrimExt(uri)
h = makePermRedirect(uri)
}
return
Expand Down
21 changes: 2 additions & 19 deletions server/server_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package server

import (
"net/url"
"reflect"
"testing"

"github.com/kjk/common/assert"
"github.com/kjk/common/u"
)

func TestGen404Candidates(t *testing.T) {
Expand All @@ -27,23 +27,6 @@ func TestGen404Candidates(t *testing.T) {
}
}

func TestMakeFullRedirectURL(t *testing.T) {
tests := []string{
"/foo.html#me;him", "/bar", "/bar#me;him",
"/foo.html", "/bar", "/bar",
"/foo.html?me=him", "/bar", "/bar?me=him",
"/foo.html?me=him#me", "/bar", "/bar?me=him#me",
}
for i := 0; i < len(tests); i += 3 {
u, err := url.Parse(tests[i])
assert.NoError(t, err)
path := tests[i+1]
exp := tests[i+2]
got := MakeFullRedirectURL(path, u)
assert.Equal(t, exp, got)
}
}

func TestTrimExt(t *testing.T) {
tests := []string{
"foo.html", "foo",
Expand All @@ -54,7 +37,7 @@ func TestTrimExt(t *testing.T) {

for i := 0; i < len(tests); i += 2 {
exp := tests[i+1]
got := trimExt(tests[i])
got := u.TrimExt(tests[i])
assert.Equal(t, exp, got)
}
}
21 changes: 20 additions & 1 deletion u/strings.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package u

import "strings"
import (
"path/filepath"
"strings"
)

// NormalizeNewlinesInPlace changes CRLF (Windows) and
// CR (Mac) to LF (Unix)
Expand Down Expand Up @@ -64,3 +67,19 @@ func ToTrimmedLines(d []byte) []string {
}
return lines[:i]
}

// TrimExt removes extension from s
func TrimExt(s string) string {
idx := strings.LastIndex(s, ".")
if idx == -1 {
return s
}
return s[:idx]
}

// ExtEqualFold returns true if s ends with extension (e.g. ".html")
// case-insensitive
func ExtEqualFold(s string, ext string) bool {
e := filepath.Ext(s)
return strings.EqualFold(e, ext)
}
17 changes: 17 additions & 0 deletions u/strings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ func TestTrimPrefix(t *testing.T) {
}
}

func TestTrimExt(t *testing.T) {
tests := []string{
"foo", "foo",
"foo.html", "foo",
"foo.", "foo",
"foo.html.txt", "foo.html",
}

n := len(tests)
for i := 0; i < n; i += 2 {
got := TrimExt(tests[i])
exp := tests[i+1]
assert.Equal(t, exp, got)
assert.Equal(t, exp, got, "%#v, %#v", tests[i], tests[i+1])
}
}

func TestCapitalize(t *testing.T) {
tests := []struct {
s string
Expand Down

0 comments on commit 7ee1848

Please sign in to comment.