Skip to content

Commit

Permalink
Distheap (#342)
Browse files Browse the repository at this point in the history
* minmax dist heap

* complete refactor

* feat: distheap

* fix

* refactor
  • Loading branch information
friendlymatthew authored Jun 14, 2024
1 parent 30b6f22 commit 430a483
Show file tree
Hide file tree
Showing 9 changed files with 383 additions and 423 deletions.
8 changes: 4 additions & 4 deletions pkg/hnsw/friends.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ import (
type Point []float32

type Friends struct {
friends []*BaseQueue
friends []*DistHeap
}

// NewFriends creates a new vector, note the max level is inclusive.
func NewFriends(topLevel int) *Friends {
friends := make([]*BaseQueue, topLevel+1)
friends := make([]*DistHeap, topLevel+1)

for i := 0; i <= topLevel; i++ {
friends[i] = NewBaseQueue(MinComparator{})
friends[i] = NewDistHeap()
}

return &Friends{
Expand Down Expand Up @@ -51,7 +51,7 @@ func (v *Friends) InsertFriendsAtLevel(level int, friendId Id, dist float32) {
}
}

func (v *Friends) GetFriendsAtLevel(level int) (*BaseQueue, error) {
func (v *Friends) GetFriendsAtLevel(level int) (*DistHeap, error) {
if !v.HasLevel(level) {
return nil, errors.New("failed to get friends at level")
}
Expand Down
10 changes: 8 additions & 2 deletions pkg/hnsw/friends_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,18 @@ func TestVector_LevelManagement(t *testing.T) {
t.Fatalf("expected hex and oct friends list at level %v to be 1, got: %v || %v", i, hexFriends.Len(), octFriends.Len())
}

top := hexFriends.Top()
top, err := hexFriends.PeekMinItem()
if err != nil {
t.Fatal(err)
}
if top.id != octId {
t.Fatalf("expected %v, got %v", octId, top.id)
}

top = octFriends.Top()
top, err = octFriends.PeekMinItem()
if err != nil {
t.Fatal(err)
}
if top.id != hexId {
t.Fatalf("expected %v, got %v", hexId, top.id)
}
Expand Down
140 changes: 140 additions & 0 deletions pkg/hnsw/heap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package hnsw

import (
"fmt"
)

type Item struct {
id Id
dist float32
}

var EmptyHeapError = fmt.Errorf("Empty Heap")

type DistHeap struct {
items []*Item
visited map[Id]bool
}

func NewDistHeap() *DistHeap {
d := &DistHeap{
items: make([]*Item, 0),
visited: make(map[Id]bool),
}
return d
}
func FromItems(items []*Item) *DistHeap {
visited := make(map[Id]bool)
for _, item := range items {
visited[item.id] = true
}

d := &DistHeap{items: items, visited: visited}
d.Init()

return d
}

func (d *DistHeap) Init() {
n := d.Len()
for i := n/2 - 1; i >= 0; i-- {
d.down(i, n)
}
}

func (d *DistHeap) PeekMinItem() (*Item, error) {
if d.IsEmpty() {
return nil, EmptyHeapError
}

return d.items[0], nil
}
func (d *DistHeap) PeekMaxItem() (*Item, error) {
if d.Len() == 0 {
return nil, EmptyHeapError
}

// Find the maximum element without removing it
n := d.Len()

i := 0
l := lchild(0)
if l < n && !d.Less(l, i) {
i = l
}

r := rchild(0)
if r < n && !d.Less(r, i) {
i = r
}

return d.items[i], nil
}
func (d *DistHeap) PopMinItem() (*Item, error) {
if d.IsEmpty() {
return nil, EmptyHeapError
}

n := d.Len() - 1
d.Swap(0, n)
d.down(0, n)
return d.Pop(), nil
}
func (d *DistHeap) PopMaxItem() (*Item, error) {
if d.IsEmpty() {
return nil, EmptyHeapError
}

n := d.Len()
i := 0
l := lchild(0)

if l < n && !d.Less(l, i) {
i = l
}

r := rchild(0)
if r < n && !d.Less(r, i) {
i = r
}

d.Swap(i, n-1)
d.down(i, n-1)

return d.Pop(), nil
}
func (d *DistHeap) Insert(id Id, dist float32) {
if d.visited[id] {
for idx, item := range d.items {
if item.id == id {
item.dist = dist
d.Fix(idx)
return
}
}
} else {
d.Push(&Item{id: id, dist: dist})
d.up(d.Len() - 1)
d.visited[id] = true
}
}
func (d *DistHeap) Fix(i int) {
if !d.down(i, d.Len()) {
d.up(i)
}
}

func (d DistHeap) IsEmpty() bool { return len(d.items) == 0 }
func (d DistHeap) Len() int { return len(d.items) }
func (d DistHeap) Less(i, j int) bool { return d.items[i].dist < d.items[j].dist }
func (d DistHeap) Swap(i, j int) { d.items[i], d.items[j] = d.items[j], d.items[i] }
func (d *DistHeap) Push(x *Item) {
(*d).items = append((*d).items, x)
}
func (d *DistHeap) Pop() *Item {
old := (*d).items
n := len(old)
x := old[n-1]
(*d).items = old[0 : n-1]
return x
}
174 changes: 174 additions & 0 deletions pkg/hnsw/heap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package hnsw

import "testing"

func TestHeap(t *testing.T) {

t.Run("basic min max properties", func(t *testing.T) {
h := NewDistHeap()

for i := 10; i > 0; i-- {
h.Insert(Id(i), float32(10-i))
}

if h.Len() != 10 {
t.Fatalf("heap length should be 10, got %v", h.Len())
}

expectedId := Id(10)
for !h.IsEmpty() {
peekMinItem, err := h.PeekMinItem()
if err != nil {
t.Fatalf("failed to peek min item: %v", err)
}

minItem, err := h.PopMinItem()
if err != nil {
t.Fatalf("failed to pop min item, err: %v", err)
}

if peekMinItem.id != minItem.id {
t.Fatalf("mismatched item id, expected %v, got %v", expectedId, peekMinItem.id)
}

if minItem.id != expectedId {
t.Fatalf("mismatched ids, expected %v, got: %v", expectedId, minItem.id)
}

expectedId -= 1
}
})

t.Run("basic min max properties 2", func(t *testing.T) {
h := NewDistHeap()

for i := 0; i <= 10; i++ {
h.Insert(Id(i), float32(10-i))
}

maxExpectedId := Id(0)
minExpectedId := Id(10)

for !h.IsEmpty() {
peekMaxItem, err := h.PeekMaxItem()

if err != nil {
t.Fatalf("failed to peek max item, err: %v", err)
}

maxItem, err := h.PopMaxItem()

if err != nil {
t.Fatalf("failed to pop max item, err: %v", err)
}

if peekMaxItem.id != maxItem.id {
t.Fatalf("mismatched max ids, expected %v, got: %v", maxItem.id, peekMaxItem.id)
}

if maxItem.id != maxExpectedId {
t.Fatalf("expected id to be %v, got %v", maxExpectedId, maxItem.id)
}

if h.IsEmpty() {
continue
}

peekMinItem, err := h.PeekMinItem()
if err != nil {
t.Fatalf("failed to peek min item, err: %v", err)
}

minItem, err := h.PopMinItem()

if err != nil {
t.Fatalf("failed to pop min item, err: %v", err)
}

if peekMinItem.id != minItem.id {
t.Fatalf("mismatched min ids, expected %v, got: %v", maxItem.id, peekMaxItem.id)
}

if minItem.id != minExpectedId {
t.Fatalf("expected id to be %v, got %v", minExpectedId, minItem.id)
}

minExpectedId -= 1
maxExpectedId += 1
}
})

t.Run("bricks and ladders || min heap", func(t *testing.T) {
type Case struct {
heights []int
bricks int
ladders int
expected int
}

cases := [3]Case{
{
heights: []int{4, 2, 7, 6, 9, 14, 12},
bricks: 5,
ladders: 1,
expected: 4,
},
{
heights: []int{4, 12, 2, 7, 3, 18, 20, 3, 19},
bricks: 10,
ladders: 2,
expected: 7,
},
{
heights: []int{14, 3, 19, 3},
bricks: 17,
ladders: 0,
expected: 3,
},
}

for _, c := range cases {
res, err := furthestBuildings(c.heights, c.bricks, c.ladders)
if err != nil {
t.Fatal(err)
}

if res != c.expected {
t.Errorf("got %d, want %d", res, c.expected)
}
}
})
}

func furthestBuildings(heights []int, bricks, ladders int) (int, error) {

ladderJumps := NewDistHeap()

for idx := 0; idx < len(heights)-1; idx++ {
height := heights[idx]
nextHeight := heights[idx+1]

if height >= nextHeight {
continue
}

jump := nextHeight - height

ladderJumps.Insert(Id(idx), float32(jump))

if ladderJumps.Len() > ladders {
minLadderJump, err := ladderJumps.PopMinItem()
if err != nil {
return -1, err
}

if bricks-int(minLadderJump.dist) < 0 {
return idx, nil
}

bricks -= int(minLadderJump.dist)
}
}

return len(heights) - 1, nil
}
Loading

0 comments on commit 430a483

Please sign in to comment.