Skip to content

Commit

Permalink
display acc rate of the samples from the current sample/warmup
Browse files Browse the repository at this point in the history
  • Loading branch information
chaozg committed Oct 1, 2024
1 parent a56b84d commit db6dc7f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions cuqi/experimental/mcmc/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def sample(self, Ns, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':

# Draw samples
pbar = tqdm(range(Ns), "Sample: ")
for _ in pbar:
for idx in pbar:

# Perform one step of the sampler
acc = self.step()
Expand All @@ -231,7 +231,7 @@ def sample(self, Ns, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':
self._samples.append(self.current_point)

# display acc rate at progress bar
pbar.set_postfix_str(f"acc rate: {np.mean(self._acc):.2%}")
pbar.set_postfix_str(f"acc rate: {np.mean(self._acc[-1-idx:]):.2%}")

# Add sample to batch
if batch_size > 0:
Expand Down Expand Up @@ -279,7 +279,7 @@ def warmup(self, Nb, tune_freq=0.1) -> 'Sampler':
self._samples.append(self.current_point)

# display acc rate at progress bar
pbar.set_postfix_str(f"acc rate: {np.mean(self._acc):.2%}")
pbar.set_postfix_str(f"acc rate: {np.mean(self._acc[-1-idx:]):.2%}")

# Call callback function if specified
self._call_callback(self.current_point, len(self._samples)-1)
Expand Down

0 comments on commit db6dc7f

Please sign in to comment.