Ce qui rend JAX si génial

Pour la recherche en apprentissage automatique hautes performances, Just After eXceution (JAX) est NumPy sur le CPU, le GPU et le TPU, avec une excellente différenciation automatisée. Il s’agit d’une bibliothèque Python pour le calcul numérique haute performance, en particulier la recherche en apprentissage automatique. Son API numérique est basée sur NumPy, une bibliothèque de fonctions utilisées en calcul scientifique. Python et NumPy sont tous deux des langages de programmation renommés et utilisés, ce qui rend JAX simple, polyvalent et simple à mettre en œuvre. Cet article se concentrera sur les fonctionnalités JAX et leur implémentation pour créer un modèle d’apprentissage en profondeur. Voici les sujets à traiter.

Table des matières

  1. Raison d’utiliser JAX
  2. Qu’est-ce que XLA ?
  3. Qu’y a-t-il dans l’écosystème de JAX ?
  4. Créer un modèle ML avec JAX

JAX n’est pas un produit officiel de Google mais sa popularité augmente, connaissons les raisons de sa popularité.

Raison d’utiliser JAX

Bien que JAX fournisse une API simple et puissante pour développer du code numérique accéléré, travailler efficacement avec JAX nécessite parfois une réflexion supplémentaire. JAX est essentiellement un compilateur Just-In-Time (JIT) qui se concentre sur la génération de code efficace tout en utilisant la simplicité de Python pur. Outre l’API NumPy, JAX contient un ensemble extensible de transformations de fonctions composables qui facilitent la recherche en apprentissage automatique, telles que :

  • différenciation: L’optimisation basée sur les gradients est essentielle à l’apprentissage automatique. JAX permet nativement la différenciation automatisée de fonctions numériques arbitraires en mode avant et arrière à l’aide de transformations de fonctions telles que Gradients, Hessian et Jacobians (jacfwd et jacrev).
  • Vectorisation : Dans la recherche sur l’apprentissage automatique, une fonction unique est fréquemment appliquée à de grandes quantités de données, comme le calcul de la perte sur un lot ou l’évaluation des gradients par exemple pour un apprentissage différentiellement privé. La transformation vmap en JAX permet une vectorisation automatisée, ce qui simplifie ce type de programmation. Lors du développement de nouveaux algorithmes, par exemple, les chercheurs n’ont pas besoin d’envisager le traitement par lots. JAX permet également le parallélisme des données à grande échelle avec la transformation pmap associée, qui distribue élégamment les données trop vastes pour la mémoire d’un seul accélérateur.
  • Compilation juste à temps (JIT) : XLA est utilisé pour compiler JIT et exécuter des applications JAX sur des accélérateurs GPU et Cloud TPU. La compilation JIT, en conjonction avec l’API cohérente NumPy de JAX, permet aux chercheurs sans expérience préalable en informatique haute performance d’évoluer facilement vers un ou plusieurs accélérateurs.

Êtes-vous à la recherche d’un référentiel complet de bibliothèques Python utilisées en science des données, vérifier ici

Qu’est-ce que XLA ?

XLA (Accelerated Linear Algebra) est un compilateur d’algèbre linéaire spécifique à un domaine qui peut accélérer les modèles TensorFlow avec peu de modifications du code source.

Lorsqu’un programme TensorFlow est exécuté, l’exécuteur TensorFlow exécute chaque opération indépendamment. L’exécuteur envoie à une implémentation de noyau GPU précompilée pour chaque opération TensorFlow. XLA offre une manière supplémentaire d’exécuter le modèle en compilant le graphe TensorFlow dans une séquence de noyaux informatiques spécialement conçus pour le modèle spécifié. Étant donné que ces noyaux sont spécifiques au modèle, ils peuvent utiliser des informations spécifiques au modèle pour optimiser.

Architecture de XLA

Le langage d’entrée dans XLA est appelé opérations de haut niveau (HLO). Il est plus pratique de considérer HLO comme une représentation intermédiaire du compilateur. Ainsi, HLO représente un programme “entre” les langues source et cible.

