Your brain runs on about 20 watts. That’s about the same power draw as charging a phone. Meanwhile, training GPT-4 consumed enough energy to power roughly 5,000 American homes for a year.

The efficiency gap looks massive. Your brain processes information on the same power as a phone charger while datacenters measure their power budgets in megawatts.

But this comparison misses the point entirely.

This doesn’t mean AI has matched human intelligence. It means we’ve been asking the wrong question.

A recent paper from Argonne National Laboratory (“Enabling Physical AI Through Biological Principles”) challenges this common framing. The real gap isn’t about raw efficiency. The gap is about deployment environment.

The researchers introduce a framework based on two key metrics: action entropy rate (how many bits of decision per second) and learning cost (energy required to integrate new information). Using this lens, they identify why brains and current AI systems are optimized for fundamentally different environments.

Your brain operates at a single point in space and time, processing continuous sensory input. LLMs operate in datacenters, batching requests from thousands of users across different times and locations.

Your brain compresses roughly 125 megabytes per second of sensory input into 10 bits/second of action. That’s a compression ratio of 10^8. LLMs provide zero compression. They take text in and produce text out at roughly the same information density. If you input 100 tokens, you get back 100 tokens.

This creates fundamentally different constraints. AI thrives in datacenters where power and information concentrate from multiple users. Humans thrive in individual, real-time environments where sparse, low-entropy data must be processed continuously.

Why does this matter for ML engineers? The Argonne researchers identified five architectural principles that biological systems evolved over 500 million years—principles that modern ML systems are now rediscovering:

  1. Sparse activation – only activate what you need
  2. Event-driven processing – compute when something happens
  3. Memory-compute co-location – treat data movement as the enemy
  4. Continual learning – update without full retrains
  5. Precision as a lever – use just enough bits

Each principle maps to practical ML techniques.

Principle 1: Sparse Activation

At any moment, only 1-4% of your neurons fire. Right now, your visual cortex processes this text while your motor cortex controlling toe movement stays silent.

This isn’t accidental or random. Your brain actively prevents most neurons from firing. For every neuron that might fire, feedback from other neurons keeps it in check. The average neuron in the human brain receives roughly 4,200 inputs per second from 6,000 synapses. Yet it only fires about 0.7 times per second!

Research from Levy and Calvert shows that neural communication consumes 35 times more energy than computation itself, making selective activation crucial.

Modern ML has rediscovered this through Mixture of Experts (MoE) architectures. DeepSeek-V3, released in December 2024, epitomizes this approach: it has 671 billion total parameters but activates only 37 billion per token. That’s about 5.5% activation (eerily close to biological sparsity).


Here’s how it works: Instead of one large neural network processing everything, you build multiple smaller networks (called “experts”) that each specialize in different patterns. Then you add a “router” (another small network) that looks at each input and decides which 1-2 experts should handle it. The experts are just regular feedforward networks, nothing fancy. The magic is in only activating the ones you need.

Mixtral 8x7B uses 8 experts and activates 2 per token, achieving performance comparable to Llama 2 70B while using 5x fewer active parameters during inference.

Building a Basic MoE Layer

Here’s how you build one 👇🏾.

import torch
import torch.nn as nn
import torch.nn.functional as F


class ExpertNetwork(nn.Module):
    """Single expert: a simple two-layer feedforward network."""
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, input_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.act(self.fc1(x)))


def load_balancing_loss(router_probs: torch.Tensor, alpha: float = 0.01) -> torch.Tensor:
    """
    Simple load-balancing loss encouraging uniform expert utilization.

    router_probs: [batch, seq, num_experts] probabilities (softmax over experts).
    Minimizing sum(p_e^2) (with sum p_e = 1) pushes toward uniform usage.
    """
    num_experts = router_probs.shape[-1]
    usage = router_probs.mean(dim=(0, 1))  # [num_experts], sums to ~1
    return alpha * num_experts * torch.sum(usage * usage)


class SimpleMoE(nn.Module):
    """
    Basic Mixture of Experts with top-k routing (teaching version).

    - Dense softmax router_probs for analysis / aux losses
    - Top-k experts selected per token
    - Weighted combine of expert outputs
    """
    def __init__(self, input_dim: int, hidden_dim: int, num_experts: int = 8, top_k: int = 2):
        super().__init__()
        if top_k < 1 or top_k > num_experts:
            raise ValueError("top_k must be in [1, num_experts].")

        self.num_experts = num_experts
        self.top_k = top_k

        self.router = nn.Linear(input_dim, num_experts, bias=False)
        self.experts = nn.ModuleList(
            [ExpertNetwork(input_dim, hidden_dim) for _ in range(num_experts)]
        )

    def forward(self, x: torch.Tensor, return_aux: bool = True):
        """
        x: [batch, seq, dim]
        returns:
          y: [batch, seq, dim]
          aux: dict with router_probs and lb_loss (if return_aux)
        """
        batch_size, seq_len, dim = x.shape
        x_flat = x.reshape(-1, dim)  # [tokens, dim]
        tokens = x_flat.shape[0]

        logits = self.router(x_flat)                 # [tokens, num_experts]
        probs = F.softmax(logits, dim=-1, dtype=torch.float32).to(logits.dtype)
           # [tokens, num_experts]

        topk_probs, topk_idx = torch.topk(probs, self.top_k, dim=-1)  # [tokens, k], [tokens, k]
        topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)  # normalize over chosen experts

        y_flat = torch.zeros_like(x_flat)  # [tokens, dim]

        # Loop over experts (not over experts * k) and scatter-add weighted outputs.
        for expert_id, expert in enumerate(self.experts):
            mask = (topk_idx == expert_id)  # [tokens, k] boolean
            if not mask.any():
                continue

            token_idx, k_pos = mask.nonzero(as_tuple=True)  # which tokens routed here, and which slot
            expert_in = x_flat.index_select(0, token_idx)   # [n, dim]
            expert_out = expert(expert_in)                  # [n, dim]
            w = topk_probs[token_idx, k_pos].unsqueeze(-1)  # [n, 1]

            y_flat.index_add_(0, token_idx, expert_out * w)

        y = y_flat.reshape(batch_size, seq_len, dim)

        if not return_aux:
            return y

        router_probs = probs.reshape(batch_size, seq_len, self.num_experts)
        lb = load_balancing_loss(router_probs, alpha=0.01)

        aux = {
            "router_probs": router_probs,
            "topk_idx": topk_idx.reshape(batch_size, seq_len, self.top_k),
            "topk_probs": topk_probs.reshape(batch_size, seq_len, self.top_k),
            "lb_loss": lb,
        }
        return y, aux

