section(“6) pack unpack”)
B, Cemb = 2, 128
class_token = torch.randn(B, 1, Cemb, device=device)
image_tokens = torch.randn(B, 196, Cemb, device=device)
text_tokens = torch.randn(B, 32, Cemb, device=device)
show_shape(“class_token”, class_token)
show_shape(“image_tokens”, image_tokens)
show_shape(“text_tokens”, text_tokens)
packed, ps = pack([class_token, image_tokens, text_tokens], “b * c”)
show_shape(“packed”, packed)
print(“packed_shapes (ps):”, ps)
mixer = nn.Sequential(
nn.LayerNorm(Cemb),
nn.Linear(Cemb, 4 * Cemb),
nn.GELU(),
nn.Linear(4 * Cemb, Cemb),
).to(device)
mixed = mixer(packed)
show_shape(“mixed”, mixed)
class_out, image_out, text_out = unpack(mixed, ps, “b * c”)
show_shape(“class_out”, class_out)
show_shape(“image_out”, image_out)
show_shape(“text_out”, text_out)
assert class_out.shape == class_token.shape
assert image_out.shape == image_tokens.shape
assert text_out.shape == text_tokens.shape
section(“7) layers”)
class PatchEmbed(nn.Module):
def __init__(self, in_channels=3, emb_dim=192, patch=8):
super().__init__()
self.patch = patch
self.to_patches = Rearrange(“b c (h p1) (w p2) -> b (h w) (p1 p2 c)”, p1=patch, p2=patch)
self.proj = nn.Linear(in_channels * patch * patch, emb_dim)
def forward(self, x):
x = self.to_patches(x)
return self.proj(x)
class SimpleVisionHead(nn.Module):
def __init__(self, emb_dim=192, num_classes=10):
super().__init__()
self.pool = Reduce(“b t c -> b c”, reduction=”mean”)
self.classifier = nn.Linear(emb_dim, num_classes)
def forward(self, tokens):
x = self.pool(tokens)
return self.classifier(x)
patch_embed = PatchEmbed(in_channels=3, emb_dim=192, patch=8).to(device)
head = SimpleVisionHead(emb_dim=192, num_classes=10).to(device)
imgs = torch.randn(4, 3, 32, 32, device=device)
tokens = patch_embed(imgs)
logits = head(tokens)
show_shape(“tokens”, tokens)
show_shape(“logits”, logits)
section(“8) practical”)
x = torch.randn(2, 32, 16, 16, device=device)
g = 8
xg = rearrange(x, “b (g cg) h w -> (b g) cg h w”, g=g)
show_shape(“x”, x)
show_shape(“xg”, xg)
mean = reduce(xg, “bg cg h w -> bg 1 1 1”, “mean”)
var = reduce((xg – mean) ** 2, “bg cg h w -> bg 1 1 1”, “mean”)
xg_norm = (xg – mean) / torch.sqrt(var + 1e-5)
x_norm = rearrange(xg_norm, “(b g) cg h w -> b (g cg) h w”, b=2, g=g)
show_shape(“x_norm”, x_norm)
z = torch.randn(3, 64, 20, 30, device=device)
z_flat = rearrange(z, “b c h w -> b c (h w)”)
z_unflat = rearrange(z_flat, “b c (h w) -> b c h w”, h=20, w=30)
assert (z – z_unflat).abs().max().item() < 1e-6
show_shape(“z_flat”, z_flat)
section(“9) views”)
a = torch.randn(2, 3, 4, 5, device=device)
b = rearrange(a, “b c h w -> b h w c”)
print(“a.is_contiguous():”, a.is_contiguous())
print(“b.is_contiguous():”, b.is_contiguous())
print(“b._base is a:”, getattr(b, “_base”, None) is a)
section(“Done ✅ You now have reusable einops patterns for vision, attention, and multimodal token packing”)