XLA traduit les graphes décrits dans HLO en instructions machine pour plusieurs plates-formes. XLA est modulaire dans le sens où un backend alternatif peut être facilement inséré pour cibler une architecture matérielle innovante. XLA transfère le calcul HLO à un backend après la phase indépendante de la cible. Le backend peut effectuer des optimisations supplémentaires au niveau HLO, cette fois en tenant compte des données et des exigences spécifiques à la cible.

L’étape suivante consiste à générer du code spécifique à la cible. LLVM est utilisé par les backends CPU et GPU fournis avec XLA pour l’optimisation de la représentation intermédiaire de bas niveau et la création de code. Ces backends produisent le LLVM IR requis pour décrire efficacement le calcul XLA HLO, puis utilisent LLVM pour émettre du code natif à partir de cette représentation intermédiaire LLVM.

Raison d’utiliser XLA

Il y a quatre raisons principales d’utiliser XLA.

  • Car la traduction semble détailler l’analyse et la synthèse par définition. La traduction mot à mot est inefficace.
  • Diviser le défi complexe de la traduction en deux moitiés plus simples et plus gérables.
  • Un nouveau back-end peut être construit pour un front-end existant afin de fournir des compilateurs reciblables et vice versa.
  • Pour effectuer des optimisations indépendantes de la machine.

Qu’y a-t-il dans l’écosystème de JAX ?

L’écosystème se compose de cinq bibliothèques différentes.

Haïku

Le traitement d’objets avec état, tels que des réseaux de neurones avec des paramètres pouvant être entraînés, peut être difficile avec le paradigme de programmation JAX des transformations de fonctions composables. Haiku est une bibliothèque de réseaux de neurones qui permet aux utilisateurs d’utiliser des paradigmes de programmation traditionnels orientés objet tout en utilisant la puissance et la simplicité du paradigme fonctionnel pur de JAX.

Plusieurs projets externes, dont Coax, DeepChem et NumPyro, utilisent activement Haiku. Il étend l’API pour Sonnet, notre modèle de programmation de réseau neuronal basé sur des modules dans TensorFlow.

optaxe

L’optimisation basée sur les gradients est importante pour l’apprentissage automatique. Optax inclut une bibliothèque de transformation de gradient ainsi que des opérateurs de composition (tels que la chaîne) qui permettent le développement de nombreux optimiseurs communs (tels que RMSProp ou Adam) dans une seule ligne de code. La structure de composition d’Optax se prête facilement à la recombinaison des mêmes éléments fondamentaux dans des optimiseurs sur mesure. Il comprend également des utilitaires pour l’estimation de gradient stochastique et l’optimisation de second ordre.

RLax

RLax est une bibliothèque qui fournit des éléments de base importants pour le développement de l’apprentissage par renforcement (RL), également connu sous le nom d’apprentissage par renforcement profond. Les composants de RLax incluent l’apprentissage TD, les gradients de politique, les critiques d’acteurs, le MAP, l’optimisation de politique proximale, la transformation de valeur non linéaire, les fonctions de valeur génériques et de nombreuses approches d’exploration.

RLax n’est pas censé être un cadre pour le développement et le déploiement de systèmes d’agents RL à part entière. Acme est un exemple d’architecture d’agent complète basée sur des composants RLax.

Chex

Les tests sont essentiels pour la fiabilité des logiciels, et le code de recherche ne fait pas exception. Tirer des conclusions scientifiques d’essais de recherche nécessite de croire en la précision de votre code. Chex est une collection d’utilitaires de test utilisés par les auteurs de bibliothèques pour s’assurer que les blocs de construction communs sont corrects et résilients, ainsi que par les utilisateurs finaux pour valider leurs programmes expérimentaux.

Chex comprend un certain nombre d’outils, tels que les tests unitaires compatibles JAX, les assertions sur les attributs de type de données JAX, les simulations et les contrefaçons, et les environnements de test multi-périphériques.

Jraph

Jraph est une petite bibliothèque pour travailler avec les réseaux de neurones Graph GNN dans JAX. Jraph fournit une structure de données standardisée pour les graphes, un ensemble d’outils pour travailler avec des graphes et un ensemble de modèles de réseaux neuronaux de graphes qui sont facilement bifurquables et extensibles. Les autres fonctionnalités majeures incluent le traitement par lots GraphTuple qui tire parti des accélérateurs matériels, la prise en charge de la compilation JIT pour les graphiques de forme variable via le remplissage et le masquage, et les pertes spécifiées sur les partitions d’entrée. Jraph, comme Optax et nos autres bibliothèques, n’a aucune restriction sur le choix de l’utilisateur d’une bibliothèque de réseau de neurones.

