diff --git a/safe_map_test.go b/safe_map_test.go index 7890853..ac39fb7 100644 --- a/safe_map_test.go +++ b/safe_map_test.go @@ -16,7 +16,7 @@ func init() { gob.Register(new(SafeMap[string, int])) } -func TestNil(t *testing.T) { +func TestSafeMap_Nil(t *testing.T) { var m SafeMap[string, int] assert.False(t, m.Has("z")) diff --git a/set.go b/set.go index c806899..84779a2 100644 --- a/set.go +++ b/set.go @@ -32,11 +32,17 @@ func NewSet[K comparable](values ...K) *Set[K] { // Clear resets the set to an empty set func (m *Set[K]) Clear() { + if m == nil { + return + } m.items = nil } // Len returns the number of items in the set func (m *Set[K]) Len() int { + if m == nil || m.items == nil { + return 0 + } return m.items.Len() } @@ -45,7 +51,7 @@ func (m *Set[K]) Len() int { // While its safe to call methods of the set from within the Range function, its discouraged. // If you ever switch to one of the SafeSet sets, it will cause a deadlock. func (m *Set[K]) Range(f func(k K) bool) { - if m == nil || m.items == nil { + if m.Len() == 0 { return } for k := range m.items { @@ -57,22 +63,34 @@ func (m *Set[K]) Range(f func(k K) bool) { // Has returns true if the value exists in the set. func (m *Set[K]) Has(k K) bool { + if m.Len() == 0 { + return false + } return m.items.Has(k) } // Delete removes the value from the set. If the value does not exist, nothing happens. func (m *Set[K]) Delete(k K) { + if m.Len() == 0 { + return + } m.items.Delete(k) } // Values returns a new slice containing the values of the set. func (m *Set[K]) Values() []K { + if m.Len() == 0 { + return nil + } return m.items.Keys() } // Add adds the value to the set. // If the value already exists, nothing changes. func (m *Set[K]) Add(k ...K) SetI[K] { + if m == nil { + panic("cannot add values to a nil Set") + } if m.items == nil { m.items = make(map[K]struct{}) } @@ -90,6 +108,9 @@ func (m *Set[K]) Merge(in SetI[K]) { // Copy adds the values from in to the set. func (m *Set[K]) Copy(in SetI[K]) { + if m == nil { + panic("cannot copy to a nil Set") + } if in == nil || in.Len() == 0 { return } @@ -104,6 +125,9 @@ func (m *Set[K]) Copy(in SetI[K]) { // Equal returns true if the two sets are the same length and contain the same values. func (m *Set[K]) Equal(m2 SetI[K]) bool { + if m == nil { + return m2.Len() == 0 + } if m.Len() != m2.Len() { return false } @@ -168,10 +192,12 @@ func (m *Set[K]) UnmarshalJSON(in []byte) (err error) { // String returns the set as a string. func (m *Set[K]) String() string { ret := "{" - for i, v := range m.Values() { - ret += fmt.Sprintf("%#v", v) - if i < m.Len()-1 { - ret += "," + if m.Len() != 0 { + for i, v := range m.Values() { + ret += fmt.Sprintf("%#v", v) + if i < m.Len()-1 { + ret += "," + } } } ret += "}" @@ -180,12 +206,20 @@ func (m *Set[K]) String() string { // All returns an iterator over all the items in the set. Order is not determinate. func (m *Set[K]) All() iter.Seq[K] { + if m.Len() == 0 { + return func(yield func(K) bool) { + return + } + } return m.items.KeysIter() } // Insert adds the values from seq to the map. // Duplicates are overridden. func (m *Set[K]) Insert(seq iter.Seq[K]) { + if m == nil { + panic("cannot insert into a nil Set") + } if m.items == nil { m.items = NewStdMap[K, struct{}]() } @@ -207,12 +241,17 @@ func CollectSet[K comparable](seq iter.Seq[K]) *Set[K] { // the new keys and values are set using ordinary assignment. func (m *Set[K]) Clone() *Set[K] { m1 := NewSet[K]() - m1.items = m.items.Clone() + if m.Len() != 0 { + m1.items = m.items.Clone() + } return m1 } // DeleteFunc deletes any values for which del returns true. func (m *Set[K]) DeleteFunc(del func(K) bool) { + if m.Len() == 0 { + return + } del2 := func(k K, s struct{}) bool { return del(k) } diff --git a/set_ordered.go b/set_ordered.go index f91c5ec..2902952 100644 --- a/set_ordered.go +++ b/set_ordered.go @@ -3,6 +3,7 @@ package maps import ( "cmp" "encoding/json" + "fmt" "iter" "slices" ) @@ -25,9 +26,25 @@ func NewOrderedSet[K cmp.Ordered](values ...K) *OrderedSet[K] { return s } +// Clear resets the set to an empty set +func (m *OrderedSet[K]) Clear() { + if m == nil { + return + } + m.Set.Clear() +} + +// Len returns the number of items in the set +func (m *OrderedSet[K]) Len() int { + if m == nil || m.items == nil { + return 0 + } + return m.Set.Len() +} + // Range will range over the values in order. func (m *OrderedSet[K]) Range(f func(k K) bool) { - if m == nil || m.items == nil { + if m.Len() == 0 { return } values := m.Values() @@ -38,13 +55,58 @@ func (m *OrderedSet[K]) Range(f func(k K) bool) { } } -// Values returns the values as a slice, in order. +// Has returns true if the value exists in the set. +func (m *OrderedSet[K]) Has(k K) bool { + if m.Len() == 0 { + return false + } + return m.Set.Has(k) +} + +// Delete removes the value from the set. If the value does not exist, nothing happens. +func (m *OrderedSet[K]) Delete(k K) { + if m.Len() == 0 { + return + } + m.Set.Delete(k) +} + +// Equal returns true if the two sets are the same length and contain the same values. +func (m *OrderedSet[K]) Equal(m2 SetI[K]) bool { + if m == nil { + return m2.Len() == 0 + } + return m.Set.Equal(m2) +} + +// Values returns a new slice containing the values of the set. func (m *OrderedSet[K]) Values() []K { + if m.Len() == 0 { + return nil + } v := m.items.Keys() slices.Sort(v) return v } +// Add adds the value to the set. +// If the value already exists, nothing changes. +func (m *OrderedSet[K]) Add(k ...K) SetI[K] { + if m == nil { + panic("cannot add values to a nil Set") + } + m.Set.Add(k...) + return m +} + +// Copy adds the values from in to the set. +func (m *OrderedSet[K]) Copy(in SetI[K]) { + if m == nil { + panic("cannot copy to a nil Set") + } + m.Set.Copy(in) +} + // MarshalJSON implements the json.Marshaler interface to convert the map into a JSON object. func (m *OrderedSet[K]) MarshalJSON() (out []byte, err error) { if m.Len() == 0 { @@ -55,21 +117,56 @@ func (m *OrderedSet[K]) MarshalJSON() (out []byte, err error) { // All returns an iterator over all the items in the set. Order is determinate. func (m *OrderedSet[K]) All() iter.Seq[K] { + if m.Len() == 0 { + return func(yield func(K) bool) { + return + } + } v := m.Values() return slices.Values(v) } +// Insert adds the values from seq to the map. +// Duplicates are overridden. +func (m *OrderedSet[K]) Insert(seq iter.Seq[K]) { + if m == nil { + panic("cannot insert into a nil Set") + } + m.Set.Insert(seq) +} + // Clone returns a copy of the Set. This is a shallow clone: // the new keys and values are set using ordinary assignment. func (m *OrderedSet[K]) Clone() *OrderedSet[K] { m1 := NewOrderedSet[K]() - m1.items = m.items.Clone() + if m != nil { + m1.items = m.items.Clone() + } return m1 } -// Add adds the value to the set. -// If the value already exists, nothing changes. -func (m *OrderedSet[K]) Add(k ...K) SetI[K] { - m.Set.Add(k...) - return m +// DeleteFunc deletes any values for which del returns true. +func (m *OrderedSet[K]) DeleteFunc(del func(K) bool) { + if m.Len() == 0 { + return + } + m.Set.DeleteFunc(del) +} + +// String returns the set as a string. +func (m *OrderedSet[K]) String() string { + if m == nil { + return "{}" + } + ret := "{" + if m.Len() != 0 { + for i, v := range m.Values() { + ret += fmt.Sprintf("%#v", v) + if i < m.Len()-1 { + ret += "," + } + } + } + ret += "}" + return ret } diff --git a/set_ordered_test.go b/set_ordered_test.go index d45fde9..fab1235 100644 --- a/set_ordered_test.go +++ b/set_ordered_test.go @@ -3,7 +3,9 @@ package maps import ( "cmp" "encoding/gob" + "fmt" "github.com/stretchr/testify/assert" + "slices" "testing" ) @@ -11,11 +13,11 @@ type orderedSetT = OrderedSet[string] type orderedSetTI = SetI[string] func TestOrderedSet_SetI(t *testing.T) { - runSetITests[OrderedSet[string]](t, makeSetI[OrderedSet[string]]) + runSetITests[orderedSetT](t, makeSetI[orderedSetT]) } func init() { - gob.Register(new(OrderedSet[string])) + gob.Register(new(orderedSetT)) } func TestOrderedSet_Values(t *testing.T) { @@ -58,7 +60,7 @@ func TestOrderedSet_MarshalJSON(t *testing.T) { } } -func TestOrderedSetAll(t *testing.T) { +func TestOrderedSet_All(t *testing.T) { set := NewOrderedSet[int]() set.Add(5) set.Add(3) @@ -107,6 +109,54 @@ func TestOrderedSet_Clone(t *testing.T) { m2 := m1.Clone() assert.True(t, m1.Equal(m2)) + var m3 *OrderedSet[string] + m4 := m3.Clone() + m3.Equal(m4) + assert.True(t, m3.Equal(m4)) + m2.Add("d") assert.False(t, m1.Equal(m2)) } + +func TestOrderedSet_Nil(t *testing.T) { + t.Run("Nil", func(t *testing.T) { + var m1, m2 *OrderedSet[string] + + assert.Equal(t, 0, m1.Len()) + m1.Clear() + assert.True(t, m1.Equal(m2)) + m3 := m2.Clone() + assert.True(t, m1.Equal(m3)) + m3.Add("a") + assert.False(t, m1.Equal(m3)) + m1.Range(func(k string) bool { + assert.Fail(t, "no range should happen") + return false + }) + assert.False(t, m1.Has("b")) + m1.Delete("a") + assert.Empty(t, m1.Values()) + assert.Equal(t, "{}", m1.String()) + m1.DeleteFunc(func(k string) bool { + return false + }) + for _ = range m1.All() { + assert.Fail(t, "no range should happen") + } + assert.Panics(t, func() { + m1.Insert(slices.Values([]string{"a"})) + }) + assert.Panics(t, func() { + m1.Add("a") + }) + assert.Panics(t, func() { + m1.Copy(m2) + }) + }) +} + +func ExampleOrderedSet_String() { + m := NewOrderedSet("a", "c", "a", "b") + fmt.Print(m.String()) + // Output: {"a","b","c"} +} diff --git a/set_test.go b/set_test.go index a787c1a..4113b97 100644 --- a/set_test.go +++ b/set_test.go @@ -4,7 +4,7 @@ import ( "encoding/gob" "fmt" "github.com/stretchr/testify/assert" - "sort" + "slices" "testing" ) @@ -22,12 +22,8 @@ func init() { func ExampleSet_String() { m := new(Set[string]) m.Add("a") - m.Add("b") - m.Add("a") - v := m.Values() - sort.Strings(v) - fmt.Print(v) - // Output: [a b] + fmt.Print(m.String()) + // Output: {"a"} } func TestCollectSet(t *testing.T) { @@ -43,3 +39,42 @@ func TestSet_Clone(t *testing.T) { m3 := m2.Clone() assert.True(t, m1.Equal(m3)) } + +func TestSet_Nil(t *testing.T) { + t.Run("Nil", func(t *testing.T) { + var m1, m2 *Set[string] + + assert.Equal(t, 0, m1.Len()) + m1.Clear() + assert.True(t, m1.Equal(m2)) + m3 := m2.Clone() + assert.True(t, m1.Equal(m3)) + m3.Add("a") + assert.False(t, m1.Equal(m3)) + m1.Range(func(k string) bool { + assert.Fail(t, "no range should happen") + return false + }) + assert.False(t, m1.Has("b")) + m1.Delete("a") + assert.Empty(t, m1.Values()) + assert.Equal(t, "{}", m1.String()) + m1.DeleteFunc(func(k string) bool { + return false + }) + + for _ = range m1.All() { + assert.Fail(t, "no range should happen") + } + assert.Panics(t, func() { + m1.Insert(slices.Values([]string{"a"})) + }) + assert.Panics(t, func() { + m1.Add("a") + }) + assert.Panics(t, func() { + m1.Copy(m2) + }) + + }) +} diff --git a/seti.go b/seti.go index b3f834c..15547df 100644 --- a/seti.go +++ b/seti.go @@ -7,6 +7,7 @@ type SetI[K comparable] interface { Add(k ...K) SetI[K] Clear() Len() int + Copy(in SetI[K]) Range(func(k K) bool) Has(k K) bool Values() []K diff --git a/seti_test.go b/seti_test.go index 1c5e898..75469e3 100644 --- a/seti_test.go +++ b/seti_test.go @@ -37,6 +37,7 @@ func runSetITests[M any](t *testing.T, f makeSetF) { testSetAll(t, f) testSetInsert(t, f) testSetDeleteFunc(t, f) + testSetCopy(t, f) } func testSetClear(t *testing.T, f makeSetF) { @@ -201,8 +202,7 @@ func testSetMarshalJSON(t *testing.T, f makeSetF) { s, err = json.Marshal(m) assert.NoError(t, err) // Note: The below output is what is produced, but isn't guaranteed. go seems to currently be sorting keys - assert.Equal(t, string(s), "[]") - + assert.Equal(t, "[]", string(s)) }) } @@ -286,3 +286,12 @@ func testSetDeleteFunc(t *testing.T, f makeSetF) { assert.Equal(t, 1, m1.Len()) }) } + +func testSetCopy(t *testing.T, f makeSetF) { + t.Run("DeleteFunc", func(t *testing.T) { + m1 := f("a", "b", "c") + m2 := f() + m2.Copy(m1) + assert.True(t, m1.Equal(m2)) + }) +}