What this gets you: You can build a model with 4x the parameters but only pay 2x the compute cost. Top-2 routing means you’re only running 2 out of 4 experts per token, but the model still has access to all 4 experts’ knowledge.

The catch: Routers can get lazy during training. Imagine you have 8 specialists on your team, but the coordinator keeps assigning all the work to just 2 people because they’re slightly better. The other 6 specialists sit idle, your team isn’t using its full capacity, and you’re still paying the coordinator to make assignments. That’s what happens without load balancing. You end up with an 8-expert model that only uses 2 experts. This is called “expert collapse”.

Load Balancing Loss

If you want to prevent expert collapse, you’ll need to implement load balancing.

def load_balancing_loss(router_probs, num_experts, alpha=0.01):
    """
    Encourages balanced expert utilization.
    
    Without this, routers often collapse to using 1-2 experts.
    """
    # Average probability assigned to each expert across batch
    expert_usage = router_probs.mean(dim=[0, 1])
    
    # Ideal: uniform distribution (1/num_experts for each)
    uniform_target = torch.ones_like(expert_usage) / num_experts
    
    # Penalize deviation from uniform
    return alpha * num_experts * (expert_usage * expert_usage).sum()

During training, you’ll add this loss alongside your main task loss. The load balancing loss penalizes the router when it overuses certain experts, which forces it to distribute work more evenly.

def load_balancing_loss(router_probs, num_experts, alpha=0.01):
    """
    Encourages balanced expert utilization.
    
    Without this, routers often collapse to using 1-2 experts.
    """
    expert_usage = router_probs.mean(dim=[0, 1])
    uniform_target = torch.ones_like(expert_usage) / num_experts
    return alpha * num_experts * (expert_usage * expert_usage).sum()

# Use it during training
for inputs, targets in train_loader:
    optimizer.zero_grad()

    outputs, aux = model(inputs, return_aux=True)
    task_loss = criterion(outputs, targets)

    total_loss = task_loss + aux["lb_loss"]
    total_loss.backward()
    optimizer.step()

The alpha=0.01 parameter controls how strongly you penalize imbalance. Start with 0.01 and adjust if you see experts still collapsing during training.

The Lottery Ticket Hypothesis

The Lottery Ticket Hypothesis from Frankle and Carbin (2019) provides another perspective on sparsity. Their research demonstrated that dense networks contain sparse subnetworks (“winning tickets”) that achieve full accuracy with only 10-20% of the original parameters. In some cases, networks pruned to just 3.6% of their original size performed nearly identically to the full network.

This suggests massive overparameterization in typical training, and practical opportunities for efficient gains through magnitude-based pruning.

🧠 Sparsity isn’t just an optimization trick. It’s a fundamental architectural principle that allows for efficient, selective processing.

Principle 2: Event Driven Processing

Your visual cortex doesn’t process 30 frames per second uniformly. It activates strongly when something changes and stays relatively quiet during static scenes. This event-driven processing saves a lot of energy compared to continuous sensing.

Neurons don’t fire on clock cycles. They fire when there’s information worth transmitting. Think of the difference between a motion sensor light and a regular light bulb. Regular bulbs burn power constantly. Motion sensors only turn on when something moves. Neurons work like motion sensors; they activate when there’s a change, not on a fixed timer. CPUs and GPUs work like regular bulbs – always on, always burning power.

ML serving systems have started adopting this motion-sensor approach. The three big wins: batching requests together instead of processing one at a time, stopping early when you have enough confidence, and caching results to avoid recomputing the same thing.

Continuous Batching

GPUs are terrible at processing one request at a time. Most of the time goes to overhead: loading weights, moving data, launching kernels. However, the actual matrix multiplications are fast.

Continuous batching (sometimes called dynamic batching) solves this. Wait 10 milliseconds for requests to arrive, then process them together. Your GPU processes 32 requests almost as fast as it processes 1.

Real-world results vary (depending on model size, request patterns, and hardware), but the wins are substantial. Anyscale reported up to 23x throughput improvements in their vLLM benchmarks.

Here’s an example implementation 👇🏾

import asyncio
from dataclasses import dataclass
from typing import List, Optional

import torch


@dataclass
class Request:
  """Container for a single inference request."""
  input_tensor: torch.Tensor
  future: asyncio.Future