Créer un modèle ML avec JAX

Pour cet article, création d’un modèle Generative Adversarial Net sur la plate-forme TensorFlow formée sur l’ensemble de données MNIST dans Jax’s Haiku.

Commençons par installer Haiku et Optax

!pip install dm-haiku
! pip install optax

Importer les bibliothèques nécessaires

import functools
from typing import Any, NamedTuple
 
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

Lecture du jeu de données

mnist_dataset = tfds.load("mnist")
def make_dataset(batch_size, seed=1):
  def _preprocess(sample):
    image = tf.image.convert_image_dtype(sample["image"], tf.float32)
    return 2.0 * image - 1.0
 
  ds = mnist["train"]
  ds = ds.map(map_func=_preprocess, 
              num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.cache()
  ds = ds.shuffle(10 * batch_size, seed=seed).repeat().batch(batch_size)
  return iter(tfds.as_numpy(ds))

Créer un générateur et un discriminateur

Le modèle est utilisé comme générateur pour produire de nouvelles instances plausibles à partir du domaine problématique, tandis que le modèle est utilisé comme discriminateur pour déterminer si un exemple est réel (du domaine) ou généré.

class Generator(hk.Module):
  def __init__(self, output_channels=(32, 1), name=None):
    super().__init__(name=name)
    self.output_channels = output_channels
 
  def __call__(self, x):
    x = hk.Linear(7 * 7 * 64)(x)
    x = jnp.reshape(x, x.shape[:1] + (7, 7, 64)) 
    for output_channels in self.output_channels:
      x = jax.nn.relu(x)
      x = hk.Conv2DTranspose(output_channels=output_channels,
                             kernel_shape=[5, 5],
                             stride=2,
                             padding="SAME")(x)
    return jnp.tanh(x)
class Discriminator(hk.Module):
 
  def __init__(self,
               output_channels=(8, 16, 32, 64, 128),
               strides=(2, 1, 2, 1, 2),
               name=None):   
    super().__init__(name=name)
    self.output_channels = output_channels
    self.strides = strides
 
  def __call__(self, x):
    for output_channels, stride in zip(self.output_channels, self.strides):
      x = hk.Conv2D(output_channels=output_channels,
                    kernel_shape=[5, 5],
                    stride=stride,
                    padding="SAME")(x)
      x = jax.nn.leaky_relu(x, negative_slope=0.2)
    x = hk.Flatten()(x)    
    logits = hk.Linear(2)(x)
    return logits

Création de l’algorithme GAN

import optax
class GAN_algo_basic:
  def __init__(self, num_latents):
    self.num_latents = num_latents
    self.gen_transform = hk.without_apply_rng(
        hk.transform(lambda *args: Generator()(*args)))
    self.disc_transform = hk.without_apply_rng(
        hk.transform(lambda *args: Discriminator()(*args)))
    self.optimizers = GANTuple(gen=optax.adam(1e-4, b1=0.5, b2=0.9),
                               disc=optax.adam(1e-4, b1=0.5, b2=0.9))
 
  @functools.partial(jax.jit, static_argnums=0)
  def initial_state(self, rng, batch):
    dummy_latents = jnp.zeros((batch.shape[0], self.num_latents))
    rng_gen, rng_disc = jax.random.split(rng)
    params = GANTuple(gen=self.gen_transform.init(rng_gen, dummy_latents),
                      disc=self.disc_transform.init(rng_disc, batch))
    print("Generator: nn{}n".format(tree_shape(params.gen)))
    print("Discriminator: nn{}n".format(tree_shape(params.disc)))
    opt_state = GANTuple(gen=self.optimizers.gen.init(params.gen),
                         disc=self.optimizers.disc.init(params.disc))
    
    return GANState(params=params, opt_state=opt_state)
 
  def sample(self, rng, gen_params, num_samples):
    """Generates images from noise latents."""
    latents = jax.random.normal(rng, shape=(num_samples, self.num_latents))
    return self.gen_transform.apply(gen_params, latents)
 
  def gen_loss(self, gen_params, rng, disc_params, batch):
    fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0])
    fake_logits = self.disc_transform.apply(disc_params, fake_batch)
    fake_probs = jax.nn.softmax(fake_logits)[:, 1]
    loss = -jnp.log(fake_probs)
    
    return jnp.mean(loss)
 
  def disc_loss(self, disc_params, rng, gen_params, batch):
    fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0])
    real_and_fake_batch = jnp.concatenate([batch, fake_batch], axis=0)
    real_and_fake_logits = self.disc_transform.apply(disc_params, 
                                                     real_and_fake_batch)
    real_logits, fake_logits = jnp.split(real_and_fake_logits, 2, axis=0)
    real_labels = jnp.ones((batch.shape[0],), dtype=jnp.int32)
    real_loss = sparse_softmax_cross_entropy(real_logits, real_labels)
    fake_labels = jnp.zeros((batch.shape[0],), dtype=jnp.int32)
    fake_loss = sparse_softmax_cross_entropy(fake_logits, fake_labels)
 
    return jnp.mean(real_loss + fake_loss)
  @functools.partial(jax.jit, static_argnums=0)
  def update(self, rng, gan_state, batch):
    rng, rng_gen, rng_disc = jax.random.split(rng, 3)
    disc_loss, disc_grads = jax.value_and_grad(self.disc_loss)(
        gan_state.params.disc,
        rng_disc, 
        gan_state.params.gen,
        batch)
    disc_update, disc_opt_state = self.optimizers.disc.update(
        disc_grads, gan_state.opt_state.disc)
    disc_params = optax.apply_updates(gan_state.params.disc, disc_update)
    gen_loss, gen_grads = jax.value_and_grad(self.gen_loss)(
        gan_state.params.gen,
        rng_gen, 
        gan_state.params.disc,
        batch)
    gen_update, gen_opt_state = self.optimizers.gen.update(
        gen_grads, gan_state.opt_state.gen)
    gen_params = optax.apply_updates(gan_state.params.gen, gen_update)
    
    params = GANTuple(gen=gen_params, disc=disc_params)
    opt_state = GANTuple(gen=gen_opt_state, disc=disc_opt_state)
    gan_state = GANState(params=params, opt_state=opt_state)
    log = {
        "gen_loss": gen_loss,
        "disc_loss": disc_loss,
    }
 
    return rng, gan_state, log

