Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add kdtree #171

Merged
merged 6 commits into from
Apr 16, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add heap.go and heap_test.go for heap using in kdtree
  • Loading branch information
frozenkp committed Apr 15, 2017
commit 759ee645c5f5ac9c914194c83267d11a4f25ac01
88 changes: 88 additions & 0 deletions kdtree/heap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package kdtree

import (
"errors"
)

type heapNode struct {
value []float64
length float64
}

type heap struct {
tree []heapNode
}

// newHeap return a pointer of heap.
func newHeap() *heap {
h := &heap{}
h.tree = make([]heapNode, 0)
return &heap{}
}

// maximum return the max heapNode in the heap.
func (h *heap) maximum() (heapNode, error) {
if len(h.tree) == 0 {
return heapNode{}, h.errEmpty()
}

return h.tree[0], nil
}

// extractMax remove the Max heapNode in the heap.
func (h *heap) extractMax() {
if len(h.tree) == 0 {
return
}

h.tree[0] = h.tree[len(h.tree)-1]
h.tree = h.tree[:len(h.tree)-1]

target := 1
for true {
largest := target
if target*2-1 >= len(h.tree) {
break
}
if h.tree[target*2-1].length > h.tree[target].length {
largest = target * 2
}

if target*2 >= len(h.tree) {
break
}
if h.tree[target*2].length > h.tree[largest-1].length {
largest = target*2 + 1
}

if largest == target {
break
}
h.tree[largest-1], h.tree[target-1] = h.tree[target-1], h.tree[largest-1]
target = largest
}
}

// insert put a new heapNode into heap.
func (h *heap) insert(value []float64, length float64) {
node := heapNode{}
node.length = length
node.value = make([]float64, len(value))
copy(node.value, value)
h.tree = append(h.tree, node)

target := len(h.tree)
for target != 1 {
if h.tree[(target/2)-1].length >= h.tree[target-1].length {
break
}
h.tree[target-1], h.tree[(target/2)-1] = h.tree[(target/2)-1], h.tree[target-1]
target /= 2
}
}

// errEmpty is return an error which is returned
// when heap is empty.
func (h *heap) errEmpty() error {
return errors.New("empty heap")
}
40 changes: 40 additions & 0 deletions kdtree/heap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package kdtree

import (
"testing"

. "github.com/smartystreets/goconvey/convey"
)

func TestHeap(t *testing.T) {
h := newHeap()

Convey("Given a heap", t, func() {

Convey("When heap is empty", func() {
_, err := h.maximum()

Convey("The err should be errEmpty", func() {
So(err, ShouldEqual, h.errEmpty())
})
})

Convey("When insert 5 nodes", func() {
for i := 0; i < 5; i++ {
h.insert([]float64{}, float64(i))
}
max1, _ := h.maximum()
h.extractMax()
max2, _ := h.maximum()

Convey("The max1.value should be 4", func() {
So(max1.value, ShouldEqual, 4)
})
Convey("The max2.value should be 3", func() {
So(max2.value, ShouldEqual, 3)
})

})

})
}
111 changes: 55 additions & 56 deletions kdtree/kdtree.go
Original file line number Diff line number Diff line change
@@ -1,81 +1,80 @@
package kdtree

import(
"sort"
"errors"
import (
"errors"
"sort"
)

type node struct{
feature int
value []float64
left *node
right *node
type node struct {
feature int
value []float64
left *node
right *node
}

// Tree is a kdtree.
type Tree struct{
firstDiv *node
type Tree struct {
firstDiv *node
}

// New return a Tree pointer.
func New()*Tree{
return &Tree{}
func New() *Tree {
return &Tree{}
}

// Build builds the kdtree with specific data.
func (t *Tree) Build(data [][]float64)err{
if len(data)==0 {
return errors.New("no input data")
}
size := len(data[0])
for _, v := range data {
if len(v) != size {
return errors.New("amounts of features are not the same")
}
}
func (t *Tree) Build(data [][]float64) error {
if len(data) == 0 {
return errors.New("no input data")
}
size := len(data[0])
for _, v := range data {
if len(v) != size {
return errors.New("amounts of features are not the same")
}
}

t.firstDiv = t.buildHandle(data, 0)
t.firstDiv = t.buildHandle(data, 0)

return nil
return nil
}

// buildHandle builds the kdtree recursively.
func (t *tree) buildHandle(data [][]float64, featureIndex int)*node{
n := &node{feature:featureIndex}
func (t *Tree) buildHandle(data [][]float64, featureIndex int) *node {
n := &node{feature: featureIndex}

sort.Slice(data, func(i, j int)bool{
return data[i][featureIndex]<data[j][featureIndex]
})
middle:= len(data)/2
sort.Slice(data, func(i, j int) bool {
return data[i][featureIndex] < data[j][featureIndex]
})
middle := len(data) / 2

n.value = make([]float64, len(data[middle]))
copy(n.value, data[middle])
n.value = make([]float64, len(data[middle]))
copy(n.value, data[middle])

divPoint := middle
for i:=middle+1 ; i<len(data) ; i++ {
if data[i][featureIndex] == data[middle][featureIndex] {
divPoint=i;
}else{
break
}
}
divPoint := middle
for i := middle + 1; i < len(data); i++ {
if data[i][featureIndex] == data[middle][featureIndex] {
divPoint = i
} else {
break
}
}

if divPoint==1 {
n.Left = &node{feature:-1}
n.Left.value = make([]float64, len(data[0]))
copy(n.Left.value, data[0])
}else{
n.Left = t.buildHandle(data[:divPoint])
}
if divPoint == 1 {
n.left = &node{feature: -1}
n.left.value = make([]float64, len(data[0]))
copy(n.left.value, data[0])
} else {
n.left = t.buildHandle(data[:divPoint], (featureIndex+1)%len(data[0]))
}

if divPoint==(len(data)-2) {
n.Right = &node{feature:-1}
n.Right.value = make([]float64, len(data[divPoint+1]))
copy(n.Right.value, data[divPoint+1])
}else if divPoint!=(len(data)-1){
n.Right = t.buildHandle(data[divPoint+1:])
}
if divPoint == (len(data) - 2) {
n.right = &node{feature: -1}
n.right.value = make([]float64, len(data[divPoint+1]))
copy(n.right.value, data[divPoint+1])
} else if divPoint != (len(data) - 1) {
n.right = t.buildHandle(data[divPoint+1:], (featureIndex+1)%len(data[0]))
}

return n
return n
}