diff --git a/reflect/reflect.go b/reflect/reflect.go index 88b4c02..e21def5 100644 --- a/reflect/reflect.go +++ b/reflect/reflect.go @@ -36,7 +36,7 @@ func StructToMap(val interface{}) (map[string]interface{}, bool) { return mapVal, true } -// StructFieldMap takes a struct and extracts the Field info into a map by +// StructFieldMap takes a struct type and extracts the Field info into a map by // field name. The "cql" key in the struct field's tag value is the key // name. Examples: // @@ -50,15 +50,12 @@ func StructToMap(val interface{}) (map[string]interface{}, bool) { // Field int "myName" // // If lowercaseFields is set to true, field names are lowercased in the map -func StructFieldMap(val interface{}, lowercaseFields bool) (map[string]Field, error) { - // indirect so function works with both structs and pointers to them - structVal := r.Indirect(r.ValueOf(val)) - kind := structVal.Kind() - if kind != r.Struct { - return nil, fmt.Errorf("expected val to be a struct, got %T", val) +func StructFieldMap(structType r.Type, lowercaseFields bool) (map[string]Field, error) { + if structType.Kind() != r.Struct { + return nil, fmt.Errorf("expected val to be a struct, got %v", structType) } - structFields := cachedTypeFields(structVal.Type()) + structFields := cachedTypeFields(structType) mapVal := make(map[string]Field, len(structFields)) for _, info := range structFields { name := info.name diff --git a/reflect/reflect_test.go b/reflect/reflect_test.go index 9f28504..1a21589 100644 --- a/reflect/reflect_test.go +++ b/reflect/reflect_test.go @@ -138,7 +138,7 @@ func TestMapToStruct(t *testing.T) { } func TestStructFieldMap(t *testing.T) { - m, err := StructFieldMap(Tweet{}, false) + m, err := StructFieldMap(reflect.TypeOf(Tweet{}), false) if err != nil { t.Fatalf("expected field map to be created, err: %v", err) } @@ -185,7 +185,7 @@ func TestStructFieldMap(t *testing.T) { } // Test lowercasing fields - m2, err := StructFieldMap(Tweet{}, true) + m2, err := StructFieldMap(reflect.TypeOf(Tweet{}), true) if err != nil { t.Fatalf("expected field map to be created, err: %v", err) } @@ -203,7 +203,7 @@ func TestStructFieldMapEmbeddedStruct(t *testing.T) { Embedder string } - m, err := StructFieldMap(EmbeddedTweet{}, false) + m, err := StructFieldMap(reflect.TypeOf(EmbeddedTweet{}), false) if err != nil { t.Fatalf("expected field map to be created, err: %v", err) } @@ -234,7 +234,7 @@ func TestStructFieldMapEmbeddedStruct(t *testing.T) { } func TestStructFieldMapNonStruct(t *testing.T) { - _, err := StructFieldMap(42, false) + _, err := StructFieldMap(reflect.TypeOf(42), false) if err == nil { t.Fatalf("expected StructFieldMap to have an error, got nil error") } diff --git a/scanner.go b/scanner.go index 091f6ce..6ec6e2d 100644 --- a/scanner.go +++ b/scanner.go @@ -81,6 +81,7 @@ func (s *scanner) iterSlice(iter Scannable) (int, error) { if err != nil { return rowsScanned, err } + fillInZeroedPtrs(ptrs) sliceElem.Set(reflect.Append(sliceElem, wrapPtrValue(outVal, sliceElemType))) rowsScanned++ @@ -124,6 +125,7 @@ func (s *scanner) iterSingle(iter Scannable) (int, error) { if err != nil { return 0, err } + fillInZeroedPtrs(ptrs) s.rowsScanned++ return 1, nil @@ -132,10 +134,9 @@ func (s *scanner) iterSingle(iter Scannable) (int, error) { // structFields matches the SelectStatement field names selected to names of // fields within the target struct type func (s *scanner) structFields(structType reflect.Type) ([]*r.Field, error) { - fmPtr := reflect.New(structType).Interface() - m, err := r.StructFieldMap(fmPtr, true) + m, err := r.StructFieldMap(structType, true) if err != nil { - return nil, fmt.Errorf("could not decode struct of type %T: %v", fmPtr, err) + return nil, fmt.Errorf("could not decode struct of type %v: %v", structType, err) } structFields := make([]*r.Field, len(s.stmt.fields)) @@ -185,9 +186,13 @@ func generatePtrs(structFields []*r.Field, structVal reflect.Value) []interface{ switch elem.Kind() { case reflect.Map: - elem.Set(reflect.MakeMap(elem.Type())) + if elem.IsNil() { + elem.Set(reflect.MakeMap(elem.Type())) + } case reflect.Slice: - elem.Set(reflect.MakeSlice(elem.Type(), 0, 0)) + if elem.IsNil() { + elem.Set(reflect.MakeSlice(elem.Type(), 0, 0)) + } } ptrs[i] = elem.Addr().Interface() @@ -195,6 +200,33 @@ func generatePtrs(structFields []*r.Field, structVal reflect.Value) []interface{ return ptrs } +// fillInZeroedPtrs is necessary to re-allocate nil slices/maps in our ptr +// list. Gocql unfortunately sees no data as an opportunity to zero out the +// entire slice rather than leaving it as the empty slice. This means something +// like []string{} will get turned into []string(nil) which aren't technically +// the same +func fillInZeroedPtrs(ptrs []interface{}) { + for _, ptr := range ptrs { + if _, ok := ptr.(*ignoreFieldType); ok { + continue + } + + elem := reflect.ValueOf(ptr).Elem() + + switch elem.Kind() { + case reflect.Map: + if elem.IsNil() || elem.IsZero() { + elem.Set(reflect.MakeMap(elem.Type())) + } + case reflect.Slice: + if elem.IsNil() || elem.IsZero() { + elem.Set(reflect.MakeSlice(elem.Type(), 0, 0)) + } + } + + } +} + // allocateNilReference checks to see if the in is not nil itself but points to // an object which itself is nil. Note that it only checks one depth down. // Returns true if any allocation has happened, false if no allocation was needed diff --git a/scanner_test.go b/scanner_test.go index 0981362..0f31d8e 100644 --- a/scanner_test.go +++ b/scanner_test.go @@ -254,6 +254,27 @@ func TestScanIterEmbedded(t *testing.T) { iter.Reset() } +func TestFillInZeroedPtrs(t *testing.T) { + str := "" + strSlice := []string{} + strMap := map[string]string{} + strSliceNil := []string(nil) + strMapNil := map[string]string(nil) + + // Test with already allocated + fillInZeroedPtrs([]interface{}{&str, &strSlice, &strMap}) + assert.Equal(t, "", str) + assert.Equal(t, []string{}, strSlice) + assert.Equal(t, map[string]string{}, strMap) + + // Test with nil allocated + assert.NotEqual(t, []string{}, strSliceNil) + assert.NotEqual(t, map[string]string{}, strMapNil) + fillInZeroedPtrs([]interface{}{&strSliceNil, &strMapNil}) + assert.Equal(t, []string{}, strSliceNil) + assert.Equal(t, map[string]string{}, strMapNil) +} + func TestAllocateNilReference(t *testing.T) { // Test non pointer, should do nothing var a string