class ContinuousBatcher:
    """
    Continuous (dynamic) batching for inference.

    Collects requests for up to max_wait_ms or until max_batch_size,
    then runs one batched model call and resolves per-request futures.

    Notes:
    - torch.stack requires same shapes; for NLP you usually pad/pack.
    - model execution is offloaded so we don't block the event loop.
    """
    def __init__(self, model: torch.nn.Module,
    max_batch_size: int = 32, max_wait_ms: int = 10):
        self.model = model.eval()
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms

        self._q: asyncio.Queue[Request] = asyncio.Queue()
        self._worker: Optional[asyncio.Task] = None

    def start(self) -> None:
        if self._worker is None or self._worker.done():
            self._worker = asyncio.create_task(self._batch_loop())

    async def stop(self) -> None:
        if self._worker is not None:
            self._worker.cancel()
            try:
                await self._worker
            except asyncio.CancelledError:
                pass
            self._worker = None

    async def infer(self, input_tensor: torch.Tensor) -> torch.Tensor:
      """Submit an inference request and wait for the result."""
        if self._worker is None:
            self.start()

        loop = asyncio.get_running_loop()
        fut: asyncio.Future = loop.create_future()
        await self._q.put(Request(input_tensor, fut))
        return await fut

    async def _batch_loop(self) -> None:
        timeout_s = self.max_wait_ms / 1000.0

        while True:
            reqs: List[Request] = []

            # Always wait for at least one request
            first = await self._q.get()
            reqs.append(first)

            # Then keep draining until full or timeout
            deadline = asyncio.get_running_loop().time() + timeout_s
            while len(reqs) < self.max_batch_size:
                remaining = deadline - asyncio.get_running_loop().time()
                if remaining <= 0:
                    break
                try:
                    reqs.append(await asyncio.wait_for(self._q.get(),       timeout=remaining))
                except asyncio.TimeoutError:
                    break

            await self._run_and_resolve(reqs)

    async def _run_and_resolve(self, reqs: List[Request]) -> None:
        inputs = torch.stack([r.input_tensor for r in reqs])

        def _do_infer(x: torch.Tensor) -> torch.Tensor:
            with torch.inference_mode():
                return self.model(x)

        try:
            outputs = await asyncio.to_thread(_do_infer, inputs)
            for r, out in zip(reqs, outputs):
                if not r.future.done():
                    r.future.set_result(out)
        except Exception as e:
            for r in reqs:
                if not r.future.done():
                    r.future.set_exception(e)

What this gets you: Throughput scales with how well you can fill batches. If requests arrive steadily, you’ll process 20-30 requests per batch instead of 1 at a time. Real systems see 10-20x throughput improvements vs naive baselines.

Early Exit

Most requests don’t need your full model. A clear, easy example might only need 4 layers to classify correctly. A hard, ambiguous example needs all 12 layers. Early-exit networks check confidence at intermediate layers. Confident? Stop computing. Uncertain? Keep going.

import torch
import torch.nn as nn
import torch.nn.functional as F


class EarlyExitTransformer(nn.Module):
  """
    Transformer with intermediate classification heads.
    
    Easy examples can exit early, saving computation.
    Hard examples use the full network depth.
  """
  def __init__(
      self,
      num_layers: int = 12,
      hidden_dim: int = 768,
      num_classes: int = 10,
      confidence_threshold: float = 0.95,
      exit_layers = (4, 8, 12),
  ):
      super().__init__()
      self.confidence_threshold = confidence_threshold
      self.exit_layers = list(exit_layers)

      # Transformer layers
      self.layers = nn.ModuleList([
          nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=12, batch_first=True)
          for _ in range(num_layers)
      ])

      # Exit points after layers 4, 8, and 12
      self.exits = nn.ModuleList([
          nn.Linear(hidden_dim, num_classes)
          for _ in self.exit_layers
      ])

  def forward(self, x: torch.Tensor):
    """
        Forward pass with early exits.
        
        Returns:
            logits: Final predictions
            exit_layer: Which layer produced the output
    """
      # x: [batch, seq, hidden]
      for layer_idx, layer in enumerate(self.layers, start=1):
          x = layer(x)

          # Check if this is an exit point
          if layer_idx in self.exit_layers:
              exit_idx = self.exit_layers.index(layer_idx)
              logits = self.exits[exit_idx](x.mean(dim=1))  # pool over seq

              # Check confidence
              probs = F.softmax(logits, dim=-1)
              max_prob, _ = probs.max(dim=-1)

               # Exit early if confident
              if max_prob.mean() > self.confidence_threshold:
                  return logits, layer_idx

      return logits, len(self.layers)

vLLM and PagedAttention

LLMs have a memory problem. When generating text, they store attention keys and values (KV cache) for every token they’ve processed so far. This cache grows with every token. For a 100-token prompt, you’re storing 100 sets of keys and values. For a 1000-token conversation, you’re storing 1000 sets.

Traditional approaches allocate a fixed block of memory for each request. If the conversation is short, you waste memory. If it’s long, you run out of space. vLLM’s PagedAttention solves this by managing memory in small chunks (like how your operating system manages RAM).

Need more space? Grab another chunk. Done with some tokens? Free that chunk for another request.

Now you can fit way more requests in GPU memory at once. Systems like vLLM can serve 2-5x more concurrent requests on the same hardware.

Here’s a simplified prefix cache (the core idea behind these optimizations):

import hashlib
from collections import OrderedDict

class PrefixKVCache:
    """
    Cache KV tensors for prompt prefixes.
    
    When multiple requests share a prompt prefix, reuse the
    computed KV states instead of recomputing them.
    """
    def __init__(self, max_entries=2048):
        self.max_entries = max_entries
        self.cache = OrderedDict()
    
    def _compute_key(self, token_ids):
        """Hash token sequence to create cache key."""
        token_bytes = token_ids.cpu().numpy().tobytes()
        return hashlib.sha256(token_bytes).hexdigest()
    
    def get(self, token_ids):
        """Retrieve cached KV states if available."""
        key = self._compute_key(token_ids)
        
        if key in self.cache:
            self.cache.move_to_end(key)  # Mark as recently used
            return self.cache[key]
        
        return None
    
    def put(self, token_ids, kv_states):
        """Store KV states for this token sequence."""
        key = self._compute_key(token_ids)
        
        self.cache[key] = kv_states
        self.cache.move_to_end(key)
        
        # Evict oldest if cache is full
        if len(self.cache) > self.max_entries:
            self.cache.popitem(last=False)

