Last active
April 11, 2019 08:45
-
-
Save benjbaron/bc4ebbc1146dfe6c564d55c344c9b60f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package types | |
import ( | |
"encoding/json" | |
"time" | |
) | |
// FIXME: This does not currently work with JSON's omitempty. | |
// See: https://github.com/golang/go/issues/11939 | |
// NullString represents a string that may be null. | |
// Use NullString as follows. | |
// if s.Valid { | |
// // use s.String | |
// } else { | |
// // NULL value | |
// } | |
// | |
type NullString struct { | |
String string | |
Valid bool // Valid is true if String is not NULL | |
} | |
// NewNullString creates a new NullString. | |
func NewNullString(s string, valid bool) NullString { | |
return NullString{ | |
String: s, | |
Valid: valid, | |
} | |
} | |
// NullStringFrom creates a new String that will never be blank. | |
func NullStringFrom(s string) NullString { | |
return NewNullString(s, true) | |
} | |
// MarshalJSON for NullString and implements json.Marshaler. | |
func (s NullString) MarshalJSON() ([]byte, error) { | |
if !s.Valid { | |
return json.Marshal(nil) | |
} | |
return json.Marshal(s.String) | |
} | |
// UnmarshalJSON for NullString and implements json.Unmarshaler. | |
// It supports string and null input. Blank string input does not produce a null NullString. | |
func (s *NullString) UnmarshalJSON(b []byte) error { | |
var x *string | |
if err := json.Unmarshal(b, &x); err != nil { | |
return err | |
} | |
if x != nil { | |
s.Valid = true | |
s.String = *x | |
} else { | |
s.Valid = false | |
} | |
return nil | |
} | |
// NullInt represents an int that may be null. | |
type NullInt struct { | |
Int int | |
Valid bool // Valid is true if Int is not NULL | |
} | |
// NewNullInt creates a new NullInt. | |
func NewNullInt(i int, valid bool) NullInt { | |
return NullInt{ | |
Int: i, | |
Valid: valid, | |
} | |
} | |
// NullIntFrom creates a new NullInt that will always be valid. | |
func NullIntFrom(i int) NullInt { | |
return NewNullInt(i, true) | |
} | |
// MarshalJSON for NullInt and implements json.Marshaler. | |
func (i NullInt) MarshalJSON() ([]byte, error) { | |
if !i.Valid { | |
return json.Marshal(nil) | |
} | |
return json.Marshal(i.Int) | |
} | |
// UnmarshalJSON for NullInt and implements json.Unmarshaler. | |
func (i *NullInt) UnmarshalJSON(b []byte) error { | |
var x *int | |
if err := json.Unmarshal(b, &x); err != nil { | |
return err | |
} | |
if x != nil { | |
i.Valid = true | |
i.Int = *x | |
} else { | |
i.Valid = false | |
} | |
return nil | |
} | |
// NullUInt represents an uint that may be null. | |
type NullUInt struct { | |
UInt uint | |
Valid bool // Valid is true if UInt is not NULL | |
} | |
// NewNullUInt creates a new NullUInt. | |
func NewNullUInt(i uint, valid bool) NullUInt { | |
return NullUInt{ | |
UInt: i, | |
Valid: valid, | |
} | |
} | |
// NullUIntFrom creates a new NullUInt that will always be valid. | |
func NullUIntFrom(i uint) NullUInt { | |
return NewNullUInt(i, true) | |
} | |
// MarshalJSON for NullUInt and implements json.Marshaler. | |
func (i NullUInt) MarshalJSON() ([]byte, error) { | |
if !i.Valid { | |
return json.Marshal(nil) | |
} | |
return json.Marshal(i.UInt) | |
} | |
// UnmarshalJSON for NullUInt and implements json.Unmarshaler. | |
func (i *NullUInt) UnmarshalJSON(b []byte) error { | |
var x *uint | |
if err := json.Unmarshal(b, &x); err != nil { | |
return err | |
} | |
if x != nil { | |
i.Valid = true | |
i.UInt = *x | |
} else { | |
i.Valid = false | |
} | |
return nil | |
} | |
// NullInt64 represents an int64 that may be null. | |
type NullInt64 struct { | |
Int64 int64 | |
Valid bool // Valid is true if Int64 is not NULL | |
} | |
// NewNullInt64 creates a new NullInt64. | |
func NewNullInt64(i int64, valid bool) NullInt64 { | |
return NullInt64{ | |
Int64: i, | |
Valid: valid, | |
} | |
} | |
// NullInt64From creates a new NullInt64 that will always be valid. | |
func NullInt64From(i int64) NullInt64 { | |
return NewNullInt64(i, true) | |
} | |
// MarshalJSON for NullInt64 and implements json.Marshaler. | |
func (i NullInt64) MarshalJSON() ([]byte, error) { | |
if !i.Valid { | |
return json.Marshal(nil) | |
} | |
return json.Marshal(i.Int64) | |
} | |
// UnmarshalJSON for NullInt64 and implements json.Unmarshaler. | |
func (i *NullInt64) UnmarshalJSON(b []byte) error { | |
var x *int64 | |
if err := json.Unmarshal(b, &x); err != nil { | |
return err | |
} | |
if x != nil { | |
i.Valid = true | |
i.Int64 = *x | |
} else { | |
i.Valid = false | |
} | |
return nil | |
} | |
// NullBool represents a bool that may be null. | |
type NullBool struct { | |
Bool bool | |
Valid bool // Valid is true if Bool is not NULL | |
} | |
// NewNullBool creates a new NullBool. | |
func NewNullBool(b bool, valid bool) NullBool { | |
return NullBool{ | |
Bool: b, | |
Valid: valid, | |
} | |
} | |
// NullBoolFrom creates a new NullBool that will always be valid. | |
func NullBoolFrom(b bool) NullBool { | |
return NewNullBool(b, true) | |
} | |
// MarshalJSON for NullBool and implements json.Marshaler. | |
func (b NullBool) MarshalJSON() ([]byte, error) { | |
if !b.Valid { | |
return json.Marshal(nil) | |
} | |
return json.Marshal(b.Bool) | |
} | |
// UnmarshalJSON for NullBool and implements json.Unmarshaler. | |
func (b *NullBool) UnmarshalJSON(bytes []byte) error { | |
var x *bool | |
if err := json.Unmarshal(bytes, &x); err != nil { | |
return err | |
} | |
if x != nil { | |
b.Valid = true | |
b.Bool = *x | |
} else { | |
b.Valid = false | |
} | |
return nil | |
} | |
// NullFloat64 represents a float64 that may be null. | |
type NullFloat64 struct { | |
Float64 float64 | |
Valid bool // Valid is true if Float64 is not NULL | |
} | |
// NewNullFloat64 creates a new NullFloat64. | |
func NewNullFloat64(f float64, valid bool) NullFloat64 { | |
return NullFloat64{ | |
Float64: f, | |
Valid: valid, | |
} | |
} | |
// NullFloat64From creates a new NullFloat64 that will always be valid. | |
func NullFloat64From(f float64) NullFloat64 { | |
return NewNullFloat64(f, true) | |
} | |
// MarshalJSON for NullFloat64 and implements json.Marshaler. | |
func (f NullFloat64) MarshalJSON() ([]byte, error) { | |
if !f.Valid { | |
return json.Marshal(nil) | |
} | |
return json.Marshal(f.Float64) | |
} | |
// UnmarshalJSON for NullFloat64 and implements json.Unmarshaler. | |
func (f *NullFloat64) UnmarshalJSON(bytes []byte) error { | |
var x *float64 | |
if err := json.Unmarshal(bytes, &x); err != nil { | |
return err | |
} | |
if x != nil { | |
f.Valid = true | |
f.Float64 = *x | |
} else { | |
f.Valid = false | |
} | |
return nil | |
} | |
// NullTime is an alias for mysql.NullTime data type | |
type NullTime struct { | |
Time time.Time | |
Valid bool // Valid is true if time.Time is not NULL | |
} | |
// NewNullTime creates a new NullTime. | |
func NewNullTime(t time.Time, valid bool) NullTime { | |
return NullTime{ | |
Time: t, | |
Valid: valid, | |
} | |
} | |
// NullTimeFrom creates a new NullTime that will always be valid. | |
func NullTimeFrom(t time.Time) NullTime { | |
return NewNullTime(t, true) | |
} | |
// MarshalJSON for NullTime and implements json.Marshaler. | |
func (t NullTime) MarshalJSON() ([]byte, error) { | |
if !t.Valid { | |
return json.Marshal(nil) | |
} | |
return t.Time.MarshalJSON() | |
} | |
// UnmarshalJSON for NullTime and implements json.Unmarshaler. | |
func (t *NullTime) UnmarshalJSON(bytes []byte) error { | |
var x *time.Time | |
if err := json.Unmarshal(bytes, &x); err != nil { | |
return err | |
} | |
if x != nil { | |
t.Valid = true | |
t.Time = *x | |
} else { | |
t.Valid = false | |
} | |
return nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package types | |
import ( | |
"encoding/json" | |
"fmt" | |
"testing" | |
"time" | |
"github.com/stretchr/testify/require" | |
) | |
func TestNullStringFrom(t *testing.T) { | |
s := NullStringFrom("test") | |
require.Equal(t, "test", s.String) | |
require.True(t, s.Valid) | |
zero := NullStringFrom("") | |
require.Equal(t, "", zero.String) | |
require.True(t, zero.Valid) | |
} | |
var testNullString = []struct { | |
JSON []byte | |
Value NullString | |
}{ | |
{[]byte(`"test"`), NullString{"test", true}}, | |
{[]byte(`""`), NullString{"", true}}, | |
{[]byte(`null`), NullString{Valid: false}}, | |
} | |
func TestNullString_UnmarshalJSON(t *testing.T) { | |
for i, param := range testNullString { | |
t.Run(fmt.Sprintf("TestNullString_UnmarshalJSON (%d)", i), func(t *testing.T) { | |
var s NullString | |
err := json.Unmarshal(param.JSON, &s) | |
require.NoError(t, err) | |
require.Equal(t, param.Value, s) | |
}) | |
} | |
} | |
func TestNullString_MarshalJSON(t *testing.T) { | |
for i, param := range testNullString { | |
t.Run(fmt.Sprintf("TestNullString_MarshalJSON (%d)", i), func(t *testing.T) { | |
bytes, err := json.Marshal(param.Value) | |
require.NoError(t, err) | |
require.Equal(t, param.JSON, bytes) | |
}) | |
} | |
} | |
func TestNullIntFrom(t *testing.T) { | |
i := NullIntFrom(42) | |
require.Equal(t, 42, i.Int) | |
require.True(t, i.Valid) | |
zero := NullIntFrom(0) | |
require.Equal(t, 0, zero.Int) | |
require.True(t, zero.Valid) | |
} | |
var testNullInt = []struct { | |
JSON []byte | |
Value NullInt | |
}{ | |
{[]byte(`12345`), NullInt{12345, true}}, | |
{[]byte(`0`), NullInt{0, true}}, | |
{[]byte(`null`), NullInt{Valid: false}}, | |
} | |
func TestNullInt_UnmarshalJSON(t *testing.T) { | |
for i, param := range testNullInt { | |
t.Run(fmt.Sprintf("TestNullInt_UnmarshalJSON (%d)", i), func(t *testing.T) { | |
var i NullInt | |
err := json.Unmarshal(param.JSON, &i) | |
require.NoError(t, err) | |
require.Equal(t, param.Value, i) | |
}) | |
} | |
} | |
func TestNullInt_MarshalJSON(t *testing.T) { | |
for i, param := range testNullInt { | |
t.Run(fmt.Sprintf("TestNullInt_MarshalJSON (%d)", i), func(t *testing.T) { | |
bytes, err := json.Marshal(param.Value) | |
require.NoError(t, err) | |
require.Equal(t, param.JSON, bytes) | |
}) | |
} | |
} | |
func TestNullUIntFrom(t *testing.T) { | |
i := NullUIntFrom(42) | |
require.Equal(t, uint(42), i.UInt) | |
require.True(t, i.Valid) | |
zero := NullUIntFrom(0) | |
require.Equal(t, uint(0), zero.UInt) | |
require.True(t, zero.Valid) | |
} | |
var testNullUInt = []struct { | |
JSON []byte | |
Value NullUInt | |
}{ | |
{[]byte(`12345`), NullUInt{12345, true}}, | |
{[]byte(`0`), NullUInt{0, true}}, | |
{[]byte(`null`), NullUInt{Valid: false}}, | |
} | |
func TestNullUInt_UnmarshalJSON(t *testing.T) { | |
for i, param := range testNullUInt { | |
t.Run(fmt.Sprintf("TestNullUInt_UnmarshalJSON (%d)", i), func(t *testing.T) { | |
var i NullUInt | |
err := json.Unmarshal(param.JSON, &i) | |
require.NoError(t, err) | |
require.Equal(t, param.Value, i) | |
}) | |
} | |
} | |
func TestNullUInt_MarshalJSON(t *testing.T) { | |
for i, param := range testNullUInt { | |
t.Run(fmt.Sprintf("TestNullUInt_MarshalJSON (%d)", i), func(t *testing.T) { | |
bytes, err := json.Marshal(param.Value) | |
require.NoError(t, err) | |
require.Equal(t, param.JSON, bytes) | |
}) | |
} | |
} | |
func TestNullInt64From(t *testing.T) { | |
i := NullInt64From(42) | |
require.Equal(t, int64(42), i.Int64) | |
require.True(t, i.Valid) | |
zero := NullInt64From(0) | |
require.Equal(t, int64(0), zero.Int64) | |
require.True(t, zero.Valid) | |
} | |
var testNullInt64 = []struct { | |
JSON []byte | |
Value NullInt64 | |
}{ | |
{[]byte(`12345`), NullInt64{12345, true}}, | |
{[]byte(`0`), NullInt64{0, true}}, | |
{[]byte(`null`), NullInt64{Valid: false}}, | |
} | |
func TestNullInt64_UnmarshalJSON(t *testing.T) { | |
for i, param := range testNullInt64 { | |
t.Run(fmt.Sprintf("TestNullInt64_UnmarshalJSON (%d)", i), func(t *testing.T) { | |
var i NullInt64 | |
err := json.Unmarshal(param.JSON, &i) | |
require.NoError(t, err) | |
require.Equal(t, param.Value, i) | |
}) | |
} | |
} | |
func TestNullInt64_MarshalJSON(t *testing.T) { | |
for i, param := range testNullInt64 { | |
t.Run(fmt.Sprintf("TestNullInt64_MarshalJSON (%d)", i), func(t *testing.T) { | |
bytes, err := json.Marshal(param.Value) | |
require.NoError(t, err) | |
require.Equal(t, param.JSON, bytes) | |
}) | |
} | |
} | |
func TestNullBoolFrom(t *testing.T) { | |
b := NullBoolFrom(true) | |
require.Equal(t, true, b.Bool) | |
require.True(t, b.Valid) | |
zero := NullBoolFrom(false) | |
require.Equal(t, false, zero.Bool) | |
require.True(t, zero.Valid) | |
} | |
var testNullBool = []struct { | |
JSON []byte | |
Value NullBool | |
}{ | |
{[]byte(`true`), NullBool{true, true}}, | |
{[]byte(`false`), NullBool{false, true}}, | |
{[]byte(`null`), NullBool{Valid: false}}, | |
} | |
func TestNullBool_UnmarshalJSON(t *testing.T) { | |
for i, param := range testNullBool { | |
t.Run(fmt.Sprintf("TestNullBool_UnmarshalJSON (%d)", i), func(t *testing.T) { | |
var b NullBool | |
err := json.Unmarshal(param.JSON, &b) | |
require.NoError(t, err) | |
require.Equal(t, param.Value, b) | |
}) | |
} | |
} | |
func TestNullBool_MarshalJSON(t *testing.T) { | |
for i, param := range testNullBool { | |
t.Run(fmt.Sprintf("TestNullBool_MarshalJSON (%d)", i), func(t *testing.T) { | |
bytes, err := json.Marshal(param.Value) | |
require.NoError(t, err) | |
require.Equal(t, param.JSON, bytes) | |
}) | |
} | |
} | |
func TestNullFloat64From(t *testing.T) { | |
f := NullFloat64From(1.2345) | |
require.Equal(t, 1.2345, f.Float64) | |
require.True(t, f.Valid) | |
zero := NullFloat64From(0.0) | |
require.Equal(t, 0.0, zero.Float64) | |
require.True(t, zero.Valid) | |
} | |
var testNullFloat64 = []struct { | |
JSON []byte | |
Value NullFloat64 | |
}{ | |
{[]byte(`1.2345`), NullFloat64{1.2345, true}}, | |
{[]byte(`0`), NullFloat64{0.0, true}}, | |
{[]byte(`null`), NullFloat64{Valid: false}}, | |
} | |
func TestNullFloat64_UnmarshalJSON(t *testing.T) { | |
for i, param := range testNullFloat64 { | |
t.Run(fmt.Sprintf("TestNullFloat64_UnmarshalJSON (%d)", i), func(t *testing.T) { | |
var b NullFloat64 | |
err := json.Unmarshal(param.JSON, &b) | |
require.NoError(t, err) | |
require.Equal(t, param.Value, b) | |
}) | |
} | |
} | |
func TestNullFloat64_MarshalJSON(t *testing.T) { | |
for i, param := range testNullFloat64 { | |
t.Run(fmt.Sprintf("TestNullFloat64_MarshalJSON (%d)", i), func(t *testing.T) { | |
bytes, err := json.Marshal(param.Value) | |
require.NoError(t, err) | |
require.Equal(t, param.JSON, bytes) | |
}) | |
} | |
} | |
var ( | |
timeString = "2012-12-21T21:21:21Z" | |
timeZero = "0001-01-01T00:00:00Z" | |
timeValue, _ = time.Parse(time.RFC3339, timeString) | |
) | |
func TestNullTimeFrom(t *testing.T) { | |
v := NullTimeFrom(timeValue) | |
require.Equal(t, timeValue, v.Time) | |
require.True(t, v.Valid) | |
zero := NullTimeFrom(time.Time{}) | |
require.Equal(t, time.Time{}, zero.Time) | |
require.True(t, zero.Valid) | |
} | |
var testNullTime = []struct { | |
JSON []byte | |
Value NullTime | |
}{ | |
{[]byte(`"` + timeString + `"`), NullTime{timeValue, true}}, | |
{[]byte(`"` + timeZero + `"`), NullTime{time.Time{}, true}}, | |
{[]byte(`null`), NullTime{Valid: false}}, | |
} | |
func TestNullTime_UnmarshalJSON(t *testing.T) { | |
for i, param := range testNullTime { | |
t.Run(fmt.Sprintf("TestNullTime_UnmarshalJSON (%d)", i), func(t *testing.T) { | |
var v NullTime | |
err := json.Unmarshal(param.JSON, &v) | |
require.NoError(t, err) | |
require.Equal(t, param.Value, v) | |
}) | |
} | |
} | |
func TestNullTime_MarshalJSON(t *testing.T) { | |
for i, param := range testNullTime { | |
t.Run(fmt.Sprintf("TestNullTime_MarshalJSON (%d)", i), func(t *testing.T) { | |
bytes, err := json.Marshal(param.Value) | |
require.NoError(t, err) | |
require.Equal(t, param.JSON, bytes) | |
}) | |
} | |
} | |
type testStruct struct { | |
ID NullInt64 `json:"id"` | |
AppID NullString `json:"app_id"` | |
CreatedAt NullTime `json:"created_at"` | |
Score NullFloat64 `json:"score"` | |
Blocked NullBool `json:"blocked"` | |
} | |
var testParam = []struct { | |
JSON []byte | |
Value testStruct | |
}{ | |
{ | |
JSON: []byte(`{"id":1,"app_id":"test","created_at":"` + timeString + `","score":1.2,"blocked":true}`), | |
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(true)}, | |
}, | |
{ | |
JSON: []byte(`{"id":1,"app_id":"","created_at":"` + timeString + `","score":1.2,"blocked":true}`), | |
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom(""), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(true)}, | |
}, | |
{ | |
JSON: []byte(`{"id":1,"app_id":null,"created_at":"` + timeString + `","score":1.2,"blocked":true}`), | |
Value: testStruct{ID: NullInt64From(1), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(true)}, | |
}, | |
{ | |
JSON: []byte(`{"id":0,"app_id":"test","created_at":"` + timeZero + `","score":1.2,"blocked":true}`), | |
Value: testStruct{ID: NullInt64From(0), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(time.Time{}), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(true)}, | |
}, | |
{ | |
JSON: []byte(`{"id":1,"app_id":"test","created_at":null,"score":0,"blocked":true}`), | |
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), Score: NullFloat64From(0), Blocked: NullBoolFrom(true)}, | |
}, | |
{ | |
JSON: []byte(`{"id":0,"app_id":"test","created_at":"` + timeString + `","score":1.2,"blocked":true}`), | |
Value: testStruct{ID: NullInt64From(0), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(true)}, | |
}, | |
{ | |
JSON: []byte(`{"id":1,"app_id":"test","created_at":"` + timeString + `","score":0,"blocked":true}`), | |
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(0), Blocked: NullBoolFrom(true)}, | |
}, | |
{ | |
JSON: []byte(`{"id":1,"app_id":"test","created_at":"` + timeString + `","score":null,"blocked":true}`), | |
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Blocked: NullBoolFrom(true)}, | |
}, | |
{ | |
JSON: []byte(`{"id":1,"app_id":"test","created_at":"` + timeString + `","score":1.2,"blocked":false}`), | |
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(false)}, | |
}, | |
{ | |
JSON: []byte(`{"id":1,"app_id":"test","created_at":"` + timeString + `","score":1.2,"blocked":null}`), | |
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2)}, | |
}, | |
} | |
func TestJSONUnmarshalling(t *testing.T) { | |
for i, param := range testParam { | |
t.Run(fmt.Sprintf("TestJSONUnmarshalling (%d)", i), func(t *testing.T) { | |
var v testStruct | |
err := json.Unmarshal(param.JSON, &v) | |
require.NoError(t, err) | |
require.Equal(t, param.Value, v) | |
}) | |
} | |
} | |
func TestJSONMarshalling(t *testing.T) { | |
for i, param := range testParam { | |
t.Run(fmt.Sprintf("TestJSONMarshalling (%d)", i), func(t *testing.T) { | |
bytes, err := json.Marshal(param.Value) | |
require.NoError(t, err) | |
require.Equal(t, param.JSON, bytes) | |
}) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment