Skip to content

v0.8.3

Compare
Choose a tag to compare
@cgarciae cgarciae released this 30 Apr 09:56
· 558 commits to main since this release

What's Changed

  • Add git fetch upstream to contributing doc. by @carlosgmartin in #3757
  • removed getattr/setattr unboxing magic from nnx.Pytree by @chiamp in #3743
  • added Einsum layer to NNX by @chiamp in #3741
  • Make TrainState's step possibly jax.Array. This makes replicate valid for type checking. by @copybara-service in #3763
  • v0.8.3 by @cgarciae in #3758
  • [nnx] fix demo notebook by @cgarciae in #3744
  • added nnx api reference by @chiamp in #3762
  • updated rng docstring for init, apply and make_rng by @chiamp in #3765
  • use note box in make_rng docstring by @cgarciae in #3767
  • [nnx] improved graph update mechanism by @cgarciae in #3759
  • use note box in docstrings by @chiamp in #3769
  • Add reset_gate flag to MGUCell. by @carlosgmartin in #3760
  • Access thread_resources via jax.interpreters.pxla instead of jax.experimental.maps by @copybara-service in #3775
  • Minor doc improvements by @canyon289 in #3588
  • added MGU reset_gate test by @chiamp in #3773
  • [nnx] Pytrees are Trees by @cgarciae in #3768
  • Use short-circuiting access to debug_key_reuse by @copybara-service in #3781
  • fix tabulate on norm wrappers by @chiamp in #3772
  • Add kw_only struct.dataclass test by @chiamp in #3651
  • extended PyTreeNode to take dataclass kwargs by @chiamp in #3785
  • [nnx] Arrays are state by @cgarciae in #3791
  • [nnx] add GraphNode base class by @cgarciae in #3790
  • [nnx] jit accepts many Modules by @cgarciae in #3783
  • Exposing the experimental _split_transpose JAX scan parameter in Flax. by @copybara-service in #3795
  • Expose nnx.GraphNode by @chiamp in #3796
  • [nnx] Rngs and RngStream inherit from GraphNode by @cgarciae in #3793
  • [nnx] TrainState uses struct by @cgarciae in #3788
  • [nnx] split returns graphdef first by @cgarciae in #3794
  • Remove the uninitialized field "embedding" in nn.Embed by @copybara-service in #3801
  • Add nnx.training by @chiamp in #3782
  • [nnx] non-str State keys by @cgarciae in #3802
  • [nnx] allow all jit kwargs in nnx.jit by @cgarciae in #3809
  • [nnx] simplify readme by @cgarciae in #3805
  • [nnx] Fix nnx basics by @cgarciae in #3812
  • [nnx] grad accepts argnums by @cgarciae in #3798
  • [nnx] improve toy examples by @cgarciae in #3813
  • [nnx] expose Sequential by @cgarciae in #3814
  • [nnx] Rng Variable tags by @cgarciae in #3807
  • [nnx] remove copy in graph unflatten by @cgarciae in #3804
  • fixed optax guide links and docstring typos by @chiamp in #3789
  • added dropout broadcast test by @chiamp in #3776
  • relaxed grads kwarg for Optimizer.update by @chiamp in #3818
  • added tree_map deprecation warning filter by @chiamp in #3828
  • updated tree_map by @chiamp in #3823
  • added NNX vs JAX transformations guide by @chiamp in #3819
  • Updated NNX MNIST tutorial by @chiamp in #3810
  • [nnx] add Dropout.rngs by @cgarciae in #3815
  • removed autosummary from linen docs by @chiamp in #3792
  • Fix cloudpickle sentinel cloning by @cgarciae in #3825
  • [nnx] remove pytreelib by @cgarciae in #3816
  • [nnx] fix nnx_basics by @cgarciae in #3839
  • [linen] fix DenseGeneral init by @cgarciae in #3834
  • [nnx] jit constrain object state by @cgarciae in #3817
  • Copybara import of the project: by @copybara-service in #3857
  • Add example of unbox() and replace_boxed() to the jit guide by @IvyZX in #3843
  • RNNCellBase refactor FLIP by @cgarciae in #3099
  • [nnx] Some small documentation suggestions. by @gnecula in #3861
  • updated nnx dropout by @chiamp in #3841
  • Fix LogicalRules type annotation. (Tuple[str] is a tuple with single element string, by @copybara-service in #3877
  • Add option to skip float32 promotion when computing means and variances for normalization. by @copybara-service in #3873
  • added nnx api reference link by @chiamp in #3871
  • option of forcing the input of softmax to be fp32 for better numerical stability in mixed-precision training. by @copybara-service in #3874
  • allow custom dot_general for einsum. by @copybara-service in #3884
  • [NVIDIA] Extend the custom fp8 accumulate dtype in non-jit scenarios by @kaixih in #3827
  • updated robots.txt by @chiamp in #3886
  • fixed autosummary links by @chiamp in #3887
  • Fix jax.tree_util.register_dataclass in older JAX versions. by @copybara-service in #3885
  • [nnx] v0.1 by @cgarciae in #3876

Full Changelog: v0.8.2...v0.8.3