Skip to content

Commit

Permalink
update v2.0-pre (#922)
Browse files Browse the repository at this point in the history
* Update doc URL. (#821)

* Support indexing 2-axes RaggedTensor, Support slicing for RaggedTensor (#825)

* Support index 2-axes RaggedTensor, Support slicing for RaggedTensor

* Fix compiling errors

* Fix unit test

* Change RaggedTensor.data to RaggedTensor.values

* Fix style

* Add docs

* Run nightly-cpu when pushing code to nightly-cpu branch

* Prune with max_arcs in IntersectDense (#820)

* Add checking for array constructor

* Prune with max arcs

* Minor fix

* Fix typo

* Fix review comments

* Fix typo

* Release v1.8

* Create a ragged tensor from a regular tensor. (#827)

* Create a ragged tensor from a regular tensor.

* Add tests for creating ragged tensors from regular tensors.

* Add more tests.

* Print ragged tensors in a way like what PyTorch is doing.

* Fix test cases.

* Trigger GitHub actions manually. (#829)

* Run GitHub actions on merging. (#830)

* Support printing ragged tensors in a more compact way. (#831)

* Support printing ragged tensors in a more compact way.

* Disable support for torch 1.3.1

* Fix test failures.

* Add levenshtein alignment (#828)

* Add levenshtein graph

* Contruct k2.RaggedTensor in python part

* Fix review comments, return aux_labels in ctc_graph

* Fix tests

* Fix bug of accessing symbols

* Fix bug of accessing symbols

* Change argument name, add levenshtein_distance interface

* Fix test error, add tests for levenshtein_distance

* Fix review comments and add unit test for c++ side

* update the interface of levenshtein alignment

* Fix review comments

* Release v1.9

* Support a[b[i]] where both a and b are ragged tensors. (#833)

* Display import error solution message on MacOS (#837)

* Fix installation doc. (#841)

* Fix installation doc.

Remove Windows support. Will fix it later.

* Fix style issues.

* fix typos in the install instructions (#844)

* make cmake adhere to the modernized way of finding packages outside default dirs (#845)

* import torch first in the smoke tests to preven SEGFAULT (#846)

* Add doc about how to install a CPU version of k2. (#850)

* Add doc about how to install a CPU version of k2.

* Remove property setter of Fsa.labels

* Update Ubuntu version in GitHub CI since 16.04 reaches end-of-life.

* Support PyTorch 1.10. (#851)

* Fix test cases for k2.union() (#853)

* Fix out-of-boundary access (read). (#859)

* Update all the example codes in the docs (#861)

* Update all the example codes in the docs

I have run all the modified codes with  the newest version k2.

* do some changes

* Fix compilation errors with CUB 1.15. (#865)

* Update README. (#873)

* Update README.

* Fix typos.

* Fix ctc graph (make aux_labels of final arcs -1) (#877)

* Fix LICENSE location to k2 folder (#880)

* Release v1.11. (#881)

It contains bugfixes.

* Update documentation for hash.h (#887)

* Update documentation for hash.h

* Typo fix

* Wrap MonotonicLowerBound (#883)

* Wrap MonotonicLowerBound

* Add unit tests

* Support int64; update documents

* Remove extra commas after 'TOPSORTED' properity and fix RaggedTensor constructer parameter 'byte_offset' out-of-range bug. (#892)

Co-authored-by: gzchenduisheng <[email protected]>

* Fix small typos (#896)

* Fix k2.ragged.create_ragged_shape2 (#901)

Before the fix, we have to specify both `row_splits` and `row_ids`
while calling `k2.create_ragged_shape2` even if one of them is `None`.

After this fix, we only need to specify one of them.

* Add rnnt loss (#891)

* Add cpp code of mutual information

* mutual information working

* Add rnnt loss

* Add pruned rnnt loss

* Minor Fixes

* Minor fixes & fix code style

* Fix cpp style

* Fix code style

* Fix s_begin values in padding positions

* Fix bugs related to boundary; Fix s_begin padding value; Add more tests

* Minor fixes

* Fix comments

* Add boundary to pruned loss tests

* Use more efficient way to fix boundaries (#906)

* Release v1.12 (#907)

* Change the sign of the rnnt_loss and add reduction argument (#911)

* Add right boundary constrains for s_begin

* Minor fixes to the interface of rnnt_loss to make it return positive value

* Fix comments

* Release a new version

* Minor fixes

* Minor fixes to the docs

* Fix building doc. (#908)

* Fix building doc.

* Minor fixes.

* Minor fixes.

* Fix building doc (#912)

* Fix building doc

* Fix flake8

* Support torch 1.10.x (#914)

* Support torch 1.10.x

* Fix installing PyTorch.

* Update INSTALL.rst (#915)

* Update INSTALL.rst

Setting a few additional env variables to enable compilation from source *with CUDA GPU computation support enabled*

* Fix torch/cuda/python versions in the doc. (#918)

* Fix torch/cuda/python versions in the doc.

* Minor fixes.

* Fix building for CUDA 11.6 (#917)

* Fix building for CUDA 11.6

* Minor fixes.

* Implement Unstack (#920)

* Implement unstack

* Remove code does not relate to this PR

* Remove for loop on output dim; add Unstack ragged

* Add more docs

* Fix comments

* Fix docs & unit tests

* SubsetRagged & PruneRagged (#919)

* Extend interface of SubsampleRagged.

* Add interface for pruning ragged tensor.

* Draft of new RNN-T decoding method

* Implements SubsampleRaggedShape

* Implements PruneRagged

* Rename subsample-> subset

* Minor fixes

* Fix comments

Co-authored-by: Daniel Povey <[email protected]>

Co-authored-by: Fangjun Kuang <[email protected]>
Co-authored-by: Piotr Żelasko <[email protected]>
Co-authored-by: Jan "yenda" Trmal <[email protected]>
Co-authored-by: Mingshuang Luo <[email protected]>
Co-authored-by: Ludwig Kürzinger <[email protected]>
Co-authored-by: Daniel Povey <[email protected]>
Co-authored-by: drawfish <[email protected]>
Co-authored-by: gzchenduisheng <[email protected]>
Co-authored-by: alexei-v-ivanov <[email protected]>
  • Loading branch information
10 people authored Feb 20, 2022
1 parent 9832893 commit 8e4c2e5
Show file tree
Hide file tree
Showing 78 changed files with 5,161 additions and 308 deletions.
5 changes: 5 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
show-source=true
statistics=true
max-line-length=80
per-file-ignores =
# line too long E501
# line break before operator W503
k2/python/k2/rnnt_loss.py: E501, W503
k2/python/tests/rnnt_loss_test.py: W503
exclude =
.git,
setup.py,
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/build-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-18.04, macos-10.15]
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"]
# Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"]
# Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10.x
python-version: [3.6, 3.7, 3.8, 3.9]
exclude:
- python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0]
Expand Down
22 changes: 15 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
# from https://download.pytorch.org/whl/torch_stable.html
# Note: There are no torch versions for CUDA 11.2
#
# 1.10 supports: cuda10.2 (default), 11.1, 11.3
# 1.10.x supports: cuda10.2 (default), 11.1, 11.3
# 1.9.x supports: cuda10.2 (default), 11.1
# PyTorch 1.8.x supports: cuda 10.1, 10.2 (default), 11.1
# PyTorch 1.7.x supports: cuda 10.1, 10.2 (default), 11.0
Expand All @@ -50,9 +50,9 @@ jobs:
# CUDA 11.3 is for torch 1.10
cuda: ["10.1", "10.2", "11.0", "11.1", "11.3"]
gcc: ["7"]
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"]
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"]
#
# Python 3.9 is for PyTorch 1.7.1, 1.8.0, 1.8.1, 1.9.x, 1.10
# Python 3.9 is for PyTorch 1.7.1, 1.8.0, 1.8.1, 1.9.x, 1.10.x
python-version: [3.6, 3.7, 3.8, 3.9]
exclude:
- cuda: "11.3" # exclude 11.3 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1]
Expand All @@ -73,7 +73,7 @@ jobs:
torch: "1.9.0"
- cuda: "11.3"
torch: "1.9.1"
- cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10]
- cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2]
torch: "1.5.0"
- cuda: "11.0"
torch: "1.5.1"
Expand All @@ -88,7 +88,11 @@ jobs:
- cuda: "11.0"
torch: "1.9.1"
- cuda: "11.0"
torch: "1.10"
torch: "1.10.0"
- cuda: "11.0"
torch: "1.10.1"
- cuda: "11.0"
torch: "1.10.2"
- cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1]
torch: "1.5.0"
- cuda: "11.1"
Expand All @@ -99,12 +103,16 @@ jobs:
torch: "1.7.0"
- cuda: "11.1"
torch: "1.7.1"
- cuda: "10.1" # exclude CUDA 10.1 for [1.9.0, 1.9.1, 1.10]
- cuda: "10.1" # exclude CUDA 10.1 for [1.9.0, 1.9.1, 1.10.0, 10.1, 10.2]
torch: "1.9.0"
- cuda: "10.1"
torch: "1.9.1"
- cuda: "10.1"
torch: "1.10"
torch: "1.10.0"
- cuda: "10.1"
torch: "1.10.1"
- cuda: "10.1"
torch: "1.10.2"
- python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0]
torch: "1.5.0"
- python-version: 3.9
Expand Down
20 changes: 14 additions & 6 deletions .github/workflows/build_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
cuda: ["10.1", "10.2", "11.0", "11.1", "11.3"]
# from https://download.pytorch.org/whl/torch_stable.html
#
# PyTorch 1.10 supports: 10.2 (default), 11.1, 11.3
# PyTorch 1.10.x supports: 10.2 (default), 11.1, 11.3
# PyTorch 1.9.x supports: 10.2 (default), 11.1
# PyTorch 1.8.1 supports: cuda 10.1, 10.2 (default), 11.1
# PyTorch 1.8.0 supports: cuda 10.1, 10.2 (default), 11.1
Expand All @@ -56,9 +56,9 @@ jobs:
# https://github.com/csukuangfj/k2/runs/2533830771?check_suite_focus=true
# and
# https://github.com/NVIDIA/apex/issues/805
torch: ["1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"]
torch: ["1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"]
exclude:
# - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10]
# - cuda: "11.0" # exclude 11.0 for [1.5.0, 1.5.1, 1.6.0, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2]
# torch: "1.5.0"
# - cuda: "11.0"
# torch: "1.5.1"
Expand All @@ -73,7 +73,11 @@ jobs:
- cuda: "11.0"
torch: "1.9.1"
- cuda: "11.0"
torch: "1.10"
torch: "1.10.0"
- cuda: "11.0"
torch: "1.10.1"
- cuda: "11.0"
torch: "1.10.2"
# - cuda: "11.1" # exclude 11.1 for [1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1]
# torch: "1.5.0"
# - cuda: "11.1"
Expand All @@ -84,12 +88,16 @@ jobs:
torch: "1.7.0"
- cuda: "11.1"
torch: "1.7.1"
- cuda: "10.1" # exclude 10.1 for [1.9.0, 1.9.1, 1.10]
- cuda: "10.1" # exclude 10.1 for [1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2]
torch: "1.9.0"
- cuda: "10.1"
torch: "1.9.1"
- cuda: "10.1"
torch: "1.10"
torch: "1.10.0"
- cuda: "10.1"
torch: "1.10.1"
- cuda: "10.1"
torch: "1.10.2"
- python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0]
torch: "1.5.0"
- python-version: 3.9
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_conda_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
#
# Other PyTorch versions are not tested
#
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"]
torch: ["1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"]
exclude:
- python-version: 3.9 # exclude Python 3.9 for [1.5.0, 1.5.1, 1.6.0, 1.7.0]
torch: "1.5.0"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nightly-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
os: [ubuntu-18.04, macos-10.15]
# Python 3.9 is for PyTorch 1.7.1, 1.8.x, 1.9.x, 1.10
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10"]
torch: ["1.4.0", "1.5.0", "1.5.1", "1.6.0", "1.7.0", "1.7.1", "1.8.0", "1.8.1", "1.9.0", "1.9.1", "1.10.0", "1.10.1", "1.10.2"]
exclude:
- python-version: 3.9 # exclude Python 3.9 for [1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0]
torch: "1.4.0"
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ message(STATUS "Enabled languages: ${languages}")

project(k2 ${languages})

set(K2_VERSION "1.9")
set(K2_VERSION "1.13")

# ----------------- Supported build types for K2 project -----------------
set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel)
Expand Down
12 changes: 12 additions & 0 deletions INSTALL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ From source
git clone https://github.com/k2-fsa/k2.git
cd k2
python3 setup.py install
From source (with CUDA support)
=========================

.. code-block:: bash
git clone https://github.com/k2-fsa/k2.git
cd k2
export K2_CMAKE_ARGS="-DK2_WITH_CUDA=ON -DCMAKE_BUILD_TYPE=Release"
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/lib:$LD_LIBRARY_PATH
export PATH=$PATH:/usr/local/cuda/bin
python3 setup.py install
Read `<https://k2.readthedocs.io/en/latest/installation/from_source.html>`_
to learn more
Expand Down
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include LICENSE

45 changes: 10 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ speech recognition system with multiple decoding passes including lattice
rescoring and confidence estimation. We hope k2 will have many other
applications as well.

One of the key algorithms that we want to make efficient in the short term is
One of the key algorithms that we have implemented is
pruned composition of a generic FSA with a "dense" FSA (i.e. one that
corresponds to log-probs of symbols at the output of a neural network). This
can be used as a fast implementation of decoding for ASR, and for CTC and
Expand Down Expand Up @@ -78,46 +78,21 @@ general and extensible framework to allow further development of ASR technology.

## Current state of the code

A lot of the code is still unfinished (Sep 11, 2020).
We finished the CPU versions of many algorithms and this code is in `k2/csrc/host/`;
however, after that we figured out how to implement things on the GPU and decided
to change the interfaces so the CPU and GPU code had a more unified interface.
Currently in `k2/csrc/` we have more GPU-oriented implementations (although
these algorithms will also work on CPU). We had almost finished the Python
wrapping for the older code, in the `k2/python/` subdirectory, but we decided not to
release code with that wrapping because it would have had to be reworked to be compatible
with our GPU algorithms. Instead we will use the interfaces drafted in `k2/csrc/`
e.g. the Context object (which encapsulates things like memory managers from external
toolkits) and the Tensor object which can be used to wrap tensors from external toolkits;
and wrap those in Python (using pybind11). The code in host/ will eventually
be either deprecated, rewritten or wrapped with newer-style interfaces.

## Plans for initial release

We hope to get the first version working in early October. The current
short-term aim is to finish the GPU implementation of pruned composition of a
normal FSA with a dense FSA, which is the same as decoder search in speech
recognition and can be used to implement CTC training and lattice-free MMI (LF-MMI) training. The
proof-of-concept that we will release initially is something that's like CTC
but allowing more general supervisions (general FSAs rather than linear
sequences). This will work on GPU. The same underlying code will support
LF-MMI so that would be easy to implement soon after. We plan to put
example code in a separate repository.
We have wrapped all the C++ code to Python with [pybind11](https://github.com/pybind/pybind11)
and have finished the integration with [PyTorch](https://github.com/pytorch/pytorch).

We are currently writing speech recognition recipes using k2, which are hosted in a
separate repository. Please see <https://github.com/k2-fsa/icefall>.

## Plans after initial release

We will then gradually implement more algorithms in a way that's compatible
with the interfaces in `k2/csrc/`. Some of them will be CPU-only to start
with. The idea is to eventually have very rich capabilities for operating on
collections of sequences, including methods to convert from a lattice to a
collection of linear sequences and back again (for purposes of neural language
model rescoring, neural confidence estimation and the like).
We are currently trying to make k2 ready for production use (see the branch
[v2.0-pre](https://github.com/k2-fsa/k2/tree/v2.0-pre)).

## Quick start

Want to try it out without installing anything? We have setup a [Google Colab][1].

Caution: k2 is not nearly ready for actual use! We are still coding the core
algorithms, and hope to have an early version working by early October.
You can find more Colab notebooks using k2 in speech recognition at
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html>.

[1]: https://colab.research.google.com/drive/1qbHUhNZUX7AYEpqnZyf29Lrz2IPHBGlX?usp=sharing
5 changes: 3 additions & 2 deletions cmake/cub.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ function(download_cub)

include(FetchContent)

set(cub_URL "https://github.com/NVlabs/cub/archive/1.10.0.tar.gz")
set(cub_HASH "SHA256=8531e09f909aa021125cffa70a250761dfc247f960d7a1a12f65e6651ffb6477")
set(cub_URL "https://github.com/NVlabs/cub/archive/1.15.0.tar.gz")
set(cub_HASH "SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685")


FetchContent_Declare(cub
URL ${cub_URL}
Expand Down
14 changes: 7 additions & 7 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
dataclasses
graphviz
recommonmark
sphinx
sphinx-autodoc-typehints
sphinx_rtd_theme
sphinxcontrib-bibtex
dataclasses==0.6
graphviz==0.19.1
recommonmark==0.7.1
sphinx==4.3.2
sphinx-autodoc-typehints==1.12.0
sphinx_rtd_theme==1.0.0
sphinxcontrib-bibtex==2.4.1
torch>=1.6.0
9 changes: 7 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# -- Project information -----------------------------------------------------

project = 'k2'
copyright = '2020-2021, k2 development team'
copyright = '2020-2022, k2 development team'
author = 'k2 development team'


Expand Down Expand Up @@ -147,7 +147,12 @@ def find_source():

# Replace key with value in the generated doc
REPLACE_PATTERN = {
'_k2.ragged': 'k2.ragged',
# somehow it results in errors
# Handler <function process_docstring at 0x7f47a290aca0> for event
# 'autodoc-process-docstring' threw an exception (exception:
# <module '_k2.ragged'> is a built-in module)
#
# '_k2.ragged': 'k2.ragged',
'at::Tensor': 'torch.Tensor'
}

Expand Down
14 changes: 7 additions & 7 deletions docs/source/core_concepts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,13 @@ In k2, you would use the following code to compute it:
fsa = k2.Fsa.from_str(s)
fsa.draw('fsa2.svg')
fsa = k2.create_fsa_vec([fsa])
total_scores = k2.get_tot_scores(fsa, log_semiring=False, use_double_scores=False)
total_scores = fsa.get_tot_scores(log_semiring=False, use_double_scores=False)
print(total_scores)
# It prints: tensor([0.2000])
.. HINT::

:func:`k2.get_tot_scores` takes a vector of FSAs as input,
:func:`k2.Fsa.get_tot_scores` takes a vector of FSAs as input,
so we use :func:`k2.create_fsa_vec` to turn an FSA into a vector of FSAs.

Most operations in k2 take a vector of FSAs as input and process them
Expand Down Expand Up @@ -230,7 +230,7 @@ The code in k2 looks like:
'''
fsa = k2.Fsa.from_str(s)
fsa = k2.create_fsa_vec([fsa])
total_scores = k2.get_tot_scores(fsa, log_semiring=True, use_double_scores=False)
total_scores = fsa.get_tot_scores(log_semiring=True, use_double_scores=False)
print(total_scores)
# It prints: tensor([0.8444])
Expand Down Expand Up @@ -319,7 +319,7 @@ the FSA given in :numref:`autograd example`:
fsa.scores = nnet_output
fsa.draw('autograd_tropical.svg')
fsa_vec = k2.create_fsa_vec([fsa])
total_scores = k2.get_tot_scores(fsa_vec, log_semiring=False, use_double_scores=False)
total_scores = fsa.get_tot_scores(log_semiring=False, use_double_scores=False)
total_scores.backward()
print(nnet_output.grad)
Expand Down Expand Up @@ -366,11 +366,11 @@ Example 2: Autograd in log semiring

For the log semiring, we just change::

total_scores = k2.get_tot_scores(fsa_vec, log_semiring=False, use_double_scores=False)
total_scores = fsa.get_tot_scores(log_semiring=False, use_double_scores=False)

to::

total_scores = k2.get_tot_scores(fsa_vec, log_semiring=True, use_double_scores=False)
total_scores = fsa.get_tot_scores(log_semiring=True, use_double_scores=False)

For completeness and ease of reference, we repost the code below.

Expand All @@ -392,7 +392,7 @@ For completeness and ease of reference, we repost the code below.
fsa.scores = nnet_output
fsa.draw('autograd_log.svg')
fsa_vec = k2.create_fsa_vec([fsa])
total_scores = k2.get_tot_scores(fsa_vec, log_semiring=True, use_double_scores=False)
total_scores = fsa.get_tot_scores(log_semiring=True, use_double_scores=False)
total_scores.backward()
print(nnet_output.grad)
Expand Down
6 changes: 3 additions & 3 deletions docs/source/installation/conda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ Read the following if you want to learn more.
Supported versions
------------------

.. |conda_python_versions| image:: ./images/python-3.6_3.7_3.8-blue.svg
.. |conda_python_versions| image:: ./images/python_ge_3.6-blue.svg
:alt: Supported python versions

.. |conda_cuda_versions| image:: ./images/cuda-10.1_10.2_11.0_11.1-orange.svg
.. |conda_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg
:alt: Supported cuda versions

.. |conda_pytorch_versions| image:: ./images/pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg
.. |conda_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg
:alt: Supported pytorch versions

- |conda_python_versions|
Expand Down
6 changes: 3 additions & 3 deletions docs/source/installation/from_source.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ The following versions of Python, CUDA, and PyTorch are known to work.
- |source_cuda_versions|
- |source_pytorch_versions|

.. |source_python_versions| image:: ./images/source_python-3.6_3.7_3.8_3.9-blue.svg
.. |source_python_versions| image:: ./images/python_ge_3.6-blue.svg
:alt: Supported python versions

.. |source_cuda_versions| image:: ./images/source_cuda-10.1_10.2_11.0_11.1_11.2_11.3-orange.svg
.. |source_cuda_versions| image:: ./images/cuda_ge_10.1-orange.svg
:alt: Supported cuda versions

.. |source_pytorch_versions| image:: ./images/source_pytorch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1-green.svg
.. |source_pytorch_versions| image:: ./images/pytorch_ge_1.5.0-green.svg
:alt: Supported pytorch versions

Before compiling k2, some preparation work has to be done:
Expand Down
Loading

0 comments on commit 8e4c2e5

Please sign in to comment.