Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

L14 - Parameter-Efficient Fine-Tuning (LoRA) [DRAFT]

Fine-tune 7B models on a single GPU with Low-Rank Adaptation


In L10, we learned about fine-tuning through SFT and RLHF. But full fine-tuning of a 7B model requires 80GB+ of VRAM. LoRA (Low-Rank Adaptation) makes fine-tuning possible on consumer hardware by updating only a tiny fraction of parameters.

By the end of this post, you’ll understand:


Part 1: The Memory Problem

Full Fine-Tuning Memory Requirements

For a 7B parameter model:

ComponentMemory (FP32)Memory (FP16)
Model weights28 GB14 GB
Gradients28 GB14 GB
Optimizer states (AdamW)56 GB28 GB
Total112 GB56 GB

Even in FP16, this requires multiple A100 GPUs (80GB each).

The key insight: Most weight updates are low-rank. We don’t need to update the full weight matrix.


Part 2: Low-Rank Matrix Intuition

Visualizing Matrix Rank

A full-rank matrix has independent rows/columns:

W=[123456789]Rank = 3W = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix} \quad \text{Rank = 3}

A low-rank matrix (rank 1) can be written as outer product:

W=[123][123]=[123246369]Rank = 1W = \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} \begin{bmatrix} 1 & 2 & 3 \end{bmatrix} = \begin{bmatrix} 1 & 2 & 3 \\ 2 & 4 & 6 \\ 3 & 6 & 9 \end{bmatrix} \quad \text{Rank = 1}

Key observation:

For a d×dd \times d matrix with rank rr:

Example: For d=4096d=4096 and r=8r=8:


Part 3: How LoRA Works

The LoRA Equation

Instead of updating the full weight matrix WW:

Wnew=W+ΔWW_{\text{new}} = W + \Delta W

LoRA represents ΔW\Delta W as a low-rank decomposition:

Wnew=W+BAW_{\text{new}} = W + BA

Where:

Visualization:

Matplotlib is building the font cache; this may take a moment.
/tmp/ipykernel_5290/1562609318.py:35: UserWarning: Glyph 128293 (\N{FIRE}) missing from font(s) DejaVu Sans.
  plt.tight_layout()
/opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 128293 (\N{FIRE}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
<Figure size 1600x400 with 4 Axes>

Part 4: Implementing LoRA from Scratch

LoRA Linear Layer

import torch
import torch.nn as nn

class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=8, alpha=16):
        super().__init__()
        self.rank = rank
        self.alpha = alpha

        # Initialize A with random Gaussian, B with zeros
        self.lora_A = nn.Parameter(torch.randn(in_features, rank) / rank)
        self.lora_B = nn.Parameter(torch.zeros(rank, out_features))

        # Scaling factor (usually set to alpha / rank)
        self.scaling = alpha / rank

    def forward(self, x):
        # Standard: W @ x
        # LoRA: W @ x + (B @ A) @ x = W @ x + x @ A @ B
        return (x @ self.lora_A @ self.lora_B) * self.scaling

# Replace a Linear layer with LoRA
class LoRALinear(nn.Module):
    def __init__(self, linear_layer, rank=8, alpha=16):
        super().__init__()
        self.linear = linear_layer
        self.linear.weight.requires_grad = False  # Freeze original weights

        # Add LoRA matrices
        self.lora = LoRALayer(
            linear_layer.in_features,
            linear_layer.out_features,
            rank=rank,
            alpha=alpha
        )

    def forward(self, x):
        # Original output + LoRA adaptation
        return self.linear(x) + self.lora(x)

# Usage
original_layer = nn.Linear(4096, 4096)
lora_layer = LoRALinear(original_layer, rank=8)

# Only LoRA parameters require gradients!
trainable_params = sum(p.numel() for p in lora_layer.parameters() if p.requires_grad)
print(f"Trainable params: {trainable_params:,}")  # 65,536 instead of 16,777,216

Part 5: Applying LoRA to a Transformer

Which Layers to Apply LoRA To?

Common strategy: Apply to attention projection matrices (Q, K, V, O).

