Skip to content

Commit

Permalink
Added middleware to check matching charsets against request (go-chi#286)
Browse files Browse the repository at this point in the history
  • Loading branch information
csucu authored and pkieltyka committed Dec 11, 2017
1 parent 05458a1 commit 893f598
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 0 deletions.
51 changes: 51 additions & 0 deletions middleware/content_charset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package middleware

import (
"net/http"
"strings"
)

// ContentCharset generates a handler that writes a 415 Unsupported Media Type response if none of the charsets match.
// An empty charset will allow requests with no Content-Type header or no specified charset.
func ContentCharset(charsets ...string) func(next http.Handler) http.Handler {
for i, c := range charsets {
charsets[i] = strings.ToLower(c)
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !contentEncoding(r.Header.Get("Content-Type"), charsets...) {
w.WriteHeader(http.StatusUnsupportedMediaType)
return
}

next.ServeHTTP(w, r)
})
}
}

// Check the content encoding against a list of acceptable values.
func contentEncoding(ce string, charsets ...string) bool {
_, ce = split(strings.ToLower(ce), ";")
_, ce = split(ce, "charset=")
ce, _ = split(ce, ";")
for _, c := range charsets {
if ce == c {
return true
}
}

return false
}

// Split a string in two parts, cleaning any whitespace.
func split(str, sep string) (string, string) {
var a, b string
var parts = strings.SplitN(str, sep, 2)
a = strings.TrimSpace(parts[0])
if len(parts) == 2 {
b = strings.TrimSpace(parts[1])
}

return a, b
}
124 changes: 124 additions & 0 deletions middleware/content_charset_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/go-chi/chi"
)

func TestContentCharset(t *testing.T) {
t.Parallel()

var tests = []struct {
name string
inputValue string
inputContentCharset []string
want int
}{
{
"should accept requests with a matching charset",
"application/json; charset=UTF-8",
[]string{"UTF-8"},
http.StatusOK,
},
{
"should be case-insensitive",
"application/json; charset=utf-8",
[]string{"UTF-8"},
http.StatusOK,
},
{
"should accept requests with a matching charset with extra values",
"application/json; foo=bar; charset=UTF-8; spam=eggs",
[]string{"UTF-8"},
http.StatusOK,
},
{
"should accept requests with a matching charset when multiple charsets are supported",
"text/xml; charset=UTF-8",
[]string{"UTF-8", "Latin-1"},
http.StatusOK,
},
{
"should accept requests with no charset if empty charset headers are allowed",
"text/xml",
[]string{"UTF-8", ""},
http.StatusOK,
},
{
"should not accept requests with no charset if empty charset headers are not allowed",
"text/xml",
[]string{"UTF-8"},
http.StatusUnsupportedMediaType,
},
{
"should not accept requests with a mismatching charset",
"text/plain; charset=Latin-1",
[]string{"UTF-8"},
http.StatusUnsupportedMediaType,
},
{
"should not accept requests with a mismatching charset even if empty charsets are allowed",
"text/plain; charset=Latin-1",
[]string{"UTF-8", ""},
http.StatusUnsupportedMediaType,
},
}

for _, tt := range tests {
var tt = tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

var recorder = httptest.NewRecorder()

var r = chi.NewRouter()
r.Use(ContentCharset(tt.inputContentCharset...))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {})

var req, _ = http.NewRequest("GET", "/", nil)
req.Header.Set("Content-Type", tt.inputValue)

r.ServeHTTP(recorder, req)
var res = recorder.Result()

if res.StatusCode != tt.want {
t.Errorf("response is incorrect, got %d, want %d", recorder.Code, tt.want)
}
})
}
}

func TestSplit(t *testing.T) {
t.Parallel()

var s1, s2 = split(" type1;type2 ", ";")

if s1 != "type1" || s2 != "type2" {
t.Errorf("Want type1, type2 got %s, %s", s1, s2)
}

s1, s2 = split("type1 ", ";")

if s1 != "type1" {
t.Errorf("Want \"type1\" got \"%s\"", s1)
}
}

func TestContentEncoding(t *testing.T) {
t.Parallel()

if !contentEncoding("application/json; foo=bar; charset=utf-8; spam=eggs", []string{"utf-8"}...) {
t.Error("Want true, got false")
}

if contentEncoding("text/plain; charset=latin-1", []string{"utf-8"}...) {
t.Error("Want false, got true")
}

if !contentEncoding("text/xml; charset=UTF-8", []string{"latin-1", "utf-8"}...) {
t.Error("Want true, got false")
}
}

0 comments on commit 893f598

Please sign in to comment.