What this gets you: When requests share common prefixes (system prompts, document headers, repeated questions), you compute once and reuse.

If you have a hit rate of 40%, you’ve just cut compute by 40%. This is the most “brain-like” optimization! Recognition beats recomputation, and your infrastructure costs drop accordingly.

Principle 3: Memory-Compute Co-Location

Here’s something that surprised me: your GPU spends more energy fetching numbers than multiplying them. Specifically, reading 32 bits from DRAM (your GPU’s main memory storage) costs hundreds to 1000x more energy than adding two 32-bit numbers together!

Stanford CS149 (Fall 2024) lists ballpark numbers like ~1 pJ for an integer operation versus ~1200 pJ to read 64 bits from mobile DRAM (called LPDDR, the low-power memory used in phones and laptops). MIT’s 2024 systems material confirms: DRAM reads are drastically more expensive than arithmetic operations.

This creates a strange bottleneck. In common ML setups, data movement can eat ~80% of your energy budget. The actual multiplication and addition operations? Those are cheap. It’s the shuffling data around that gets you.

Think about how artificial neural networks work: you have weights (numbers) stored in memory, and when you need to compute something, you fetch those weights, do the math, then store the results back. Fetch, compute, store. Each step burns energy moving data.

Your brain doesn’t work like that. The connection between two neurons (the synapse) has a certain strength. That strength IS the stored information. When a signal arrives, the synapse transmits it based on that strength. Storage and computation happen in the same place. No fetching data from memory, no moving results back.

We can’t replicate this completely (we’re stuck with von Neumann architecture for now), but we can optimize for memory locality.

Training: Activation Checkpointing

During training, your model creates intermediate values (activations) at every layer. Normally, you save all of these in memory so you can use them during backpropagation. For a 12-layer transformer, that’s a lot of memory.

Checkpointing says: what if we only save activations at certain layers (checkpoints), then recompute the ones we didn’t save when we need them? You trade compute (recomputing) for memory (not storing everything). This cuts memory usage by about 50% but adds 30-40% more compute.

from torch.utils.checkpoint import checkpoint
import torch.nn as nn

class CheckpointedTransformer(nn.Module):
  """
  Transformer with gradient checkpointing.

  Saves activations only at checkpoints, recomputes
  intermediate activations during backward pass.
  """
  def __init__(self, num_layers: int = 12, hidden_dim: int = 768, nhead: int = 12):
      super().__init__()
      self.layers = nn.ModuleList([
          nn.TransformerEncoderLayer(
              d_model=hidden_dim,
              nhead=nhead,
              batch_first=True,   
          )
          for _ in range(num_layers)
      ])

  def forward(self, x):
    """
      Forward pass with checkpointing.
      
      Reduces peak memory by ~50% at the cost of
      ~33% more compute during backpropagation.
    """
    for i, layer in enumerate(self.layers):
        if self.training and (i % 2 == 0):
            x = checkpoint(layer, x, use_reentrant=False)
        else:
            x = layer(x)
    return x

The win: You can train bigger models on the same hardware. But that’s not the main point. The real benefit is reducing how much data you shuffle back and forth from memory. Remember, moving data costs hundreds to 1000x more energy than computing with it. Checkpointing cuts that data movement dramatically.

Inference: Quantization Reduces Data Movement

Quantization is when you store your model’s numbers using fewer bits. BERT-base has 110 million parameters. At 32 bits per parameter, that’s 440 MB. At 8 bits per parameter, it’s 110 MB.

Every time you run inference, you need to load those parameters from memory. With quantization, you’re loading 110 MB instead of 440 MB. Every inference now moves 330 MB less data from memory to the compute units. Since moving data costs hundreds to 1000x more energy than doing math, that 330 MB savings per inference adds up fast at scale.

import torchao
from torchao.quantization import quantize_, Int8WeightOnlyConfig

def quantize_model_int8_weight_only(model):
  """
  Apply int8 weight-only quantization to Linear layers (in-place).
  """
  model.eval()
  quantize_(model, Int8WeightOnlyConfig())
  return model

The win: Every inference moves 4× less data. Run a million inferences per day? You’ve eliminated 330 TB of data movement. Since moving data costs a lot more energy than computing, these savings add up fast. Plus, smaller models mean you fit more concurrent requests in the same GPU memory!

The fastest way to waste time in ML systems is to optimize the wrong thing. Most people focus on the wrong thing. They spend weeks making their matrix multiplications 20% faster when data movement is eating 80% of their energy budget. Optimize for memory locality first. Who cares if the math is fast when data movement is eating 80% of your energy budget?

Principle 4: Continual Learning

Your brain has a neat trick; it learns during sleep. While you’re unconscious, your brain replays the day’s experiences and gradually integrates them into long-term memory without erasing what you already know. You learn new skills without forgetting old ones.

AI systems can’t do this. When you fine-tune a model on new data, it often forgets what it learned before. Train a language model to be good at medical questions? It might get worse at coding questions. This phenomenon is called catastrophic forgetting (catastrophic because it’s sudden and severe, not gradual).

For example: you train a model to classify images of cats and dogs. It gets 95% accuracy. Great! Then you train it on birds and fish. Now when you test it on cats and dogs again, accuracy drops to 60%. The model didn’t integrate the new knowledge. It overwrote the old knowledge.

