-
Notifications
You must be signed in to change notification settings - Fork 1
/
bst_vector.py
140 lines (106 loc) · 3.08 KB
/
bst_vector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python
import random
class _BSTVectorBase(object):
def set(self, index, value):
pass
def get(self, index):
pass
def sample(self):
pass
@property
def norm2(self):
return self._norm2
def __repr__(self):
return "<BSTVector of dimension {}: {}>".format(
self._dim, self.__str__())
class _BSTVectorLeaf(_BSTVectorBase):
def __init__(self):
self.value = 0.0
self._norm2 = 0.0
self._dim = 1
def set(self, index, value):
self.value = float(value)
self._update_norm2()
def get(self, index):
if index != 0:
raise IndexError
return self.value
def sample(self):
return 0
def _update_norm2(self):
self._norm2 = self.value ** 2
def __str__(self):
return self.value.__str__()
class _BSTVectorNode(_BSTVectorBase):
def __init__(self, dim):
self._dim = dim
self._norm2 = 0.0
self.left = None
self.right = None
def set(self, index, value):
if index < self.cutoff:
child_side = 'left'
child_size = self.cutoff
child_index = index
else:
child_side = 'right'
child_size = self._dim - self.cutoff
child_index = index - self.cutoff
if self.__getattribute__(child_side) is None:
self.__setattr__(child_side, BSTVector(child_size))
child = self.__getattribute__(child_side)
child.set(child_index, value)
if child.norm2 == 0.0:
self.__setattr__(child_side, None)
self._update_norm2()
def get(self, index):
if index >= self._dim:
raise IndexError
if index < self.cutoff:
child_side = 'left'
child_index = index
else:
child_side = 'right'
child_index = index - self.cutoff
child = self.__getattribute__(child_side)
if child is None:
return 0.0
return child.get(child_index)
def sample(self, seed=None):
if self.norm2 == 0.0:
raise ValueError, "No nonzero entries"
left_norm2 = self.left.norm2 if self.left is not None \
else 0.0
if seed is not None:
random.seed(seed)
if random.uniform(0, self.norm2) < left_norm2:
# DON'T re-seed when calling recursively
# self.left has to be non-None if the inequality is satisfied,
# because if it were None then the RHS is zero
return self.left.sample()
# Likewise, don't re-seed.
# self.right is non-None because norm2 can't be equal to left_norm2
# if we're in this branch.
return self.cutoff + self.right.sample()
def _update_norm2(self):
self._norm2 = sum(child.norm2
for child in [self.left, self.right]
if child is not None)
@property
def cutoff(self):
return self._dim // 2
def __str__(self):
if self.left is None:
left_str = "-" * self.cutoff
else:
left_str = self.left.__str__()
if self.right is None:
right_str = "-" * (self._dim - self.cutoff)
else:
right_str = self.right.__str__()
return "({} {})".format(left_str, right_str)
def BSTVector(dim):
if dim == 1:
return _BSTVectorLeaf()
else:
return _BSTVectorNode(dim)