## 关于要替代TensorFlow的JAX，你知道多少？

2 月 11 日 AI前线

AI 前线导读：这个简短的教程将介绍关于 JAX 的基础知识。JAX 是一个 Python 库，它通过函数转换来增强 numpy 和 Python 代码，使运行机器学习程序中常见的操作轻而易举。具体来说，它会使得编写标准 Python / numpy 代码变得简单，并且能够立即执行

- 及时编译函数，通过 XLA 在加速器上高效运行

- 自动矢量化函数，并执行处理“批量”数据等

1 JAX 只是 numpy（大多数情况下）

import random
import itertools
import jax
import jax.numpy as np
# Current convention is to import original numpy as "onp"
import numpy as onp
from __future__ import print_function
2 背景

# Sigmoid nonlinearity
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# Computes our network's output
def net(params, x):
w1, b1, w2, b2 = params
hidden = np.tanh(np.dot(w1, x) + b1)
return sigmoid(np.dot(w2, hidden) + b2)
# Cross-entropy loss
def loss(params, x, y):
out = net(params, x)
cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)
return cross_entropy
# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
predictions = [int(net(params, inp) > 0.5) for inp in inputs]
for inp, out in zip(inputs, predictions):
print(inp, '->', out)
return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])

def initial_params():
return [
onp.random.randn(3, 2),  # w1
onp.random.randn(3),  # b1
onp.random.randn(3),  # w2
onp.random.randn(),  #b2
]

# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# Initialize parameters randomly
params = initial_params()
for n in itertools.count():
# Grab a single random input
x = inputs[onp.random.choice(inputs.shape[0])]
# Compute the target output
y = onp.bitwise_xor(*x)
# Get the gradient of the loss for this input/output pair
# Update parameters via gradient descent
params = [param - learning_rate * grad
# Every 100 iterations, check whether we've solved XOR
if not n % 100:
print('Iteration {}'.format(n))
if test_all_inputs(inputs, params):
break

4 jax.jit

# Time the original gradient function
# Run once to trigger JIT compilation

10 loops, best of 3: 13.1 ms per loop

1000 loops, best of 3: 862 µs per loop

params = initial_params()
for n in itertools.count():
x = inputs[onp.random.choice(inputs.shape[0])]
y = onp.bitwise_xor(*x)
params = [param - learning_rate * grad
if not n % 100:
print('Iteration {}'.format(n))
if test_all_inputs(inputs, params):
break

5 jax.vmap

jax.vmap 还可接受其他参数：

• in_axes 是一个元组或整数，它告诉 JAX 函数参数应该对哪些轴并行化。元组应该与 vmap'd 函数的参数数量相同，或者只有一个参数时为整数。示例中，我们将使用（None，0,0），指“不在第一个参数（params）上并行化，并在第二个和第三个参数（x 和 y）的第一个（第零个）维度上并行化”。

• out_axes 类似于 in_axes，除了它指定了函数输出的哪些轴并行化。我们在例子中使用 0，表示在函数唯一输出的第一个（第零个）维度上进行并行化（损失梯度）。

params = initial_params()

batch_size = 100

for n in itertools.count():
# Generate a batch of inputs
x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)]
y = onp.bitwise_xor(x[:, 0], x[:, 1])
# The call to loss_grad remains the same!
# Note that we now need to average gradients over the batch
params = [param - learning_rate * np.mean(grad, axis=0)
if not n % 100:
print('Iteration {}'.format(n))
if test_all_inputs(inputs, params):
break

6 指南