The learning cost stays high because instead of updating what the model knows (like your brain does), we often retrain from scratch or accept the performance hit. This gets expensive fast when models cost millions of dollars to train.

The landscape is changing. The 2024 EU AI Act requires providers of general-purpose AI models to document their energy consumption. So when you’re deciding between ‘retrain from scratch’ (simple but expensive) versus ‘fine-tune efficiently’ (complex but cheaper), you’ll need to justify the energy cost either way. Your CFO will want to see the numbers and/or regulators will want to see the numbers. Energy tracking isn’t optional anymore.

Elastic Weight Consolidation

EWC prevents catastrophic forgetting by figuring out which model parameters (the numbers that store what your model knows) were critical for old tasks, then makes those parameters harder to change when training on new tasks.

After training on Task A, EWC identifies which parameters were most important for that task. Then when you train on Task B, it adds a penalty to your loss function that says “don’t change those important parameters too much”. Parameters that mattered for Task A get protected.

import torch
import torch.nn as nn

class EWC:
  """
  Elastic Weight Consolidation (diagonal Fisher approximation).

  Tutorial version: estimates importance via accumulated squared grads.
  """
  def __init__(self, model: nn.Module, lambda_ewc: float = 100.0):
      self.model = model
      self.lambda_ewc = lambda_ewc
      self.fisher_info = {}
      self.optimal_params = {}

  @torch.no_grad()
  def _snapshot_params(self):
      self.optimal_params = {
          name: p.detach().clone()
          for name, p in self.model.named_parameters()
          if p.requires_grad
      }

  def compute_fisher_information(self, dataloader, criterion: nn.Module):
      """
      Estimate parameter importance using Fisher information.
      
      Run this on old task data before training on new task.
      """
      device = next(self.model.parameters()).device

      self.fisher_info = {
          name: torch.zeros_like(p, device=device)
          for name, p in self.model.named_parameters()
          if p.requires_grad
      }

      # Typically you want deterministic activations here (no dropout noise)
      self.model.eval()

      for inputs, targets in dataloader:
          inputs = inputs.to(device)
          targets = targets.to(device)

          self.model.zero_grad(set_to_none=True)
          outputs = self.model(inputs)
          loss = criterion(outputs, targets)
          loss.backward()

          # Accumulate squared gradients (Fisher information)
          for name, p in self.model.named_parameters():
              if p.grad is not None and name in self.fisher_info:
                  self.fisher_info[name] += p.grad.detach().pow(2)

      denom = max(1, len(dataloader))
      for name in self.fisher_info:
          self.fisher_info[name] /= denom

      # Save current parameters
      self._snapshot_params()

  def penalty(self) -> torch.Tensor:
     """
        Compute EWC penalty for current parameters.
        
        Add this to your loss: total_loss = task_loss + ewc.penalty()
      """
      device = next(self.model.parameters()).device
      if not self.fisher_info:
          return torch.tensor(0.0, device=device)

      loss = torch.tensor(0.0, device=device)
      for name, p in self.model.named_parameters():
          if name in self.fisher_info:
              fisher = self.fisher_info[name]
              optimal = self.optimal_params[name]
              loss += (fisher * (p - optimal).pow(2)).sum()

      return self.lambda_ewc * loss

Experience Replay: Learning Like You Sleep

EWC tries to protect important weights while learning new tasks. Experience Replay takes a different approach inspired by what your brain actually does during sleep: it mixes old experiences with new ones.

During slow-wave sleep, your hippocampus replays experiences from the day alongside older memories, gradually integrating everything without erasing what came before. Experience Replay does the same thing: when training on new data, you mix in examples from previous tasks.

The idea is simple; if you’re constantly reminding the model about old tasks while teaching it new ones, it can’t forget them.

import random
from collections import deque

import torch
import torch.nn as nn


class ExperienceReplayBuffer:
  """
  Store examples from old tasks and mix them with new task data.
    
    This mimics how your brain replays old memories during sleep
    to prevent forgetting them while learning new things.
  """
  def __init__(self, capacity: int = 10000):
      self.buffer = deque(maxlen=capacity)

  def add(self, inputs: torch.Tensor, targets: torch.Tensor):
      # Store on CPU to avoid growing GPU memory over time
      self.buffer.append((inputs.detach().cpu().clone(), targets.detach().cpu().clone()))

  def sample(self, batch_size: int):
      if len(self.buffer) <= batch_size:
          return list(self.buffer)
      return random.sample(self.buffer, batch_size)

  def __len__(self):
      return len(self.buffer)


# Usage: Training on multiple tasks sequentially
model = YourModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

replay_buffer = ExperienceReplayBuffer(capacity=10000)

# Train on Task A (e.g., classifying cats and dogs)
for epoch in range(num_epochs):
  model.train()
  for inputs, targets in task_a_dataloader:
      inputs = inputs.to(device)
      targets = targets.to(device)

      optimizer.zero_grad(set_to_none=True)
      outputs = model(inputs)
      loss = criterion(outputs, targets)
      loss.backward()
      optimizer.step()

      replay_buffer.add(inputs, targets)

