School of Computer Science and Engineering, Jain (Deemed-to-be University), Bengaluru
Vision Transformers process images as sequences of patch tokens. Self-attention scales quadratically with sequence length — for ViT-B/16 on a 224×224 image, that's 38,416 pairwise scores per attention head per layer. Token pruning reduces this cost, but pruning at the input layer discards tokens before they carry semantic meaning. The wrong tokens get cut and accuracy drops.
Prune progressively at stages 1/4 and 2/3 of the encoder, scoring tokens by L₂-norm of their post-stage hidden representations. No additional learned parameters.
Align CLS representations between sparse student and frozen dense teacher exactly at pruning boundaries, where divergence is structurally greatest.
A controlled 2×2 ablation isolates pruning timing as the dominant accuracy contributor. Progressive pruning without any distillation already reaches 88.98% — beating the dense teacher by 1.95 points. Distillation adds a further 0.62. The schedule, not the supervision, is what drives the gain.
The 73.5% reduction in attention FLOPs translates directly to lower inference latency in edge deployment scenarios — mobile devices, embedded systems, real-time video pipelines — where quadratic attention is the primary throughput bottleneck. Because PTP-FAKD adds no parameters, it applies directly to pretrained ViT-B/16 checkpoints without architectural changes, and the framework partitions naturally to deeper models like ViT-L/16 (24 blocks, stages of 8) and ViT-H/14 (32 blocks, stages of ~11).
An input image becomes 197 tokens (196 patches + CLS). The student processes them through three stages of 4 blocks each, pruning between stages. The frozen teacher processes the full sequence in parallel and supervises the student at logit and feature level.
Naive pruning scores tokens at the embedding layer, where representations carry texture statistics but no semantic content. A high-norm token there often corresponds to a high-frequency edge patch, not a class-relevant region. Pruning after 4 and 8 blocks of attention lets context propagate first — by then, high-norm tokens correspond to object-centric regions. The classification head ends up attending to semantically concentrated tokens, which has an implicit regularization effect that explains why the sparse student outperforms the dense teacher.
Cross-entropy with label smoothing (ε = 0.1), KL-divergence distillation at temperature T = 4.0, and L₂-normalized MSE on CLS features. α anneals from 0.70 to 0.50 over training. β = 0.3 is fixed.
κ₁ = 0.85 (keeps 85% of tokens after stage 1, reducing 197 → 167) and κ₂ = 0.50 (keeps half after stage 2, reducing 167 → 98). The CLS token is never pruned. Selected via grid search on a held-out 10% validation split of CIFAR-100, kept independent of the test set.
The four-corner experiment that isolates pruning timing from supervision objective.
Logit-only KD on a sparse ViT-B/16 student transfers cleanly to CIFAR-10, closing 30% of the teacher gap.
Pick a sample image, then drag the slider through the four pruning stages. Watch how PTP keeps semantically meaningful tokens while naive pruning scatters across edges and textures.