Former le modèle

for step in range(num_steps):
  rng, gan_state, log = model.update(rng, gan_state, next(dataset))
  if step % log_every == 0:   
    log = jax.device_get(log)
    gen_loss = log["gen_loss"]
    disc_loss = log["disc_loss"]
    print(f"Step {step}: "
          f"gen_loss = {gen_loss:.3f}, disc_loss = {disc_loss:.3f}")
    steps.append(step)
    gen_losses.append(gen_loss)
    disc_losses.append(disc_loss)

Le modèle sera formé pour 5000 étapes en raison de contraintes de temps. Cela dépend de l’utilisateur pour sélectionner le nombre d’étapes. Pour 5000 pas, il a fallu environ 60 minutes.

Analytique Inde Magazine

Analyse des pertes pour le générateur et le discriminateur

fig, axes = plt.subplots(1, 2, figsize=(20, 6))
 
# Plot the discriminator loss.
axes[0].plot(steps, disc_losses, "-")
axes[0].set_title("Discriminator loss", fontsize=20)
 
# Plot the generator loss.
axes[1].plot(steps, gen_losses, '-')
axes[1].set_title("Generator loss", fontsize=20);
Analytique Inde Magazine

Nous pouvons observer que la perte du générateur était assez élevée pendant les 2000 étapes initiales et après 3000 étapes, la perte du discriminateur et du générateur est devenue à peu près constante en moyenne.

conclusion

Just After eXceution (JAX) est un calcul numérique performant, notamment dans la recherche en apprentissage automatique. Son API numérique est basée sur NumPy, une bibliothèque de fonctions utilisées en calcul scientifique. Avec cet article, nous avons compris l’écosystème de JAX et l’implémentation d’Optax et Haiku qui font partie de cet écosystème.

Références

Leave a Comment

Your email address will not be published.