# Train on Task B (e.g., classifying birds and fish)
# Mix in examples from Task A to prevent forgetting
for epoch in range(num_epochs):
  model.train()
  for inputs_new, targets_new in task_b_dataloader:
      inputs_new = inputs_new.to(device)
      targets_new = targets_new.to(device)

      # Sample old examples from Task A
      if len(replay_buffer) > 0:
          old_batch = replay_buffer.sample(batch_size // 2)
          inputs_old = torch.stack([x for x, _ in old_batch]).to(device)
          targets_old = torch.stack([y for _, y in old_batch]).to(device)
          
          # Mix old and new examples
          inputs_mixed = torch.cat([inputs_old, inputs_new], dim=0)
          targets_mixed = torch.cat([targets_old, targets_new], dim=0)
      else:
          inputs_mixed = inputs_new
          targets_mixed = targets_new

      # Train on mixed batch
      optimizer.zero_grad(set_to_none=True)
      outputs = model(inputs_mixed)
      loss = criterion(outputs, targets_mixed)
      loss.backward()
      optimizer.step()

      # Store new examples for future tasks
      replay_buffer.add(inputs_new, targets_new)

The model maintains performance on Task A while learning Task B. Instead of accuracy dropping from 95% to 60% (catastrophic forgetting), it might only drop to 90-92%. You’re constantly reminding the model about old tasks by showing it old examples mixed with new ones.

The trade-off: You need to store old training examples (memory cost) and spend extra compute replaying them. But this is way cheaper than full retraining.

Continual learning systems in robotics use experience replay to learn new skills without forgetting old ones. A robot learning to grasp new objects replays examples of grasping old objects to maintain that capability.

The win: instead of retraining your entire model from scratch every time you need to update it (which costs thousands to millions of dollars and weeks of compute time), you make small, incremental updates.

Think of it like performing maintenance on your car. You regularly get oil changes and tune-ups to avoid/delay major repair costs. Principle 4 is the preventative maintenance approach.

Principle 5: Precision as a Lever

Your brain operates on noisy signals. Individual neurons misfire, signals get lost, and the whole system is fundamentally unreliable at the microscopic level. Yet somehow, you can recognize your friend’s face, remember your phone number, and read this sentence without errors. The brain compensates for noise through redundancy and clever encoding.

Meanwhile, your GPU does math with perfect precision. Every calculation is exact to 32 bits of floating-point precision (FP32). Yet ML models trained with this machine precision still make mistakes, hallucinate facts, and fail in unexpected ways.

Here’s the counterintuitive part: you don’t need that perfect precision. Most ML workloads work fine with much less. Instead of 32 bits per number, you can often use 16 bits (FP16 or BF16) or even 8 bits (FP8). Your model gets slightly less precise, but the predictions barely change.

Why does this matter? Two reasons:

First, smaller numbers mean less data to move (remember, data movement costs hundreds to 1000x more energy than math). Using 16-bit numbers instead of 32-bit numbers cuts your data movement in half.

Second, modern hardware is designed for this. NVIDIA’s H100 GPU includes a “Transformer Engine” with FP8 support specifically because training transformers with 8-bit precision is faster and uses less energy than 32-bit precision. The hardware makers know precision is a lever, not a requirement.

import torch
import torch.nn as nn
from torch.amp import autocast, GradScaler

def train_with_mixed_precision(model, train_loader, optimizer, criterion, dtype=torch.bfloat16):
    """
    Mixed precision training.
    - BF16: usually no GradScaler needed
    - FP16: GradScaler recommended
    """
    device = next(model.parameters()).device
    use_scaler = (dtype == torch.float16)

    scaler = GradScaler("cuda", enabled=use_scaler)
    model.train()

    for inputs, targets in train_loader:
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", dtype=dtype):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        if use_scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

The win: Mixed precision is now standard practice for training large models. You run most operations in 16-bit precision (BF16 or FP16) and keep a few critical operations in 32-bit precision for numerical stability. Training time drops by 30-50% and you use about half the GPU memory. The model quality? Barely changes.

This is why every major framework (PyTorch, TensorFlow, JAX) has built-in mixed precision support. It’s the first thing people enable when training transformers.

Quantization: Going Even Lower

Mixed precision helps during training. Quantization helps during inference (serving your model to users).

Here’s the idea: your trained model has millions of parameters stored as 32-bit numbers. Each parameter takes 4 bytes of memory. For a 7 billion parameter model, that’s 28 GB just to store the weights. Every inference loads all 28 GB from memory.

Quantization says: “what if we stored those parameters as 8-bit integers instead”? Now each parameter takes 1 byte. Your 7B parameter model shrinks from 28 GB to 7 GB. You’re moving 21 GB less data per inference. Since data movement costs hundreds to 1000x more energy than math, this saves enormous amounts of power at scale.

The surprising part: model quality barely drops. Recent research (like Yao et al.’s “‘Give Me BF16 or Give Me Death’?”) shows you can often quantize models down to 8 bits or even 4 bits with minimal accuracy loss. A 7B parameter model quantized to 4 bits fits in 3.5 GB and runs on consumer hardware.

def quantize_for_inference(model):
    """
    Apply post-training quantization for deployment.
    
    Weight-only INT8 is the most practical pattern.
    """
    model.eval()
    
    quantized_model = torch.ao.quantization.quantize_dynamic(
        model,
        {nn.Linear},
        dtype=torch.qint8
    )
    
    return quantized_model

The win: You’re moving 4x less data from memory on every inference. For BERT-base, that’s saving 330 MB of data movement per request. Remember, moving data costs hundreds to 1000x more energy than doing math with it. This compounds fast. Run a million inferences per day? You just eliminated 330 TB of data shuffling. Plus, smaller models mean you can fit more concurrent requests in the same GPU memory.

Here’s what matters: precision is a lever not a requirement. Your brain operates on noisy signals and still works. Your GPU does perfect 32-bit math and still makes mistakes. You don’t need that precision for most ML tasks. Use just enough bits to get the job done (usually 8-16 bits instead of 32), and spend the savings on moving less data around. That’s where the real energy and cost reductions come from.

Hardware Evolution: Where This Is Going

ML is leaving the datacenter. Not completely (cloud inference isn’t going anywhere), but the next wave of AI systems need to run on robots, IoT devices, and your laptop without draining the battery in 20 minutes.

This creates a problem: you can’t put a 700-watt GPU in a robot vacuum or a smartphone. You need hardware that’s fundamentally more efficient. And guess what engineers are looking at for inspiration? The brain’s architecture.

Neuromorphic Chips: Event-Driven Hardware

Remember how we talked about neurons firing when there’s information worth transmitting? Neuromorphic chips implement this directly in hardware.

Intel’s Loihi 2 chip doesn’t run on a clock like a normal CPU. Instead, individual computing units (artificial neurons) only activate when they receive input that crosses a threshold. No input? No power consumed. This is the asynchronous, event-driven approach built into silicon.

Intel recently scaled this up to Hala Point, now the world’s largest neuromorphic system. IBM’s NorthPole chip takes a different angle: it puts memory and compute in the same physical location (remember Principle 3 about data movement costs?). For certain vision tasks, NorthPole achieves 25x better energy efficiency than traditional GPUs.

Mobile Neural Processors: ML in Your Pocket

Your smartphone probably already has a dedicated AI chip. Apple calls it the Neural Engine. Qualcomm calls it the Hexagon NPU (Neural Processing Unit, which is just a chip designed specifically for running neural networks).

The latest Snapdragon X2 Elite Extreme has a Hexagon NPU 6 that delivers 80 TOPS. TOPS means Trillions of Operations Per Second (yes, trillions with a T). For context, that’s competitive with Apple’s M4 chip and fast enough to run a 7-billion parameter language model locally on your laptop.

This matters because running models on-device means:

  • No cloud API costs
  • No network latency
  • Your data stays on your device (better privacy)
  • It works offline

The brain’s 20-watt power budget is becoming a design target for consumer hardware.

The Business Reality

This isn’t just academic anymore, and it boils down to three concepts: regulation, costs, and hardware limitations.

Regulation: The EU AI Act requires companies building general-purpose AI models to document their energy consumption. It’s not optional. If you’re deploying models in Europe, you need to measure and report power usage.

Cost: CFOs care about inference costs because they scale linearly with usage. Serve 1 million requests? Pay for 1 million inferences. Serve 100 million requests? Pay 100x more. Energy efficiency directly impacts your unit economics.

Hardware limits: Moore’s Law (the observation that chip transistor counts double every ~2 years) is slowing down. We’re hitting physical limits of how small we can make transistors. You can’t just wait for next year’s GPU to be 2x faster at the same power.

This means algorithmic efficiency isn’t optional anymore. It’s the only lever left to pull.

Your Monday Morning Action Plan

Here’s where to start if you want to make your ML systems more efficient. You don’t need to do all of this. Pick one thing that fits your current project.

1. Measure Your Baseline

Right now, do you know how much energy your model uses per inference? What about throughput or latency? If the answer is “not really,” start there. You can’t optimize what you haven’t measured.

Here’s a simple GPU power tracker:

import time

import torch

try:
  import pynvml
  HAS_NVML = True
except ImportError:
  HAS_NVML = False


class EfficiencyTracker:
  """
  Track GPU power consumption and throughput.
    
  Provides simple metrics to understand where your
  model stands before and after optimizations.
  """
  def __init__(self, device_id: int = 0):
      self.device_id = device_id
      self.handle = None
      if HAS_NVML:
          pynvml.nvmlInit()
          self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)

  def __del__(self):
      if HAS_NVML:
          try:
              pynvml.nvmlShutdown()
          except Exception:
              pass

  def get_power_watts(self) -> float:
      """Get current GPU power draw in watts."""
      if self.handle is None:
          return 0.0
      return pynvml.nvmlDeviceGetPowerUsage(self.handle) / 1000.0

  def measure_inference(self, model, dataloader, num_iters: int = 50, warmup_iters: int = 5):
       """Measure inference efficiency metrics."""
      model.eval()
      device = next(model.parameters()).device

      # Warmup (important for kernel autotune / caching)
      with torch.inference_mode():
          it = iter(dataloader)
          for _ in range(warmup_iters):
              inputs, _ = next(it)
              inputs = inputs.to(device, non_blocking=True)
              _ = model(inputs)

      torch.cuda.synchronize()
      start = time.perf_counter()

      total_energy_j = 0.0
      total_samples = 0
      last_t = start
      last_p = self.get_power_watts()  # initial sample

      with torch.inference_mode():
          for i, (inputs, _) in enumerate(dataloader):
              if i >= num_iters:
                  break

              inputs = inputs.to(device, non_blocking=True)

              _ = model(inputs)
              torch.cuda.synchronize()

              now = time.perf_counter()
              p = self.get_power_watts()

              # Trapezoidal integration of power over time
              dt = now - last_t
              total_energy_j += 0.5 * (last_p + p) * dt

              last_t = now
              last_p = p
              total_samples += inputs.size(0)

      end = time.perf_counter()
      total_time_s = end - start

      return {
          "throughput_samples_per_sec": total_samples / total_time_s if total_time_s > 0 else 0.0,
          "avg_time_per_sample_ms": (total_time_s / total_samples) * 1000 if total_samples > 0 else 0.0,
          "energy_per_sample_joules": total_energy_j / total_samples if total_samples > 0 else 0.0,
          "total_energy_joules": total_energy_j,
      }


