-
Notifications
You must be signed in to change notification settings - Fork 4
/
sparseMatrix.go
114 lines (103 loc) · 2.65 KB
/
sparseMatrix.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package sm
import (
"fmt"
"math"
)
// SparseMatrixIndex is the index of the tensor, e.g. index can be (0,1) or (0,1,23),
// referencing a cell in a 2d or 3d matrix
type SparseMatrixIndex []int
// SparseMatrix is a sparse matrix implementation
type SparseMatrix struct {
Indices []SparseMatrixIndex
Values []float64
}
// Equal return true of the indices are equal
func (si SparseMatrixIndex) Equal(other SparseMatrixIndex) bool {
if len(si) != len(other) {
return false
}
for i := 0; i < len(si); i++ {
if si[i] != other[i] {
return false
}
}
return true
}
// Get returns the value specified by the index, or -1.0 and an error otherwise
func (s SparseMatrix) Get(smi SparseMatrixIndex) (float64, error) {
for i, v := range s.Indices {
if v.Equal(smi) {
return s.Values[i], nil
}
if len(s.Indices[i]) != len(smi) {
return -1.0, fmt.Errorf("Sparse Matrix: Dimension mismatch %d vs %d", len(s.Indices[i]), len(smi))
}
}
return 0.0, nil
}
// Set set the value specified by the index
func (s *SparseMatrix) Set(smi SparseMatrixIndex, value float64) {
for i, v := range s.Indices {
if v.Equal(smi) {
s.Values[i] = value
return
}
}
s.Indices = append(s.Indices, smi)
s.Values = append(s.Values, value)
}
// Add add the value to the value specified by the index
func (s *SparseMatrix) Add(smi SparseMatrixIndex, value float64) {
for i, v := range s.Indices {
if v.Equal(smi) {
s.Values[i] += value
return
}
}
// will only be called if the value was not previously present
s.Indices = append(s.Indices, smi)
s.Values = append(s.Values, value)
}
// Mul multiplies the value to the value specified by the index
func (s SparseMatrix) Mul(smi SparseMatrixIndex, value float64) {
for i, v := range s.Indices {
if v.Equal(smi) {
s.Values[i] *= value
return
}
}
// else 0 * value = 0, nothing will change
}
// Scale multiplies the value to all cells
func (s SparseMatrix) Scale(value float64) {
for i := range s.Values {
s.Values[i] *= value
}
}
// Equal returns true if equal, false if not equal
func (s SparseMatrix) Equal(other SparseMatrix) bool {
if len(s.Indices) != len(other.Indices) {
return false
}
if len(s.Values) != len(other.Values) {
return false
}
for i := range s.Indices {
if len(s.Indices[i]) != len(other.Indices[i]) {
return false
}
if math.Abs(s.Values[i]-other.Values[i]) > 0.000001 {
return false
}
for j := range s.Indices[i] {
if s.Indices[i][j] != other.Indices[i][j] {
return false
}
}
}
return true
}
// CreateSparseMatrix returns an empty sparse matrix
func CreateSparseMatrix() SparseMatrix {
return SparseMatrix{Indices: []SparseMatrixIndex{}, Values: []float64{}}
}