Skip to content

Commit

Permalink
Revert logup without is_first
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharsamocha7 committed Sep 3, 2024
1 parent 1ee6a70 commit a3a9330
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 46 deletions.
38 changes: 16 additions & 22 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,25 @@ pub struct LogupAtRow<const BATCH_SIZE: usize, E: EvalAtRow> {
pub queue: [(E::EF, E::EF); BATCH_SIZE],
/// Number of fractions in the queue.
pub queue_size: usize,
/// A constant to subtract from each row, to make the totall sum of the last column zero.
/// In other words, claimed_sum / 2^log_size.
/// This is used to make the constraint uniform.
pub cumsum_shift: SecureField,
/// The claimed sum of all the fractions.
pub claimed_sum: SecureField,
/// The evaluation of the last cumulative sum column.
pub prev_col_cumsum: E::EF,
is_finalized: bool,
/// The value of the `is_first` constant column at current row.
/// See [`super::constant_columns::gen_is_first()`].
pub is_first: E::F,
}
impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
pub fn new(interaction: usize, claimed_sum: SecureField, log_size: u32) -> Self {
pub fn new(interaction: usize, claimed_sum: SecureField, is_first: E::F) -> Self {
Self {
interaction,
queue: [(E::EF::zero(), E::EF::zero()); BATCH_SIZE],
queue_size: 0,
cumsum_shift: claimed_sum / BaseField::from_u32_unchecked(1 << log_size),
claimed_sum,
prev_col_cumsum: E::EF::zero(),
is_finalized: false,
is_first,
}
}
pub fn push_lookup<const N: usize>(
Expand Down Expand Up @@ -96,13 +98,11 @@ impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
let [cur_cumsum, prev_row_cumsum] =
eval.next_extension_interaction_mask(self.interaction, [0, -1]);

let diff = cur_cumsum - prev_row_cumsum - self.prev_col_cumsum;
// Instead of checking diff = num / denom, check diff = num / denom - cumsum_shift.
// This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint
// uniform - apply on all rows.
let fixed_diff = diff + self.cumsum_shift;
// Fix `prev_row_cumsum` by subtracting `claimed_sum` if this is the first row.
let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first * self.claimed_sum;
let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum;

eval.add_constraint(fixed_diff * denom - num);
eval.add_constraint(diff * denom - num);

self.is_finalized = true;
}
Expand Down Expand Up @@ -210,21 +210,13 @@ impl LogupTraceGenerator {
SecureField,
) {
// Compute claimed sum.
let mut last_col_coords = self.trace.pop().unwrap().columns;
let last_col_coords = self.trace.pop().unwrap().columns;
let packed_sums: [PackedBaseField; SECURE_EXTENSION_DEGREE] = last_col_coords
.each_ref()
.map(|c| c.data.iter().copied().sum());
let base_sums = packed_sums.map(|s| s.pointwise_sum());
let claimed_sum = SecureField::from_m31_array(base_sums);

// Shift the last column to make the sum zero.
let cumsum_shift = claimed_sum / BaseField::from_u32_unchecked(1 << self.log_size);
last_col_coords.iter_mut().enumerate().for_each(|(i, c)| {
c.data
.iter_mut()
.for_each(|x| *x -= PackedBaseField::broadcast(cumsum_shift.to_m31_array()[i]))
});

// Prefix sum the last column.
let coord_prefix_sum = last_col_coords.map(inclusive_prefix_sum);
self.trace.push(SecureColumnByCoords {
Expand Down Expand Up @@ -300,12 +292,14 @@ mod tests {

use super::LogupAtRow;
use crate::constraint_framework::InfoEvaluator;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;

#[test]
#[should_panic]
fn test_logup_not_finalized_panic() {
let mut logup = LogupAtRow::<2, InfoEvaluator>::new(1, SecureField::one(), 7);
let mut logup =
LogupAtRow::<2, InfoEvaluator>::new(1, SecureField::one(), BaseField::one());
logup.push_frac(
&mut InfoEvaluator::default(),
SecureField::one(),
Expand Down
7 changes: 7 additions & 0 deletions crates/prover/src/examples/blake/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tracing::{span, Level};
use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval};
use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval};
use super::xor_table::{XorTableComponent, XorTableEval};
use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::TraceLocationAllocator;
use crate::core::air::{Component, ComponentProver};
use crate::core::backend::simd::m31::LOG_N_LANES;
Expand Down Expand Up @@ -363,10 +364,16 @@ where
span.exit();

// Constant trace.
// TODO(ShaharS): share is_first column between components when constant columns support this.
let span = span!(Level::INFO, "Constant Trace").entered();
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(
chain![
vec![gen_is_first(log_size)],
ROUND_LOG_SPLIT
.iter()
.map(|l| gen_is_first(log_size + l))
.collect_vec(),
xor_table::generate_constant_trace::<12, 4>(),
xor_table::generate_constant_trace::<9, 2>(),
xor_table::generate_constant_trace::<8, 2>(),
Expand Down
5 changes: 3 additions & 2 deletions crates/prover/src/examples/blake/round/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ impl FrameworkEval for BlakeRoundEval {
fn max_constraint_log_degree_bound(&self) -> u32 {
self.log_size + 1
}
fn evaluate<E: EvalAtRow>(&self, eval: E) -> E {
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let blake_eval = constraints::BlakeRoundEval {
eval,
xor_lookup_elements: &self.xor_lookup_elements,
round_lookup_elements: &self.round_lookup_elements,
logup: LogupAtRow::new(1, self.claimed_sum, self.log_size),
logup: LogupAtRow::new(1, self.claimed_sum, is_first),
};
blake_eval.eval()
}
Expand Down
6 changes: 4 additions & 2 deletions crates/prover/src/examples/blake/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ impl FrameworkEval for BlakeSchedulerEval {
self.log_size + 1
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
eval_blake_scheduler_constraints(
&mut eval,
&self.blake_lookup_elements,
&self.round_lookup_elements,
LogupAtRow::new(1, self.claimed_sum, self.log_size),
LogupAtRow::new(1, self.claimed_sum, is_first),
);
eval
}
Expand All @@ -55,6 +56,7 @@ mod tests {

use itertools::Itertools;

use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::FrameworkEval;
use crate::core::poly::circle::CanonicCoset;
use crate::examples::blake::round::RoundElements;
Expand Down Expand Up @@ -86,7 +88,7 @@ mod tests {
&blake_lookup_elements,
);

let trace = TreeVec::new(vec![trace, interaction_trace]);
let trace = TreeVec::new(vec![trace, interaction_trace, vec![gen_is_first(LOG_SIZE)]]);
let trace_polys = trace.map_cols(|c| c.interpolate());

let component = BlakeSchedulerEval {
Expand Down
7 changes: 5 additions & 2 deletions crates/prover/src/examples/blake/xor_table/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use itertools::Itertools;
use tracing::{span, Level};

use super::{column_bits, limb_bits, XorAccumulator, XorElements};
use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
Expand Down Expand Up @@ -157,12 +158,14 @@ pub fn generate_constant_trace<const ELEM_BITS: u32, const EXPAND_BITS: u32>(
})
.collect();

[a_col, b_col, c_col]
let mut constant_trace = [a_col, b_col, c_col]
.map(|x| {
CircleEvaluation::new(
CanonicCoset::new(column_bits::<ELEM_BITS, EXPAND_BITS>()).circle_domain(),
x,
)
})
.to_vec()
.to_vec();
constant_trace.insert(0, gen_is_first(column_bits::<ELEM_BITS, EXPAND_BITS>()));
constant_trace
}
3 changes: 2 additions & 1 deletion crates/prover/src/examples/blake/xor_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ impl<const ELEM_BITS: u32, const EXPAND_BITS: u32> FrameworkEval
column_bits::<ELEM_BITS, EXPAND_BITS>() + 1
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let xor_eval = constraints::XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> {
eval,
lookup_elements: &self.lookup_elements,
logup: LogupAtRow::new(1, self.claimed_sum, self.log_size()),
logup: LogupAtRow::new(1, self.claimed_sum, is_first),
};
xor_eval.eval()
}
Expand Down
29 changes: 16 additions & 13 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use itertools::{chain, Itertools};
use itertools::Itertools;
use num_traits::One;
use tracing::{span, Level};

use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements};
use crate::constraint_framework::{
assert_constraints, EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator,
Expand Down Expand Up @@ -43,7 +44,8 @@ impl FrameworkEval for PlonkEval {
}

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let mut logup = LogupAtRow::<2, _>::new(1, self.claimed_sum, self.log_n_rows);
let [is_first] = eval.next_interaction_mask(2, [0]);
let mut logup = LogupAtRow::<2, _>::new(1, self.claimed_sum, is_first);

let [a_wire] = eval.next_interaction_mask(2, [0]);
let [b_wire] = eval.next_interaction_mask(2, [0]);
Expand Down Expand Up @@ -207,17 +209,18 @@ pub fn prove_fibonacci_plonk(
// Constant trace.
let span = span!(Level::INFO, "Constant").entered();
let mut tree_builder = commitment_scheme.tree_builder();
let constants_trace_location = tree_builder.extend_evals(
chain!([circuit.a_wire, circuit.b_wire, circuit.c_wire, circuit.op]
.into_iter()
.map(|col| {
CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(
CanonicCoset::new(log_n_rows).circle_domain(),
col,
)
}))
.collect_vec(),
);
let is_first = gen_is_first(log_n_rows);
let mut constant_trace = [circuit.a_wire, circuit.b_wire, circuit.c_wire, circuit.op]
.into_iter()
.map(|col| {
CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(
CanonicCoset::new(log_n_rows).circle_domain(),
col,
)
})
.collect_vec();
constant_trace.insert(0, is_first);
let constants_trace_location = tree_builder.extend_evals(constant_trace);
tree_builder.commit(channel);
span.exit();

Expand Down
23 changes: 19 additions & 4 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use itertools::Itertools;
use num_traits::One;
use tracing::{span, Level};

use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements};
use crate::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator,
Expand Down Expand Up @@ -59,7 +60,8 @@ impl FrameworkEval for PoseidonEval {
self.log_n_rows + LOG_EXPAND
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let logup = LogupAtRow::new(1, self.claimed_sum, self.log_n_rows);
let [is_first] = eval.next_interaction_mask(2, [0]);
let logup = LogupAtRow::new(1, self.claimed_sum, is_first);
eval_poseidon_constraints(&mut eval, logup, &self.lookup_elements);
eval
}
Expand Down Expand Up @@ -358,6 +360,14 @@ pub fn prove_poseidon(
tree_builder.commit(channel);
span.exit();

// Constant trace.
let span = span!(Level::INFO, "Constant").entered();
let mut tree_builder = commitment_scheme.tree_builder();
let constant_trace = vec![gen_is_first(log_n_rows)];
tree_builder.extend_evals(constant_trace);
tree_builder.commit(channel);
span.exit();

// Prove constraints.
let component = PoseidonComponent::new(
&mut TraceLocationAllocator::default(),
Expand All @@ -378,8 +388,9 @@ mod tests {
use itertools::Itertools;
use num_traits::One;

use crate::constraint_framework::assert_constraints;
use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{assert_constraints, EvalAtRow};
use crate::core::air::Component;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -454,13 +465,14 @@ mod tests {
let (trace1, claimed_sum) =
gen_interaction_trace(LOG_N_ROWS, interaction_data, &lookup_elements);

let traces = TreeVec::new(vec![trace0, trace1]);
let traces = TreeVec::new(vec![trace0, trace1, vec![gen_is_first(LOG_N_ROWS)]]);
let trace_polys =
traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec());
assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |mut eval| {
let [is_first] = eval.next_interaction_mask(2, [0]);
eval_poseidon_constraints(
&mut eval,
LogupAtRow::new(1, claimed_sum, LOG_N_ROWS),
LogupAtRow::new(1, claimed_sum, is_first),
&lookup_elements,
);
});
Expand Down Expand Up @@ -503,6 +515,9 @@ mod tests {
// Interaction columns.
commitment_scheme.commit(proof.commitments[1], &sizes[1], channel);

// Constant columns.
commitment_scheme.commit(proof.commitments[2], &sizes[2], channel);

verify(&[&component], channel, commitment_scheme, proof).unwrap();
}
}

0 comments on commit a3a9330

Please sign in to comment.