diff --git a/Cargo.lock b/Cargo.lock index 88482eb..aa0291c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,7 +112,7 @@ dependencies = [ [[package]] name = "ndarray_einsum_beta" -version = "0.2.0" +version = "0.2.1" dependencies = [ "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "ndarray 0.12.1 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/Cargo.toml b/Cargo.toml index 1772927..578e35c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ndarray_einsum_beta" -version = "0.2.0" +version = "0.2.1" authors = ["oracleofnj "] edition = "2018" license = "Apache-2.0" diff --git a/src/lib.rs b/src/lib.rs index 440a60d..38f280c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! The `ndarray_einsum` crate implements the `einsum` function, originally +//! implemented for numpy by Mark Wiebe and subsequently reimplemented for +//! other tensor libraries such as Tensorflow and PyTorch. einsum (Einstein summation) +//! implements general multidimensional tensor contraction. Many linear algebra operations +//! and generalizations of those operations can be expressed as special cases of tensor +//! contraction. Examples include matrix multiplication, matrix trace, vector dot product, +//! tensor Hadamard [element-wise] product, axis permutation, outer product, batch +//! matrix multiplication, bilinear transformations, and many more. use std::collections::HashMap; use ndarray::prelude::*; @@ -19,14 +27,13 @@ use ndarray::{Data, IxDyn, LinalgScalar}; mod validation; pub use validation::{ - einsum_path, validate, validate_and_optimize_order, validate_and_size, validate_and_size_from_shapes, - Contraction, OutputSize, SizedContraction, + einsum_path, validate, validate_and_optimize_order, validate_and_size, + validate_and_size_from_shapes, Contraction, SizedContraction, }; mod optimizers; pub use optimizers::{ - generate_optimized_order, ContractionOrder, OperandNumPair, - OptimizationMethod, + generate_optimized_order, ContractionOrder, OperandNumPair, OptimizationMethod, }; mod contractors; diff --git a/src/validation.rs b/src/validation.rs index 11d9c88..4b5d6fb 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -9,21 +9,85 @@ use regex::Regex; use serde::Serialize; use std::collections::{HashMap, HashSet}; +/// The result of running an `einsum`-formatted string through the regex. #[derive(Debug)] struct EinsumParse { operand_indices: Vec, output_indices: Option, } +/// A `Contraction` contains the result of parsing an `einsum`-formatted string. +/// +/// ``` +/// # use ndarray_einsum_beta::*; +/// let contraction = Contraction::new("ij,jk->ik").unwrap(); +/// assert_eq!(contraction.operand_indices, vec![vec!['i', 'j'], vec!['j', 'k']]); +/// assert_eq!(contraction.output_indices, vec!['i', 'k']); +/// assert_eq!(contraction.summation_indices, vec!['j']); +/// +/// let contraction = Contraction::new("ij,jk").unwrap(); +/// assert_eq!(contraction.operand_indices, vec![vec!['i', 'j'], vec!['j', 'k']]); +/// assert_eq!(contraction.output_indices, vec!['i', 'k']); +/// assert_eq!(contraction.summation_indices, vec!['j']); +/// ``` #[derive(Debug, Clone, Serialize)] pub struct Contraction { + /// A vector with as many elements as input operands, where each + /// member of the vector is a `Vec` with each char representing the label for + /// each axis of the operand. pub operand_indices: Vec>, + + /// Specifies which axes the resulting tensor will contain + // (corresponding to axes in one or more of the input operands). pub output_indices: Vec, + + /// Contains the axes that will be summed over (a.k.a contracted) by the operation. pub summation_indices: Vec, } impl Contraction { - pub fn from_indices( + /// Validates and creates a `Contraction` from an `einsum`-formatted string. + pub fn new(input_string: &str) -> Result { + let p = parse_einsum_string(input_string).ok_or("Invalid string")?; + Contraction::from_parse(&p) + } + + /// If output_indices has been specified in the parse (i.e. explicit case), + /// e.g. "ij,jk->ik", simply converts the strings to `Vec`s and passes + /// them to Contraction::from_indices. If the output indices haven't been specified, + /// e.g. "ij,jk", figures out which ones aren't duplicated and hence summed over, + /// sorts them alphabetically, and uses those as the output indices. + fn from_parse(parse: &EinsumParse) -> Result { + let requested_output_indices: Vec = match &parse.output_indices { + Some(s) => s.chars().collect(), + _ => { + // Handle implicit case, e.g. nothing to the right of the arrow + let mut input_indices = HashMap::new(); + for c in parse.operand_indices.iter().flat_map(|s| s.chars()) { + *input_indices.entry(c).or_insert(0) += 1; + } + + let mut unique_indices: Vec = input_indices + .iter() + .filter(|(_, &v)| v == 1) + .map(|(&k, _)| k) + .collect(); + unique_indices.sort(); + unique_indices + } + }; + + let operand_indices: Vec> = parse + .operand_indices + .iter() + .map(|x| x.chars().collect::>()) + .collect(); + Contraction::from_indices(&operand_indices, &requested_output_indices) + } + + /// Validates and creates a `Contraction` from a slice of `Vec`s containing + /// the operand indices, and a slice of `char` containing the desired output indices. + fn from_indices( operand_indices: &[Vec], output_indices: &[char], ) -> Result { @@ -66,7 +130,66 @@ impl Contraction { } pub type OutputSize = HashMap; +trait OutputSizeMethods { + fn from_contraction_and_shapes( + contraction: &Contraction, + operand_shapes: &[Vec], + ) -> Result; +} +impl OutputSizeMethods for OutputSize { + /// Build the HashMap containing the axis lengths + fn from_contraction_and_shapes( + contraction: &Contraction, + operand_shapes: &[Vec], + ) -> Result { + // Check that len(operand_indices) == len(operands) + if contraction.operand_indices.len() != operand_shapes.len() { + return Err( + "number of operands in contraction does not match number of operands supplied", + ); + } + + let mut index_lengths: OutputSize = HashMap::new(); + + for (indices, operand_shape) in contraction.operand_indices.iter().zip(operand_shapes) { + // Check that len(operand_indices[i]) == len(operands[i].shape()) + if indices.len() != operand_shape.len() { + return Err( + "number of indices in one or more operands does not match dimensions of operand", + ); + } + + // Check that whenever there are multiple copies of an index, + // operands[i].shape()[m] == operands[j].shape()[n] + for (&c, &n) in indices.iter().zip(operand_shape) { + let existing_n = index_lengths.entry(c).or_insert(n); + if *existing_n != n { + return Err("repeated index with different size"); + } + } + } + + Ok(index_lengths) + } +} +/// A `SizedContraction` contains a `Contraction` as well as a `HashMap` +/// specifying the axis lengths for each index in the contraction. +/// +/// Note that output_size is a misnomer (to be changed); it contains all the axis lengths, +/// including the ones that will be contracted (i.e. not just the ones in +/// contraction.output_indices). +/// +/// ``` +/// # use ndarray_einsum_beta::*; +/// # use ndarray::prelude::*; +/// let m1: Array2 = Array::zeros((2, 3)); +/// let m2: Array2 = Array::zeros((3, 4)); +/// let sc = validate_and_size("ij,jk->ik", &[&m1, &m2]).unwrap(); +/// assert_eq!(sc.output_size[&'i'], 2); +/// assert_eq!(sc.output_size[&'k'], 4); +/// assert_eq!(sc.output_size[&'j'], 3); +/// ``` #[derive(Debug, Clone, Serialize)] pub struct SizedContraction { pub contraction: Contraction, @@ -74,6 +197,10 @@ pub struct SizedContraction { } impl SizedContraction { + /// Creates a new SizedContraction based on a subset of the operand indices and/or output + /// indices. Not intended for general use; used internally in the crate when compiling + /// a multi-tensor contraction into a set of tensor simplifications (a.k.a. singleton + /// contractions) and pairwise contractions. pub fn subset( &self, new_operand_indices: &[Vec], @@ -92,7 +219,10 @@ impl SizedContraction { return Err("Character found in new_operand_indices but not in self.output_size"); } + // Validate what they asked for and compute summation_indices let new_contraction = Contraction::from_indices(new_operand_indices, new_output_indices)?; + + // Clone output_size, omitting unused characters let new_output_size: OutputSize = self .output_size .iter() @@ -106,11 +236,11 @@ impl SizedContraction { }) } - pub fn from_contraction_and_shapes( + fn from_contraction_and_shapes( contraction: &Contraction, operand_shapes: &[Vec], ) -> Result { - let output_size = get_output_size_from_shapes(&contraction, operand_shapes)?; + let output_size = OutputSize::from_contraction_and_shapes(&contraction, operand_shapes)?; Ok(SizedContraction { contraction: contraction.clone(), @@ -118,6 +248,8 @@ impl SizedContraction { }) } + /// Create a SizedContraction from an already-created Contraction and a list + /// of operands. pub fn from_contraction_and_operands( contraction: &Contraction, operands: &[&dyn ArrayLike], @@ -127,6 +259,38 @@ impl SizedContraction { SizedContraction::from_contraction_and_shapes(contraction, &operand_shapes) } + /// Create a SizedContraction from an already-created Contraction and a slice + /// of `Vec`s containing the shapes of each operand. + pub fn from_string_and_shapes( + input_string: &str, + operand_shapes: &[Vec], + ) -> Result { + let contraction = validate(input_string)?; + SizedContraction::from_contraction_and_shapes(&contraction, operand_shapes) + } + + /// Create a SizedContraction from an `einsum`-formatted input string and a list + /// of operands. + pub fn new( + input_string: &str, + operands: &[&dyn ArrayLike], + ) -> Result { + let operand_shapes = get_operand_shapes(operands); + + SizedContraction::from_string_and_shapes(input_string, &operand_shapes) + } + + /// Perform the contraction on a set of operands. + /// + /// ``` + /// # use ndarray_einsum_beta::*; + /// # use ndarray::prelude::*; + /// let m1: Array2 = Array::zeros((2, 3)); + /// let m2: Array2 = Array::zeros((3, 4)); + /// let out: ArrayD = Array::zeros((2, 4)).into_dyn(); + /// let sc = validate_and_size("ij,jk->ik", &[&m1, &m2]).unwrap(); + /// assert_eq!(sc.contract_operands(&[&m1, &m2]), out); + /// ``` pub fn contract_operands( &self, operands: &[&dyn ArrayLike], @@ -136,40 +300,7 @@ impl SizedContraction { } } -fn generate_contraction(parse: &EinsumParse) -> Result { - let mut input_indices = HashMap::new(); - for c in parse.operand_indices.iter().flat_map(|s| s.chars()) { - *input_indices.entry(c).or_insert(0) += 1; - } - - let mut unique_indices = Vec::new(); - let mut duplicated_indices = Vec::new(); - for (&c, &n) in input_indices.iter() { - if n > 1 { - duplicated_indices.push(c); - } else { - unique_indices.push(c); - }; - } - - // Handle implicit case, e.g. nothing to the right of the arrow - let requested_output_indices: Vec = match &parse.output_indices { - Some(s) => s.chars().collect(), - _ => { - let mut o = unique_indices.clone(); - o.sort(); - o - } - }; - - let operand_indices: Vec> = parse - .operand_indices - .iter() - .map(|x| x.chars().collect::>()) - .collect(); - Contraction::from_indices(&operand_indices, &requested_output_indices) -} - +/// Runs an input string through a regex and convert it to an EinsumParse. fn parse_einsum_string(input_string: &str) -> Option { lazy_static! { // Unwhitespaced version: @@ -197,41 +328,9 @@ fn parse_einsum_string(input_string: &str) -> Option { }) } +/// Wrapper around [Contraction::new()](struct.Contraction.html#method.new). pub fn validate(input_string: &str) -> Result { - let p = parse_einsum_string(input_string).ok_or("Invalid string")?; - generate_contraction(&p) -} - -fn get_output_size_from_shapes( - contraction: &Contraction, - operand_shapes: &[Vec], -) -> Result { - // Check that len(operand_indices) == len(operands) - if contraction.operand_indices.len() != operand_shapes.len() { - return Err("number of operands in contraction does not match number of operands supplied"); - } - - let mut index_lengths: OutputSize = HashMap::new(); - - for (indices, operand_shape) in contraction.operand_indices.iter().zip(operand_shapes) { - // Check that len(operand_indices[i]) == len(operands[i].shape()) - if indices.len() != operand_shape.len() { - return Err( - "number of indices in one or more operands does not match dimensions of operand", - ); - } - - // Check that whenever there are multiple copies of an index, - // operands[i].shape()[m] == operands[j].shape()[n] - for (&c, &n) in indices.iter().zip(operand_shape) { - let existing_n = index_lengths.entry(c).or_insert(n); - if *existing_n != n { - return Err("repeated index with different size"); - } - } - } - - Ok(index_lengths) + Contraction::new(input_string) } fn get_operand_shapes(operands: &[&dyn ArrayLike]) -> Vec> { @@ -241,24 +340,21 @@ fn get_operand_shapes(operands: &[&dyn ArrayLike]) -> Vec> { .collect() } +/// Only included so the function can be called from WASM, i.e. without +/// arguments that are already ndarray ```ArrayBase```s. pub fn validate_and_size_from_shapes( input_string: &str, operand_shapes: &[Vec], ) -> Result { - let contraction = validate(input_string)?; - let output_size = get_output_size_from_shapes(&contraction, operand_shapes)?; - - Ok(SizedContraction { - contraction, - output_size, - }) + SizedContraction::from_string_and_shapes(input_string, operand_shapes) } +/// Wrapper around [SizedContraction::new()](struct.SizedContraction.html#method.new). pub fn validate_and_size( input_string: &str, operands: &[&dyn ArrayLike], ) -> Result { - validate_and_size_from_shapes(input_string, &get_operand_shapes(operands)) + SizedContraction::new(input_string, operands) } pub fn validate_and_optimize_order(