# Usage
tracker = EfficiencyTracker(device_id=0)
metrics = tracker.measure_inference(model, test_loader)

print(f"Throughput: {metrics['throughput_samples_per_sec']:.2f} samples/sec")
print(f"Avg time/sample: {metrics['avg_time_per_sample_ms']:.2f} ms")
print(f"Energy per sample: {metrics['energy_per_sample_joules']:.4f} J")

2. Calculate Costs at Scale

You might be serving 1,000 requests per day. That feels manageable. But what happens when you’re successful and that becomes 100,000 requests? Or a million?

A model that costs $0.01 per inference seems cheap until you multiply it by a million. That’s $10,000 per day, or $3.6 million per year. Suddenly, shaving that cost to $0.005 per inference (through quantization or batching) saves you $1.8 million annually. Your CFO will love you.

Grab a calculator and run the numbers for your actual usage patterns. It’s eye-opening.

3. Pick One Optimization (Not All of Them)

Here are three that give you the most bang for your effort:

Quantization – Takes about an hour to implement, cuts your model size by 4× (and memory bandwidth by the same amount). If you’re serving a model right now, start here.

Mixed precision training – Also takes about an hour to add three lines of code to your training loop. You’ll see 30-50% faster training and use half the GPU memory. If you’re training models, start here.

