人生苦长我用jax
为什么选择 JAX?
JAX 是 Google 开发的数值计算库,具有以下特点:
- 函数式编程:所有函数都是纯函数,没有副作用
- 即时编译(JIT):通过
@jit装饰器可以将 Python 代码编译成高效的 XLA 汇编 - 自动微分:通过
grad、value_and_grad等函数自动计算梯度 - 向量化:通过
vmap、pmap轻松实现批量和分布式计算 - GPU/TPU 支持:可以在 GPU 和 TPU 上运行
Flax 基础
什么是 Flax?
Flax 是 JAX 上最流行的神经网络库,它提供了一种声明式的方式来定义神经网络。
核心概念
- nn.Module:所有神经网络的基类
- @nn.compact:装饰器,用于定义网络结构(在
__call__方法内部定义子层) - nn.Dense:全连接层
- nn.Conv:卷积层
- nn.GroupNorm:组归一化
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来源 Attic的博客!