def apply_lora_to_model(model, rank=8, alpha=16):
    """Apply LoRA to all attention layers."""

    for name, module in model.named_modules():
        # Apply to Q, K, V, O projections in attention
        if isinstance(module, nn.Linear) and any(x in name for x in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
            parent_name = '.'.join(name.split('.')[:-1])
            child_name = name.split('.')[-1]

            # Replace with LoRA version
            parent = model.get_submodule(parent_name)
            setattr(parent, child_name, LoRALinear(module, rank=rank, alpha=alpha))

    # Freeze all non-LoRA parameters
    for name, param in model.named_parameters():
        if 'lora' not in name:
            param.requires_grad = False

    return model

# Usage
model = GPT(config)
model = apply_lora_to_model(model, rank=8)

# Count trainable parameters
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())

print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
# Typically 0.1-1% of total parameters!

Part 6: Training with LoRA

Standard Training Loop

# Only optimize LoRA parameters
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4  # Can use higher LR than full fine-tuning
)

for epoch in range(num_epochs):
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass (uses frozen W + trainable BA)
        logits = model(input_ids)
        loss = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1))

        # Backward pass (only LoRA parameters get gradients)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Memory Comparison

Full Fine-Tuning (7B model):

LoRA Fine-Tuning (7B model, rank=8):

Result: Fits on a single RTX 3090 (24GB)!


Part 7: QLoRA - Quantized LoRA

The Next Level: 4-bit Quantization

QLoRA combines:

  1. 4-bit quantization of the frozen base model

  2. LoRA adapters in FP16

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

# Load model in 4-bit
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"  # NormalFloat4
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=bnb_config,
    device_map="auto"
)

# Apply LoRA on top
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

Memory savings:


Part 8: Saving and Loading LoRA Adapters

Save Only the Adapters

# Save only LoRA weights (tiny file!)
torch.save({
    'lora_state_dict': {k: v for k, v in model.state_dict().items() if 'lora' in k}
}, 'lora_adapters.pt')

# File size: ~100 MB instead of 14 GB!

Load Adapters onto Base Model

# Load base model (unchanged)
base_model = GPT.from_pretrained('gpt-7b')

# Apply LoRA architecture
base_model = apply_lora_to_model(base_model, rank=8)

# Load trained adapters
checkpoint = torch.load('lora_adapters.pt')
base_model.load_state_dict(checkpoint['lora_state_dict'], strict=False)

Use case: Ship multiple “personalities” as separate adapter files:

Users download one base model (14 GB) + swap adapters as needed!


Part 9: When to Use LoRA vs. Full Fine-Tuning

CriterionFull Fine-TuningLoRA
MemoryHigh (56+ GB)Low (14 GB)
Training speedSlowerFaster (fewer gradients)
Final performanceSlightly betterVery close (within 1-2%)
Task similarityWorks for very different tasksBest for similar tasks
Adapter swappingNoYes (multiple adapters)

Rule of thumb:


Part 10: Hyperparameter Tuning

Key Hyperparameters

ParameterTypical RangeEffect
rank (r)8-64Higher = more capacity, more memory
alpha16-32Scaling factor (usually 2× rank)
target_modules[“q_proj”, “v_proj”]Which layers to adapt
dropout0.0-0.1Regularization
learning_rate1e-4 to 3e-4Higher than full fine-tuning

Grid search example:

for rank in [8, 16, 32]:
    for alpha in [16, 32]:
        model = apply_lora(base_model, rank=rank, alpha=alpha)
        score = train_and_evaluate(model)
        print(f"Rank={rank}, Alpha={alpha}: {score:.3f}")

Part 11: Visualizing LoRA Updates

Heatmap: Which Layers Change Most?

Distribution: LoRA Weight Magnitudes


Summary

  1. LoRA reduces trainable parameters by 100-1000× using low-rank decomposition

  2. Memory drops from 56 GB → 14 GB for 7B models

  3. QLoRA adds 4-bit quantization → ~4 GB total

  4. Apply to attention projections (Q, K, V, O) for best results

  5. Performance within 1-2% of full fine-tuning

  6. Can swap adapters at inference (multiple tasks, one base model)

Next Up: L15 – Mixed Precision Training. Make training 2× faster with FP16/BF16!