Basic batching – Takes maybe a day to implement properly. Gives you 10-20× better throughput if you’re serving requests online. If you’re running inference for multiple users, start here.

Pick whichever one matches your current pain point. Don’t try to do all three this week.

4. Track Energy Like You Track Loss

Add energy consumption to your training logs. Put it right next to your loss curves and accuracy numbers. Make it visible.

Why? Because what gets measured gets optimized. If you’re only looking at accuracy, you’ll only optimize for accuracy. If you’re tracking energy alongside accuracy, you’ll start making different trade-off decisions.

⚠️ A word of caution: These optimizations compound in unexpected ways. Quantization + batching + mixed precision can interact in ways that aren’t always obvious. Don’t try to implement everything at once. Measure your baseline, pick one optimization, measure again, then move to the next.

If something breaks or your accuracy tanks, you’ll know exactly which change caused it. Trust me on this one 😅.

Takeaway

Evolution spent half a billion years optimizing the brain for efficiency. It arrived at sparse activation (only fire what’s needed), event-driven processing (compute when something happens), co-located memory and compute (no data buses), continuous learning (update without retraining), and noisy-but-good-enough precision (use just enough bits).

These aren’t just biological curiosities you’d learn about in a neuroscience class. They’re practical engineering patterns that show up whenever you’re building systems under tight constraints.

Limited power? You need sparsity. Real-time requirements? You need event-driven architecture. Expensive data movement? You need locality. Frequent updates? You need continual learning. Tight energy budget? You need lower precision.

Modern ML is stumbling back towards these same solutions. MoE models implement sparse activation. Continuous batching and caching implement event-driven processing. Quantization and activation checkpointing reduce data movement. EWC and experience replay enable continual learning. Mixed precision and quantization embrace “good enough” over “perfect”.

The convergence isn’t accidental. When you optimize for the same constraints (power, memory, latency, update frequency), you end up with similar architectures. Your brain figured this out 500 million years ago through trial and error. Now you can apply it to your ML systems.

References

Anyscale. “Achieve 23x LLM Inference Throughput & Reduce p50 Latency.” Anyscale Blog, 2023. https://www.anyscale.com/blog/continuous-batching-llm-inference.

Argonne National Laboratory. “Enabling Physical AI Through Biological Principles.” arXiv, 2025. https://arxiv.org/pdf/2509.24521.

DeepSeek AI. “DeepSeek-V3 Technical Report.” arXiv, December 2024. https://arxiv.org/html/2412.19437v1.

Extreme Networks. “Artificial Intelligence, Real Consequences: Confronting AI’s Growing Energy Appetite.” Extreme Networks Blog, 2024. https://www.extremenetworks.com/resources/blogs/confronting-ai-growing-energy-appetite-part-1.

Frankle, Jonathan, and Michael Carbin. “The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks.” arXiv, March 2019. https://arxiv.org/abs/1803.03635.

IBM Research. “IBM Research’s New NorthPole AI Chip.” IBM Research Blog, 2023. https://research.ibm.com/blog/northpole-ibm-ai-chip.

Intel Corporation. “Intel Builds World’s Largest Neuromorphic System.” Press release, April 2024. https://www.intc.com/news-events/press-releases/detail/1691/intel-builds-worlds-largest-neuromorphic-system-to.

Jiang, Albert Q., et al. “Mixtral of Experts.” arXiv, January 2024. https://arxiv.org/abs/2401.04088.

Kirkpatrick, James, et al. “Overcoming Catastrophic Forgetting in Neural Networks.” Proceedings of the National Academy of Sciences 114, no. 13 (2017): 3521-3526.

Levy, William B., and Robert A. Calvert. “Communication Consumes 35 Times More Energy than Computation in the Human Cortex.” Proceedings of the National Academy of Sciences 118, no. 18 (2021): e2008173118.

Micikevicius, Paulius, et al. “Mixed Precision Training.” arXiv, February 2018. https://arxiv.org/abs/1710.03740.

NIST. “Brain-Inspired Computing Can Help Us Create Faster, More Energy-Efficient Devices.” NIST Taking Measure Blog, 2023. https://www.nist.gov/blogs/taking-measure/brain-inspired-computing-can-help-us-create-faster-more-energy-efficient.

Rusu, Andrei A., et al. “Progressive Neural Networks.” arXiv, June 2016. https://arxiv.org/abs/1606.04671.

Sojasingarayar, Abonia. “vLLM and PagedAttention: A Comprehensive Overview.” Medium, 2024. https://medium.com/@abonia/vllm-and-pagedattention-a-comprehensive-overview-20046d8d0c61.

Yao, Zhewei, et al. “‘Give Me BF16 or Give Me Death’? Accuracy-Performance Trade-Offs in LLM Quantization.” arXiv, November 2024. https://arxiv.org/abs/2411.02355.