RL Diagram

Learn, Implement, Master

Welcome to my blog focused on reinforcement learning with JAX! This resource is designed to help researchers and practitioners understand both the theoretical foundations of reinforcement learning and their efficient implementation using Google's JAX library.

Throughout this blog, you'll find tutorials, code examples, and in-depth explanations of key RL concepts—from basics like Markov Decision Processes and Q-learning to advanced topics such as policy gradient methods and model-based RL. Each article emphasizes practical implementation techniques that leverage JAX's automatic differentiation, vectorization, and GPU/TPU acceleration.

15-Week Course Schedule

Week 1

Introduction to RL

Overview of RL paradigm, key concepts, relation to other ML approaches, and introduction to JAX basics.

Week 2

Markov Decision Processes

Mathematical formulation of MDPs, policies, value functions, and the Bellman equations.

Week 3

Dynamic Programming

Policy evaluation, policy improvement, value iteration, and policy iteration with JAX implementation.

Week 4

Monte Carlo Methods

Monte Carlo prediction and control, first-visit MC, every-visit MC, and exploring starts.

Week 5

Temporal Difference Learning

TD prediction (TD(0)), SARSA, Q-learning, and expected SARSA with JAX acceleration.

Week 6

n-step Bootstrapping

n-step TD prediction, n-step SARSA, n-step Q-learning, and the forward and backward views.

Week 7

Planning and Learning

Model-based RL, Dyna-Q, prioritized sweeping, and monte carlo tree search using JAX.

Week 8

Function Approximation

Feature-based representations, gradient TD methods, and linear function approximation with JAX.

Week 9

Deep Q-Networks

Deep Q-learning, experience replay, target networks, and implementing DQN with JAX and Haiku.

Week 10

Policy Gradient Methods

REINFORCE algorithm, policy gradient theorem, and actor-critic methods implemented with JAX.

Week 11

Trust Region Methods

TRPO, PPO, and constrained policy optimization with JAX's automatic differentiation.

Week 12

Model-Based RL

Learning environment models, model-based planning, and imagination-based methods with JAX implementations.

Week 13

Exploration vs. Exploitation

Multi-armed bandits, UCB, Thompson sampling, intrinsic motivation, and curiosity-driven learning.

Week 14

Multi-Agent RL

Markov games, cooperative and competitive settings, multi-agent learning algorithms with JAX.

Week 15

Future Directions

Meta-learning, hierarchical RL, inverse RL, offline RL, and the latest research developments.

Reinforcement Learning Basics with JAX
Deep Reinforcement Learning

Reinforcement Learning Basics with JAX

Mastering RL algorithms through theory and efficient implementation with Google's high-performance numerical computing library

JAX • PyTree • Neural Networks Q-Learning • Policy Gradients
Ji Sue Lee
Ji Sue Lee Hanyang University

Key Features

Take advantage of these powerful features to accelerate your learning and implementation

Performance

JAX Acceleration

Learn how to leverage JAX's just-in-time compilation and automatic differentiation to speed up RL algorithm implementation and training.

Learn more
Learning

Interactive Notebooks

Access complete Jupyter notebooks with step-by-step implementations of RL algorithms, complete with detailed explanations and visualizations.

Explore notebooks
Implementation

Complete Code Repository

Access a GitHub repository with clean, modular implementations of all the algorithms discussed in the blog, designed for research and practical applications.

View repository