BG = “#fafaf8”
DARK = “#1a1a1a”
# Color ramp: blue for common tokens, red for rare
TOKEN_COLORS = [“#1a5276”, “#2471a3”, “#5dade2”, “#e67e22”, “#c0392b”, “#7d2a2a”]
steps = np.arange(N_STEPS)
fig = plt.figure(figsize=(16, 11), facecolor=BG)
fig.suptitle(
“SGD vs. Adam on Rare Tokens — Frequency Bias and Variance Normalization”,
fontsize=14, fontweight=”bold”, color=DARK, y=0.99
)
gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)
# ── 1. SGD weight trajectories ────────────────────────────────
ax1 = fig.add_subplot(gs[0, :2])
ax1.set_facecolor(BG)
ax1.axhline(1.0, color=DARK, lw=1, ls=”–“, alpha=0.3, label=”True weight = 1.0″)
for i, (token, color) in enumerate(zip(TOKENS, TOKEN_COLORS)):
ax1.plot(steps, sgd_history[:, i], color=color, lw=1.8,
label=f”{token} (freq={FREQ[i]:.3f})”)
ax1.set_title(“SGD — Weight Trajectories\nRare tokens barely move from zero”, fontsize=11, color=DARK)
ax1.set_xlabel(“Training Step”, fontsize=9)
ax1.set_ylabel(“Learned Weight”, fontsize=9)
ax1.legend(fontsize=8, loc=”right”)
ax1.set_ylim(-0.3, 1.6)
ax1.spines[[“top”, “right”]].set_visible(False)
# Annotate failure zone
ax1.annotate(
“Rare tokens stuck\nnear zero”,
xy=(N_STEPS * 0.95, sgd_history[-1, 5]),
xytext=(N_STEPS * 0.65, -0.15),
fontsize=8.5, color=”#c0392b”,
arrowprops=dict(arrowstyle=”->”, color=”#c0392b”, lw=1.2),
bbox=dict(boxstyle=”round,pad=0.3″, facecolor=”#fff0f0″, edgecolor=”#c0392b”, alpha=0.85)
)
# ── 2. Final weight error bar chart ───────────────────────────
ax2 = fig.add_subplot(gs[0, 2])
ax2.set_facecolor(BG)
x = np.arange(6)
w_sgd = sgd_final
w_adam = adam_final
width = 0.35
bars_sgd = ax2.bar(x – width/2, np.abs(w_sgd – TRUE_W), width, color=”#c0392b”, alpha=0.85, label=”SGD error”)
bars_adam = ax2.bar(x + width/2, np.abs(w_adam – TRUE_W), width, color=”#2980b9″, alpha=0.85, label=”Adam error”)
ax2.set_xticks(x)
ax2.set_xticklabels([t[:8] for t in TOKENS], rotation=30, ha=”right”, fontsize=8)
ax2.set_ylabel(“|learned w − true w|”, fontsize=9)
ax2.set_title(“Final Weight Error\n(lower = better)”, fontsize=11, color=DARK)
ax2.legend(fontsize=8)
ax2.spines[[“top”, “right”]].set_visible(False)
# ── 3. Adam weight trajectories ───────────────────────────────
ax3 = fig.add_subplot(gs[1, :2])
ax3.set_facecolor(BG)
ax3.axhline(1.0, color=DARK, lw=1, ls=”–“, alpha=0.3, label=”True weight = 1.0″)
for i, (token, color) in enumerate(zip(TOKENS, TOKEN_COLORS)):
ax3.plot(steps, adam_history[:, i], color=color, lw=1.8,
label=f”{token} (freq={FREQ[i]:.3f})”)
ax3.set_title(“Adam — Weight Trajectories\nRare tokens converge via variance normalization”, fontsize=11, color=DARK)
ax3.set_xlabel(“Training Step”, fontsize=9)
ax3.set_ylabel(“Learned Weight”, fontsize=9)
ax3.legend(fontsize=8, loc=”right”)
ax3.set_ylim(-0.3, 1.6)
ax3.spines[[“top”, “right”]].set_visible(False)
ax3.annotate(
“Rare tokens converge\ndespite sparse gradients”,
xy=(N_STEPS * 0.95, adam_history[-1, 5]),
xytext=(N_STEPS * 0.60, 0.3),
fontsize=8.5, color=”#27ae60″,
arrowprops=dict(arrowstyle=”->”, color=”#27ae60″, lw=1.2),
bbox=dict(boxstyle=”round,pad=0.3″, facecolor=”#f0fff4″, edgecolor=”#27ae60″, alpha=0.85)
)
# ── 4. Effective LR vs frequency ─────────────────────────────
ax4 = fig.add_subplot(gs[1, 2])
ax4.set_facecolor(BG)
ax4.scatter(FREQ, effective_lr, c=TOKEN_COLORS, s=120, zorder=5, edgecolors=”white”, lw=1.5)
for i, token in enumerate(TOKENS):
ax4.annotate(token, (FREQ[i], effective_lr[i]),
textcoords=”offset points”, xytext=(6, 4), fontsize=7.5, color=TOKEN_COLORS[i])
ax4.axhline(LR, color=DARK, lw=1, ls=”–“, alpha=0.4)
ax4.text(0.5, LR * 1.05, f”Nominal LR = {LR}”, fontsize=8, color=DARK, alpha=0.6)
ax4.set_xscale(“log”)
ax4.set_yscale(“log”)
ax4.set_xlabel(“Token Frequency (log scale)”, fontsize=9)
ax4.set_ylabel(“Adam Effective LR lr/√v̂ (log scale)”, fontsize=9)
ax4.set_title(“Adam’s Automatic Equalizer\nRare tokens get amplified LR”, fontsize=11, color=DARK)
ax4.spines[[“top”, “right”]].set_visible(False)
plt.savefig(“sgd_vs_adam.png”, dpi=150, bbox_inches=”tight”, facecolor=BG)
plt.show()

