Skip to content

Commit

Permalink
feat: power of 2 div using type system (zkonduit#702)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Feb 4, 2024
1 parent e0d3f4f commit 95d4fd4
Show file tree
Hide file tree
Showing 13 changed files with 223 additions and 103 deletions.
39 changes: 39 additions & 0 deletions examples/onnx/1l_tiny_div/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from torch import nn
import torch
import json

class Circuit(nn.Module):
def __init__(self, inplace=False):
super(Circuit, self).__init__()

def forward(self, x):
return x/ 10000


circuit = Circuit()


x = torch.empty(1, 8).random_(0, 2)

out = circuit(x)

print(out)

torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})


d1 = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(
input_data=[d1],
)

# Serialize data into file:
json.dump(data, open("input.json", 'w'))
1 change: 1 addition & 0 deletions examples/onnx/1l_tiny_div/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_data": [[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]]}
Binary file added examples/onnx/1l_tiny_div/network.onnx
Binary file not shown.
79 changes: 79 additions & 0 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ use serde::{Deserialize, Serialize};
/// An enum representing the operations that consist of both lookups and arithmetic operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HybridOp {
Div {
denom: utils::F32,
use_range_check_for_int: bool,
},
ReduceMax {
axes: Vec<usize>,
},
Expand Down Expand Up @@ -113,6 +117,21 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(res.clone(), vec![inter_1, inter_2])
}
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
// if denom is a round number and use_range_check_for_int is true, use range check check
if denom.0.fract() == 0.0 && *use_range_check_for_int {
let divisor = Tensor::from(vec![denom.0 as i128].into_iter());
let res = crate::tensor::ops::div(&[x, divisor.clone()])?;
(res, vec![-divisor.clone(), divisor])
} else {
let res = crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64);
(res, vec![x])
}
}
HybridOp::ReduceArgMax { dim } => {
let res = tensor::ops::argmax_axes(&x, *dim)?;
let indices = Tensor::from(0..x.dims()[*dim] as i128);
Expand Down Expand Up @@ -272,6 +291,13 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {

fn as_string(&self) -> String {
match self {
HybridOp::Div {
denom,
use_range_check_for_int,
} => format!(
"DIV (denom={}, use_range_check_for_int={})",
denom, use_range_check_for_int
),
HybridOp::SumPool {
padding,
stride,
Expand Down Expand Up @@ -335,6 +361,29 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
*kernel_shape,
*normalized,
)?,
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::div(
config,
region,
values[..].try_into()?,
i128_to_felt(denom.0 as i128),
)?
} else {
layouts::nonlinearity(
config,
region,
values.try_into()?,
&LookupOp::Div {
denom: denom.clone(),
},
)?
}
}
HybridOp::Gather { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather(values[0].get_inner_tensor()?, idx, *dim)?.into()
Expand Down Expand Up @@ -427,11 +476,41 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
Ok(scale)
}

fn required_range_checks(&self) -> Vec<Range> {
match self {
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
vec![(-denom.0 as i128 + 1, denom.0 as i128 - 1)]
} else {
vec![]
}
}
_ => vec![],
}
}

fn required_lookups(&self) -> Vec<LookupOp> {
match self {
HybridOp::ReduceMax { .. }
| HybridOp::ReduceMin { .. }
| HybridOp::MaxPool2d { .. } => Op::<F>::required_lookups(&LookupOp::ReLU),
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
vec![]
} else {
vec![LookupOp::Div {
denom: denom.clone(),
}]
}
}
HybridOp::Softmax { scale, .. } => {
vec![
LookupOp::Exp { scale: *scale },
Expand Down
25 changes: 14 additions & 11 deletions src/circuit/ops/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ pub enum PolyOp {
Sub,
Neg,
Mult,
Identity,
Identity {
out_scale: Option<crate::Scale>,
},
Reshape(Vec<usize>),
MoveAxis {
source: usize,
Expand Down Expand Up @@ -85,7 +87,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Resize { .. } => "RESIZE".into(),
PolyOp::Iff => "IFF".into(),
PolyOp::Einsum { equation, .. } => format!("EINSUM {}", equation),
PolyOp::Identity => "IDENTITY".into(),
PolyOp::Identity { out_scale } => {
format!("IDENTITY (out_scale={:?})", out_scale)
}
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
PolyOp::Flatten(_) => "FLATTEN".into(),
PolyOp::Pad(_) => "PAD".into(),
Expand Down Expand Up @@ -135,7 +139,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
PolyOp::Identity => Ok(inputs[0].clone()),
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
PolyOp::Reshape(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
Expand Down Expand Up @@ -264,7 +268,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
Expand Down Expand Up @@ -322,9 +326,8 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
output_scale
}
PolyOp::Add => {
let mut scale_a = 0;
let scale_b = in_scales[0];
scale_a += in_scales[1];
let scale_a = in_scales[0];
let scale_b = in_scales[1];
assert_eq!(scale_a, scale_b);
scale_a
}
Expand All @@ -336,19 +339,19 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
_ => in_scales[0],
};
Ok(scale)
}

fn requires_homogenous_input_scales(&self) -> Vec<usize> {
if matches!(
self,
PolyOp::Add { .. } | PolyOp::Sub | PolyOp::Concat { .. }
) {
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
vec![0, 1]
} else if matches!(self, PolyOp::Iff) {
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else {
vec![]
}
Expand Down
4 changes: 4 additions & 0 deletions src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,10 @@ pub(crate) async fn gen_witness(
if let Some(output_path) = output {
serde_json::to_writer(&File::create(output_path)?, &witness)?;
}

// print the witness in debug
debug!("witness: \n {}", witness.as_json()?.to_colored_json_auto()?);

Ok(witness)
}

Expand Down
4 changes: 2 additions & 2 deletions src/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -968,8 +968,8 @@ impl GraphCircuit {
lookup_safety_margin * max_lookup_inputs,
);
if lookup_safety_margin == 1 {
margin.0 += 1;
margin.1 += 1;
margin.0 += 4;
margin.1 += 4;
}

margin
Expand Down
2 changes: 2 additions & 0 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,8 @@ impl Model {
inputs.iter().map(|x| x.dims()).collect::<Vec<_>>()
);

debug!("input nodes: {:?}", n.inputs());

if n.is_lookup() {
let (mut min, mut max) = (0, 0);
for i in &inputs {
Expand Down
Loading

0 comments on commit 95d4fd4

Please sign in to comment.